Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
|
@@ -22,10 +22,7 @@ EMB_BACKEND_ORDER = [s.strip().lower() for s in os.getenv("EMB_BACKEND_ORDER", o
|
|
| 22 |
HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
|
| 23 |
HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2").strip()
|
| 24 |
|
| 25 |
-
#
|
| 26 |
-
# - HF_API_URL_PIPELINE : force l'URL pipeline (feature-extraction)
|
| 27 |
-
# - HF_API_URL_MODELS : force l'URL models
|
| 28 |
-
# - HF_API_URL : compat; si contient "/pipeline", on l'utilise coté pipeline sinon coté models
|
| 29 |
HF_API_URL_USER = os.getenv("HF_API_URL", "").strip()
|
| 30 |
HF_API_URL_PIPELINE = os.getenv("HF_API_URL_PIPELINE", "").strip()
|
| 31 |
HF_API_URL_MODELS = os.getenv("HF_API_URL_MODELS", "").strip()
|
|
@@ -36,6 +33,7 @@ if HF_API_URL_USER:
|
|
| 36 |
else:
|
| 37 |
HF_API_URL_MODELS = HF_API_URL_USER
|
| 38 |
|
|
|
|
| 39 |
HF_URL_PIPELINE = (HF_API_URL_PIPELINE or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}")
|
| 40 |
HF_URL_MODELS = (HF_API_URL_MODELS or f"https://api-inference.huggingface.co/models/{HF_MODEL}")
|
| 41 |
|
|
@@ -115,9 +113,11 @@ def _retry_sleep(attempt: int):
|
|
| 115 |
jitter = 1.0 + random.uniform(-RETRY_JITTER, RETRY_JITTER)
|
| 116 |
return max(0.25, back * jitter)
|
| 117 |
|
| 118 |
-
def
|
| 119 |
-
|
| 120 |
-
)
|
|
|
|
|
|
|
| 121 |
if not HF_TOKEN:
|
| 122 |
raise RuntimeError("HF_API_TOKEN manquant (backend=hf).")
|
| 123 |
|
|
@@ -126,19 +126,24 @@ def _hf_http(
|
|
| 126 |
"Content-Type": "application/json",
|
| 127 |
"Accept": "application/json",
|
| 128 |
}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
if headers_extra:
|
| 130 |
headers.update(headers_extra)
|
| 131 |
|
| 132 |
r = requests.post(url, headers=headers, json=payload, timeout=HF_TIMEOUT)
|
| 133 |
size = int(r.headers.get("Content-Length", "0"))
|
| 134 |
if r.status_code >= 400:
|
| 135 |
-
# Affiche une partie du corps pour diagnostiquer
|
| 136 |
LOG.error(f"HF error {r.status_code}: {r.text[:1000]}")
|
| 137 |
r.raise_for_status()
|
| 138 |
|
| 139 |
data = r.json()
|
| 140 |
arr = np.array(data, dtype=np.float32)
|
| 141 |
-
# data peut être: [tokens, dim]
|
| 142 |
if arr.ndim == 3: # [batch, tokens, dim]
|
| 143 |
arr = arr.mean(axis=1)
|
| 144 |
elif arr.ndim == 2:
|
|
@@ -155,31 +160,45 @@ def _hf_http(
|
|
| 155 |
|
| 156 |
def _hf_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
|
| 157 |
"""
|
| 158 |
-
1) Essaie PIPELINE feature-extraction (si
|
| 159 |
2) Fallback MODELS + header X-Task: feature-extraction
|
|
|
|
| 160 |
"""
|
| 161 |
-
# payload commun
|
| 162 |
payload: Dict[str, Any] = {"inputs": (batch if len(batch) > 1 else batch[0])}
|
| 163 |
-
if HF_WAIT:
|
| 164 |
-
payload["options"] = {"wait_for_model": True}
|
| 165 |
|
| 166 |
-
# ordre: pipeline first (configurable)
|
| 167 |
urls = [HF_URL_PIPELINE, HF_URL_MODELS] if HF_PIPELINE_FIRST else [HF_URL_MODELS, HF_URL_PIPELINE]
|
|
|
|
|
|
|
| 168 |
for idx, url in enumerate(urls, 1):
|
| 169 |
try:
|
| 170 |
if "/models/" in url:
|
|
|
|
| 171 |
return _hf_http(url, payload, headers_extra={"X-Task": "feature-extraction"})
|
| 172 |
else:
|
|
|
|
| 173 |
return _hf_http(url, payload, headers_extra=None)
|
| 174 |
except requests.HTTPError as he:
|
| 175 |
code = he.response.status_code if he.response is not None else 0
|
| 176 |
-
|
|
|
|
| 177 |
if code in (404, 405, 501) and idx < len(urls):
|
| 178 |
LOG.warning(f"HF endpoint {url} non dispo ({code}), fallback vers alternative ...")
|
| 179 |
continue
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 180 |
raise
|
| 181 |
-
|
| 182 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 183 |
|
| 184 |
def _di_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
|
| 185 |
if not DI_TOKEN:
|
|
@@ -455,7 +474,7 @@ def wipe_collection(project_id: str, x_auth_token: Optional[str] = Header(defaul
|
|
| 455 |
raise HTTPException(401, "Unauthorized")
|
| 456 |
col = f"proj_{project_id}"
|
| 457 |
try:
|
| 458 |
-
|
| 459 |
except Exception as e:
|
| 460 |
raise HTTPException(400, f"wipe failed: {e}")
|
| 461 |
|
|
|
|
| 22 |
HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
|
| 23 |
HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2").strip()
|
| 24 |
|
| 25 |
+
# URLs configurables
|
|
|
|
|
|
|
|
|
|
| 26 |
HF_API_URL_USER = os.getenv("HF_API_URL", "").strip()
|
| 27 |
HF_API_URL_PIPELINE = os.getenv("HF_API_URL_PIPELINE", "").strip()
|
| 28 |
HF_API_URL_MODELS = os.getenv("HF_API_URL_MODELS", "").strip()
|
|
|
|
| 33 |
else:
|
| 34 |
HF_API_URL_MODELS = HF_API_URL_USER
|
| 35 |
|
| 36 |
+
# Défaults
|
| 37 |
HF_URL_PIPELINE = (HF_API_URL_PIPELINE or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}")
|
| 38 |
HF_URL_MODELS = (HF_API_URL_MODELS or f"https://api-inference.huggingface.co/models/{HF_MODEL}")
|
| 39 |
|
|
|
|
| 113 |
jitter = 1.0 + random.uniform(-RETRY_JITTER, RETRY_JITTER)
|
| 114 |
return max(0.25, back * jitter)
|
| 115 |
|
| 116 |
+
def _with_task_param(url: str, task: str = "feature-extraction") -> str:
|
| 117 |
+
# Ajoute ?task=feature-extraction (ou &task=...) si absent
|
| 118 |
+
return url + ("&" if "?" in url else "?") + f"task={task}"
|
| 119 |
+
|
| 120 |
+
def _hf_http(url: str, payload: Dict[str, Any], headers_extra: Optional[Dict[str, str]] = None) -> Tuple[np.ndarray, int]:
|
| 121 |
if not HF_TOKEN:
|
| 122 |
raise RuntimeError("HF_API_TOKEN manquant (backend=hf).")
|
| 123 |
|
|
|
|
| 126 |
"Content-Type": "application/json",
|
| 127 |
"Accept": "application/json",
|
| 128 |
}
|
| 129 |
+
# options.wait_for_model dans le JSON + X-Wait-For-Model côté header -> compat maximale
|
| 130 |
+
if HF_WAIT:
|
| 131 |
+
payload.setdefault("options", {})["wait_for_model"] = True
|
| 132 |
+
headers["X-Wait-For-Model"] = "true"
|
| 133 |
+
headers["X-Use-Cache"] = "true"
|
| 134 |
+
|
| 135 |
if headers_extra:
|
| 136 |
headers.update(headers_extra)
|
| 137 |
|
| 138 |
r = requests.post(url, headers=headers, json=payload, timeout=HF_TIMEOUT)
|
| 139 |
size = int(r.headers.get("Content-Length", "0"))
|
| 140 |
if r.status_code >= 400:
|
|
|
|
| 141 |
LOG.error(f"HF error {r.status_code}: {r.text[:1000]}")
|
| 142 |
r.raise_for_status()
|
| 143 |
|
| 144 |
data = r.json()
|
| 145 |
arr = np.array(data, dtype=np.float32)
|
| 146 |
+
# data peut être: [tokens, dim], [batch, tokens, dim], [batch, dim], [dim]
|
| 147 |
if arr.ndim == 3: # [batch, tokens, dim]
|
| 148 |
arr = arr.mean(axis=1)
|
| 149 |
elif arr.ndim == 2:
|
|
|
|
| 160 |
|
| 161 |
def _hf_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
|
| 162 |
"""
|
| 163 |
+
1) Essaie PIPELINE feature-extraction (si dispo)
|
| 164 |
2) Fallback MODELS + header X-Task: feature-extraction
|
| 165 |
+
3) Si encore 400 à cause de SentenceSimilarityPipeline, force aussi ?task=feature-extraction sur l'URL MODELS
|
| 166 |
"""
|
|
|
|
| 167 |
payload: Dict[str, Any] = {"inputs": (batch if len(batch) > 1 else batch[0])}
|
|
|
|
|
|
|
| 168 |
|
|
|
|
| 169 |
urls = [HF_URL_PIPELINE, HF_URL_MODELS] if HF_PIPELINE_FIRST else [HF_URL_MODELS, HF_URL_PIPELINE]
|
| 170 |
+
last_exc: Optional[Exception] = None
|
| 171 |
+
|
| 172 |
for idx, url in enumerate(urls, 1):
|
| 173 |
try:
|
| 174 |
if "/models/" in url:
|
| 175 |
+
# 2) MODELS avec header X-Task
|
| 176 |
return _hf_http(url, payload, headers_extra={"X-Task": "feature-extraction"})
|
| 177 |
else:
|
| 178 |
+
# 1) PIPELINE
|
| 179 |
return _hf_http(url, payload, headers_extra=None)
|
| 180 |
except requests.HTTPError as he:
|
| 181 |
code = he.response.status_code if he.response is not None else 0
|
| 182 |
+
body = he.response.text if he.response is not None else ""
|
| 183 |
+
last_exc = he
|
| 184 |
if code in (404, 405, 501) and idx < len(urls):
|
| 185 |
LOG.warning(f"HF endpoint {url} non dispo ({code}), fallback vers alternative ...")
|
| 186 |
continue
|
| 187 |
+
# Si on a tapé MODELS et reçu SentenceSimilarityPipeline -> réessaie avec ?task=feature-extraction
|
| 188 |
+
if "/models/" in url and "SentenceSimilarityPipeline" in (body or ""):
|
| 189 |
+
try:
|
| 190 |
+
forced_url = _with_task_param(url, "feature-extraction")
|
| 191 |
+
LOG.warning("HF MODELS a choisi Similarity -> retry avec %s + X-Task", forced_url)
|
| 192 |
+
return _hf_http(forced_url, payload, headers_extra={"X-Task": "feature-extraction"})
|
| 193 |
+
except Exception as he2:
|
| 194 |
+
last_exc = he2
|
| 195 |
raise
|
| 196 |
+
except Exception as e:
|
| 197 |
+
last_exc = e
|
| 198 |
+
raise
|
| 199 |
+
|
| 200 |
+
# ne devrait pas arriver
|
| 201 |
+
raise RuntimeError(f"HF: aucun endpoint utilisable ({last_exc})")
|
| 202 |
|
| 203 |
def _di_post_embeddings_once(batch: List[str]) -> Tuple[np.ndarray, int]:
|
| 204 |
if not DI_TOKEN:
|
|
|
|
| 474 |
raise HTTPException(401, "Unauthorized")
|
| 475 |
col = f"proj_{project_id}"
|
| 476 |
try:
|
| 477 |
+
qdr.delete_collection(col); return {"ok": True}
|
| 478 |
except Exception as e:
|
| 479 |
raise HTTPException(400, f"wipe failed: {e}")
|
| 480 |
|