chouchouvs commited on
Commit
b678bb5
·
verified ·
1 Parent(s): dd055bb

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +35 -28
main.py CHANGED
@@ -17,15 +17,14 @@ LOG = logging.getLogger("remote_indexer")
17
  # ---------- ENV ----------
18
  EMB_BACKEND = os.getenv("EMB_BACKEND", "hf").strip().lower() # "hf" (défaut) ou "deepinfra"
19
 
20
- # HF
21
  HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
22
  HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
23
- # Si tu as un Inference Endpoint privé, ou si tu veux l’API "models/..." :
24
- # ex: https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2
25
  HF_URL = (os.getenv("HF_API_URL", "").strip()
26
- or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}")
27
 
28
- # DeepInfra
29
  DI_TOKEN = os.getenv("DEEPINFRA_API_KEY", "").strip()
30
  DI_MODEL = os.getenv("DEEPINFRA_EMBED_MODEL", "thenlper/gte-small").strip()
31
  DI_URL = os.getenv("DEEPINFRA_EMBED_URL", "https://api.deepinfra.com/v1/embeddings").strip()
@@ -76,14 +75,25 @@ def _auth(x_auth: Optional[str]):
76
  raise HTTPException(status_code=401, detail="Unauthorized")
77
 
78
  def _hf_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
 
 
 
 
 
 
79
  if not HF_TOKEN:
80
  raise RuntimeError("HF_API_TOKEN manquant (backend=hf).")
81
- headers = {"Authorization": f"Bearer {HF_TOKEN}"}
 
 
 
 
 
 
82
  try:
83
- r = requests.post(HF_URL, headers=headers, json=batch, timeout=120)
84
  size = int(r.headers.get("Content-Length", "0"))
85
  if r.status_code >= 400:
86
- # Log détaillé pour comprendre le 403/4xx
87
  try:
88
  LOG.error(f"HF error {r.status_code}: {r.text}")
89
  except Exception:
@@ -97,6 +107,9 @@ def _hf_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
97
  # [batch, dim] (sentence-transformers) ou [batch, tokens, dim] -> mean-pooling
98
  if arr.ndim == 3:
99
  arr = arr.mean(axis=1)
 
 
 
100
  if arr.ndim != 2:
101
  raise RuntimeError(f"HF: unexpected embeddings shape: {arr.shape}")
102
  # normalisation
@@ -105,6 +118,11 @@ def _hf_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
105
  return arr.astype(np.float32), size
106
 
107
  def _di_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
 
 
 
 
 
108
  if not DI_TOKEN:
109
  raise RuntimeError("DEEPINFRA_API_KEY manquant (backend=deepinfra).")
110
  headers = {"Authorization": f"Bearer {DI_TOKEN}", "Content-Type": "application/json"}
@@ -122,7 +140,6 @@ def _di_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
122
  except Exception as e:
123
  raise RuntimeError(f"DeepInfra POST failed: {e}")
124
 
125
- # OpenAI-like : {"data":[{"embedding":[...],"index":0}, ...]}
126
  data = js.get("data")
127
  if not isinstance(data, list) or not data:
128
  raise RuntimeError(f"DeepInfra embeddings: réponse invalide {js}")
@@ -130,7 +147,6 @@ def _di_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
130
  arr = np.asarray(embs, dtype=np.float32)
131
  if arr.ndim != 2:
132
  raise RuntimeError(f"DeepInfra: unexpected embeddings shape: {arr.shape}")
133
- # normalisation
134
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
135
  arr = arr / norms
136
  return arr.astype(np.float32), size
@@ -145,8 +161,7 @@ def _post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
145
 
146
  def _ensure_collection(name: str, dim: int):
147
  try:
148
- qdr.get_collection(name)
149
- return
150
  except Exception:
151
  pass
