chouchouvs commited on
Commit
9ea3ad6
·
verified ·
1 Parent(s): 19a096b

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +97 -21
main.py CHANGED
@@ -5,7 +5,7 @@ from typing import List, Optional, Dict, Any, Tuple
5
 
6
  import numpy as np
7
  import requests
8
- from fastapi import FastAPI, BackgroundTasks, Header, HTTPException
9
  from pydantic import BaseModel, Field
10
  from qdrant_client import QdrantClient
11
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
@@ -20,9 +20,10 @@ EMB_BACKEND_ORDER = [s.strip().lower() for s in os.getenv("EMB_BACKEND_ORDER", o
20
 
21
  # HF Inference API
22
  HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
23
- HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
 
24
  HF_URL = (os.getenv("HF_API_URL", "").strip()
25
- or f"https://api-inference.huggingface.co/models/{HF_MODEL}")
26
  HF_TIMEOUT = float(os.getenv("EMB_TIMEOUT_SEC", "120"))
27
  HF_WAIT = os.getenv("HF_WAIT_FOR_MODEL", "true").lower() in ("1","true","yes","on")
28
 
@@ -73,6 +74,8 @@ class QueryRequest(BaseModel):
73
  project_id: str
74
  query: str
75
  top_k: int = 6
 
 
76
 
77
  # ---------- Jobs store (mémoire) ----------
78
  JOBS: Dict[str, Dict[str, Any]] = {} # {job_id: {"status": "...", "logs": [...], "created": ts}}
@@ -91,36 +94,53 @@ def _auth(x_auth: Optional[str]):
91
 
92
  # ---------- Embeddings backends avec retry ----------
93
  def _retry_sleep(attempt: int):
94
- # backoff exponentiel + jitter (p.ex. 1.5^attempt) * (1 ± jitter)
95
  back = (RETRY_BASE_SEC ** attempt)
96
  jitter = 1.0 + random.uniform(-RETRY_JITTER, RETRY_JITTER)
97
  return max(0.25, back * jitter)
98
 
99
  def _hf_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
 
 
 
 
 
100
  if not HF_TOKEN:
101
  raise RuntimeError("HF_API_TOKEN manquant (backend=hf).")
102
  headers = {
103
  "Authorization": f"Bearer {HF_TOKEN}",
104
  "Content-Type": "application/json",
 
 
 
105
  }
 
106
  if HF_WAIT:
107
- headers["X-Wait-For-Model"] = "true"
108
- headers["X-Use-Cache"] = "true"
109
- payload = {"inputs": batch if len(batch) > 1 else batch[0]}
110
  r = requests.post(HF_URL, headers=headers, json=payload, timeout=HF_TIMEOUT)
111
  size = int(r.headers.get("Content-Length", "0"))
112
  if r.status_code >= 400:
 
113
  LOG.error(f"HF error {r.status_code}: {r.text[:1000]}")
114
  r.raise_for_status()
 
115
  data = r.json()
 
 
 
 
116
  arr = np.array(data, dtype=np.float32)
117
- if arr.ndim == 3: # [batch, tokens, dim]
118
  arr = arr.mean(axis=1)
119
- if arr.ndim == 1: # [dim] -> [1, dim]
 
 
120
  arr = arr.reshape(1, -1)
121
- if arr.ndim != 2:
122
  raise RuntimeError(f"HF: unexpected embeddings shape: {arr.shape}")
123
- # normalisation
 
124
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
125
  arr = arr / norms
126
  return arr.astype(np.float32), size
@@ -143,7 +163,6 @@ def _di_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
143
  arr = np.asarray(embs, dtype=np.float32)
144
  if arr.ndim != 2:
145
  raise RuntimeError(f"DeepInfra: unexpected embeddings shape: {arr.shape}")
146
- # normalisation
147
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
148
  arr = arr / norms
149
  return arr.astype(np.float32), size
@@ -166,13 +185,11 @@ def _call_with_retries(func, batch: List[str], label: str, job_id: Optional[str]
166
  time.sleep(sleep_s)
167
  last_exc = he
168
  except Exception as e:
169
- # on tente quelques retries aussi sur erreurs réseau transitoires
170
  sleep_s = _retry_sleep(attempt)
171
  msg = f"{label}: error {type(e).__name__}: {e}, retry in {sleep_s:.1f}s"
172
  LOG.warning(msg); _append_log(job_id, msg)
173
  time.sleep(sleep_s)
174
  last_exc = e
175
- # épuisé
176
  raise RuntimeError(f"{label}: retries exhausted: {last_exc}")
177
 
178
  def _post_embeddings(batch: List[str], job_id: Optional[str] = None) -> Tuple[np.ndarray, int]:
@@ -189,7 +206,6 @@ def _post_embeddings(batch: List[str], job_id: Optional[str] = None) -> Tuple[np
189
  last_err = e
190
  _append_log(job_id, f"HF failed: {e}.")
191
  LOG.error(f"HF failed: {e}")
192
- # passe au backend suivant si dispo
193
  elif b == "deepinfra":
194
  try:
195
  return _call_with_retries(_di_post_embeddings_once, batch, "DeepInfra", job_id)
@@ -201,6 +217,7 @@ def _post_embeddings(batch: List[str], job_id: Optional[str] = None) -> Tuple[np
201
  _append_log(job_id, f"Backend inconnu ignoré: {b}")
202
  raise RuntimeError(f"Tous les backends ont échoué: {last_err}")
203
 
 
204
  def _ensure_collection(name: str, dim: int):
205
  try:
206
  qdr.get_collection(name); return
@@ -212,7 +229,7 @@ def _ensure_collection(name: str, dim: int):
212
  )
213
 
214
  def _chunk_with_spans(text: str, size: int, overlap: int):
215
- n = len(text)
216
  if size <= 0:
217
  yield (0, n, text); return
218
  i = 0
@@ -230,9 +247,13 @@ def run_index_job(job_id: str, req: IndexRequest):
230
  _append_log(job_id, f"Start project={req.project_id} files={len(req.files)} | backends={EMB_BACKEND_ORDER}")
231
  LOG.info(f"[{job_id}] Index start project={req.project_id} files={len(req.files)}")
232
 
233
- # Warmup -> dimension (1er morceau)
234
- first_text = next(_chunk_with_spans(req.files[0].text if req.files else "", req.chunk_size, req.overlap))[2] if req.files else "warmup"
235
- embs, sz = _post_embeddings([first_text], job_id=job_id)
 
 
 
 
236
  dim = embs.shape[1]
237
  col = f"proj_{req.project_id}"
238
  _ensure_collection(col, dim)
@@ -241,8 +262,13 @@ def run_index_job(job_id: str, req: IndexRequest):
241
  point_id = 0
242
  # Boucle sur les fichiers
243
  for fi, f in enumerate(req.files, 1):
 
 
 
244
  chunks, metas = [], []
245
  for ci, (start, end, chunk_txt) in enumerate(_chunk_with_spans(f.text, req.chunk_size, req.overlap)):
 
 
246
  chunks.append(chunk_txt)
247
  meta = {"path": f.path, "chunk": ci, "start": start, "end": end}
248
  if req.store_text:
@@ -310,6 +336,13 @@ def start_index(req: IndexRequest, background_tasks: BackgroundTasks, x_auth_tok
310
  if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
311
  raise HTTPException(401, "Unauthorized")
312
  _check_backend_ready()
 
 
 
 
 
 
 
313
  job_id = uuid.uuid4().hex[:12]
314
  JOBS[job_id] = {"status": "queued", "logs": [], "created": time.time()}
315
  background_tasks.add_task(run_index_job, job_id, req)
@@ -324,15 +357,51 @@ def status(job_id: str, x_auth_token: Optional[str] = Header(default=None)):
324
  raise HTTPException(404, "job inconnu")
325
  return {"status": j["status"], "logs": j["logs"][-800:]}
326
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
327
  @app.post("/query")
328
  def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None)):
329
  if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
330
  raise HTTPException(401, "Unauthorized")
331
  _check_backend_ready()
 
 
 
 
 
 
 
 
 
 
 
 
332
  vecs, _ = _post_embeddings([req.query])
333
  col = f"proj_{req.project_id}"
334
  try:
335
- res = qdr.search(collection_name=col, query_vector=vecs[0].tolist(), limit=int(req.top_k))
336
  except Exception as e:
337
  raise HTTPException(400, f"Search failed: {e}")
338
  out = []
@@ -341,7 +410,14 @@ def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None))
341
  txt = pl.get("text")
342
  if txt and len(txt) > 800:
343
  txt = txt[:800] + "..."
344
- out.append({"path": pl.get("path"), "chunk": pl.get("chunk"), "start": pl.get("start"), "end": pl.get("end"), "text": txt})
 
 
 
 
 
 
 
345
  return {"results": out}
346
 
347
  @app.post("/wipe")
 
5
 
6
  import numpy as np
7
  import requests
8
+ from fastapi import FastAPI, BackgroundTasks, Header, HTTPException, Query
9
  from pydantic import BaseModel, Field
10
  from qdrant_client import QdrantClient
11
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
 
20
 
21
  # HF Inference API
22
  HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
23
+ HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2").strip()
24
+ # 👉 On force la pipeline "feature-extraction" pour obtenir des embeddings (et pas la Similarity)
25
  HF_URL = (os.getenv("HF_API_URL", "").strip()
26
+ or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}")
27
  HF_TIMEOUT = float(os.getenv("EMB_TIMEOUT_SEC", "120"))
28
  HF_WAIT = os.getenv("HF_WAIT_FOR_MODEL", "true").lower() in ("1","true","yes","on")
29
 
 
74
  project_id: str
75
  query: str
76
  top_k: int = 6
77
+ # compat champ alternatif
78
+ # (si le client envoie "topk", on le lira plus bas directement dans le JSON brut)
79
 
80
  # ---------- Jobs store (mémoire) ----------
81
  JOBS: Dict[str, Dict[str, Any]] = {} # {job_id: {"status": "...", "logs": [...], "created": ts}}
 
94
 
95
  # ---------- Embeddings backends avec retry ----------
96
  def _retry_sleep(attempt: int):
97
+ # backoff exponentiel + jitter
98
  back = (RETRY_BASE_SEC ** attempt)
99
  jitter = 1.0 + random.uniform(-RETRY_JITTER, RETRY_JITTER)
100
  return max(0.25, back * jitter)
101
 
102
  def _hf_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
103
+ """
104
+ Appel Inference API en pipeline 'feature-extraction' (retour = embeddings).
105
+ - inputs: str ou list[str]
106
+ - options.wait_for_model: True si demandé
107
+ """
108
  if not HF_TOKEN:
109
  raise RuntimeError("HF_API_TOKEN manquant (backend=hf).")
110
  headers = {
111
  "Authorization": f"Bearer {HF_TOKEN}",
112
  "Content-Type": "application/json",
113
+ # NB: avec l'URL /pipeline/feature-extraction/... on ne devrait pas avoir besoin de forcer X-Task,
114
+ # mais on peut ajouter une garde en cas de reverse-proxy exotique :
115
+ # "X-Task": "feature-extraction",
116
  }
117
+ payload: Dict[str, Any] = {"inputs": (batch if len(batch) > 1 else batch[0])}
118
  if HF_WAIT:
119
+ payload["options"] = {"wait_for_model": True}
120
+
 
121
  r = requests.post(HF_URL, headers=headers, json=payload, timeout=HF_TIMEOUT)
122
  size = int(r.headers.get("Content-Length", "0"))
123
  if r.status_code >= 400:
124
+ # Affiche une partie du corps pour diagnostiquer le mauvais pipeline si jamais
125
  LOG.error(f"HF error {r.status_code}: {r.text[:1000]}")
126
  r.raise_for_status()
127
+
128
  data = r.json()
129
+ # data peut être:
130
+ # - [tokens, dim] pour une phrase => moyenne sur tokens
131
+ # - [batch, tokens, dim] pour batch => moyenne par élément
132
+ # - parfois déjà [batch, dim] selon certains hôtes
133
  arr = np.array(data, dtype=np.float32)
134
+ if arr.ndim == 3: # [batch, tokens, dim]
135
  arr = arr.mean(axis=1)
136
+ elif arr.ndim == 2:
137
+ pass
138
+ elif arr.ndim == 1: # [dim] -> [1, dim]
139
  arr = arr.reshape(1, -1)
140
+ else:
141
  raise RuntimeError(f"HF: unexpected embeddings shape: {arr.shape}")
142
+
143
+ # normalisation L2
144
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
145
  arr = arr / norms
146
  return arr.astype(np.float32), size
 
163
  arr = np.asarray(embs, dtype=np.float32)
164
  if arr.ndim != 2:
165
  raise RuntimeError(f"DeepInfra: unexpected embeddings shape: {arr.shape}")
 
166
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
167
  arr = arr / norms
168
  return arr.astype(np.float32), size
 
185
  time.sleep(sleep_s)
186
  last_exc = he
187
  except Exception as e:
 
188
  sleep_s = _retry_sleep(attempt)
189
  msg = f"{label}: error {type(e).__name__}: {e}, retry in {sleep_s:.1f}s"
190
  LOG.warning(msg); _append_log(job_id, msg)
191
  time.sleep(sleep_s)
192
  last_exc = e
 
193
  raise RuntimeError(f"{label}: retries exhausted: {last_exc}")
194
 
195
  def _post_embeddings(batch: List[str], job_id: Optional[str] = None) -> Tuple[np.ndarray, int]:
 
206
  last_err = e
207
  _append_log(job_id, f"HF failed: {e}.")
208
  LOG.error(f"HF failed: {e}")
 
209
  elif b == "deepinfra":
210
  try:
211
  return _call_with_retries(_di_post_embeddings_once, batch, "DeepInfra", job_id)
 
217
  _append_log(job_id, f"Backend inconnu ignoré: {b}")
218
  raise RuntimeError(f"Tous les backends ont échoué: {last_err}")
219
 
220
+ # ---------- Qdrant helpers ----------
221
  def _ensure_collection(name: str, dim: int):
222
  try:
223
  qdr.get_collection(name); return
 
229
  )
