Spaces:
Running
Running
| # -*- coding: utf-8 -*- | |
| """ | |
| HF Space - main.py de substitution pour tests Qdrant / indexation minimale | |
| Endpoints: | |
| - POST /wipe?project_id=XXX | |
| - POST /index | |
| - GET /status/{job_id} | |
| - GET /collections/{project_id}/count | |
| - POST /query | |
| - GET /health <-- healthcheck OK | |
| UI Gradio montée sur "/" pour tests sans console. | |
| ENV attendues : | |
| - QDRANT_URL : https://...qdrant.io:6333 | |
| - QDRANT_API_KEY : clé Qdrant | |
| - COLLECTION_PREFIX : "proj_" par défaut | |
| - EMB_PROVIDER : "hf" (défaut) ou "dummy" | |
| - HF_EMBED_MODEL : "BAAI/bge-m3" par défaut | |
| - HUGGINGFACEHUB_API_TOKEN (si EMB_PROVIDER=hf) | |
| - LOG_LEVEL : DEBUG (défaut) | |
| - PORT : 7860 (fourni par HF) | |
| Dépendances suggérées : | |
| fastapi>=0.111, uvicorn>=0.30, httpx>=0.27, pydantic>=2.7, gradio>=4.43, numpy>=2.0 | |
| """ | |
| from __future__ import annotations | |
| import os | |
| import time | |
| import uuid | |
| import hashlib | |
| import logging | |
| import asyncio | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import numpy as np | |
| import httpx | |
| import uvicorn | |
| from pydantic import BaseModel, Field, ValidationError | |
| from fastapi import FastAPI, HTTPException, Query | |
| from fastapi.middleware.cors import CORSMiddleware | |
| import gradio as gr | |
| # ------------------------------------------------------------------------------ | |
| # Configuration & logs | |
| # ------------------------------------------------------------------------------ | |
| LOG_LEVEL = os.getenv("LOG_LEVEL", "DEBUG").upper() | |
| logging.basicConfig( | |
| level=getattr(logging, LOG_LEVEL, logging.DEBUG), | |
| format="%(asctime)s - %(levelname)s - %(message)s", | |
| ) | |
| LOG = logging.getLogger("remote_indexer_min") | |
| QDRANT_URL = os.getenv("QDRANT_URL", "").rstrip("/") | |
| QDRANT_API_KEY = os.getenv("QDRANT_API_KEY", "") | |
| COLLECTION_PREFIX = os.getenv("COLLECTION_PREFIX", "proj_").strip() or "proj_" | |
| EMB_PROVIDER = os.getenv("EMB_PROVIDER", "hf").lower() # "hf" | "dummy" | |
| HF_EMBED_MODEL = os.getenv("HF_EMBED_MODEL", "BAAI/bge-m3") | |
| HF_TOKEN = os.getenv("HUGGINGFACEHUB_API_TOKEN", "") | |
| if not QDRANT_URL or not QDRANT_API_KEY: | |
| LOG.warning("QDRANT_URL / QDRANT_API_KEY non fournis : l'upsert échouera.") | |
| if EMB_PROVIDER == "hf" and not HF_TOKEN: | |
| LOG.warning("EMB_PROVIDER=hf sans HUGGINGFACEHUB_API_TOKEN. Utilise EMB_PROVIDER=dummy pour tester sans token.") | |
| # ------------------------------------------------------------------------------ | |
| # Schémas Pydantic | |
| # ------------------------------------------------------------------------------ | |
| class FileItem(BaseModel): | |
| path: str | |
| text: str | |
| class IndexRequest(BaseModel): | |
| project_id: str = Field(..., min_length=1) | |
| files: List[FileItem] = Field(default_factory=list) | |
| chunk_size: int = Field(200, ge=64, le=4096) | |
| overlap: int = Field(20, ge=0, le=512) | |
| batch_size: int = Field(32, ge=1, le=1024) | |
| store_text: bool = True | |
| class QueryRequest(BaseModel): | |
| project_id: str | |
| text: str | |
| top_k: int = Field(5, ge=1, le=100) | |
| # ------------------------------------------------------------------------------ | |
| # Job store (en mémoire) | |
| # ------------------------------------------------------------------------------ | |
| class JobState(BaseModel): | |
| job_id: str | |
| project_id: str | |
| stage: str = "pending" # pending -> embedding -> upserting -> done/failed | |
| total_files: int = 0 | |
| total_chunks: int = 0 | |
| embedded: int = 0 | |
| upserted: int = 0 | |
| errors: List[str] = Field(default_factory=list) | |
| messages: List[str] = Field(default_factory=list) | |
| started_at: float = Field(default_factory=time.time) | |
| finished_at: Optional[float] = None | |
| def log(self, msg: str) -> None: | |
| stamp = time.strftime("%H:%M:%S") | |
| line = f"[{stamp}] {msg}" | |
| self.messages.append(line) | |
| LOG.debug(f"[{self.job_id}] {msg}") | |
| JOBS: Dict[str, JobState] = {} | |
| # ------------------------------------------------------------------------------ | |
| # Utilitaires | |
| # ------------------------------------------------------------------------------ | |
| def hash8(s: str) -> str: | |
| return hashlib.sha256(s.encode("utf-8")).hexdigest()[:16] | |
| def l2_normalize(vec: List[float]) -> List[float]: | |
| arr = np.array(vec, dtype=np.float32) | |
| n = float(np.linalg.norm(arr)) | |
| if n > 0: | |
| arr = arr / n | |
| return arr.astype(np.float32).tolist() | |
| def flatten_any(x: Any) -> List[float]: | |
| """Aplatis potentiels [[...]] ou [[[...]]] en 1D.""" | |
| if isinstance(x, (list, tuple)): | |
| if len(x) > 0 and isinstance(x[0], (list, tuple)): | |
| return flatten_any(x[0]) | |
| return list(map(float, x)) | |
| raise ValueError("Embedding vector mal formé") | |
| def chunk_text(text: str, chunk_size: int, overlap: int) -> List[Tuple[int, int, str]]: | |
| """Retourne [(start, end, chunk)] et ignore les fragments < 30 chars.""" | |
| text = text or "" | |
| if not text.strip(): | |
| return [] | |
| res = [] | |
| n = len(text) | |
| i = 0 | |
| while i < n: | |
| j = min(i + chunk_size, n) | |
| chunk = text[i:j] | |
| if len(chunk.strip()) >= 30: | |
| res.append((i, j, chunk)) | |
| i = j - overlap | |
| if i <= 0: | |
| i = j | |
| return res | |
| async def ensure_collection(client: httpx.AsyncClient, coll: str, vector_size: int) -> None: | |
| """Crée la collection Qdrant (distance=Cosine), ou la recrée si dim mismatch.""" | |
| url = f"{QDRANT_URL}/collections/{coll}" | |
| r = await client.get(url, headers={"api-key": QDRANT_API_KEY}, timeout=20) | |
| recreate = False | |
| if r.status_code == 200: | |
| data = r.json() | |
| existing_size = data.get("result", {}).get("vectors", {}).get("size") | |
| if existing_size and int(existing_size) != int(vector_size): | |
| LOG.warning(f"Collection {coll} dim={existing_size} ≠ attendu {vector_size} → recréation") | |
| await client.delete(url, headers={"api-key": QDRANT_API_KEY}, timeout=20) | |
| recreate = True | |
| else: | |
| LOG.debug(f"Collection {coll} existante (dim={existing_size})") | |
| if r.status_code != 200 or recreate: | |
| body = {"vectors": {"size": vector_size, "distance": "Cosine"}} | |
| r2 = await client.put(url, headers={"api-key": QDRANT_API_KEY}, json=body, timeout=30) | |
| if r2.status_code not in (200, 201): | |
| raise HTTPException(status_code=500, detail=f"Qdrant PUT collection a échoué: {r2.text}") | |
| async def qdrant_upsert(client: httpx.AsyncClient, coll: str, points: List[Dict[str, Any]]) -> int: | |
| if not points: | |
| return 0 | |
| url = f"{QDRANT_URL}/collections/{coll}/points?wait=true" | |
| body = {"points": points} | |
| r = await client.put(url, headers={"api-key": QDRANT_API_KEY}, json=body, timeout=60) | |
| if r.status_code not in (200, 202): | |
| raise HTTPException(status_code=500, detail=f"Qdrant upsert échoué: {r.text}") | |
| return len(points) | |
| async def qdrant_count(client: httpx.AsyncClient, coll: str) -> int: | |
| url = f"{QDRANT_URL}/collections/{coll}/points/count" | |
| r = await client.post(url, headers={"api-key": QDRANT_API_KEY}, json={"exact": True}, timeout=20) | |
| if r.status_code != 200: | |
| raise HTTPException(status_code=500, detail=f"Qdrant count échoué: {r.text}") | |
| return int(r.json().get("result", {}).get("count", 0)) | |
| async def qdrant_search(client: httpx.AsyncClient, coll: str, vector: List[float], limit: int = 5) -> Dict[str, Any]: | |
| url = f"{QDRANT_URL}/collections/{coll}/points/search" | |
| r = await client.post( | |
| url, | |
| headers={"api-key": QDRANT_API_KEY}, | |
| json={"vector": vector, "limit": limit, "with_payload": True}, | |
| timeout=30, | |
| ) | |
| if r.status_code != 200: | |
| raise HTTPException(status_code=500, detail=f"Qdrant search échoué: {r.text}") | |
| return r.json() | |
| # ------------------------------------------------------------------------------ | |
| # Embeddings (HF Inference ou dummy) | |
| # ------------------------------------------------------------------------------ | |
| async def embed_hf(client: httpx.AsyncClient, texts: List[str], model: str = HF_EMBED_MODEL, token: str = HF_TOKEN) -> List[List[float]]: | |
| if not token: | |
| raise HTTPException(status_code=400, detail="HUGGINGFACEHUB_API_TOKEN manquant pour EMB_PROVIDER=hf") | |
| url = f"https://api-inference.huggingface.co/models/{model}" | |
| headers = {"Authorization": f"Bearer {token}"} | |
| payload = {"inputs": texts, "options": {"wait_for_model": True}} | |
| r = await client.post(url, headers=headers, json=payload, timeout=120) | |
| if r.status_code != 200: | |
| raise HTTPException(status_code=502, detail=f"HF Inference error: {r.text}") | |
| data = r.json() | |
| embeddings: List[List[float]] = [] | |
| if isinstance(data, list): | |
| for row in data: | |
| vec = flatten_any(row) | |
| embeddings.append(l2_normalize(vec)) | |
| else: | |
| vec = flatten_any(data) | |
| embeddings.append(l2_normalize(vec)) | |
| return embeddings | |
| def embed_dummy(texts: List[str], dim: int = 128) -> List[List[float]]: | |
| out: List[List[float]] = [] | |
| for t in texts: | |
| h = hashlib.sha256(t.encode("utf-8")).digest() | |
| arr = np.frombuffer((h * ((dim // len(h)) + 1))[:dim], dtype=np.uint8).astype(np.float32) | |
| arr = (arr - 127.5) / 127.5 | |
| arr = arr / (np.linalg.norm(arr) + 1e-9) | |
| out.append(arr.astype(np.float32).tolist()) | |
| return out | |
| async def embed_texts(client: httpx.AsyncClient, texts: List[str]) -> List[List[float]]: | |
| if EMB_PROVIDER == "hf": | |
| return await embed_hf(client, texts) | |
| return embed_dummy(texts, dim=128) | |
| # ------------------------------------------------------------------------------ | |
| # Pipeline d'indexation | |
| # ------------------------------------------------------------------------------ | |
| async def run_index_job(job: JobState, req: IndexRequest) -> None: | |
| job.stage = "embedding" | |
| job.total_files = len(req.files) | |
| job.log(f"Index start project={req.project_id} files={len(req.files)} chunk_size={req.chunk_size} overlap={req.overlap} batch_size={req.batch_size} store_text={req.store_text}") | |
| # Dédup global par hash du texte de fichier | |
| file_hashes = [hash8(f.text) for f in req.files] | |
| uniq = len(set(file_hashes)) | |
| if uniq != len(file_hashes): | |
| job.log(f"Attention: {len(file_hashes)-uniq} fichier(s) ont un texte identique (hash dupliqué).") | |
| # Chunking | |
| records: List[Dict[str, Any]] = [] | |
| for f in req.files: | |
| chunks = chunk_text(f.text, req.chunk_size, req.overlap) | |
| if not chunks: | |
| job.log(f"{f.path}: 0 chunk (trop court ou vide)") | |
| for idx, (start, end, ch) in enumerate(chunks): | |
| payload = {"path": f.path, "chunk": idx, "start": start, "end": end} | |
| if req.store_text: | |
| payload["text"] = ch | |
| records.append({"payload": payload, "raw": ch}) | |
| job.total_chunks = len(records) | |
| job.log(f"Total chunks = {job.total_chunks}") | |
| if job.total_chunks == 0: | |
| job.stage = "failed" | |
| job.errors.append("Aucun chunk à indexer.") | |
| job.finished_at = time.time() | |
| return | |
| async with httpx.AsyncClient(timeout=120) as client: | |
| # Warmup dim | |
| warmup_vec = (await embed_texts(client, [records[0]["raw"]]))[0] | |
| vec_dim = len(warmup_vec) | |
| job.log(f"Warmup embeddings dim={vec_dim} provider={EMB_PROVIDER}") | |
| # Collection Qdrant | |
| coll = f"{COLLECTION_PREFIX}{req.project_id}" | |
| await ensure_collection(client, coll, vector_size=vec_dim) | |
| job.stage = "upserting" | |
| batch_points: List[Dict[str, Any]] = [] | |
| async def flush_batch(): | |
| nonlocal batch_points | |
| if not batch_points: | |
| return 0 | |
| added = await qdrant_upsert(client, coll, batch_points) | |
| job.upserted += added | |
| job.log(f"+{added} points upsert (total={job.upserted})") | |
| batch_points = [] | |
| return added | |
| EMB_BATCH = max(8, min(64, req.batch_size * 2)) | |
| i = 0 | |
| while i < len(records): | |
| sub = records[i : i + EMB_BATCH] | |
| texts = [r["raw"] for r in sub] | |
| vecs = await embed_texts(client, texts) | |
| if len(vecs) != len(sub): | |
| raise HTTPException(status_code=500, detail="Embedding batch size mismatch") | |
| job.embedded += len(vecs) | |
| for r, v in zip(sub, vecs): | |
| point = {"id": str(uuid.uuid4()), "vector": v, "payload": r["payload"]} | |
| batch_points.append(point) | |
| if len(batch_points) >= req.batch_size: | |
| await flush_batch() | |
| i += EMB_BATCH | |
| await flush_batch() | |
| job.stage = "done" | |
| job.finished_at = time.time() | |
| job.log("Index job terminé.") | |
| # ------------------------------------------------------------------------------ | |
| # FastAPI app + endpoints | |
| # ------------------------------------------------------------------------------ | |
| fastapi_app = FastAPI(title="Remote Indexer - Minimal Test Space") | |
| fastapi_app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def health(): | |
| return {"status": "ok"} | |
| async def root(): | |
| return {"ok": True, "service": "remote-indexer-min", "qdrant": bool(QDRANT_URL), "emb_provider": EMB_PROVIDER} | |
| async def wipe(project_id: str = Query(..., min_length=1)): | |
| if not QDRANT_URL or not QDRANT_API_KEY: | |
| raise HTTPException(status_code=400, detail="QDRANT_URL / QDRANT_API_KEY requis") | |
| coll = f"{COLLECTION_PREFIX}{project_id}" | |
| async with httpx.AsyncClient() as client: | |
| r = await client.delete(f"{QDRANT_URL}/collections/{coll}", headers={"api-key": QDRANT_API_KEY}, timeout=30) | |
| if r.status_code not in (200, 202, 404): | |
| raise HTTPException(status_code=500, detail=f"Echec wipe: {r.text}") | |
| return {"ok": True, "collection": coll, "wiped": True} | |
| async def index(req: IndexRequest): | |
| if not QDRANT_URL or not QDRANT_API_KEY: | |
| raise HTTPException(status_code=400, detail="QDRANT_URL / QDRANT_API_KEY requis") | |
| job_id = uuid.uuid4().hex[:12] | |
| job = JobState(job_id=job_id, project_id=req.project_id) | |
| JOBS[job_id] = job | |
| asyncio.create_task(run_index_job(job, req)) | |
| job.log(f"Job {job_id} créé pour project {req.project_id}") | |
| return {"job_id": job_id, "project_id": req.project_id} | |
| async def status(job_id: str): | |
| job = JOBS.get(job_id) | |
| if not job: | |
| raise HTTPException(status_code=404, detail="job_id inconnu") | |
| return job.model_dump() | |
| async def coll_count(project_id: str): | |
| if not QDRANT_URL or not QDRANT_API_KEY: | |
| raise HTTPException(status_code=400, detail="QDRANT_URL / QDRANT_API_KEY requis") | |
| coll = f"{COLLECTION_PREFIX}{project_id}" | |
| async with httpx.AsyncClient() as client: | |
| cnt = await qdrant_count(client, coll) | |
| return {"project_id": project_id, "collection": coll, "count": cnt} | |
| async def query(req: QueryRequest): | |
| if not QDRANT_URL or not QDRANT_API_KEY: | |
| raise HTTPException(status_code=400, detail="QDRANT_URL / QDRANT_API_KEY requis") | |
| coll = f"{COLLECTION_PREFIX}{req.project_id}" | |
| async with httpx.AsyncClient() as client: | |
| vec = (await embed_texts(client, [req.text]))[0] | |
| data = await qdrant_search(client, coll, vec, limit=req.top_k) | |
| return data | |
| # ------------------------------------------------------------------------------ | |
| # Gradio UI | |
| # ------------------------------------------------------------------------------ | |
| def _default_two_docs() -> List[Dict[str, str]]: | |
| a = "Alpha bravo charlie delta echo foxtrot golf hotel india. " * 3 | |
| b = "Lorem ipsum dolor sit amet, consectetuer adipiscing elit, sed diam nonummy. " * 3 | |
| return [{"path": "a.txt", "text": a}, {"path": "b.txt", "text": b}] | |
| async def ui_wipe(project: str): | |
| try: | |
| resp = await wipe(project) # appelle la route interne | |
| return f"✅ Wipe ok — collection {resp['collection']} supprimée." | |
| except Exception as e: | |
| LOG.exception("wipe UI error") | |
| return f"❌ Wipe erreur: {e}" | |
| async def ui_index_sample(project: str, chunk_size: int, overlap: int, batch_size: int, store_text: bool): | |
| files = _default_two_docs() | |
| req = IndexRequest( | |
| project_id=project, | |
| files=[FileItem(**f) for f in files], | |
| chunk_size=chunk_size, | |
| overlap=overlap, | |
| batch_size=batch_size, | |
| store_text=store_text, | |
| ) | |
| try: | |
| data = await index(req) | |
| job_id = data["job_id"] | |
| return f"🚀 Job lancé: {job_id}" | |
| except ValidationError as ve: | |
| return f"❌ Payload invalide: {ve}" | |
| except Exception as e: | |
| LOG.exception("index UI error") | |
| return f"❌ Index erreur: {e}" | |
| async def ui_status(job_id: str): | |
| if not job_id.strip(): | |
| return "⚠️ Renseigne un job_id" | |
| try: | |
| st = await status(job_id) | |
| lines = [f"Job {st['job_id']} — stage={st['stage']} files={st['total_files']} chunks={st['total_chunks']} embedded={st['embedded']} upserted={st['upserted']}"] | |
| lines += st.get("messages", [])[-50:] | |
| if st.get("errors"): | |
| lines.append("Erreurs:") | |
| lines += [f" - {e}" for e in st["errors"]] | |
| return "\n".join(lines) | |
| except Exception as e: | |
| return f"❌ Status erreur: {e}" | |
| async def ui_count(project: str): | |
| try: | |
| resp = await coll_count(project) | |
| return f"📊 Count — collection={resp['collection']} → {resp['count']} points" | |
| except Exception as e: | |
| LOG.exception("count UI error") | |
| return f"❌ Count erreur: {e}" | |
| async def ui_query(project: str, text: str, topk: int): | |
| try: | |
| data = await query(QueryRequest(project_id=project, text=text, top_k=topk)) | |
| hits = data.get("result", []) | |
| if not hits: | |
| return "Aucun résultat." | |
| out = [] | |
| for h in hits: | |
| score = h.get("score") | |
| payload = h.get("payload", {}) | |
| path = payload.get("path") | |
| chunk = payload.get("chunk") | |
| preview = (payload.get("text") or "")[:120].replace("\n", " ") | |
| out.append(f"{score:.4f} — {path} [chunk {chunk}] — {preview}…") | |
| return "\n".join(out) | |
| except Exception as e: | |
| LOG.exception("query UI error") | |
| return f"❌ Query erreur: {e}" | |
| with gr.Blocks(title="Remote Indexer - Minimal Test", analytics_enabled=False) as ui: | |
| gr.Markdown("## 🔬 Remote Indexer — Tests sans console\n" | |
| "Wipe → Index 2 docs → Status → Count → Query\n" | |
| f"- **Embeddings**: `{EMB_PROVIDER}` (model: `{HF_EMBED_MODEL}`)\n" | |
| f"- **Qdrant**: `{'OK' if QDRANT_URL else 'ABSENT'}`\n" | |
| "Astuce: si pas de token HF, mets `EMB_PROVIDER=dummy`.") | |
| with gr.Row(): | |
| project_tb = gr.Textbox(label="Project ID", value="DEEPWEB") | |
| jobid_tb = gr.Textbox(label="Job ID (pour Status)", value="", interactive=True) | |
| with gr.Row(): | |
| wipe_btn = gr.Button("🧨 Wipe collection", variant="stop") | |
| index_btn = gr.Button("🚀 Indexer 2 documents", variant="primary") | |
| count_btn = gr.Button("📊 Count points", variant="secondary") | |
| with gr.Row(): | |
| chunk_size = gr.Slider(64, 1024, value=200, step=8, label="chunk_size") | |
| overlap = gr.Slider(0, 256, value=20, step=2, label="overlap") | |
| batch_size = gr.Slider(1, 128, value=32, step=1, label="batch_size") | |
| store_text = gr.Checkbox(value=True, label="store_text (payload)") | |
| out_log = gr.Textbox(lines=18, label="Logs / Résultats", interactive=False) | |
| with gr.Row(): | |
| query_tb = gr.Textbox(label="Query text", value="alpha bravo") | |
| topk = gr.Slider(1, 20, value=5, step=1, label="top_k") | |
| query_btn = gr.Button("🔎 Query") | |
| query_out = gr.Textbox(lines=10, label="Résultats Query", interactive=False) | |
| wipe_btn.click(ui_wipe, inputs=[project_tb], outputs=[out_log]) | |
| index_btn.click(ui_index_sample, inputs=[project_tb, chunk_size, overlap, batch_size, store_text], outputs=[out_log]) | |
| count_btn.click(ui_count, inputs=[project_tb], outputs=[out_log]) | |
| query_btn.click(ui_query, inputs=[project_tb, query_tb, topk], outputs=[query_out]) | |
| # Monte l'UI Gradio sur la FastAPI | |
| app = gr.mount_gradio_app(fastapi_app, ui, path="/") | |
| if __name__ == "__main__": | |
| # Démarre Uvicorn pour les Spaces Docker (CMD: python -u /app/main.py) | |
| port = int(os.getenv("PORT", "7860")) | |
| LOG.info(f"Démarrage Uvicorn sur 0.0.0.0:{port}") | |
| uvicorn.run(app, host="0.0.0.0", port=port) | |