DeepIndex / main.py
chouchouvs's picture
Update main.py
6eb5a6e verified
raw
history blame
23.2 kB
# -*- coding: utf-8 -*-
"""
Version optimisée du module FAISS :
- Réduction de la dimension des vecteurs (EMB_DIM, configurable)
- Index quantisé **IVF‑PQ** (faible empreinte disque)
- Chargement *on‑disk* (mmap) pour limiter la RAM
- Option `store_text` : ne pas persister le texte brut dans le dataset
- Compression gzip des artefacts exportés
- Paramètres contrôlables via variables d’environnement
"""
from __future__ import annotations
import os
import io
import json
import time
import tarfile
import logging
import hashlib
from typing import List, Dict, Any, Tuple, Optional
import numpy as np
import faiss
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
# --------------------------------------------------------------------------- #
# CONFIGURATION (variables d’environnement – modifiable à la volée)
# --------------------------------------------------------------------------- #
EMB_PROVIDER = os.getenv("EMB_PROVIDER", "dummy").strip().lower()
EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/all-mpnet-base-v2").strip()
EMB_BATCH = int(os.getenv("EMB_BATCH", "32"))
EMB_DIM = int(os.getenv("EMB_DIM", "64")) # ← dimension réduite (ex. 64)
# FAISS quantisation
FAISS_TYPE = os.getenv("FAISS_TYPE", "IVF_PQ").upper() # FLAT ou IVF_PQ
FAISS_NLIST = int(os.getenv("FAISS_NLIST", "100")) # nb de centroides (IVF)
FAISS_M = int(os.getenv("FAISS_M", "8")) # sous‑vecteurs (PQ)
FAISS_NBITS = int(os.getenv("FAISS_NBITS", "8")) # bits / sous‑vecteur
# Stockage du texte brut dans le dataset ? (False → économise disque)
STORE_TEXT = os.getenv("STORE_TEXT", "false").lower() in ("1", "true", "yes")
# --------------------------------------------------------------------------- #
# LOGGING
# --------------------------------------------------------------------------- #
LOG = logging.getLogger("appli_v1")
if not LOG.handlers:
h = logging.StreamHandler()
h.setFormatter(logging.Formatter("[%(levelname)s] %(asctime)s - %(message)s", "%H:%M:%S"))
LOG.addHandler(h)
LOG.setLevel(logging.INFO)
# --------------------------------------------------------------------------- #
# UTILITAIRES
# --------------------------------------------------------------------------- #
def list_repo_files(repo_dir: str, top_k: int = 500) -> List[str]:
"""
Retourne la liste des fichiers texte du dépôt, en respectant .gitignore
(via Git si disponible, sinon fallback os.walk).
"""
if not os.path.isdir(repo_dir):
return []
files: List[str] = []
try:
from git import Repo
repo = Repo(repo_dir)
# fichiers trackés
tracked = repo.git.ls_files().splitlines()
files.extend(tracked)
# fichiers non‑trackés mais non ignorés
untracked = repo.git.ls_files(others=True, exclude_standard=True).splitlines()
files.extend(untracked)
# filtrage simple
files = [
f for f in files
if not f.startswith('.git/') and not any(p.startswith('.') for p in f.split(os.sep))
]
files = sorted(set(files))[:top_k]
except Exception as e:
LOG.debug("Git indisponible / pas un dépôt → fallback os.walk : %s", e)
for root, _, names in os.walk(repo_dir):
for name in sorted(names):
if name.startswith('.'):
continue
rel = os.path.relpath(os.path.join(root, name), repo_dir)
if rel.startswith('.git') or any(p.startswith('.') for p in rel.split(os.sep)):
continue
files.append(rel)
if len(files) >= top_k:
break
if len(files) >= top_k:
break
files = sorted(set(files))
return files
def read_file_safe(file_path: str) -> str:
"""Lit un fichier en UTF‑8, ignore les erreurs."""
try:
with open(file_path, "r", encoding="utf-8", errors="ignore") as f:
return f.read()
except Exception as e:
LOG.error("Erreur lecture %s : %s", file_path, e)
return f"# Erreur lecture : {e}"
def write_file_safe(file_path: str, content: str) -> str:
"""Écrit un fichier, crée les dossiers parents si besoin."""
try:
os.makedirs(os.path.dirname(file_path), exist_ok=True)
with open(file_path, "w", encoding="utf-8") as f:
f.write(content)
return f"✅ Fichier sauvegardé : {os.path.basename(file_path)}"
except Exception as e:
LOG.error("Erreur écriture %s : %s", file_path, e)
return f"❌ Erreur sauvegarde : {e}"
# --------------------------------------------------------------------------- #
# FAKE / DUMMY FAISS (pour compatibilité)
# --------------------------------------------------------------------------- #
class DummyFAISS:
"""Classe factice – aucune fonctionnalité réelle."""
pass
def create_faiss_index(*_, **__) -> DummyFAISS:
LOG.warning("FAISS désactivé – utilisation du client distant")
return DummyFAISS()
def search_faiss_index(*_, **__) -> List[Any]:
LOG.warning("FAISS désactivé – utilisation du client distant")
return []
# --------------------------------------------------------------------------- #
# EMBEDDING PROVIDERS
# --------------------------------------------------------------------------- #
_ST_MODEL: Optional[Any] = None
_HF_TOKENIZER: Optional[Any] = None
_HF_MODEL: Optional[Any] = None
def _emb_dummy(texts: List[str], dim: int = EMB_DIM) -> np.ndarray:
"""Vecteurs aléatoires déterministes (SHA‑1 → seed)."""
vecs = np.zeros((len(texts), dim), dtype="float32")
for i, t in enumerate(texts):
h = hashlib.sha1((t or "").encode("utf-8")).digest()
rng = np.random.default_rng(int.from_bytes(h[:8], "little", signed=False))
v = rng.standard_normal(dim).astype("float32")
vecs[i] = v / (np.linalg.norm(v) + 1e-9)
return vecs
def _get_st_model():
global _ST_MODEL
if _ST_MODEL is None:
from sentence_transformers import SentenceTransformer
_ST_MODEL = SentenceTransformer(EMB_MODEL, cache_folder=os.getenv("HF_HOME", "/tmp/.cache/huggingface"))
LOG.info("[st] modèle chargé : %s", EMB_MODEL)
return _ST_MODEL
def _emb_st(texts: List[str]) -> np.ndarray:
model = _get_st_model()
vecs = model.encode(
texts,
batch_size=max(1, EMB_BATCH),
convert_to_numpy=True,
normalize_embeddings=True,
show_progress_bar=False,
).astype("float32")
return vecs
def _get_hf_model():
global _HF_TOKENIZER, _HF_MODEL
if _HF_MODEL is None or _HF_TOKENIZER is None:
from transformers import AutoTokenizer, AutoModel
_HF_TOKENIZER = AutoTokenizer.from_pretrained(EMB_MODEL, cache_dir=os.getenv("HF_HOME", "/tmp/.cache/huggingface"))
_HF_MODEL = AutoModel.from_pretrained(EMB_MODEL, cache_dir=os.getenv("HF_HOME", "/tmp/.cache/huggingface"))
_HF_MODEL.eval()
LOG.info("[hf] modèle chargé : %s", EMB_MODEL)
return _HF_TOKENIZER, _HF_MODEL
def _mean_pool(last_hidden_state: np.ndarray, attention_mask: np.ndarray) -> np.ndarray:
mask = attention_mask[..., None].astype(last_hidden_state.dtype)
summed = (last_hidden_state * mask).sum(axis=1)
counts = mask.sum(axis=1).clip(min=1e-9)
return summed / counts
def _emb_hf(texts: List[str]) -> np.ndarray:
import torch
tok, mod = _get_hf_model()
all_vecs: List[np.ndarray] = []
bs = max(1, EMB_BATCH)
with torch.no_grad():
for i in range(0, len(texts), bs):
batch = texts[i:i + bs]
enc = tok(batch, padding=True, truncation=True, return_tensors="pt")
out = mod(**enc)
last = out.last_hidden_state # (b, t, h)
pooled = _mean_pool(last.numpy(), enc["attention_mask"].numpy())
all_vecs.append(pooled.astype("float32"))
return np.concatenate(all_vecs, axis=0)
def _reduce_dim(vectors: np.ndarray, target_dim: int = EMB_DIM) -> np.ndarray:
"""PCA simple pour réduire la dimension (si target_dim < current)."""
if target_dim >= vectors.shape[1]:
return vectors
from sklearn.decomposition import PCA
pca = PCA(n_components=target_dim, random_state=0)
return pca.fit_transform(vectors).astype("float32")
# --------------------------------------------------------------------------- #
# DATASET / FAISS I/O
# --------------------------------------------------------------------------- #
def _save_dataset(ds_dir: str, rows: List[Dict[str, Any]], store_text: bool = STORE_TEXT) -> None:
"""Sauvegarde le dataset au format JSONL (optionnellement sans le texte)."""
os.makedirs(ds_dir, exist_ok=True)
data_path = os.path.join(ds_dir, "data.jsonl")
with open(data_path, "w", encoding="utf-8") as f:
for r in rows:
if not store_text:
r = {k: v for k, v in r.items() if k != "text"}
f.write(json.dumps(r, ensure_ascii=False) + "\n")
meta = {"format": "jsonl", "columns": ["path", "text", "chunk_id"], "count": len(rows)}
with open(os.path.join(ds_dir, "meta.json"), "w", encoding="utf-8") as f:
json.dump(meta, f, ensure_ascii=False, indent=2)
def _load_dataset(ds_dir: str) -> List[Dict[str, Any]]:
data_path = os.path.join(ds_dir, "data.jsonl")
if not os.path.isfile(data_path):
return []
out: List[Dict[str, Any]] = []
with open(data_path, "r", encoding="utf-8") as f:
for line in f:
try:
out.append(json.loads(line))
except Exception:
continue
return out
def _save_faiss(fx_dir: str, xb: np.ndarray, meta: Dict[str, Any]) -> None:
"""Sauvegarde un index FAISS quantisé (IVF‑PQ) ou plat selon FAISS_TYPE."""
os.makedirs(fx_dir, exist_ok=True)
idx_path = os.path.join(fx_dir, "emb.faiss")
if FAISS_TYPE == "IVF_PQ":
# ---- IVF‑PQ ---------------------------------------------------------
quantizer = faiss.IndexFlatIP(xb.shape[1]) # base (inner‑product ≈ cosine)
index = faiss.IndexIVFPQ(quantizer, xb.shape[1], FAISS_NLIST, FAISS_M, FAISS_NBITS)
# entraînement sur un sous‑échantillon (max 10 k vecteurs)
rng = np.random.default_rng(0)
train = xb[rng.choice(xb.shape[0], min(10_000, xb.shape[0]), replace=False)]
index.train(train)
index.add(xb)
meta.update({
"index_type": "IVF_PQ",
"nlist": FAISS_NLIST,
"m": FAISS_M,
"nbits": FAISS_NBITS,
})
else: # FLAT (fallback)
index = faiss.IndexFlatIP(xb.shape[1])
index.add(xb)
meta.update({"index_type": "FLAT"})
faiss.write_index(index, idx_path)
# meta.json (inclut le type d’index)
with open(os.path.join(fx_dir, "meta.json"), "w", encoding="utf-8") as f:
json.dump(meta, f, ensure_ascii=False, indent=2)
def _load_faiss(fx_dir: str) -> faiss.Index:
"""Charge l’index en mode mmap (lecture à la volée)."""
idx_path = os.path.join(fx_dir, "emb.faiss")
if not os.path.isfile(idx_path):
raise FileNotFoundError(f"FAISS index introuvable : {idx_path}")
# mmap minimise la RAM utilisée
return faiss.read_index(idx_path, faiss.IO_FLAG_MMAP)
def _tar_dir_to_bytes(dir_path: str) -> bytes:
"""Archive gzip du répertoire (compression maximale)."""
bio = io.BytesIO()
with tarfile.open(fileobj=bio, mode="w:gz", compresslevel=9) as tar:
tar.add(dir_path, arcname=os.path.basename(dir_path))
bio.seek(0)
return bio.read()
# --------------------------------------------------------------------------- #
# WORKER POOL (asynchrone)
# --------------------------------------------------------------------------- #
from concurrent.futures import ThreadPoolExecutor
MAX_WORKERS = max(1, int(os.getenv("MAX_WORKERS", "1")))
EXECUTOR = ThreadPoolExecutor(max_workers=MAX_WORKERS)
LOG.info("ThreadPoolExecutor initialisé : max_workers=%s", MAX_WORKERS)
def _proj_dirs(project_id: str) -> Tuple[str, str, str]:
base = os.path.join(os.getenv("DATA_ROOT", "/tmp/data"), project_id)
ds_dir = os.path.join(base, "dataset")
fx_dir = os.path.join(base, "faiss")
os.makedirs(ds_dir, exist_ok=True)
os.makedirs(fx_dir, exist_ok=True)
return base, ds_dir, fx_dir
def _do_index_job(
st: "JobState",
files: List[Dict[str, str]],
chunk_size: int,
overlap: int,
batch_size: int,
store_text: bool,
) -> None:
"""
Pipeline complet :
1️⃣ Chunking
2️⃣ Embedding (dummy / st / hf)
3️⃣ Réduction de dimension (PCA) si EMB_DIM < dim du modèle
4️⃣ Sauvegarde dataset (optionnel texte)
5️⃣ Index FAISS quantisé + mmap
"""
try:
base, ds_dir, fx_dir = _proj_dirs(st.project_id)
# ------------------------------------------------------------------- #
# 1️⃣ Chunking
# ------------------------------------------------------------------- #
rows: List[Dict[str, Any]] = []
st.total_files = len(files)
for f in files:
path = (f.get("path") or "unknown").strip()
txt = f.get("text") or ""
chunks = _chunk_text(txt, size=chunk_size, overlap=overlap)
for i, ck in enumerate(chunks):
rows.append({"path": path, "text": ck, "chunk_id": i})
st.total_chunks = len(rows)
LOG.info("Chunking terminé : %d chunks", st.total_chunks)
# ------------------------------------------------------------------- #
# 2️⃣ Embedding
# ------------------------------------------------------------------- #
texts = [r["text"] for r in rows]
if EMB_PROVIDER == "dummy":
xb = _emb_dummy(texts, dim=EMB_DIM)
elif EMB_PROVIDER == "st":
xb = _emb_st(texts)
else:
xb = _emb_hf(texts)
# ------------------------------------------------------------------- #
# 3️⃣ Réduction de dimension (si nécessaire)
# ------------------------------------------------------------------- #
if xb.shape[1] != EMB_DIM:
xb = _reduce_dim(xb, target_dim=EMB_DIM)
st.embedded = xb.shape[0]
LOG.info("Embedding terminé : %d vecteurs (dim=%d)", st.embedded, xb.shape[1])
# ------------------------------------------------------------------- #
# 4️⃣ Sauvegarde du dataset
# ------------------------------------------------------------------- #
_save_dataset(ds_dir, rows, store_text=store_text)
# ------------------------------------------------------------------- #
# 5️⃣ Index FAISS
# ------------------------------------------------------------------- #
meta = {
"dim": int(xb.shape[1]),
"count": int(xb.shape[0]),
"provider": EMB_PROVIDER,
"model": EMB_MODEL if EMB_PROVIDER != "dummy" else None,
}
_save_faiss(fx_dir, xb, meta)
st.indexed = int(xb.shape[0])
LOG.info("FAISS (%s) écrit : %s", FAISS_TYPE, os.path.join(fx_dir, "emb.faiss"))
# ------------------------------------------------------------------- #
# Finalisation
# ------------------------------------------------------------------- #
st.stage = "done"
st.finished_at = time.time()
except Exception as e:
LOG.exception("Job %s échoué", st.job_id)
st.errors.append(str(e))
st.stage = "failed"
st.finished_at = time.time()
def _submit_job(
project_id: str,
files: List[Dict[str, str]],
chunk_size: int,
overlap: int,
batch_size: int,
store_text: bool,
) -> str:
job_id = hashlib.sha1(f"{project_id}{time.time()}".encode()).hexdigest()[:12]
st = JobState(job_id=job_id, project_id=project_id, stage="pending", messages=[])
JOBS[job_id] = st
LOG.info("Job %s créé – %d fichiers", job_id, len(files))
EXECUTOR.submit(
_do_index_job,
st,
files,
chunk_size,
overlap,
batch_size,
store_text,
)
st.stage = "queued"
return job_id
# --------------------------------------------------------------------------- #
# FASTAPI
# --------------------------------------------------------------------------- #
fastapi_app = FastAPI(title="remote-indexer-async", version="3.0.0")
fastapi_app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
class FileItem(BaseModel):
path: str
text: str
class IndexRequest(BaseModel):
project_id: str
files: List[FileItem]
chunk_size: int = 200
overlap: int = 20
batch_size: int = 32
store_text: bool = STORE_TEXT # ← configurable
@fastapi_app.get("/health")
def health():
return {
"ok": True,
"service": "remote-indexer-async",
"provider": EMB_PROVIDER,
"model": EMB_MODEL if EMB_PROVIDER != "dummy" else None,
"cache_root": os.getenv("CACHE_ROOT", "/tmp/.cache"),
"workers": MAX_WORKERS,
"data_root": os.getenv("DATA_ROOT", "/tmp/data"),
"faiss_type": FAISS_TYPE,
"emb_dim": EMB_DIM,
}
@fastapi_app.post("/index")
def index(req: IndexRequest):
try:
files = [fi.model_dump() for fi in req.files]
job_id = _submit_job(
project_id=req.project_id,
files=files,
chunk_size=int(req.chunk_size),
overlap=int(req.overlap),
batch_size=int(req.batch_size),
store_text=bool(req.store_text),
)
return {"job_id": job_id}
except Exception as e:
LOG.exception("Erreur soumission index")
raise HTTPException(status_code=500, detail=str(e))
@fastapi_app.get("/status/{job_id}")
def status(job_id: str):
st = JOBS.get(job_id)
if not st:
raise HTTPException(status_code=404, detail="job inconnu")
return JSONResponse(st.model_dump())
class SearchRequest(BaseModel):
project_id: str
query: str
k: int = 5
@fastapi_app.post("/search")
def search(req: SearchRequest):
base, ds_dir, fx_dir = _proj_dirs(req.project_id)
# Vérifier la présence de l'index
if not (os.path.isfile(os.path.join(fx_dir, "emb.faiss")) and os.path.isfile(os.path.join(ds_dir, "data.jsonl"))):
raise HTTPException(status_code=409, detail="Index non prêt (reviens plus tard)")
rows = _load_dataset(ds_dir)
if not rows:
raise HTTPException(status_code=404, detail="dataset introuvable")
# Embedding de la requête (même provider)
if EMB_PROVIDER == "dummy":
q = _emb_dummy([req.query], dim=EMB_DIM)[0:1, :]
elif EMB_PROVIDER == "st":
q = _emb_st([req.query])[0:1, :]
else:
q = _emb_hf([req.query])[0:1, :]
# Recherche FAISS (mmap)
index = _load_faiss(fx_dir)
if index.d != q.shape[1]:
raise HTTPException(
status_code=500,
detail=f"dim incompatibles : index.d={index.d} vs query={q.shape[1]}",
)
scores, ids = index.search(q, int(max(1, req.k)))
ids = ids[0].tolist()
scores = scores[0].tolist()
out = []
for idx, sc in zip(ids, scores):
if idx < 0 or idx >= len(rows):
continue
r = rows[idx]
out.append({"path": r.get("path"), "text": r.get("text"), "score": float(sc)})
return {"results": out}
# --------------------------------------------------------------------------- #
# ARTIFACTS EXPORT (gzip)
# --------------------------------------------------------------------------- #
@fastapi_app.get("/artifacts/{project_id}/dataset")
def download_dataset(project_id: str):
_, ds_dir, _ = _proj_dirs(project_id)
if not os.path.isdir(ds_dir):
raise HTTPException(status_code=404, detail="Dataset introuvable")
buf = _tar_dir_to_bytes(ds_dir)
hdr = {"Content-Disposition": f'attachment; filename="{project_id}_dataset.tgz"'}
return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=hdr)
@fastapi_app.get("/artifacts/{project_id}/faiss")
def download_faiss(project_id: str):
_, _, fx_dir = _proj_dirs(project_id)
if not os.path.isdir(fx_dir):
raise HTTPException(status_code=404, detail="FAISS introuvable")
buf = _tar_dir_to_bytes(fx_dir)
hdr = {"Content-Disposition": f'attachment; filename="{project_id}_faiss.tgz"'}
return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=hdr)
# --------------------------------------------------------------------------- #
# GRADIO UI (facultatif – simple test)
# --------------------------------------------------------------------------- #
def _ui_index(project_id: str, sample_text: str):
files = [{"path": "sample.txt", "text": sample_text}]
try:
req = IndexRequest(project_id=project_id, files=[FileItem(**f) for f in files])
except Exception as e:
return f"❌ Erreur validation : {e}"
try:
res = index(req)
return f"✅ Job lancé : {res['job_id']}"
except Exception as e:
return f"❌ Erreur index : {e}"
def _ui_search(project_id: str, query: str, k: int):
try:
res = search(SearchRequest(project_id=project_id, query=query, k=int(k)))
return json.dumps(res, ensure_ascii=False, indent=2)
except Exception as e:
return f"❌ Erreur recherche : {e}"
import gradio as gr
with gr.Blocks(title="Remote Indexer (Async – Optimisé)", analytics_enabled=False) as ui:
gr.Markdown("## Remote Indexer – Optimisé (FAISS quantisé, mmap, texte optionnel)")
with gr.Row():
pid = gr.Textbox(label="Project ID", value="DEMO")
txt = gr.Textbox(label="Texte d’exemple", lines=4, value="Alpha bravo charlie delta echo foxtrot.")
btn_idx = gr.Button("Lancer index (sample)")
out_idx = gr.Textbox(label="Résultat")
btn_idx.click(_ui_index, inputs=[pid, txt], outputs=[out_idx])
with gr.Row():
q = gr.Textbox(label="Query", value="alpha")
k = gr.Slider(1, 20, value=5, step=1, label="Top‑K")
btn_q = gr.Button("Rechercher")
out_q = gr.Code(label="Résultats")
btn_q.click(_ui_search, inputs=[pid, q, k], outputs=[out_q])
fastapi_app = gr.mount_gradio_app(fastapi_app, ui, path="/ui")
# --------------------------------------------------------------------------- #
# MAIN
# --------------------------------------------------------------------------- #
if __name__ == "__main__":
import uvicorn
PORT = int(os.getenv("PORT", "7860"))
LOG.info("Démarrage Uvicorn – port %s – UI à /ui", PORT)
uvicorn.run(fastapi_app, host="0.0.0.0", port=PORT)