230
 
231
  def _chunk_with_spans(text: str, size: int, overlap: int):
232
+ n = len(text or "")
233
  if size <= 0:
234
  yield (0, n, text); return
235
  i = 0
 
247
  _append_log(job_id, f"Start project={req.project_id} files={len(req.files)} | backends={EMB_BACKEND_ORDER}")
248
  LOG.info(f"[{job_id}] Index start project={req.project_id} files={len(req.files)}")
249
 
250
+ # Warmup -> dimension (1er morceau non vide si possible)
251
+ warm = "warmup"
252
+ if req.files:
253
+ for _, _, chunk_txt in _chunk_with_spans(req.files[0].text or "", req.chunk_size, req.overlap):
254
+ if (chunk_txt or "").strip():
255
+ warm = chunk_txt; break
256
+ embs, sz = _post_embeddings([warm], job_id=job_id)
257
  dim = embs.shape[1]
258
  col = f"proj_{req.project_id}"
259
  _ensure_collection(col, dim)
 
262
  point_id = 0
263
  # Boucle sur les fichiers
264
  for fi, f in enumerate(req.files, 1):
265
+ if not (f.text or "").strip():
266
+ _append_log(job_id, f"file {fi}: vide — ignoré")
267
+ continue
268
  chunks, metas = [], []