152
  qdr.create_collection(
@@ -157,25 +172,21 @@ def _ensure_collection(name: str, dim: int):
157
  def _chunk_with_spans(text: str, size: int, overlap: int):
158
  n = len(text)
159
  if size <= 0:
160
- yield (0, n, text)
161
- return
162
  i = 0
163
  while i < n:
164
  j = min(n, i + size)
165
  yield (i, j, text[i:j])
166
  i = max(0, j - overlap)
167
- if i >= n:
168
- break
169
 
170
  def _append_log(job_id: str, line: str):
171
  job = JOBS.get(job_id)
172
- if not job: return
173
- job["logs"].append(line)
174
 
175
  def _set_status(job_id: str, status: str):
176
  job = JOBS.get(job_id)
177
- if not job: return
178
- job["status"] = status
179
 
180
  # ---------- Background task ----------
181
  def run_index_job(job_id: str, req: IndexRequest):
@@ -196,7 +207,6 @@ def run_index_job(job_id: str, req: IndexRequest):
196
  _append_log(job_id, f"Collection ready: {col} (dim={dim})")
197
 
198
  point_id = 0
199
-
200
  # boucle fichiers
201
  for fi, f in enumerate(req.files, 1):
202
  chunks, metas = [], []
@@ -218,7 +228,6 @@ def run_index_job(job_id: str, req: IndexRequest):
218
  _append_log(job_id, f"file {fi}/{len(req.files)}: +{len(chunks)} chunks (total={total_chunks}) ~{sz/1024:.1f}KiB")
219
  chunks, metas = [], []
220
 
221
- # flush fin de fichier
222
  if chunks:
223
  vecs, sz = _post_embeddings(chunks)
224
  batch_points = []
@@ -255,7 +264,7 @@ def root():
255
  def health():
256
  return {"ok": True}
257
 
258
- def _check_backend_ready(for_query=False):
259
  if EMB_BACKEND == "hf" and not HF_TOKEN:
260
  raise HTTPException(400, "HF_API_TOKEN manquant côté serveur (backend=hf).")
261
  if EMB_BACKEND == "deepinfra" and not DI_TOKEN:
@@ -284,12 +293,11 @@ def status(job_id: str, x_auth_token: Optional[str] = Header(default=None)):
284
  def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None)):
285
  if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
286
  raise HTTPException(401, "Unauthorized")
287
- _check_backend_ready(for_query=True)
288
  vec, _ = _post_embeddings([req.query])
289
- vec = vec[0].tolist()
290
  col = f"proj_{req.project_id}"
291
  try:
292
- res = qdr.search(collection_name=col, query_vector=vec, limit=int(req.top_k))
293
  except Exception as e:
294
  raise HTTPException(400, f"Search failed: {e}")
295
  out = []
