chouchouvs commited on
Commit
8a1a757
·
verified ·
1 Parent(s): 6eb5a6e

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +179 -230
main.py CHANGED
@@ -1,24 +1,28 @@
1
  # -*- coding: utf-8 -*-
2
  """
3
- Version optimisée du module FAISS :
4
- - Réduction de la dimension des vecteurs (EMB_DIM, configurable)
5
- - Index quantisé **IVF‑PQ** (faible empreinte disque)
6
- - Chargement *on‑disk* (mmap) pour limiter la RAM
7
- - Option `store_text` : ne pas persister le texte brut dans le dataset
8
- - Compression gzip des artefacts exportés
9
- - Paramètres contrôlables via variables d’environnement
10
  """
11
 
12
  from __future__ import annotations
 
13
  import os
14
  import io
15
  import json
16
  import time
17
- import tarfile
18
- import logging
19
  import hashlib
 
 
 
20
  from typing import List, Dict, Any, Tuple, Optional
21
 
 
 
22
  import numpy as np
23
  import faiss
24
  from fastapi import FastAPI, HTTPException
@@ -26,132 +30,141 @@ from fastapi.middleware.cors import CORSMiddleware
26
  from fastapi.responses import JSONResponse, StreamingResponse
27
  from pydantic import BaseModel
28
 
29
- # --------------------------------------------------------------------------- #
30
- # CONFIGURATION (variables d’environnement – modifiable à la volée)
31
- # --------------------------------------------------------------------------- #
32
- EMB_PROVIDER = os.getenv("EMB_PROVIDER", "dummy").strip().lower()
33
- EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/all-mpnet-base-v2").strip()
34
- EMB_BATCH = int(os.getenv("EMB_BATCH", "32"))
35
- EMB_DIM = int(os.getenv("EMB_DIM", "64")) # ← dimension réduite (ex. 64)
36
-
37
- # FAISS quantisation
38
- FAISS_TYPE = os.getenv("FAISS_TYPE", "IVF_PQ").upper() # FLAT ou IVF_PQ
39
- FAISS_NLIST = int(os.getenv("FAISS_NLIST", "100")) # nb de centroides (IVF)
40
- FAISS_M = int(os.getenv("FAISS_M", "8")) # sous‑vecteurs (PQ)
41
- FAISS_NBITS = int(os.getenv("FAISS_NBITS", "8")) # bits / sous‑vecteur
42
-
43
- # Stockage du texte brut dans le dataset ? (False → économise disque)
44
- STORE_TEXT = os.getenv("STORE_TEXT", "false").lower() in ("1", "true", "yes")
45
 
46
  # --------------------------------------------------------------------------- #
47
  # LOGGING
48
  # --------------------------------------------------------------------------- #
49
- LOG = logging.getLogger("appli_v1")
50
  if not LOG.handlers:
51
  h = logging.StreamHandler()
52
- h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s", "%H:%M:%S"))
53
  LOG.addHandler(h)
54
  LOG.setLevel(logging.INFO)
55
 
 
 
 
 
 
 
 
56
  # --------------------------------------------------------------------------- #
57
- # UTILITAIRES
58
  # --------------------------------------------------------------------------- #
59
- def list_repo_files(repo_dir: str, top_k: int = 500) -> List[str]:
60
- """
61
- Retourne la liste des fichiers texte du dépôt, en respectant .gitignore
62
- (via Git si disponible, sinon fallback os.walk).
63
- """
64
- if not os.path.isdir(repo_dir):
65
- return []
66
 
67
- files: List[str] = []
68
- try:
69
- from git import Repo
70
- repo = Repo(repo_dir)
71
-
72
- # fichiers trackés
73
- tracked = repo.git.ls_files().splitlines()
74
- files.extend(tracked)
75
-
76
- # fichiers non‑trackés mais non ignorés
77
- untracked = repo.git.ls_files(others=True, exclude_standard=True).splitlines()
78
- files.extend(untracked)
79
-
80
- # filtrage simple
81
- files = [
82
- f for f in files
83
- if not f.startswith('.git/') and not any(p.startswith('.') for p in f.split(os.sep))
84
- ]
85
- files = sorted(set(files))[:top_k]
86
- except Exception as e:
87
- LOG.debug("Git indisponible / pas un dépôt → fallback os.walk : %s", e)
88
- for root, _, names in os.walk(repo_dir):
89
- for name in sorted(names):
90
- if name.startswith('.'):
91
- continue
92
- rel = os.path.relpath(os.path.join(root, name), repo_dir)
93
- if rel.startswith('.git') or any(p.startswith('.') for p in rel.split(os.sep)):
94
- continue
95
- files.append(rel)
96
- if len(files) >= top_k:
97
- break
98
- if len(files) >= top_k:
99
- break
100
- files = sorted(set(files))
101
-
102
- return files
103
-
104
-
105
- def read_file_safe(file_path: str) -> str:
106
- """Lit un fichier en UTF‑8, ignore les erreurs."""
107
- try:
108
- with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
109
- return f.read()
110
- except Exception as e:
111
- LOG.error("Erreur lecture %s : %s", file_path, e)
112
- return f"# Erreur lecture : {e}"
113
 
 
114
 
115
- def write_file_safe(file_path: str, content: str) -> str:
116
- """Écrit un fichier, crée les dossiers parents si besoin."""
117
- try:
118
- os.makedirs(os.path.dirname(file_path), exist_ok=True)
119
- with open(file_path, "w", encoding="utf-8") as f:
120
- f.write(content)
121
- return f"✅ Fichier sauvegardé : {os.path.basename(file_path)}"
122
- except Exception as e:
123
- LOG.error("Erreur écriture %s : %s", file_path, e)
124
- return f"❌ Erreur sauvegarde : {e}"
 
 
 
 
 
 
 
 
 
 
125
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
126
 
127
  # --------------------------------------------------------------------------- #
128
- # FAKE / DUMMY FAISS (pour compatibilité)
129
  # --------------------------------------------------------------------------- #