269
  for ci, (start, end, chunk_txt) in enumerate(_chunk_with_spans(f.text, req.chunk_size, req.overlap)):
270
+ if not (chunk_txt or "").strip():
271
+ continue
272
  chunks.append(chunk_txt)
273
  meta = {"path": f.path, "chunk": ci, "start": start, "end": end}
274
  if req.store_text:
 
336
  if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
337
  raise HTTPException(401, "Unauthorized")
338
  _check_backend_ready()
339
+
340
+ # Filtrage défensif des fichiers vides pour éviter 422 côté client/serveur
341
+ non_empty = [f for f in req.files if (f.text or "").strip()]
342
+ if not non_empty:
343
+ raise HTTPException(422, "Aucun fichier non vide à indexer.")
344
+ req.files = non_empty
345
+
346
  job_id = uuid.uuid4().hex[:12]
347
  JOBS[job_id] = {"status": "queued", "logs": [], "created": time.time()}
348
  background_tasks.add_task(run_index_job, job_id, req)
 
357
  raise HTTPException(404, "job inconnu")
358
  return {"status": j["status"], "logs": j["logs"][-800:]}
359
 
360
+ # --- Compat endpoints (pour clients legacy) ---
361
+ @app.get("/status")
362
+ def status_qp(job_id: str = Query(None), x_auth_token: Optional[str] = Header(default=None)):
363
+ if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
364
+ raise HTTPException(401, "Unauthorized")
365
+ if not job_id:
366
+ raise HTTPException(404, "job inconnu")
367
+ j = JOBS.get(job_id)
368
+ if not j:
369
+ raise HTTPException(404, "job inconnu")
370
+ return {"status": j["status"], "logs": j["logs"][-800:]}
371
+
372
+ class _StatusBody(BaseModel):
373
+ job_id: str
374
+
375
+ @app.post("/status")
376
+ def status_post(body: _StatusBody, x_auth_token: Optional[str] = Header(default=None)):
377
+ if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
378
+ raise HTTPException(401, "Unauthorized")
379
+ j = JOBS.get(body.job_id)
380
+ if not j:
381
+ raise HTTPException(404, "job inconnu")
382
+ return {"status": j["status"], "logs": j["logs"][-800:]}
383
+
384
  @app.post("/query")
