chouchouvs commited on
Commit
1102a75
·
verified ·
1 Parent(s): e0f6e27

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +37 -18
main.py CHANGED
@@ -22,10 +22,7 @@ EMB_BACKEND_ORDER = [s.strip().lower() for s in os.getenv("EMB_BACKEND_ORDER", o
22
  HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
23
  HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2").strip()
24
 
25
- # On supporte 3 variables pour être souple:
26
- # - HF_API_URL_PIPELINE : force l'URL pipeline (feature-extraction)
27
- # - HF_API_URL_MODELS : force l'URL models
28
- # - HF_API_URL : compat; si contient "/pipeline", on l'utilise coté pipeline sinon coté models
29
  HF_API_URL_USER = os.getenv("HF_API_URL", "").strip()
30
  HF_API_URL_PIPELINE = os.getenv("HF_API_URL_PIPELINE", "").strip()
31
  HF_API_URL_MODELS = os.getenv("HF_API_URL_MODELS", "").strip()
@@ -36,6 +33,7 @@ if HF_API_URL_USER:
36
  else:
37
  HF_API_URL_MODELS = HF_API_URL_USER
38
 
 
39
  HF_URL_PIPELINE = (HF_API_URL_PIPELINE or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}")
40
  HF_URL_MODELS = (HF_API_URL_MODELS or f"https://api-inference.huggingface.co/models/{HF_MODEL}")
41
 
@@ -115,9 +113,11 @@ def _retry_sleep(attempt: int):
115
  jitter = 1.0 + random.uniform(-RETRY_JITTER, RETRY_JITTER)
116
  return max(0.25, back * jitter)
117
 
118
- def _hf_http(
119
- url: str, payload: Dict[str, Any], headers_extra: Optional[Dict[str, str]] = None
120
- ) -> Tuple[np.ndarray, int]:
 
 
121
  if not HF_TOKEN:
122
  raise RuntimeError("HF_API_TOKEN manquant (backend=hf).")
123
 
@@ -126,19 +126,24 @@ def _hf_http(
126
  "Content-Type": "application/json",
127
  "Accept": "application/json",
128
  }
 
 
 
 
 
 
129
  if headers_extra:
130
  headers.update(headers_extra)
131
 
132
  r = requests.post(url, headers=headers, json=payload, timeout=HF_TIMEOUT)
133
  size = int(r.headers.get("Content-Length", "0"))
134
  if r.status_code >= 400:
135
- # Affiche une partie du corps pour diagnostiquer
136
  LOG.error(f"HF error {r.status_code}: {r.text[:1000]}")
137
  r.raise_for_status()
138
 
139
  data = r.json()
140
  arr = np.array(data, dtype=np.float32)
141
- # data peut être: [tokens, dim] ou [batch, tokens, dim] ou [batch, dim] ou [dim]
142
  if arr.ndim == 3: # [batch, tokens, dim]
143
  arr = arr.mean(axis=1)
144
  elif arr.ndim == 2:
@@ -155,31 +160,45 @@ def _hf_http(
155
 
156
  def _hf_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
157
  """
158
- 1) Essaie PIPELINE feature-extraction (si activé sur l'Infra)
159
  2) Fallback MODELS + header X-Task: feature-extraction
 
160
  """
161
- # payload commun
162
  payload: Dict[str, Any] = {"inputs": (batch if len(batch) > 1 else batch[0])}
163
- if HF_WAIT:
164
- payload["options"] = {"wait_for_model": True}
165
 
166
- # ordre: pipeline first (configurable)
167
  urls = [HF_URL_PIPELINE, HF_URL_MODELS] if HF_PIPELINE_FIRST else [HF_URL_MODELS, HF_URL_PIPELINE]
 
 
168
  for idx, url in enumerate(urls, 1):
169
  try:
170
  if "/models/" in url:
 
171
  return _hf_http(url, payload, headers_extra={"X-Task": "feature-extraction"})
172
  else:
 
173
  return _hf_http(url, payload, headers_extra=None)
174
  except requests.HTTPError as he:
175
  code = he.response.status_code if he.response is not None else 0
176
- # si 404/405/501 tente l'autre forme
 
177
  if code in (404, 405, 501) and idx < len(urls):
178
  LOG.warning(f"HF endpoint {url} non dispo ({code}), fallback vers alternative ...")
179
  continue
 
 
 
 
 
 
 
 
180
  raise
181
- # ne devrait jamais tomber ici
182
- raise RuntimeError("HF: aucun endpoint utilisable")
 
 
 
 
183
 
184
  def _di_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
185
  if not DI_TOKEN:
@@ -455,7 +474,7 @@ def wipe_collection(project_id: str, x_auth_token: Optional[str] = Header(defaul
455
  raise HTTPException(401, "Unauthorized")
456
  col = f"proj_{project_id}"
457
  try:
458
- qdrant.delete_collection(col); return {"ok": True}
459
  except Exception as e:
460
  raise HTTPException(400, f"wipe failed: {e}")
461
 
 
22
  HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
23
  HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2").strip()
24
 
25
+ # URLs configurables
 
 
 
26
  HF_API_URL_USER = os.getenv("HF_API_URL", "").strip()
27
  HF_API_URL_PIPELINE = os.getenv("HF_API_URL_PIPELINE", "").strip()
28
  HF_API_URL_MODELS = os.getenv("HF_API_URL_MODELS", "").strip()
 
33
  else:
34
  HF_API_URL_MODELS = HF_API_URL_USER
35
 
36
+ # Défaults
37
  HF_URL_PIPELINE = (HF_API_URL_PIPELINE or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}")
38
  HF_URL_MODELS = (HF_API_URL_MODELS or f"https://api-inference.huggingface.co/models/{HF_MODEL}")
39
 
 
113
  jitter = 1.0 + random.uniform(-RETRY_JITTER, RETRY_JITTER)
114
  return max(0.25, back * jitter)
115
 
116
+ def _with_task_param(url: str, task: str = "feature-extraction") -> str:
117
+ # Ajoute ?task=feature-extraction (ou &task=...) si absent
118
+ return url + ("&" if "?" in url else "?") + f"task={task}"
119
+
120
+ def _hf_http(url: str, payload: Dict[str, Any], headers_extra: Optional[Dict[str, str]] = None) -> Tuple[np.ndarray, int]:
121
  if not HF_TOKEN:
122
  raise RuntimeError("HF_API_TOKEN manquant (backend=hf).")
123
 
 
126
  "Content-Type": "application/json",
127
  "Accept": "application/json",
128
  }
129
+ # options.wait_for_model dans le JSON + X-Wait-For-Model côté header -> compat maximale
130
+ if HF_WAIT:
131
+ payload.setdefault("options", {})["wait_for_model"] = True
132
+ headers["X-Wait-For-Model"] = "true"
133
+ headers["X-Use-Cache"] = "true"
134
+
135
  if headers_extra:
136
  headers.update(headers_extra)
137
 
138
  r = requests.post(url, headers=headers, json=payload, timeout=HF_TIMEOUT)
139
  size = int(r.headers.get("Content-Length", "0"))
140
  if r.status_code >= 400:
 
141
  LOG.error(f"HF error {r.status_code}: {r.text[:1000]}")
142
  r.raise_for_status()
143
 
144
  data = r.json()
145
  arr = np.array(data, dtype=np.float32)
146
+ # data peut être: [tokens, dim], [batch, tokens, dim], [batch, dim], [dim]
147
  if arr.ndim == 3: # [batch, tokens, dim]
148
  arr = arr.mean(axis=1)
149
  elif arr.ndim == 2:
 
160
 
161
  def _hf_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
162
  """
163
+ 1) Essaie PIPELINE feature-extraction (si dispo)
164
  2) Fallback MODELS + header X-Task: feature-extraction
165
+ 3) Si encore 400 à cause de SentenceSimilarityPipeline, force aussi ?task=feature-extraction sur l'URL MODELS
166
  """
 
167
  payload: Dict[str, Any] = {"inputs": (batch if len(batch) > 1 else batch[0])}
 
 
168
 
 
169
  urls = [HF_URL_PIPELINE, HF_URL_MODELS] if HF_PIPELINE_FIRST else [HF_URL_MODELS, HF_URL_PIPELINE]
170
+ last_exc: Optional[Exception] = None
171
+
172
  for idx, url in enumerate(urls, 1):
173
  try:
174
  if "/models/" in url:
175
+ # 2) MODELS avec header X-Task
176
  return _hf_http(url, payload, headers_extra={"X-Task": "feature-extraction"})
177
  else:
178
+ # 1) PIPELINE
179
  return _hf_http(url, payload, headers_extra=None)
180
  except requests.HTTPError as he:
181
  code = he.response.status_code if he.response is not None else 0
182
+ body = he.response.text if he.response is not None else ""
183
+ last_exc = he
184
  if code in (404, 405, 501) and idx < len(urls):
185
  LOG.warning(f"HF endpoint {url} non dispo ({code}), fallback vers alternative ...")
186
  continue
187
+ # Si on a tapé MODELS et reçu SentenceSimilarityPipeline -> réessaie avec ?task=feature-extraction
188
+ if "/models/" in url and "SentenceSimilarityPipeline" in (body or ""):
189
+ try:
190
+ forced_url = _with_task_param(url, "feature-extraction")
191
+ LOG.warning("HF MODELS a choisi Similarity -> retry avec %s + X-Task", forced_url)
192
+ return _hf_http(forced_url, payload, headers_extra={"X-Task": "feature-extraction"})
193
+ except Exception as he2:
194
+ last_exc = he2
195
  raise
196
+ except Exception as e:
197
+ last_exc = e
198
+ raise
199
+
200
+ # ne devrait pas arriver
201
+ raise RuntimeError(f"HF: aucun endpoint utilisable ({last_exc})")
202
 
203
  def _di_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
204
  if not DI_TOKEN:
 
474
  raise HTTPException(401, "Unauthorized")
475
  col = f"proj_{project_id}"
476
  try:
477
+ qdr.delete_collection(col); return {"ok": True}
478
  except Exception as e:
479
  raise HTTPException(400, f"wipe failed: {e}")
480