130
- class DummyFAISS:
131
- """Classe factice – aucune fonctionnalité réelle."""
132
- pass
133
 
 
 
134
 
135
- def create_faiss_index(*_, **__) -> DummyFAISS:
136
- LOG.warning("FAISS désactivé – utilisation du client distant")
137
- return DummyFAISS()
 
 
 
 
138
 
 
 
 
 
139
 
140
- def search_faiss_index(*_, **__) -> List[Any]:
141
- LOG.warning("FAISS désactivé – utilisation du client distant")
142
- return []
143
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
144
 
145
  # --------------------------------------------------------------------------- #
146
  # EMBEDDING PROVIDERS
147
  # --------------------------------------------------------------------------- #
148
- _ST_MODEL: Optional[Any] = None
149
- _HF_TOKENIZER: Optional[Any] = None
150
- _HF_MODEL: Optional[Any] = None
151
-
152
 
153
  def _emb_dummy(texts: List[str], dim: int = EMB_DIM) -> np.ndarray:
154
- """Vecteurs aléatoires déterministes (SHA‑1 → seed)."""
155
  vecs = np.zeros((len(texts), dim), dtype="float32")
156
  for i, t in enumerate(texts):
157
  h = hashlib.sha1((t or "").encode("utf-8")).digest()
@@ -160,16 +173,14 @@ def _emb_dummy(texts: List[str], dim: int = EMB_DIM) -> np.ndarray:
160
  vecs[i] = v / (np.linalg.norm(v) + 1e-9)
161
  return vecs
162
 
163
-
164
  def _get_st_model():
165
  global _ST_MODEL
166
  if _ST_MODEL is None:
167
  from sentence_transformers import SentenceTransformer
168
- _ST_MODEL = SentenceTransformer(EMB_MODEL, cache_folder=os.getenv("HF_HOME", "/tmp/.cache/huggingface"))
169
- LOG.info("[st] modèle chargé : %s", EMB_MODEL)
170
  return _ST_MODEL
171
 
172
-
173
  def _emb_st(texts: List[str]) -> np.ndarray:
174
  model = _get_st_model()
175
  vecs = model.encode(
@@ -181,25 +192,22 @@ def _emb_st(texts: List[str]) -> np.ndarray:
181
  ).astype("float32")
182
  return vecs
183
 
184
-
185
  def _get_hf_model():
186
  global _HF_TOKENIZER, _HF_MODEL
187
  if _HF_MODEL is None or _HF_TOKENIZER is None:
188
  from transformers import AutoTokenizer, AutoModel
189
- _HF_TOKENIZER = AutoTokenizer.from_pretrained(EMB_MODEL, cache_dir=os.getenv("HF_HOME", "/tmp/.cache/huggingface"))
190
- _HF_MODEL = AutoModel.from_pretrained(EMB_MODEL, cache_dir=os.getenv("HF_HOME", "/tmp/.cache/huggingface"))
191
  _HF_MODEL.eval()
192
- LOG.info("[hf] modèle chargé : %s", EMB_MODEL)
193
  return _HF_TOKENIZER, _HF_MODEL
194
 