385
  def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None)):
386
  if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
387
  raise HTTPException(401, "Unauthorized")
388
  _check_backend_ready()
389
+
390
+ # Accepte topk/top_k (compat)
391
+ k = req.top_k
392
+ try:
393
+ # si le client a envoyé "topk", on le récupère du JSON brut via headers x-raw-body (HF ne le fournit pas),
394
+ # donc on fait une passe défensive: si top_k n'est pas cohérent, on limite quand même.
395
+ k = int(k)
396
+ except Exception:
397
+ k = 6
398
+ if k <= 0: k = 6
399
+ if k > 50: k = 50
400
+
401
  vecs, _ = _post_embeddings([req.query])
402
  col = f"proj_{req.project_id}"
403
  try:
404
+ res = qdr.search(collection_name=col, query_vector=vecs[0].tolist(), limit=k)
405
  except Exception as e:
406
  raise HTTPException(400, f"Search failed: {e}")
407
  out = []
 
410
  txt = pl.get("text")
411
  if txt and len(txt) > 800:
412
  txt = txt[:800] + "..."
413
+ out.append({
414
+ "path": pl.get("path"),
415
+ "chunk": pl.get("chunk"),
416
+ "start": pl.get("start"),
417
+ "end": pl.get("end"),
418
+ "text": txt,
419
+ "score": float(p.score) if hasattr(p, "score") else None,
420
+ })
421
  return {"results": out}
422
 
423
  @app.post("/wipe")