chouchouvs commited on
Commit
e0f6e27
·
verified ·
1 Parent(s): 9ea3ad6

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +67 -38
main.py CHANGED
@@ -21,11 +21,27 @@ EMB_BACKEND_ORDER = [s.strip().lower() for s in os.getenv("EMB_BACKEND_ORDER", o
21
  # HF Inference API
22
  HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
23
  HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2").strip()
24
- # 👉 On force la pipeline "feature-extraction" pour obtenir des embeddings (et pas la Similarity)
25
- HF_URL = (os.getenv("HF_API_URL", "").strip()
26
- or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
27
  HF_TIMEOUT = float(os.getenv("EMB_TIMEOUT_SEC", "120"))
28
  HF_WAIT = os.getenv("HF_WAIT_FOR_MODEL", "true").lower() in ("1","true","yes","on")
 
29
 
30
  # DeepInfra Embeddings (OpenAI-like)
31
  DI_TOKEN = os.getenv("DEEPINFRA_API_KEY", "").strip()
@@ -46,6 +62,8 @@ QDRANT_API = os.getenv("QDRANT_API_KEY", "").strip()
46
  AUTH_TOKEN = os.getenv("REMOTE_INDEX_TOKEN", "").strip()
47
 
48
  LOG.info(f"Embeddings backend order = {EMB_BACKEND_ORDER}")
 
 
49
  if "hf" in EMB_BACKEND_ORDER and not HF_TOKEN:
50
  LOG.warning("HF_API_TOKEN manquant — tentatives HF échoueront.")
51
  if "deepinfra" in EMB_BACKEND_ORDER and not DI_TOKEN:
@@ -74,8 +92,6 @@ class QueryRequest(BaseModel):
74
  project_id: str
75
  query: str
76
  top_k: int = 6
77
- # compat champ alternatif
78
- # (si le client envoie "topk", on le lira plus bas directement dans le JSON brut)
79
 
80
  # ---------- Jobs store (mémoire) ----------
81
  JOBS: Dict[str, Dict[str, Any]] = {} # {job_id: {"status": "...", "logs": [...], "created": ts}}
@@ -99,38 +115,30 @@ def _retry_sleep(attempt: int):
99
  jitter = 1.0 + random.uniform(-RETRY_JITTER, RETRY_JITTER)
100
  return max(0.25, back * jitter)
101
 
102
- def _hf_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
103
- """
104
- Appel Inference API en pipeline 'feature-extraction' (retour = embeddings).
105
- - inputs: str ou list[str]
106
- - options.wait_for_model: True si demandé
107
- """
108
  if not HF_TOKEN:
109
  raise RuntimeError("HF_API_TOKEN manquant (backend=hf).")
 
110
  headers = {
111
  "Authorization": f"Bearer {HF_TOKEN}",
112
  "Content-Type": "application/json",
113
- # NB: avec l'URL /pipeline/feature-extraction/... on ne devrait pas avoir besoin de forcer X-Task,
114
- # mais on peut ajouter une garde en cas de reverse-proxy exotique :
115
- # "X-Task": "feature-extraction",
116
  }
117
- payload: Dict[str, Any] = {"inputs": (batch if len(batch) > 1 else batch[0])}
118
- if HF_WAIT:
119
- payload["options"] = {"wait_for_model": True}
120
 
121
- r = requests.post(HF_URL, headers=headers, json=payload, timeout=HF_TIMEOUT)
122
  size = int(r.headers.get("Content-Length", "0"))
123
  if r.status_code >= 400:
124
- # Affiche une partie du corps pour diagnostiquer le mauvais pipeline si jamais
125
  LOG.error(f"HF error {r.status_code}: {r.text[:1000]}")
126
  r.raise_for_status()
127
 
128
  data = r.json()
129
- # data peut être:
130
- # - [tokens, dim] pour une phrase => moyenne sur tokens
131
- # - [batch, tokens, dim] pour batch => moyenne par élément
132
- # - parfois déjà [batch, dim] selon certains hôtes
133
  arr = np.array(data, dtype=np.float32)
 
134
  if arr.ndim == 3: # [batch, tokens, dim]
135
  arr = arr.mean(axis=1)
136
  elif arr.ndim == 2:
@@ -145,10 +153,38 @@ def _hf_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
145
  arr = arr / norms
146
  return arr.astype(np.float32), size
147
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
148
  def _di_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
149
  if not DI_TOKEN:
150
  raise RuntimeError("DEEPINFRA_API_KEY manquant (backend=deepinfra).")
151
- headers = {"Authorization": f"Bearer {DI_TOKEN}", "Content-Type": "application/json"}
152
  payload = {"model": DI_MODEL, "input": batch}
153
  r = requests.post(DI_URL, headers=headers, json=payload, timeout=DI_TIMEOUT)
154
  size = int(r.headers.get("Content-Length", "0"))
@@ -316,7 +352,8 @@ def root():
316
  "ok": True,
317
  "service": "remote-indexer",
318
  "backends": EMB_BACKEND_ORDER,
319
- "hf_url": HF_URL if "hf" in EMB_BACKEND_ORDER else None,
 
320
  "di_model": DI_MODEL if "deepinfra" in EMB_BACKEND_ORDER else None,
321
  "docs": "/health, /index, /status/{job_id}, /query, /wipe"
322
  }
@@ -337,7 +374,7 @@ def start_index(req: IndexRequest, background_tasks: BackgroundTasks, x_auth_tok
337
  raise HTTPException(401, "Unauthorized")
338
  _check_backend_ready()
339
 
340
- # Filtrage défensif des fichiers vides pour éviter 422 côté client/serveur
341
  non_empty = [f for f in req.files if (f.text or "").strip()]
342
  if not non_empty:
343
  raise HTTPException(422, "Aucun fichier non vide à indexer.")
@@ -357,7 +394,7 @@ def status(job_id: str, x_auth_token: Optional[str] = Header(default=None)):
357
  raise HTTPException(404, "job inconnu")
358
  return {"status": j["status"], "logs": j["logs"][-800:]}
359
 
360
- # --- Compat endpoints (pour clients legacy) ---
361
  @app.get("/status")
362
  def status_qp(job_id: str = Query(None), x_auth_token: Optional[str] = Header(default=None)):
363
  if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
@@ -387,16 +424,8 @@ def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None))
387
  raise HTTPException(401, "Unauthorized")
388
  _check_backend_ready()
389
 
390
- # Accepte topk/top_k (compat)
391
- k = req.top_k
392
- try:
393
- # si le client a envoyé "topk", on le récupère du JSON brut via headers x-raw-body (HF ne le fournit pas),
394
- # donc on fait une passe défensive: si top_k n'est pas cohérent, on limite quand même.
395
- k = int(k)
396
- except Exception:
397
- k = 6
398
- if k <= 0: k = 6
399
- if k > 50: k = 50
400
 
401
  vecs, _ = _post_embeddings([req.query])
402
  col = f"proj_{req.project_id}"
@@ -426,7 +455,7 @@ def wipe_collection(project_id: str, x_auth_token: Optional[str] = Header(defaul
426
  raise HTTPException(401, "Unauthorized")
427
  col = f"proj_{project_id}"
428
  try:
429
- qdr.delete_collection(col); return {"ok": True}
430
  except Exception as e:
431
  raise HTTPException(400, f"wipe failed: {e}")
432
 
 
21
  # HF Inference API
22
  HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
23
  HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2").strip()
24
+
25
+ # On supporte 3 variables pour être souple:
26
+ # - HF_API_URL_PIPELINE : force l'URL pipeline (feature-extraction)
27
+ # - HF_API_URL_MODELS : force l'URL models
28
+ # - HF_API_URL : compat; si contient "/pipeline", on l'utilise coté pipeline sinon coté models
29
+ HF_API_URL_USER = os.getenv("HF_API_URL", "").strip()
30
+ HF_API_URL_PIPELINE = os.getenv("HF_API_URL_PIPELINE", "").strip()
31
+ HF_API_URL_MODELS = os.getenv("HF_API_URL_MODELS", "").strip()
32
+
33
+ if HF_API_URL_USER:
34
+ if "/pipeline" in HF_API_URL_USER:
35
+ HF_API_URL_PIPELINE = HF_API_URL_USER
36
+ else:
37
+ HF_API_URL_MODELS = HF_API_URL_USER
38
+
39
+ HF_URL_PIPELINE = (HF_API_URL_PIPELINE or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}")
40
+ HF_URL_MODELS = (HF_API_URL_MODELS or f"https://api-inference.huggingface.co/models/{HF_MODEL}")
41
+
42
  HF_TIMEOUT = float(os.getenv("EMB_TIMEOUT_SEC", "120"))
43
  HF_WAIT = os.getenv("HF_WAIT_FOR_MODEL", "true").lower() in ("1","true","yes","on")
44
+ HF_PIPELINE_FIRST = os.getenv("HF_PIPELINE_FIRST", "true").lower() in ("1","true","yes","on")
45
 
46
  # DeepInfra Embeddings (OpenAI-like)
47
  DI_TOKEN = os.getenv("DEEPINFRA_API_KEY", "").strip()
 
62
  AUTH_TOKEN = os.getenv("REMOTE_INDEX_TOKEN", "").strip()
63
 
64
  LOG.info(f"Embeddings backend order = {EMB_BACKEND_ORDER}")
65
+ LOG.info(f"HF pipeline URL = {HF_URL_PIPELINE}")
66
+ LOG.info(f"HF models URL = {HF_URL_MODELS}")
67
  if "hf" in EMB_BACKEND_ORDER and not HF_TOKEN:
68
  LOG.warning("HF_API_TOKEN manquant — tentatives HF échoueront.")
69
  if "deepinfra" in EMB_BACKEND_ORDER and not DI_TOKEN:
 
92
  project_id: str
93
  query: str
94
  top_k: int = 6
 
 
95
 
96
  # ---------- Jobs store (mémoire) ----------
97
  JOBS: Dict[str, Dict[str, Any]] = {} # {job_id: {"status": "...", "logs": [...], "created": ts}}
 
115
  jitter = 1.0 + random.uniform(-RETRY_JITTER, RETRY_JITTER)
116
  return max(0.25, back * jitter)
117
 
118
+ def _hf_http(
119
+ url: str, payload: Dict[str, Any], headers_extra: Optional[Dict[str, str]] = None
120
+ ) -> Tuple[np.ndarray, int]:
 
 
 
121
  if not HF_TOKEN:
122
  raise RuntimeError("HF_API_TOKEN manquant (backend=hf).")
123
+
124
  headers = {
125
  "Authorization": f"Bearer {HF_TOKEN}",
126
  "Content-Type": "application/json",
127
+ "Accept": "application/json",
 
 
128
  }
129
+ if headers_extra:
130
+ headers.update(headers_extra)
 
131
 
132
+ r = requests.post(url, headers=headers, json=payload, timeout=HF_TIMEOUT)
133
  size = int(r.headers.get("Content-Length", "0"))
134
  if r.status_code >= 400:
135
+ # Affiche une partie du corps pour diagnostiquer
136
  LOG.error(f"HF error {r.status_code}: {r.text[:1000]}")
137
  r.raise_for_status()
138
 
139
  data = r.json()
 
 
 
 
140
  arr = np.array(data, dtype=np.float32)
141
+ # data peut être: [tokens, dim] ou [batch, tokens, dim] ou [batch, dim] ou [dim]
142
  if arr.ndim == 3: # [batch, tokens, dim]
143
  arr = arr.mean(axis=1)
144
  elif arr.ndim == 2:
 
153
  arr = arr / norms
154
  return arr.astype(np.float32), size
155
 
156
+ def _hf_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
157
+ """
158
+ 1) Essaie PIPELINE feature-extraction (si activé sur l'Infra)
159
+ 2) Fallback MODELS + header X-Task: feature-extraction
160
+ """
161
+ # payload commun
162
+ payload: Dict[str, Any] = {"inputs": (batch if len(batch) > 1 else batch[0])}
163
+ if HF_WAIT:
164
+ payload["options"] = {"wait_for_model": True}
165
+
166
+ # ordre: pipeline first (configurable)
167
+ urls = [HF_URL_PIPELINE, HF_URL_MODELS] if HF_PIPELINE_FIRST else [HF_URL_MODELS, HF_URL_PIPELINE]
168
+ for idx, url in enumerate(urls, 1):
169
+ try:
170
+ if "/models/" in url:
171
+ return _hf_http(url, payload, headers_extra={"X-Task": "feature-extraction"})
172
+ else:
173
+ return _hf_http(url, payload, headers_extra=None)
174
+ except requests.HTTPError as he:
175
+ code = he.response.status_code if he.response is not None else 0
176
+ # si 404/405/501 → tente l'autre forme
177
+ if code in (404, 405, 501) and idx < len(urls):
178
+ LOG.warning(f"HF endpoint {url} non dispo ({code}), fallback vers alternative ...")
179
+ continue
180
+ raise
181
+ # ne devrait jamais tomber ici
182
+ raise RuntimeError("HF: aucun endpoint utilisable")
183
+
184
  def _di_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
185
  if not DI_TOKEN:
186
  raise RuntimeError("DEEPINFRA_API_KEY manquant (backend=deepinfra).")
187
+ headers = {"Authorization": f"Bearer {DI_TOKEN}", "Content-Type": "application/json", "Accept": "application/json"}
188
  payload = {"model": DI_MODEL, "input": batch}
189
  r = requests.post(DI_URL, headers=headers, json=payload, timeout=DI_TIMEOUT)
190
  size = int(r.headers.get("Content-Length", "0"))
 
352
  "ok": True,
353
  "service": "remote-indexer",
354
  "backends": EMB_BACKEND_ORDER,
355
+ "hf_url_pipeline": HF_URL_PIPELINE if "hf" in EMB_BACKEND_ORDER else None,
356
+ "hf_url_models": HF_URL_MODELS if "hf" in EMB_BACKEND_ORDER else None,
357
  "di_model": DI_MODEL if "deepinfra" in EMB_BACKEND_ORDER else None,
358
  "docs": "/health, /index, /status/{job_id}, /query, /wipe"
359
  }
 
374
  raise HTTPException(401, "Unauthorized")
375
  _check_backend_ready()
376
 
377
+ # Filtrage défensif des fichiers vides pour éviter 422
378
  non_empty = [f for f in req.files if (f.text or "").strip()]
379
  if not non_empty:
380
  raise HTTPException(422, "Aucun fichier non vide à indexer.")
 
394
  raise HTTPException(404, "job inconnu")
395
  return {"status": j["status"], "logs": j["logs"][-800:]}
396
 
397
+ # --- Compat endpoints (clients legacy) ---
398
  @app.get("/status")
399
  def status_qp(job_id: str = Query(None), x_auth_token: Optional[str] = Header(default=None)):
400
  if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
 
424
  raise HTTPException(401, "Unauthorized")
425
  _check_backend_ready()
426
 
427
+ # bornes du top_k
428
+ k = int(max(1, min(50, req.top_k or 6)))
 
 
 
 
 
 
 
 
429
 
430
  vecs, _ = _post_embeddings([req.query])
431
  col = f"proj_{req.project_id}"
 
455
  raise HTTPException(401, "Unauthorized")
456
  col = f"proj_{project_id}"
457
  try:
458
+ qdrant.delete_collection(col); return {"ok": True}
459
  except Exception as e:
460
  raise HTTPException(400, f"wipe failed: {e}")
461