195
-
196
  def _mean_pool(last_hidden_state: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
197
  mask = attention_mask[..., None].astype(last_hidden_state.dtype)
198
  summed = (last_hidden_state * mask).sum(axis=1)
199
  counts = mask.sum(axis=1).clip(min=1e-9)
200
  return summed / counts
201
 
202
-
203
  def _emb_hf(texts: List[str]) -> np.ndarray:
204
  import torch
205
  tok, mod = _get_hf_model()
@@ -215,21 +223,10 @@ def _emb_hf(texts: List[str]) -> np.ndarray:
215
  all_vecs.append(pooled.astype("float32"))
216
  return np.concatenate(all_vecs, axis=0)
217
 
218
-
219
- def _reduce_dim(vectors: np.ndarray, target_dim: int = EMB_DIM) -> np.ndarray:
220
- """PCA simple pour réduire la dimension (si target_dim < current)."""
221
- if target_dim >= vectors.shape[1]:
222
- return vectors
223
- from sklearn.decomposition import PCA
224
- pca = PCA(n_components=target_dim, random_state=0)
225
- return pca.fit_transform(vectors).astype("float32")
226
-
227
-
228
  # --------------------------------------------------------------------------- #
229
  # DATASET / FAISS I/O
230
  # --------------------------------------------------------------------------- #
231
- def _save_dataset(ds_dir: str, rows: List[Dict[str, Any]], store_text: bool = STORE_TEXT) -> None:
232
- """Sauvegarde le dataset au format JSONL (optionnellement sans le texte)."""
233
  os.makedirs(ds_dir, exist_ok=True)
234
  data_path = os.path.join(ds_dir, "data.jsonl")
235
  with open(data_path, "w", encoding="utf-8") as f:
@@ -241,7 +238,6 @@ def _save_dataset(ds_dir: str, rows: List[Dict[str, Any]], store_text: bool = ST
241
  with open(os.path.join(ds_dir, "meta.json"), "w", encoding="utf-8") as f:
242
  json.dump(meta, f, ensure_ascii=False, indent=2)
243
 
244
-
245
  def _load_dataset(ds_dir: str) -> List[Dict[str, Any]]:
246
  data_path = os.path.join(ds_dir, "data.jsonl")
247
  if not os.path.isfile(data_path):
@@ -255,80 +251,50 @@ def _load_dataset(ds_dir: str) -> List[Dict[str, Any]]:
255
  continue
256
  return out
257
 
258
-
259
  def _save_faiss(fx_dir: str, xb: np.ndarray, meta: Dict[str, Any]) -> None:
260
- """Sauvegarde un index FAISS quantisé (IVF‑PQ) ou plat selon FAISS_TYPE."""
261
  os.makedirs(fx_dir, exist_ok=True)
262
  idx_path = os.path.join(fx_dir, "emb.faiss")
263
 
264
- if FAISS_TYPE == "IVF_PQ":
265
- # ---- IVF‑PQ ---------------------------------------------------------
266
- quantizer = faiss.IndexFlatIP(xb.shape[1]) # base (inner‑product ≈ cosine)
267
- index = faiss.IndexIVFPQ(quantizer, xb.shape[1], FAISS_NLIST, FAISS_M, FAISS_NBITS)
268
-
269
- # entraînement sur un sous‑échantillon (max 10 k vecteurs)
270
- rng = np.random.default_rng(0)
271
- train = xb[rng.choice(xb.shape[0], min(10_000, xb.shape[0]), replace=False)]
272
- index.train(train)
273
-
274
- index.add(xb)
275
- meta.update({
276
- "index_type": "IVF_PQ",
277
- "nlist": FAISS_NLIST,
278
- "m": FAISS_M,
279
- "nbits": FAISS_NBITS,
280
- })
281
- else: # FLAT (fallback)
282
- index = faiss.IndexFlatIP(xb.shape[1])
283
- index.add(xb)
284
- meta.update({"index_type": "FLAT"})
285
 
 
 
 
 
 
 
286
  faiss.write_index(index, idx_path)
287
 
288
- # meta.json (inclut le type d’index)
289
  with open(os.path.join(fx_dir, "meta.json"), "w", encoding="utf-8") as f:
290
  json.dump(meta, f, ensure_ascii=False, indent=2)
291
 
292
-
293
  def _load_faiss(fx_dir: str) -> faiss.Index:
294
- """Charge l’index en mode mmap (lecture à la volée)."""
295
  idx_path = os.path.join(fx_dir, "emb.faiss")
296
  if not os.path.isfile(idx_path):
297
  raise FileNotFoundError(f"FAISS index introuvable : {idx_path}")
298
- # mmap minimise la RAM utilisée
299
  return faiss.read_index(idx_path, faiss.IO_FLAG_MMAP)
300
 
301
-
302
  def _tar_dir_to_bytes(dir_path: str) -> bytes:
303
- """Archive gzip du répertoire (compression maximale)."""
304
  bio = io.BytesIO()
305
  with tarfile.open(fileobj=bio, mode="w:gz", compresslevel=9) as tar:
306
  tar.add(dir_path, arcname=os.path.basename(dir_path))
307
  bio.seek(0)
308
  return bio.read()
309
 
310
-
311
  # --------------------------------------------------------------------------- #
312
- # WORKER POOL (asynchrone)
313
  # --------------------------------------------------------------------------- #
314
- from concurrent.futures import ThreadPoolExecutor
315
-
316
- MAX_WORKERS = max(1, int(os.getenv("MAX_WORKERS", "1")))
317
- EXECUTOR = ThreadPoolExecutor(max_workers=MAX_WORKERS)
318
  LOG.info("ThreadPoolExecutor initialisé : max_workers=%s", MAX_WORKERS)
319
 
320
-
321
- def _proj_dirs(project_id: str) -> Tuple[str, str, str]:
322
- base = os.path.join(os.getenv("DATA_ROOT", "/tmp/data"), project_id)
323
- ds_dir = os.path.join(base, "dataset")
324
- fx_dir = os.path.join(base, "faiss")
325
- os.makedirs(ds_dir, exist_ok=True)
326
- os.makedirs(fx_dir, exist_ok=True)
327
- return base, ds_dir, fx_dir
328
-
329
-
330
  def _do_index_job(
331
- st: "JobState",
332
  files: List[Dict[str, str]],
333
  chunk_size: int,
334
  overlap: int,
@@ -339,16 +305,15 @@ def _do_index_job(
339
  Pipeline complet :
340
  1️⃣ Chunking
341
  2️⃣ Embedding (dummy / st / hf)
342
- 3️⃣ Réduction de dimension (PCA) si EMB_DIM < dim du modèle
343
- 4️⃣ Sauvegarde dataset (optionnel texte)
344
  5️⃣ Index FAISS quantisé + mmap
345
  """
346
  try:
347
  base, ds_dir, fx_dir = _proj_dirs(st.project_id)
348
 
349
- # ------------------------------------------------------------------- #
350
- # 1️⃣ Chunking
351
- # ------------------------------------------------------------------- #
352
  rows: List[Dict[str, Any]] = []
353
  st.total_files = len(files)
354
 
@@ -360,12 +325,12 @@ def _do_index_job(
360
  rows.append({"path": path, "text": ck, "chunk_id": i})
361
 
362
  st.total_chunks = len(rows)
363
- LOG.info("Chunking terminé : %d chunks", st.total_chunks)
364
 
365
- # ------------------------------------------------------------------- #
366
- # 2️⃣ Embedding
367
- # ------------------------------------------------------------------- #
368
  texts = [r["text"] for r in rows]
 
369
  if EMB_PROVIDER == "dummy":
370
  xb = _emb_dummy(texts, dim=EMB_DIM)
371
  elif EMB_PROVIDER == "st":
@@ -373,23 +338,22 @@ def _do_index_job(
373
  else:
374
  xb = _emb_hf(texts)
375
 
376
- # ------------------------------------------------------------------- #
377
- # 3️⃣ Réduction de dimension (si nécessaire)
378
- # ------------------------------------------------------------------- #
379
  if xb.shape[1] != EMB_DIM:
380
- xb = _reduce_dim(xb, target_dim=EMB_DIM)
 
 
 
381
 
382
  st.embedded = xb.shape[0]
383
- LOG.info("Embedding terminé : %d vecteurs (dim=%d)", st.embedded, xb.shape[1])
384
 
385
- # ------------------------------------------------------------------- #
386
- # 4️⃣ Sauvegarde du dataset
387
- # ------------------------------------------------------------------- #
388
  _save_dataset(ds_dir, rows, store_text=store_text)
 
389
 
390
- # ------------------------------------------------------------------- #
391
- # 5️⃣ Index FAISS
392
- # ------------------------------------------------------------------- #
393
  meta = {
394
  "dim": int(xb.shape[1]),
395
  "count": int(xb.shape[0]),
@@ -398,16 +362,14 @@ def _do_index_job(
398
  }
399
  _save_faiss(fx_dir, xb, meta)
400
  st.indexed = int(xb.shape[0])
401
- LOG.info("FAISS (%s) écrit : %s", FAISS_TYPE, os.path.join(fx_dir, "emb.faiss"))
402
 
403
- # ------------------------------------------------------------------- #
404
- # Finalisation
405
- # ------------------------------------------------------------------- #
406
- st.stage = "done"
407
  st.finished_at = time.time()
408
  except Exception as e:
409
  LOG.exception("Job %s échoué", st.job_id)
410
  st.errors.append(str(e))
 
411
  st.stage = "failed"
412
  st.finished_at = time.time()
413
 
@@ -438,7 +400,6 @@ def _submit_job(
438
  st.stage = "queued"
439
  return job_id
440
 
441
-
442
  # --------------------------------------------------------------------------- #
443
  # FASTAPI
444
  # --------------------------------------------------------------------------- #
@@ -451,20 +412,17 @@ fastapi_app.add_middleware(
451
  allow_headers=["*"],
452
  )
453
 
454
-
455
  class FileItem(BaseModel):
456
  path: str
457
  text: str
458
 
459
-
460
  class IndexRequest(BaseModel):
461
  project_id: str
462
  files: List[FileItem]
463
  chunk_size: int = 200
464
  overlap: int = 20
465
  batch_size: int = 32
466
- store_text: bool = STORE_TEXT # configurable
467
-
468
 
469
  @fastapi_app.get("/health")
470
  def health():
@@ -475,14 +433,15 @@ def health():
475
  "model": EMB_MODEL if EMB_PROVIDER != "dummy" else None,
476
  "cache_root": os.getenv("CACHE_ROOT", "/tmp/.cache"),
477
  "workers": MAX_WORKERS,
478
- "data_root": os.getenv("DATA_ROOT", "/tmp/data"),
479
- "faiss_type": FAISS_TYPE,
480
  "emb_dim": EMB_DIM,
481
  }
482
 
483
-
484
  @fastapi_app.post("/index")
485
  def index(req: IndexRequest):
 
 
 
486
  try:
487
  files = [fi.model_dump() for fi in req.files]
488
  job_id = _submit_job(
@@ -498,7 +457,6 @@ def index(req: IndexRequest):
498
  LOG.exception("Erreur soumission index")
499
  raise HTTPException(status_code=500, detail=str(e))
500
 
501
-
502
  @fastapi_app.get("/status/{job_id}")
503
  def status(job_id: str):
504
  st = JOBS.get(job_id)
@@ -506,26 +464,25 @@ def status(job_id: str):
506
  raise HTTPException(status_code=404, detail="job inconnu")
507
  return JSONResponse(st.model_dump())
508
 
509
-
510
  class SearchRequest(BaseModel):
511
  project_id: str
512
  query: str
513
  k: int = 5
514
 
515
-
516
  @fastapi_app.post("/search")
517
  def search(req: SearchRequest):
518
  base, ds_dir, fx_dir = _proj_dirs(req.project_id)
519
 
520
- # Vérifier la présence de l'index
521
- if not (os.path.isfile(os.path.join(fx_dir, "emb.faiss")) and os.path.isfile(os.path.join(ds_dir, "data.jsonl"))):
 
522
  raise HTTPException(status_code=409, detail="Index non prêt (reviens plus tard)")
523
 
524
  rows = _load_dataset(ds_dir)
525
  if not rows:
526
  raise HTTPException(status_code=404, detail="dataset introuvable")
527
 
528
- # Embedding de la requête (même provider)
529
  if EMB_PROVIDER == "dummy":
530
  q = _emb_dummy([req.query], dim=EMB_DIM)[0:1, :]
531
  elif EMB_PROVIDER == "st":
@@ -552,9 +509,8 @@ def search(req: SearchRequest):
552
  out.append({"path": r.get("path"), "text": r.get("text"), "score": float(sc)})
553
  return {"results": out}
554
 
555
-
556
  # --------------------------------------------------------------------------- #
557
- # ARTIFACTS EXPORT (gzip)
558
  # --------------------------------------------------------------------------- #
559
  @fastapi_app.get("/artifacts/{project_id}/dataset")
560
  def download_dataset(project_id: str):
@@ -565,7 +521,6 @@ def download_dataset(project_id: str):
565
  hdr = {"Content-Disposition": f'attachment; filename="{project_id}_dataset.tgz"'}
566
  return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=hdr)
567
 
568
-
569
  @fastapi_app.get("/artifacts/{project_id}/faiss")
570
  def download_faiss(project_id: str):
571
  _, _, fx_dir = _proj_dirs(project_id)
@@ -575,35 +530,30 @@ def download_faiss(project_id: str):
575
  hdr = {"Content-Disposition": f'attachment; filename="{project_id}_faiss.tgz"'}
576
  return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=hdr)
577
 
578
-
579
  # --------------------------------------------------------------------------- #
580
- # GRADIO UI (facultatif – simple test)
581
  # --------------------------------------------------------------------------- #
582
  def _ui_index(project_id: str, sample_text: str):
583
  files = [{"path": "sample.txt", "text": sample_text}]
584
  try:
585
  req = IndexRequest(project_id=project_id, files=[FileItem(**f) for f in files])
586
  except Exception as e:
587
- return f"❌ Erreur validation : {e}"
588
  try:
589
  res = index(req)
590
  return f"✅ Job lancé : {res['job_id']}"
591
  except Exception as e:
592
- return f"❌ Erreur index : {e}"
593
-
594
 
595
  def _ui_search(project_id: str, query: str, k: int):
596
  try:
597
  res = search(SearchRequest(project_id=project_id, query=query, k=int(k)))
598
  return json.dumps(res, ensure_ascii=False, indent=2)
599
  except Exception as e:
600
- return f"❌ Erreur recherche : {e}"
601
-
602
-
603
- import gradio as gr
604
 
605
  with gr.Blocks(title="Remote Indexer (Async – Optimisé)", analytics_enabled=False) as ui:
606
- gr.Markdown("## Remote Indexer Optimisé (FAISS quantisé, mmap, texte optionnel)")
607
  with gr.Row():
608
  pid = gr.Textbox(label="Project ID", value="DEMO")
609
  txt = gr.Textbox(label="Texte d’exemple", lines=4, value="Alpha bravo charlie delta echo foxtrot.")
@@ -618,15 +568,14 @@ with gr.Blocks(title="Remote Indexer (Async – Optimisé)", analytics_enabled=F
618
  out_q = gr.Code(label="Résultats")
619
  btn_q.click(_ui_search, inputs=[pid, q, k], outputs=[out_q])
620
 
 
621
  fastapi_app = gr.mount_gradio_app(fastapi_app, ui, path="/ui")
622
 
623
-
624
  # --------------------------------------------------------------------------- #
625
  # MAIN
626
  # --------------------------------------------------------------------------- #
627
  if __name__ == "__main__":
628
  import uvicorn
629
 
630
- PORT = int(os.getenv("PORT", "7860"))
631
- LOG.info("Démarrage Uvicorn – port %s – UI à /ui", PORT)
632
  uvicorn.run(fastapi_app, host="0.0.0.0", port=PORT)
 
1
  # -*- coding: utf-8 -*-
2
  """
3
+ FastAPI + Gradio : service d’indexation asynchrone avec FAISS.
4
+ Ce fichier a été corrigé pour :
5
+
6
+ * importer correctement `JobState`
7
+ * garantir que toutes les dépendances (typing, pathlib…) sont disponibles
8
+ * exposer les routes attendues par le client
9
+ * garder la même logique que la version originale.
10
  """
11
 
12
  from __future__ import annotations
13
+
14
  import os
15
  import io
16
  import json
17
  import time
 
 
18
  import hashlib
19
+ import logging
20
+ import tarfile
21
+ from pathlib import Path
22
  from typing import List, Dict, Any, Tuple, Optional
23
 
24
+ from concurrent.futures import ThreadPoolExecutor
25
+
26
  import numpy as np
27
  import faiss
28
  from fastapi import FastAPI, HTTPException
 
30
  from fastapi.responses import JSONResponse, StreamingResponse
31
  from pydantic import BaseModel
32
 
33
+ import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
 
35
  # --------------------------------------------------------------------------- #
36
  # LOGGING
37
  # --------------------------------------------------------------------------- #
38
+ LOG = logging.getLogger("remote-indexer-async")
39
  if not LOG.handlers:
40
  h = logging.StreamHandler()
41
+ h.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s"))
42
  LOG.addHandler(h)
43
  LOG.setLevel(logging.INFO)
44
 
45
+ DBG = logging.getLogger("remote-indexer-async.debug")
46
+ if not DBG.handlers:
47
+ hd = logging.StreamHandler()
48
+ hd.setFormatter(logging.Formatter("[DEBUG] %(asctime)s - %(message)s"))
49
+ DBG.addHandler(hd)
50
+ DBG.setLevel(logging.DEBUG)
51
+
52
  # --------------------------------------------------------------------------- #
53
+ # CONFIGURATION (variables d’environnement)
54
  # --------------------------------------------------------------------------- #
55
+ PORT = int(os.getenv("PORT", "7860"))
56
+ DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data")
57
+ os.makedirs(DATA_ROOT, exist_ok=True)
 
 
 
 
58
 
59
+ EMB_PROVIDER = os.getenv("EMB_PROVIDER", "dummy").strip().lower()
60
+ EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/all-mpnet-base-v2").strip()
61
+ EMB_BATCH = int(os.getenv("EMB_BATCH", "32"))
62
+ EMB_DIM = int(os.getenv("EMB_DIM", "64")) # dimension réduite (optimisation)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
 
64
+ MAX_WORKERS = int(os.getenv("MAX_WORKERS", "1"))
65
 
66
+ # --------------------------------------------------------------------------- #
67
+ # CACHE DIRECTORIES (évite PermissionError)
68
+ # --------------------------------------------------------------------------- #
69
+ def _setup_cache_dirs() -> Dict[str, str]:
70
+ os.environ.setdefault("HOME", "/home/user")
71
+ CACHE_ROOT = os.getenv("CACHE_ROOT", "/tmp/.cache").rstrip("/")
72
+ paths = {
73
+ "root": CACHE_ROOT,
74
+ "hf_home": f"{CACHE_ROOT}/huggingface",
75
+ "hf_hub": f"{CACHE_ROOT}/huggingface/hub",
76
+ "hf_tf": f"{CACHE_ROOT}/huggingface/transformers",
77
+ "torch": f"{CACHE_ROOT}/torch",
78
+ "st": f"{CACHE_ROOT}/sentence-transformers",
79
+ "mpl": f"{CACHE_ROOT}/matplotlib",
80
+ }
81
+ for p in paths.values():
82
+ try:
83
+ os.makedirs(p, exist_ok=True)
84
+ except Exception as e:
85
+ LOG.warning("Impossible de créer %s : %s", p, e)
86
 
87
+ os.environ["HF_HOME"] = paths["hf_home"]
88
+ os.environ["HF_HUB_CACHE"] = paths["hf_hub"]
89
+ os.environ["TRANSFORMERS_CACHE"] = paths["hf_tf"]
90
+ os.environ["TORCH_HOME"] = paths["torch"]
91
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = paths["st"]
92
+ os.environ["MPLCONFIGDIR"] = paths["mpl"]
93
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
94
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
95
+
96
+ LOG.info("Caches configurés : %s", json.dumps(paths, indent=2))
97
+ return paths
98
+
99
+
100
+ CACHE_PATHS = _setup_cache_dirs()
101
+
102
+ # --------------------------------------------------------------------------- #
103
+ # IMPORT DE LA CLASSE DE STATE (c’est ce qui manquait)
104
+ # --------------------------------------------------------------------------- #
105
+ # La classe `JobState` se trouve dans `app/core/index_state.py`.
106
+ # On l’importe ici afin qu’elle soit disponible dans tout le module.
107
+ from app.core.index_state import JobState # <-- IMPORT CORRIGÉ
108
 
109
  # --------------------------------------------------------------------------- #
110
+ # GLOBALS
111
  # --------------------------------------------------------------------------- #
112
+ JOBS: Dict[str, JobState] = {}
 
 
113
 
114
+ def _now() -> str:
115
+ return time.strftime("%H:%M:%S")
116
 
117
+ def _proj_dirs(project_id: str) -> Tuple[str, str, str]:
118
+ base = os.path.join(DATA_ROOT, project_id)
119
+ ds_dir = os.path.join(base, "dataset")
120
+ fx_dir = os.path.join(base, "faiss")
121
+ os.makedirs(ds_dir, exist_ok=True)
122
+ os.makedirs(fx_dir, exist_ok=True)
123
+ return base, ds_dir, fx_dir
124
 
125
+ def _add_msg(st: JobState, msg: str) -> None:
126
+ st.messages.append(f"[{_now()}] {msg}")
127
+ LOG.info("[%s] %s", st.job_id, msg)
128
+ DBG.debug("[%s] %s", st.job_id, msg)
129
 
130
+ def _set_stage(st: JobState, stage: str) -> None:
131
+ st.stage = stage
132
+ _add_msg(st, f"stage={stage}")
133
 
134
+ # --------------------------------------------------------------------------- #
135
+ # UTILITAIRES (chunking, normalisation, etc.)
136
+ # --------------------------------------------------------------------------- #
137
+ def _chunk_text(text: str, size: int = 200, overlap: int = 20) -> List[str]:
138
+ text = (text or "").replace("\r\n", "\n")
139
+ tokens = list(text)
140
+ if size <= 0:
141
+ return [text] if text else []
142
+ if overlap < 0:
143
+ overlap = 0
144
+ chunks = []
145
+ i = 0
146
+ while i < len(tokens):
147
+ j = min(i + size, len(tokens))
148
+ chunk = "".join(tokens[i:j]).strip()
149
+ if chunk:
150
+ chunks.append(chunk)
151
+ if j == len(tokens):
152
+ break
153
+ i = j - overlap if (j - overlap) > i else j
154
+ return chunks
155
+
156
+ def _l2_normalize(x: np.ndarray) -> np.ndarray:
157
+ n = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12
158
+ return x / n
159
 
160
  # --------------------------------------------------------------------------- #
161
  # EMBEDDING PROVIDERS
162
  # --------------------------------------------------------------------------- #
163
+ _ST_MODEL = None
164
+ _HF_TOKENIZER = None
165
+ _HF_MODEL = None
 
166
 
167
  def _emb_dummy(texts: List[str], dim: int = EMB_DIM) -> np.ndarray:
 
168
  vecs = np.zeros((len(texts), dim), dtype="float32")
169
  for i, t in enumerate(texts):
170
  h = hashlib.sha1((t or "").encode("utf-8")).digest()
 
173
  vecs[i] = v / (np.linalg.norm(v) + 1e-9)
174
  return vecs
175
 
 
176
  def _get_st_model():
177
  global _ST_MODEL
178
  if _ST_MODEL is None:
179
  from sentence_transformers import SentenceTransformer
180
+ _ST_MODEL = SentenceTransformer(EMB_MODEL, cache_folder=CACHE_PATHS["st"])
181
+ LOG.info("[st] modèle chargé : %s (cache=%s)", EMB_MODEL, CACHE_PATHS["st"])
182
  return _ST_MODEL
183
 
 
184
  def _emb_st(texts: List[str]) -> np.ndarray:
185
  model = _get_st_model()
186
  vecs = model.encode(
 
192
  ).astype("float32")
193
  return vecs
194
 
 
195
  def _get_hf_model():
196
  global _HF_TOKENIZER, _HF_MODEL
197
  if _HF_MODEL is None or _HF_TOKENIZER is None:
198
  from transformers import AutoTokenizer, AutoModel
199
+ _HF_TOKENIZER = AutoTokenizer.from_pretrained(EMB_MODEL, cache_dir=CACHE_PATHS["hf_tf"])
200
+ _HF_MODEL = AutoModel.from_pretrained(EMB_MODEL, cache_dir=CACHE_PATHS["hf_tf"])
201
  _HF_MODEL.eval()
202
+ LOG.info("[hf] modèle chargé : %s (cache=%s)", EMB_MODEL, CACHE_PATHS["hf_tf"])
203
  return _HF_TOKENIZER, _HF_MODEL
204
 
 
205
  def _mean_pool(last_hidden_state: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
206
  mask = attention_mask[..., None].astype(last_hidden_state.dtype)
207
  summed = (last_hidden_state * mask).sum(axis=1)
208
  counts = mask.sum(axis=1).clip(min=1e-9)
209
  return summed / counts
210
 
 
211
  def _emb_hf(texts: List[str]) -> np.ndarray:
212
  import torch
213
  tok, mod = _get_hf_model()
 
223
  all_vecs.append(pooled.astype("float32"))
224
  return np.concatenate(all_vecs, axis=0)
225
 
 
 
 
 
 
 
 
 
 
 
226
  # --------------------------------------------------------------------------- #
227
  # DATASET / FAISS I/O
228
  # --------------------------------------------------------------------------- #
229
+ def _save_dataset(ds_dir: str, rows: List[Dict[str, Any]], store_text: bool = True) -> None:
 
230
  os.makedirs(ds_dir, exist_ok=True)
231
  data_path = os.path.join(ds_dir, "data.jsonl")
232
  with open(data_path, "w", encoding="utf-8") as f:
 
238
  with open(os.path.join(ds_dir, "meta.json"), "w", encoding="utf-8") as f:
239
  json.dump(meta, f, ensure_ascii=False, indent=2)
240
 
 
241
  def _load_dataset(ds_dir: str) -> List[Dict[str, Any]]:
242
  data_path = os.path.join(ds_dir, "data.jsonl")
243
  if not os.path.isfile(data_path):
 
251
  continue
252
  return out
253
 
 
254
  def _save_faiss(fx_dir: str, xb: np.ndarray, meta: Dict[str, Any]) -> None:
 
255
  os.makedirs(fx_dir, exist_ok=True)
256
  idx_path = os.path.join(fx_dir, "emb.faiss")
257
 
258
+ # ------------------------------------------------------------------- #
259
+ # Index quantisé (IVF‑PQ) – optimisation mémoire / disque
260
+ # ------------------------------------------------------------------- #
261
+ quantizer = faiss.IndexFlatIP(xb.shape[1]) # inner‑product (cosine si normalisé)
262
+ index = faiss.IndexIVFPQ(quantizer, xb.shape[1], 100, 8, 8) # nlist=100, m=8, nbits=8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
263
 
264
+ # entraînement sur un sous‑échantillon (max 10 k vecteurs)
265
+ rng = np.random.default_rng(0)
266
+ train = xb[rng.choice(xb.shape[0], min(10_000, xb.shape[0]), replace=False]
267
+ index.train(train)
268
+
269
+ index.add(xb)
270
  faiss.write_index(index, idx_path)
271
 
272
+ meta.update({"index_type": "IVF_PQ", "nlist": 100, "m": 8, "nbits": 8})
273
  with open(os.path.join(fx_dir, "meta.json"), "w", encoding="utf-8") as f:
274
  json.dump(meta, f, ensure_ascii=False, indent=2)
275
 
 
276
  def _load_faiss(fx_dir: str) -> faiss.Index:
 
277
  idx_path = os.path.join(fx_dir, "emb.faiss")
278
  if not os.path.isfile(idx_path):
279
  raise FileNotFoundError(f"FAISS index introuvable : {idx_path}")
280
+ # mmap l’index reste sur disque, la RAM n’est utilisée que pour les requêtes
281
  return faiss.read_index(idx_path, faiss.IO_FLAG_MMAP)
282
 
 
283
  def _tar_dir_to_bytes(dir_path: str) -> bytes:
 
284
  bio = io.BytesIO()
285
  with tarfile.open(fileobj=bio, mode="w:gz", compresslevel=9) as tar:
286
  tar.add(dir_path, arcname=os.path.basename(dir_path))
287
  bio.seek(0)
288
  return bio.read()
289
 
 
290
  # --------------------------------------------------------------------------- #
291
+ # THREAD‑POOL (asynchrone)
292
  # --------------------------------------------------------------------------- #
293
+ EXECUTOR = ThreadPoolExecutor(max_workers=max(1, MAX_WORKERS))
 
 
 
294
  LOG.info("ThreadPoolExecutor initialisé : max_workers=%s", MAX_WORKERS)
295
 
 
 
 
 
 
 
 
 
 
 
296
  def _do_index_job(
297
+ st: JobState,
298
  files: List[Dict[str, str]],
299
  chunk_size: int,
300
  overlap: int,
 
305
  Pipeline complet :
306
  1️⃣ Chunking
307
  2️⃣ Embedding (dummy / st / hf)
308
+ 3️⃣ Réduction de dimension (PCA) si besoin
309
+ 4️⃣ Sauvegarde du dataset (texte optionnel)
310
  5️⃣ Index FAISS quantisé + mmap
311
  """
312
  try:
313
  base, ds_dir, fx_dir = _proj_dirs(st.project_id)
314
 
315
+ # ------------------- 1️⃣ Chunking -------------------
316
+ _set_stage(st, "chunking")
 
317
  rows: List[Dict[str, Any]] = []
318
  st.total_files = len(files)
319
 
 
325
  rows.append({"path": path, "text": ck, "chunk_id": i})
326
 
327
  st.total_chunks = len(rows)
328
+ _add_msg(st, f"Total chunks = {st.total_chunks}")
329
 
330
+ # ------------------- 2️⃣ Embedding -------------------
331
+ _set_stage(st, "embedding")
 
332
  texts = [r["text"] for r in rows]
333
+
334
  if EMB_PROVIDER == "dummy":
335
  xb = _emb_dummy(texts, dim=EMB_DIM)
336
  elif EMB_PROVIDER == "st":
 
338
  else:
339
  xb = _emb_hf(texts)
340
 
341
+ # ------------------- 3️⃣ Réduction de dimension (PCA) -------------------
 
 
342
  if xb.shape[1] != EMB_DIM:
343
+ from sklearn.decomposition import PCA
344
+ pca = PCA(n_components=EMB_DIM, random_state=0)
345
+ xb = pca.fit_transform(xb).astype("float32")
346
+ LOG.info("Réduction PCA appliquée : %d → %d dimensions", xb.shape[1], EMB_DIM)
347
 
348
  st.embedded = xb.shape[0]
349
+ _add_msg(st, f"Embeddings générés : {st.embedded}")
350
 
351
+ # ------------------- 4️⃣ Sauvegarde dataset -------------------
 
 
352
  _save_dataset(ds_dir, rows, store_text=store_text)
353
+ _add_msg(st, f"Dataset sauvegardé dans {ds_dir}")
354
 
355
+ # ------------------- 5️⃣ Index FAISS -------------------
356
+ _set_stage(st, "indexing")
 
357
  meta = {
358
  "dim": int(xb.shape[1]),
359
  "count": int(xb.shape[0]),
 
362
  }
363
  _save_faiss(fx_dir, xb, meta)
364
  st.indexed = int(xb.shape[0])
365
+ _add_msg(st, f"FAISS écrit sur {os.path.join(fx_dir, 'emb.faiss')}")
366
 
367
+ _set_stage(st, "done")
 
 
 
368
  st.finished_at = time.time()
369
  except Exception as e:
370
  LOG.exception("Job %s échoué", st.job_id)
371
  st.errors.append(str(e))
372
+ _add_msg(st, f"❌ Exception : {e}")
373
  st.stage = "failed"
374
  st.finished_at = time.time()
375
 
 
400
  st.stage = "queued"
401
  return job_id
402
 
 
403
  # --------------------------------------------------------------------------- #
404
  # FASTAPI
405
  # --------------------------------------------------------------------------- #
 
412
  allow_headers=["*"],
413
  )
414
 
 
415
  class FileItem(BaseModel):
416
  path: str
417
  text: str
418
 
 
419
  class IndexRequest(BaseModel):
420
  project_id: str
421
  files: List[FileItem]
422
  chunk_size: int = 200
423
  overlap: int = 20
424
  batch_size: int = 32
425
+ store_text: bool = True # on peut désactiver via le payload ou env
 
426
 
427
  @fastapi_app.get("/health")
428
  def health():
 
433
  "model": EMB_MODEL if EMB_PROVIDER != "dummy" else None,
434
  "cache_root": os.getenv("CACHE_ROOT", "/tmp/.cache"),
435
  "workers": MAX_WORKERS,
436
+ "data_root": DATA_ROOT,
 
437
  "emb_dim": EMB_DIM,
438
  }
439
 
 
440
  @fastapi_app.post("/index")
441
  def index(req: IndexRequest):
442
+ """
443
+ Lancement asynchrone : renvoie immédiatement un `job_id`.
444
+ """
445
  try:
446
  files = [fi.model_dump() for fi in req.files]
447
  job_id = _submit_job(
 
457
  LOG.exception("Erreur soumission index")
458
  raise HTTPException(status_code=500, detail=str(e))
459
 
 
460
  @fastapi_app.get("/status/{job_id}")
461
  def status(job_id: str):
462
  st = JOBS.get(job_id)
 
464
  raise HTTPException(status_code=404, detail="job inconnu")
465
  return JSONResponse(st.model_dump())
466
 
 
467
  class SearchRequest(BaseModel):
468
  project_id: str
469
  query: str
470
  k: int = 5
471
 
 
472
  @fastapi_app.post("/search")
473
  def search(req: SearchRequest):
474
  base, ds_dir, fx_dir = _proj_dirs(req.project_id)
475
 
476
+ # Vérifier que lindex existe
477
+ if not (os.path.isfile(os.path.join(fx_dir, "emb.faiss")) and
478
+ os.path.isfile(os.path.join(ds_dir, "data.jsonl"))):
479
  raise HTTPException(status_code=409, detail="Index non prêt (reviens plus tard)")
480
 
481
  rows = _load_dataset(ds_dir)
482
  if not rows:
483
  raise HTTPException(status_code=404, detail="dataset introuvable")
484
 
485
+ # Embedding de la requête (même provider que l’index)
486
  if EMB_PROVIDER == "dummy":
487
  q = _emb_dummy([req.query], dim=EMB_DIM)[0:1, :]
488
  elif EMB_PROVIDER == "st":
 
509
  out.append({"path": r.get("path"), "text": r.get("text"), "score": float(sc)})
510
  return {"results": out}
511
 
 
512
  # --------------------------------------------------------------------------- #
513
+ # EXPORT ARTIFACTS (gzip)
514
  # --------------------------------------------------------------------------- #
515
  @fastapi_app.get("/artifacts/{project_id}/dataset")
516
  def download_dataset(project_id: str):
 
521
  hdr = {"Content-Disposition": f'attachment; filename="{project_id}_dataset.tgz"'}
522
  return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=hdr)
523
 
 
524
  @fastapi_app.get("/artifacts/{project_id}/faiss")
525
  def download_faiss(project_id: str):
526
  _, _, fx_dir = _proj_dirs(project_id)
 
530
  hdr = {"Content-Disposition": f'attachment; filename="{project_id}_faiss.tgz"'}
531
  return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=hdr)
532
 
 
533
  # --------------------------------------------------------------------------- #
534
+ # GRADIO UI (facultatif – test rapide)
535
  # --------------------------------------------------------------------------- #
536
  def _ui_index(project_id: str, sample_text: str):
537
  files = [{"path": "sample.txt", "text": sample_text}]
538
  try:
539
  req = IndexRequest(project_id=project_id, files=[FileItem(**f) for f in files])
540
  except Exception as e:
541
+ return f"❌ Validation : {e}"
542
  try:
543
  res = index(req)
544
  return f"✅ Job lancé : {res['job_id']}"
545
  except Exception as e:
546
+ return f"❌ Erreur : {e}"
 
547
 
548
  def _ui_search(project_id: str, query: str, k: int):
549
  try:
550
  res = search(SearchRequest(project_id=project_id, query=query, k=int(k)))
551
  return json.dumps(res, ensure_ascii=False, indent=2)
552
  except Exception as e:
553
+ return f"❌ Erreur : {e}"
 
 
 
554
 
555
  with gr.Blocks(title="Remote Indexer (Async – Optimisé)", analytics_enabled=False) as ui:
556
+ gr.Markdown("## Remote Indexer Async (FAISS quantisé, mmap, texte optionnel)")
557
  with gr.Row():
558
  pid = gr.Textbox(label="Project ID", value="DEMO")
559
  txt = gr.Textbox(label="Texte d’exemple", lines=4, value="Alpha bravo charlie delta echo foxtrot.")
 
568
  out_q = gr.Code(label="Résultats")
569
  btn_q.click(_ui_search, inputs=[pid, q, k], outputs=[out_q])
570
 
571
+ # Monte l’UI Gradio sur le même serveur FastAPI
572
  fastapi_app = gr.mount_gradio_app(fastapi_app, ui, path="/ui")
573
 
 
574
  # --------------------------------------------------------------------------- #
575
  # MAIN
576
  # --------------------------------------------------------------------------- #
577
  if __name__ == "__main__":
578
  import uvicorn
579
 
580
+ LOG.info("Démarrage Uvicorn – port %s – UI disponible à /ui", PORT)
 
581
  uvicorn.run(fastapi_app, host="0.0.0.0", port=PORT)