@@ -307,8 +315,7 @@ def wipe_collection(project_id: str, x_auth_token: Optional[str] = Header(defaul
307
  raise HTTPException(401, "Unauthorized")
308
  col = f"proj_{project_id}"
309
  try:
310
- qdr.delete_collection(col)
311
- return {"ok": True}
312
  except Exception as e:
313
  raise HTTPException(400, f"wipe failed: {e}")
314
 
 
17
  # ---------- ENV ----------
18
  EMB_BACKEND = os.getenv("EMB_BACKEND", "hf").strip().lower() # "hf" (défaut) ou "deepinfra"
19
 
20
+ # Hugging Face
21
  HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
22
  HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
23
+ # Recommandé: endpoint "models" (plus tolerant)
 
24
  HF_URL = (os.getenv("HF_API_URL", "").strip()
25
+ or f"https://api-inference.huggingface.co/models/{HF_MODEL}")
26
 
27
+ # DeepInfra (option)
28
  DI_TOKEN = os.getenv("DEEPINFRA_API_KEY", "").strip()
29
  DI_MODEL = os.getenv("DEEPINFRA_EMBED_MODEL", "thenlper/gte-small").strip()
30
  DI_URL = os.getenv("DEEPINFRA_EMBED_URL", "https://api.deepinfra.com/v1/embeddings").strip()
 
75
  raise HTTPException(status_code=401, detail="Unauthorized")
76
 
77
  def _hf_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
78
+ """
79
+ Hugging Face Inference API:
80
+ - envoyer {"inputs": ...} (string ou liste de strings)
81
+ - endpoint recommandé: /models/<repo_id>
82
+ Retour: liste de vecteurs [batch, dim] OU [batch, tokens, dim]
83
+ """
84
  if not HF_TOKEN:
85
  raise RuntimeError("HF_API_TOKEN manquant (backend=hf).")
86
+
87
+ headers = {
88
+ "Authorization": f"Bearer {HF_TOKEN}",
89
+ "Content-Type": "application/json",
90
+ # Optionnel (forçage warmup) : "X-Wait-For-Model": "true"
91
+ }
92
+ payload = {"inputs": batch if len(batch) > 1 else batch[0]}
93
  try:
94
+ r = requests.post(HF_URL, headers=headers, json=payload, timeout=120)
95
  size = int(r.headers.get("Content-Length", "0"))
96
  if r.status_code >= 400:
 
97
  try:
98
  LOG.error(f"HF error {r.status_code}: {r.text}")
99
  except Exception:
 
107
  # [batch, dim] (sentence-transformers) ou [batch, tokens, dim] -> mean-pooling
108
  if arr.ndim == 3:
109
  arr = arr.mean(axis=1)
110
+ if arr.ndim == 1:
111
+ # cas rare: un seul vecteur (batch=1)
112
+ arr = arr.reshape(1, -1)
113
  if arr.ndim != 2:
114
  raise RuntimeError(f"HF: unexpected embeddings shape: {arr.shape}")
115
  # normalisation
 
118
  return arr.astype(np.float32), size
119
 
120
  def _di_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
121
+ """
122
+ DeepInfra embeddings (OpenAI-like):
123
+ POST /v1/embeddings {model: ..., input: [...]}
124
+ Réponse: {"data":[{"embedding":[...],"index":0}, ...]}
125
+ """
126
  if not DI_TOKEN:
127
  raise RuntimeError("DEEPINFRA_API_KEY manquant (backend=deepinfra).")
128
  headers = {"Authorization": f"Bearer {DI_TOKEN}", "Content-Type": "application/json"}
 
140
  except Exception as e:
141
  raise RuntimeError(f"DeepInfra POST failed: {e}")
142
 
 
143
  data = js.get("data")
144
  if not isinstance(data, list) or not data:
145
  raise RuntimeError(f"DeepInfra embeddings: réponse invalide {js}")
 
147
  arr = np.asarray(embs, dtype=np.float32)
148
  if arr.ndim != 2:
149
  raise RuntimeError(f"DeepInfra: unexpected embeddings shape: {arr.shape}")
 
150
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
151
  arr = arr / norms
152
  return arr.astype(np.float32), size
 
161
 
162
  def _ensure_collection(name: str, dim: int):
163
  try:
164
+ qdr.get_collection(name); return
 
165
  except Exception:
166
  pass
167
  qdr.create_collection(
 
172
  def _chunk_with_spans(text: str, size: int, overlap: int):
173
  n = len(text)
174
  if size <= 0:
175
+ yield (0, n, text); return
 
176
  i = 0
177
  while i < n:
178
  j = min(n, i + size)
179
  yield (i, j, text[i:j])
180
  i = max(0, j - overlap)
181
+ if i >= n: break
 
182
 
183
  def _append_log(job_id: str, line: str):
184
  job = JOBS.get(job_id)
185
+ if job: job["logs"].append(line)
 
186
 
187
  def _set_status(job_id: str, status: str):
188
  job = JOBS.get(job_id)
189
+ if job: job["status"] = status
 
190
 
191
  # ---------- Background task ----------
192
  def run_index_job(job_id: str, req: IndexRequest):
 
207
  _append_log(job_id, f"Collection ready: {col} (dim={dim})")
208
 
209
  point_id = 0
 
210
  # boucle fichiers
211
  for fi, f in enumerate(req.files, 1):
212
  chunks, metas = [], []
 
228
  _append_log(job_id, f"file {fi}/{len(req.files)}: +{len(chunks)} chunks (total={total_chunks}) ~{sz/1024:.1f}KiB")
229
  chunks, metas = [], []
230
 
 
231
  if chunks:
232
  vecs, sz = _post_embeddings(chunks)
233
  batch_points = []
 
264
  def health():
265
  return {"ok": True}
266
 
267
+ def _check_backend_ready():
268
  if EMB_BACKEND == "hf" and not HF_TOKEN:
269
  raise HTTPException(400, "HF_API_TOKEN manquant côté serveur (backend=hf).")
270
  if EMB_BACKEND == "deepinfra" and not DI_TOKEN:
 
293
  def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None)):
294
  if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
295
  raise HTTPException(401, "Unauthorized")
296
+ _check_backend_ready()
297
  vec, _ = _post_embeddings([req.query])
 
298
  col = f"proj_{req.project_id}"
299
  try:
300
+ res = qdr.search(collection_name=col, query_vector=vec[0].tolist(), limit=int(req.top_k))
301
  except Exception as e:
302
  raise HTTPException(400, f"Search failed: {e}")
303
  out = []
 
315
  raise HTTPException(401, "Unauthorized")
316
  col = f"proj_{project_id}"
317
  try:
318
+ qdr.delete_collection(col); return {"ok": True}
 
319
  except Exception as e:
320
  raise HTTPException(400, f"wipe failed: {e}")
321