chouchouvs commited on
Commit
dd055bb
·
verified ·
1 Parent(s): 13ebb90

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +106 -27
main.py CHANGED
@@ -11,23 +11,37 @@ from qdrant_client import QdrantClient
11
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
12
 
13
  # ---------- logging ----------
14
- logging.basicConfig(
15
- level=logging.INFO,
16
- format="%(levelname)s:%(name)s:%(message)s"
17
- )
18
  LOG = logging.getLogger("remote_indexer")
19
 
20
  # ---------- ENV ----------
21
- AUTH_TOKEN = os.getenv("REMOTE_INDEX_TOKEN", "").strip() # simple header auth
 
 
22
  HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
23
  HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
24
- HF_URL = os.getenv("HF_API_URL", "").strip() or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}"
 
 
 
 
 
 
 
 
25
 
 
26
  QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
27
  QDRANT_API = os.getenv("QDRANT_API_KEY", "").strip()
28
 
29
- if not HF_TOKEN:
30
- LOG.warning("HF_API_TOKEN manquant — le service refusera /index et /query.")
 
 
 
 
 
 
31
 
32
  # ---------- Clients ----------
33
  try:
@@ -61,24 +75,74 @@ def _auth(x_auth: Optional[str]):
61
  if AUTH_TOKEN and (x_auth or "") != AUTH_TOKEN:
62
  raise HTTPException(status_code=401, detail="Unauthorized")
63
 
64
- def _post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
65
  if not HF_TOKEN:
66
- raise RuntimeError("HF_API_TOKEN manquant (server).")
67
  headers = {"Authorization": f"Bearer {HF_TOKEN}"}
68
- r = requests.post(HF_URL, headers=headers, json=batch, timeout=120)
69
- size = int(r.headers.get("Content-Length", "0"))
70
- r.raise_for_status()
71
- data = r.json()
 
 
 
 
 
 
 
 
 
 
72
  arr = np.array(data, dtype=np.float32)
73
  # [batch, dim] (sentence-transformers) ou [batch, tokens, dim] -> mean-pooling
74
  if arr.ndim == 3:
75
  arr = arr.mean(axis=1)
76
  if arr.ndim != 2:
77
- raise RuntimeError(f"Unexpected embeddings shape: {arr.shape}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
79
  arr = arr / norms
80
  return arr.astype(np.float32), size
81
 
 
 
 
 
 
 
 
 
82
  def _ensure_collection(name: str, dim: int):
83
  try:
84
  qdr.get_collection(name)
@@ -119,9 +183,9 @@ def run_index_job(job_id: str, req: IndexRequest):
119
  _set_status(job_id, "running")
120
  total_chunks = 0
121
  LOG.info(f"[{job_id}] Index start project={req.project_id} files={len(req.files)}")
122
- _append_log(job_id, f"Start project={req.project_id} files={len(req.files)}")
123
 
124
- # warmup pour dimension
125
  warmup = []
126
  for f in req.files[:1]:
127
  warmup.append(next(_chunk_with_spans(f.text, req.chunk_size, req.overlap))[2])
@@ -178,17 +242,30 @@ app = FastAPI()
178
 
179
  @app.get("/")
180
  def root():
181
- return {"ok": True, "service": "remote-indexer", "docs": "/health, /index, /status/{job_id}, /query, /wipe"}
 
 
 
 
 
 
 
182
 
183
  @app.get("/health")
184
  def health():
185
  return {"ok": True}
186
 
 
 
 
 
 
 
187
  @app.post("/index")
188
  def start_index(req: IndexRequest, background_tasks: BackgroundTasks, x_auth_token: Optional[str] = Header(default=None)):
189
- _auth(x_auth_token)
190
- if not HF_TOKEN:
191
- raise HTTPException(400, "HF_API_TOKEN manquant côté serveur.")
192
  job_id = uuid.uuid4().hex[:12]
193
  JOBS[job_id] = {"status": "queued", "logs": [], "created": time.time()}
194
  background_tasks.add_task(run_index_job, job_id, req)
@@ -196,7 +273,8 @@ def start_index(req: IndexRequest, background_tasks: BackgroundTasks, x_auth_tok
196
 
197
  @app.get("/status/{job_id}")
198
  def status(job_id: str, x_auth_token: Optional[str] = Header(default=None)):
199
- _auth(x_auth_token)
 
200
  j = JOBS.get(job_id)
201
  if not j:
202
  raise HTTPException(404, "job inconnu")
@@ -204,9 +282,9 @@ def status(job_id: str, x_auth_token: Optional[str] = Header(default=None)):
204
 
205
  @app.post("/query")
206
  def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None)):
207
- _auth(x_auth_token)
208
- if not HF_TOKEN:
209
- raise HTTPException(400, "HF_API_TOKEN manquant côté serveur.")
210
  vec, _ = _post_embeddings([req.query])
211
  vec = vec[0].tolist()
212
  col = f"proj_{req.project_id}"
@@ -225,7 +303,8 @@ def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None))
225
 
226
  @app.post("/wipe")
