Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| from __future__ import annotations | |
| import os | |
| import io | |
| import json | |
| import time | |
| import tarfile | |
| import logging | |
| import hashlib | |
| from typing import Dict, Any, List, Tuple, Optional | |
| import numpy as np | |
| import faiss | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from pydantic import BaseModel | |
| import gradio as gr | |
| # ============================================================================= | |
| # LOGGING | |
| # ============================================================================= | |
| LOG = logging.getLogger("remote-indexer-space") | |
| if not LOG.handlers: | |
| h = logging.StreamHandler() | |
| h.setFormatter(logging.Formatter("%(asctime)s - %(levelname)s - %(message)s")) | |
| LOG.addHandler(h) | |
| LOG.setLevel(logging.INFO) | |
| # ============================================================================= | |
| # CONFIG (via ENV) | |
| # ============================================================================= | |
| PORT = int(os.getenv("PORT", "7860")) | |
| DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data") # persistant dans le conteneur Space | |
| os.makedirs(DATA_ROOT, exist_ok=True) | |
| # Provider d'embeddings: | |
| # - "dummy" : vecteurs aléatoires déterministes (très rapide) | |
| # - "st" : Sentence-Transformers (CPU-friendly, simple) | |
| # - "hf" : Transformers (AutoModel/AutoTokenizer, pooling manuel) | |
| EMB_PROVIDER = os.getenv("EMB_PROVIDER", "dummy").strip().lower() | |
| # Modèle embeddings (utilisé si provider != "dummy") | |
| # Reco rapide et multilingue (FR ok) : paraphrase-multilingual-MiniLM-L12-v2 (dim=384) | |
| EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2").strip() | |
| # Batch d'encodage | |
| EMB_BATCH = int(os.getenv("EMB_BATCH", "32")) | |
| # Dimension par défaut (dummy) — pour st/hf on lit depuis le modèle | |
| EMB_DIM = int(os.getenv("EMB_DIM", "128")) | |
| # Cache global lazy | |
| _ST_MODEL = None | |
| _HF_TOKENIZER = None | |
| _HF_MODEL = None | |
| # ============================================================================= | |
| # JOB STATE | |
| # ============================================================================= | |
| class JobState(BaseModel): | |
| job_id: str | |
| project_id: str | |
| stage: str = "pending" # pending -> chunking -> embedding -> indexing -> done/failed | |
| total_files: int = 0 | |
| total_chunks: int = 0 | |
| embedded: int = 0 | |
| indexed: int = 0 | |
| errors: List[str] = [] | |
| messages: List[str] = [] | |
| started_at: float = time.time() | |
| finished_at: Optional[float] = None | |
| JOBS: Dict[str, JobState] = {} | |
| def _now() -> str: | |
| return time.strftime("%H:%M:%S") | |
| def _proj_dirs(project_id: str) -> Tuple[str, str, str]: | |
| base = os.path.join(DATA_ROOT, project_id) | |
| ds_dir = os.path.join(base, "dataset") | |
| fx_dir = os.path.join(base, "faiss") | |
| os.makedirs(ds_dir, exist_ok=True) | |
| os.makedirs(fx_dir, exist_ok=True) | |
| return base, ds_dir, fx_dir | |
| def _add_msg(st: JobState, msg: str): | |
| st.messages.append(f"[{_now()}] {msg}") | |
| LOG.info("[%s] %s", st.job_id, msg) | |
| def _set_stage(st: JobState, stage: str): | |
| st.stage = stage | |
| _add_msg(st, f"stage={stage}") | |
| # ============================================================================= | |
| # UTILS | |
| # ============================================================================= | |
| def _chunk_text(text: str, size: int = 200, overlap: int = 20) -> List[str]: | |
| text = (text or "").replace("\r\n", "\n") | |
| tokens = list(text) | |
| if size <= 0: | |
| return [text] if text else [] | |
| if overlap < 0: | |
| overlap = 0 | |
| chunks = [] | |
| i = 0 | |
| while i < len(tokens): | |
| j = min(i + size, len(tokens)) | |
| chunk = "".join(tokens[i:j]).strip() | |
| if chunk: | |
| chunks.append(chunk) | |
| if j == len(tokens): | |
| break | |
| i = j - overlap if (j - overlap) > i else j | |
| return chunks | |
| def _l2_normalize(x: np.ndarray) -> np.ndarray: | |
| n = np.linalg.norm(x, axis=1, keepdims=True) + 1e-12 | |
| return x / n | |
| # ----------------------- PROVIDER: DUMMY -------------------------------------- | |
| def _emb_dummy(texts: List[str], dim: int = EMB_DIM) -> np.ndarray: | |
| vecs = np.zeros((len(texts), dim), dtype="float32") | |
| for i, t in enumerate(texts): | |
| h = hashlib.sha1((t or "").encode("utf-8")).digest() | |
| rng = np.random.default_rng(int.from_bytes(h[:8], "little", signed=False)) | |
| v = rng.standard_normal(dim).astype("float32") | |
| vecs[i] = v / (np.linalg.norm(v) + 1e-9) | |
| return vecs | |
| # ----------------- PROVIDER: Sentence-Transformers ---------------------------- | |
| def _get_st_model(): | |
| global _ST_MODEL | |
| if _ST_MODEL is None: | |
| from sentence_transformers import SentenceTransformer | |
| _ST_MODEL = SentenceTransformer(EMB_MODEL) | |
| LOG.info(f"[st] modèle chargé: {EMB_MODEL}") | |
| return _ST_MODEL | |
| def _emb_st(texts: List[str]) -> np.ndarray: | |
| model = _get_st_model() | |
| vecs = model.encode( | |
| texts, | |
| batch_size=max(1, EMB_BATCH), | |
| convert_to_numpy=True, | |
| normalize_embeddings=True, | |
| show_progress_bar=False, | |
| ).astype("float32") | |
| return vecs | |
| def _st_dim() -> int: | |
| model = _get_st_model() | |
| try: | |
| return int(model.get_sentence_embedding_dimension()) | |
| except Exception: | |
| # fallback : encode une phrase et lit la shape | |
| v = model.encode(["dimension probe"], convert_to_numpy=True) | |
| return int(v.shape[1]) | |
| # ----------------------- PROVIDER: Transformers (HF) -------------------------- | |
| def _get_hf_model(): | |
| global _HF_TOKENIZER, _HF_MODEL | |
| if _HF_MODEL is None or _HF_TOKENIZER is None: | |
| from transformers import AutoTokenizer, AutoModel | |
| _HF_TOKENIZER = AutoTokenizer.from_pretrained(EMB_MODEL) | |
| _HF_MODEL = AutoModel.from_pretrained(EMB_MODEL) | |
| _HF_MODEL.eval() | |
| LOG.info(f"[hf] modèle chargé: {EMB_MODEL}") | |
| return _HF_TOKENIZER, _HF_MODEL | |
| def _mean_pool(last_hidden_state: "np.ndarray", attention_mask: "np.ndarray") -> "np.ndarray": | |
| # mean pooling masquée | |
| mask = attention_mask[..., None].astype(last_hidden_state.dtype) # (b, t, 1) | |
| summed = (last_hidden_state * mask).sum(axis=1) # (b, h) | |
| counts = mask.sum(axis=1).clip(min=1e-9) # (b, 1) | |
| return summed / counts | |
| def _emb_hf(texts: List[str]) -> np.ndarray: | |
| import torch | |
| tok, mod = _get_hf_model() | |
| all_vecs = [] | |
| bs = max(1, EMB_BATCH) | |
| with torch.no_grad(): | |
| for i in range(0, len(texts), bs): | |
| batch = texts[i:i+bs] | |
| enc = tok(batch, padding=True, truncation=True, return_tensors="pt") | |
| out = mod(**enc) | |
| last = out.last_hidden_state # (b, t, h) | |
| pooled = _mean_pool(last.numpy(), enc["attention_mask"].numpy()) # numpy | |
| all_vecs.append(pooled.astype("float32")) | |
| vecs = np.concatenate(all_vecs, axis=0) | |
| return _l2_normalize(vecs) | |
| def _hf_dim() -> int: | |
| # essaie de lire hidden_size | |
| try: | |
| _, mod = _get_hf_model() | |
| return int(getattr(mod.config, "hidden_size", 768)) | |
| except Exception: | |
| return 768 | |
| # ---------------------------- DATASET / FAISS --------------------------------- | |
| def _save_dataset(ds_dir: str, rows: List[Dict[str, Any]]): | |
| os.makedirs(ds_dir, exist_ok=True) | |
| data_path = os.path.join(ds_dir, "data.jsonl") | |
| with open(data_path, "w", encoding="utf-8") as f: | |
| for r in rows: | |
| f.write(json.dumps(r, ensure_ascii=False) + "\n") | |
| meta = {"format": "jsonl", "columns": ["path", "text", "chunk_id"], "count": len(rows)} | |
| with open(os.path.join(ds_dir, "meta.json"), "w", encoding="utf-8") as f: | |
| json.dump(meta, f, ensure_ascii=False, indent=2) | |
| def _load_dataset(ds_dir: str) -> List[Dict[str, Any]]: | |
| data_path = os.path.join(ds_dir, "data.jsonl") | |
| if not os.path.isfile(data_path): | |
| return [] | |
| out = [] | |
| with open(data_path, "r", encoding="utf-8") as f: | |
| for line in f: | |
| try: | |
| out.append(json.loads(line)) | |
| except Exception: | |
| continue | |
| return out | |
| def _save_faiss(fx_dir: str, xb: np.ndarray, meta: Dict[str, Any]): | |
| os.makedirs(fx_dir, exist_ok=True) | |
| idx_path = os.path.join(fx_dir, "emb.faiss") | |
| index = faiss.IndexFlatIP(xb.shape[1]) # cosine ~ inner product si normalisé | |
| index.add(xb) | |
| faiss.write_index(index, idx_path) | |
| with open(os.path.join(fx_dir, "meta.json"), "w", encoding="utf-8") as f: | |
| json.dump(meta, f, ensure_ascii=False, indent=2) | |
| def _load_faiss(fx_dir: str) -> faiss.Index: | |
| idx_path = os.path.join(fx_dir, "emb.faiss") | |
| if not os.path.isfile(idx_path): | |
| raise FileNotFoundError(f"FAISS index introuvable: {idx_path}") | |
| return faiss.read_index(idx_path) | |
| def _tar_dir_to_bytes(dir_path: str) -> bytes: | |
| bio = io.BytesIO() | |
| with tarfile.open(fileobj=bio, mode="w:gz") as tar: | |
| tar.add(dir_path, arcname=os.path.basename(dir_path)) | |
| bio.seek(0) | |
| return bio.read() | |
| # ============================================================================= | |
| # FASTAPI | |
| # ============================================================================= | |
| fastapi_app = FastAPI(title="remote-indexer", version="2.0.0") | |
| fastapi_app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"], | |
| ) | |
| class FileItem(BaseModel): | |
| path: str | |
| text: str | |
| class IndexRequest(BaseModel): | |
| project_id: str | |
| files: List[FileItem] | |
| chunk_size: int = 200 | |
| overlap: int = 20 | |
| batch_size: int = 32 | |
| store_text: bool = True | |
| def health(): | |
| info = { | |
| "ok": True, | |
| "service": "remote-indexer", | |
| "provider": EMB_PROVIDER, | |
| "model": EMB_MODEL if EMB_PROVIDER != "dummy" else None | |
| } | |
| return info | |
| def root_redirect(): | |
| return {"ok": True, "service": "remote-indexer", "ui": "/ui"} | |
| def index(req: IndexRequest): | |
| job_id = hashlib.sha1(f"{req.project_id}{time.time()}".encode()).hexdigest()[:12] | |
| st = JobState(job_id=job_id, project_id=req.project_id, stage="pending", messages=[]) | |
| JOBS[job_id] = st | |
| _add_msg(st, f"Job {job_id} créé pour project {req.project_id}") | |
| _add_msg(st, f"Index start project={req.project_id} files={len(req.files)} chunk_size={req.chunk_size} overlap={req.overlap} batch_size={req.batch_size} store_text={req.store_text} provider={EMB_PROVIDER} model={EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}") | |
| try: | |
| base, ds_dir, fx_dir = _proj_dirs(req.project_id) | |
| # 1) Chunking | |
| _set_stage(st, "chunking") | |
| rows: List[Dict[str, Any]] = [] | |
| st.total_files = len(req.files) | |
| for it in req.files: | |
| txt = it.text or "" | |
| chunks = _chunk_text(txt, size=req.chunk_size, overlap=req.overlap) | |
| _add_msg(st, f"{it.path}: len(text)={len(txt)} chunks={len(chunks)}") | |
| for ci, ck in enumerate(chunks): | |
| rows.append({"path": it.path, "text": ck, "chunk_id": ci}) | |
| st.total_chunks = len(rows) | |
| _add_msg(st, f"Total chunks = {st.total_chunks}") | |
| # 2) Embedding | |
| _set_stage(st, "embedding") | |
| if EMB_PROVIDER == "dummy": | |
| xb = _emb_dummy([r["text"] for r in rows], dim=EMB_DIM) | |
| dim = xb.shape[1] | |
| elif EMB_PROVIDER == "st": | |
| xb = _emb_st([r["text"] for r in rows]) | |
| dim = xb.shape[1] | |
| else: # "hf" | |
| xb = _emb_hf([r["text"] for r in rows]) | |
| dim = xb.shape[1] | |
| st.embedded = xb.shape[0] | |
| _add_msg(st, f"Embeddings {st.embedded}/{st.total_chunks}") | |
| _add_msg(st, f"Embeddings dim={dim}") | |
| # 3) Sauvegarde dataset (texte) | |
| _save_dataset(ds_dir, rows) | |
| _add_msg(st, f"Dataset (sans index) sauvegardé dans {ds_dir}") | |
| # 4) FAISS | |
| _set_stage(st, "indexing") | |
| faiss_meta = { | |
| "dim": int(dim), | |
| "count": int(xb.shape[0]), | |
| "provider": EMB_PROVIDER, | |
| "model": EMB_MODEL if EMB_PROVIDER != "dummy" else None | |
| } | |
| _save_faiss(fx_dir, xb, meta=faiss_meta) | |
| st.indexed = int(xb.shape[0]) | |
| _add_msg(st, f"FAISS écrit sur {os.path.join(fx_dir, 'emb.faiss')}") | |
| _add_msg(st, f"OK — dataset+index prêts (projet={req.project_id})") | |
| _set_stage(st, "done") | |
| st.finished_at = time.time() | |
| return {"job_id": job_id} | |
| except Exception as e: | |
| LOG.exception("index failed") | |
| st.errors.append(str(e)) | |
| _add_msg(st, f"❌ Exception: {e}") | |
| st.stage = "failed" | |
| st.finished_at = time.time() | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def status(job_id: str): | |
| st = JOBS.get(job_id) | |
| if not st: | |
| raise HTTPException(status_code=404, detail="job inconnu") | |
| return JSONResponse(st.model_dump()) | |
| class SearchRequest(BaseModel): | |
| project_id: str | |
| query: str | |
| k: int = 5 | |
| def search(req: SearchRequest): | |
| base, ds_dir, fx_dir = _proj_dirs(req.project_id) | |
| rows = _load_dataset(ds_dir) | |
| if not rows: | |
| raise HTTPException(status_code=404, detail="dataset introuvable (index pas encore construit ?)") | |
| # Embedding de la requête avec le MÊME provider | |
| if EMB_PROVIDER == "dummy": | |
| q = _emb_dummy([req.query], dim=EMB_DIM)[0:1, :] | |
| elif EMB_PROVIDER == "st": | |
| q = _emb_st([req.query])[0:1, :] | |
| else: | |
| q = _emb_hf([req.query])[0:1, :] | |
| # FAISS | |
| index = _load_faiss(fx_dir) | |
| if index.d != q.shape[1]: | |
| raise HTTPException(status_code=500, detail=f"dim incompatibles: index.d={index.d} vs query={q.shape[1]}") | |
| scores, ids = index.search(q, int(max(1, req.k))) | |
| ids = ids[0].tolist() | |
| scores = scores[0].tolist() | |
| out = [] | |
| for idx, sc in zip(ids, scores): | |
| if idx < 0 or idx >= len(rows): | |
| continue | |
| r = rows[idx] | |
| out.append({"path": r.get("path"), "text": r.get("text"), "score": float(sc)}) | |
| return {"results": out} | |
| # ----------- ARTIFACTS EXPORT ----------- | |
| def download_dataset(project_id: str): | |
| base, ds_dir, _ = _proj_dirs(project_id) | |
| if not os.path.isdir(ds_dir): | |
| raise HTTPException(status_code=404, detail="Dataset introuvable") | |
| buf = _tar_dir_to_bytes(ds_dir) | |
| headers = {"Content-Disposition": f'attachment; filename="{project_id}_dataset.tgz"'} | |
| return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=headers) | |
| def download_faiss(project_id: str): | |
| base, _, fx_dir = _proj_dirs(project_id) | |
| if not os.path.isdir(fx_dir): | |
| raise HTTPException(status_code=404, detail="FAISS introuvable") | |
| buf = _tar_dir_to_bytes(fx_dir) | |
| headers = {"Content-Disposition": f'attachment; filename="{project_id}_faiss.tgz"'} | |
| return StreamingResponse(io.BytesIO(buf), media_type="application/gzip", headers=headers) | |
| # ============================================================================= | |
| # GRADIO UI (facultatif) | |
| # ============================================================================= | |
| def _ui_index(project_id: str, sample_text: str): | |
| files = [{"path": "sample.txt", "text": sample_text}] | |
| from pydantic import ValidationError | |
| try: | |
| req = IndexRequest(project_id=project_id, files=[FileItem(**f) for f in files]) | |
| except ValidationError as e: | |
| return f"Erreur: {e}" | |
| try: | |
| res = index(req) | |
| return f"Job lancé: {res['job_id']}" | |
| except Exception as e: | |
| return f"Erreur index: {e}" | |
| def _ui_search(project_id: str, query: str, k: int): | |
| try: | |
| res = search(SearchRequest(project_id=project_id, query=query, k=int(k))) | |
| return json.dumps(res, ensure_ascii=False, indent=2) | |
| except Exception as e: | |
| return f"Erreur search: {e}" | |
| with gr.Blocks(title="Remote Indexer (FAISS)", analytics_enabled=False) as ui: | |
| gr.Markdown("## Remote Indexer — demo UI (API: `/index`, `/status/{job}`, `/search`, `/artifacts/...`).") | |
| gr.Markdown(f"**Provider**: `{EMB_PROVIDER}` — **Model**: `{EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}'") | |
| with gr.Tab("Index"): | |
| pid = gr.Textbox(label="Project ID", value="DEEPWEB") | |
| sample = gr.Textbox(label="Texte d’exemple", value="Alpha bravo charlie delta echo foxtrot.", lines=4) | |
| btn = gr.Button("Lancer index (sample)") | |
| out = gr.Textbox(label="Résultat") | |
| btn.click(_ui_index, inputs=[pid, sample], outputs=[out]) | |
| with gr.Tab("Search"): | |
| pid2 = gr.Textbox(label="Project ID", value="DEEPWEB") | |
| q = gr.Textbox(label="Query", value="alpha") | |
| k = gr.Slider(1, 20, value=5, step=1, label="k") | |
| btn2 = gr.Button("Rechercher") | |
| out2 = gr.Code(label="Résultats") | |
| btn2.click(_ui_search, inputs=[pid2, q, k], outputs=[out2]) | |
| fastapi_app = gr.mount_gradio_app(fastapi_app, ui, path="/ui") | |
| # ============================================================================= | |
| # MAIN | |
| # ============================================================================= | |
| if __name__ == "__main__": | |
| import uvicorn | |
| LOG.info("Démarrage Uvicorn sur 0.0.0.0:%s (UI_PATH=/ui)", PORT) | |
| uvicorn.run(fastapi_app, host="0.0.0.0", port=PORT) |