227
  def wipe_collection(project_id: str, x_auth_token: Optional[str] = Header(default=None)):
228
- _auth(x_auth_token)
 
229
  col = f"proj_{project_id}"
230
  try:
231
  qdr.delete_collection(col)
@@ -233,7 +312,7 @@ def wipe_collection(project_id: str, x_auth_token: Optional[str] = Header(defaul
233
  except Exception as e:
234
  raise HTTPException(400, f"wipe failed: {e}")
235
 
236
- # ---------- Entrypoint (respecte $PORT des Spaces) ----------
237
  if __name__ == "__main__":
238
  import uvicorn
239
  port = int(os.getenv("PORT", "7860"))
 
11
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
12
 
13
  # ---------- logging ----------
14
+ logging.basicConfig(level=logging.INFO, format="%(levelname)s:%(name)s:%(message)s")
 
 
 
15
  LOG = logging.getLogger("remote_indexer")
16
 
17
  # ---------- ENV ----------
18
+ EMB_BACKEND = os.getenv("EMB_BACKEND", "hf").strip().lower() # "hf" (défaut) ou "deepinfra"
19
+
20
+ # HF
21
  HF_TOKEN = os.getenv("HF_API_TOKEN", "").strip()
22
  HF_MODEL = os.getenv("HF_EMBED_MODEL", "sentence-transformers/all-MiniLM-L6-v2")
23
+ # Si tu as un Inference Endpoint privé, ou si tu veux l’API "models/..." :
24
+ # ex: https://api-inference.huggingface.co/models/sentence-transformers/all-MiniLM-L6-v2
25
+ HF_URL = (os.getenv("HF_API_URL", "").strip()
26
+ or f"https://api-inference.huggingface.co/pipeline/feature-extraction/{HF_MODEL}")
27
+
28
+ # DeepInfra
29
+ DI_TOKEN = os.getenv("DEEPINFRA_API_KEY", "").strip()
30
+ DI_MODEL = os.getenv("DEEPINFRA_EMBED_MODEL", "thenlper/gte-small").strip()
31
+ DI_URL = os.getenv("DEEPINFRA_EMBED_URL", "https://api.deepinfra.com/v1/embeddings").strip()
32
 
33
+ # Qdrant
34
  QDRANT_URL = os.getenv("QDRANT_URL", "http://localhost:6333")
35
  QDRANT_API = os.getenv("QDRANT_API_KEY", "").strip()
36
 
37
+ # Auth d’API du service (simple header)
38
+ AUTH_TOKEN = os.getenv("REMOTE_INDEX_TOKEN", "").strip()
39
+
40
+ LOG.info(f"Embeddings backend = {EMB_BACKEND}")
41
+ if EMB_BACKEND == "hf" and not HF_TOKEN:
42
+ LOG.warning("HF_API_TOKEN manquant — HF /index et /query échoueront.")
43
+ if EMB_BACKEND == "deepinfra" and not DI_TOKEN:
44
+ LOG.warning("DEEPINFRA_API_KEY manquant — DeepInfra embeddings échoueront.")
45
 
46
  # ---------- Clients ----------
47
  try:
 
75
  if AUTH_TOKEN and (x_auth or "") != AUTH_TOKEN:
76
  raise HTTPException(status_code=401, detail="Unauthorized")
77
 
78
+ def _hf_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
79
  if not HF_TOKEN:
80
+ raise RuntimeError("HF_API_TOKEN manquant (backend=hf).")
81
  headers = {"Authorization": f"Bearer {HF_TOKEN}"}
82
+ try:
83
+ r = requests.post(HF_URL, headers=headers, json=batch, timeout=120)
84
+ size = int(r.headers.get("Content-Length", "0"))
85
+ if r.status_code >= 400:
86
+ # Log détaillé pour comprendre le 403/4xx
87
+ try:
88
+ LOG.error(f"HF error {r.status_code}: {r.text}")
89
+ except Exception:
90
+ LOG.error(f"HF error {r.status_code} (no body)")
91
+ r.raise_for_status()
92
+ data = r.json()
93
+ except Exception as e:
94
+ raise RuntimeError(f"HF POST failed: {e}")
95
+
96
  arr = np.array(data, dtype=np.float32)
97
  # [batch, dim] (sentence-transformers) ou [batch, tokens, dim] -> mean-pooling
98
  if arr.ndim == 3:
99
  arr = arr.mean(axis=1)
100
  if arr.ndim != 2:
101
+ raise RuntimeError(f"HF: unexpected embeddings shape: {arr.shape}")
102
+ # normalisation
103
+ norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
104
+ arr = arr / norms
105
+ return arr.astype(np.float32), size
106
+
107
+ def _di_post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
108
+ if not DI_TOKEN:
109
+ raise RuntimeError("DEEPINFRA_API_KEY manquant (backend=deepinfra).")
110
+ headers = {"Authorization": f"Bearer {DI_TOKEN}", "Content-Type": "application/json"}
111
+ payload = {"model": DI_MODEL, "input": batch}
112
+ try:
113
+ r = requests.post(DI_URL, headers=headers, json=payload, timeout=120)
114
+ size = int(r.headers.get("Content-Length", "0"))
115
+ if r.status_code >= 400:
116
+ try:
117
+ LOG.error(f"DeepInfra error {r.status_code}: {r.text}")
118
+ except Exception:
119
+ LOG.error(f"DeepInfra error {r.status_code} (no body)")
120
+ r.raise_for_status()
121
+ js = r.json()
122
+ except Exception as e:
123
+ raise RuntimeError(f"DeepInfra POST failed: {e}")
124
+
125
+ # OpenAI-like : {"data":[{"embedding":[...],"index":0}, ...]}
126
+ data = js.get("data")
127
+ if not isinstance(data, list) or not data:
128
+ raise RuntimeError(f"DeepInfra embeddings: réponse invalide {js}")
129
+ embs = [d.get("embedding") for d in data]
130
+ arr = np.asarray(embs, dtype=np.float32)
131
+ if arr.ndim != 2:
132
+ raise RuntimeError(f"DeepInfra: unexpected embeddings shape: {arr.shape}")
133
+ # normalisation
134
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
135
  arr = arr / norms
136
  return arr.astype(np.float32), size
137
 
138
+ def _post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
139
+ if EMB_BACKEND == "hf":
140
+ return _hf_post_embeddings(batch)
141
+ elif EMB_BACKEND == "deepinfra":
142
+ return _di_post_embeddings(batch)
143
+ else:
144
+ raise RuntimeError(f"EMB_BACKEND inconnu: {EMB_BACKEND}")
145
+
146
  def _ensure_collection(name: str, dim: int):
147
  try:
148
  qdr.get_collection(name)
 
183
  _set_status(job_id, "running")
184
  total_chunks = 0
185
  LOG.info(f"[{job_id}] Index start project={req.project_id} files={len(req.files)}")
186
+ _append_log(job_id, f"Start project={req.project_id} files={len(req.files)} | backend={EMB_BACKEND}")
187
 
188
+ # warmup -> dimension
189
  warmup = []
190
  for f in req.files[:1]:
191
  warmup.append(next(_chunk_with_spans(f.text, req.chunk_size, req.overlap))[2])
 
242
 
243
  @app.get("/")
244
  def root():
245
+ return {
246
+ "ok": True,
247
+ "service": "remote-indexer",
248
+ "backend": EMB_BACKEND,
249
+ "hf_url": HF_URL if EMB_BACKEND == "hf" else None,
250
+ "di_model": DI_MODEL if EMB_BACKEND == "deepinfra" else None,
251
+ "docs": "/health, /index, /status/{job_id}, /query, /wipe"
252
+ }
253
 
254
  @app.get("/health")
255
  def health():
256
  return {"ok": True}
257
 
258
+ def _check_backend_ready(for_query=False):
259
+ if EMB_BACKEND == "hf" and not HF_TOKEN:
260
+ raise HTTPException(400, "HF_API_TOKEN manquant côté serveur (backend=hf).")
261
+ if EMB_BACKEND == "deepinfra" and not DI_TOKEN:
262
+ raise HTTPException(400, "DEEPINFRA_API_KEY manquant côté serveur (backend=deepinfra).")
263
+
264
  @app.post("/index")
265
  def start_index(req: IndexRequest, background_tasks: BackgroundTasks, x_auth_token: Optional[str] = Header(default=None)):
266
+ if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
267
+ raise HTTPException(401, "Unauthorized")
268
+ _check_backend_ready()
269
  job_id = uuid.uuid4().hex[:12]
270
  JOBS[job_id] = {"status": "queued", "logs": [], "created": time.time()}
271
  background_tasks.add_task(run_index_job, job_id, req)
 
273
 
274
  @app.get("/status/{job_id}")
275
  def status(job_id: str, x_auth_token: Optional[str] = Header(default=None)):
276
+ if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
277
+ raise HTTPException(401, "Unauthorized")
278
  j = JOBS.get(job_id)
279
  if not j:
280
  raise HTTPException(404, "job inconnu")
 
282
 
283
  @app.post("/query")
284
  def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None)):
285
+ if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
286
+ raise HTTPException(401, "Unauthorized")
287
+ _check_backend_ready(for_query=True)
288
  vec, _ = _post_embeddings([req.query])
289
  vec = vec[0].tolist()
290
  col = f"proj_{req.project_id}"
 
303
 
304
  @app.post("/wipe")
305
  def wipe_collection(project_id: str, x_auth_token: Optional[str] = Header(default=None)):
306
+ if AUTH_TOKEN and (x_auth_token or "") != AUTH_TOKEN:
307
+ raise HTTPException(401, "Unauthorized")
308
  col = f"proj_{project_id}"
309
  try:
310
  qdr.delete_collection(col)
 
312
  except Exception as e:
313
  raise HTTPException(400, f"wipe failed: {e}")
314
 
315
+ # ---------- Entrypoint ----------
316
  if __name__ == "__main__":
317
  import uvicorn
318
  port = int(os.getenv("PORT", "7860"))