chouchouvs commited on
Commit
13ebb90
·
verified ·
1 Parent(s): f1d2d4d

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +25 -18
main.py CHANGED
@@ -2,14 +2,19 @@
2
  from __future__ import annotations
3
  import os, time, uuid, logging
4
  from typing import List, Optional, Dict, Any, Tuple
5
- import requests
6
  import numpy as np
 
7
  from fastapi import FastAPI, BackgroundTasks, Header, HTTPException
8
  from pydantic import BaseModel, Field
9
  from qdrant_client import QdrantClient
10
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
11
 
12
- logging.basicConfig(level=logging.INFO)
 
 
 
 
13
  LOG = logging.getLogger("remote_indexer")
14
 
15
  # ---------- ENV ----------
@@ -25,7 +30,10 @@ if not HF_TOKEN:
25
  LOG.warning("HF_API_TOKEN manquant — le service refusera /index et /query.")
26
 
27
  # ---------- Clients ----------
28
- qdr = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API if QDRANT_API else None)
 
 
 
29
 
30
  # ---------- Pydantic ----------
31
  class FileIn(BaseModel):
@@ -45,7 +53,7 @@ class QueryRequest(BaseModel):
45
  query: str
46
  top_k: int = 6
47
 
48
- # ---------- Jobs store (en mémoire) ----------
49
  JOBS: Dict[str, Dict[str, Any]] = {} # {job_id: {"status": "...", "logs": [...], "created": ts}}
50
 
51
  # ---------- Utils ----------
@@ -62,13 +70,11 @@ def _post_embeddings(batch: List[str]) -> Tuple[np.ndarray, int]:
62
  r.raise_for_status()
63
  data = r.json()
64
  arr = np.array(data, dtype=np.float32)
65
- # arr: [batch, dim] (sentence-transformers)
66
- # ou [batch, tokens, dim] -> mean pooling
67
  if arr.ndim == 3:
68
  arr = arr.mean(axis=1)
69
  if arr.ndim != 2:
70
  raise RuntimeError(f"Unexpected embeddings shape: {arr.shape}")
71
- # normalisation
72
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
73
  arr = arr / norms
74
  return arr.astype(np.float32), size
@@ -115,8 +121,7 @@ def run_index_job(job_id: str, req: IndexRequest):
115
  LOG.info(f"[{job_id}] Index start project={req.project_id} files={len(req.files)}")
116
  _append_log(job_id, f"Start project={req.project_id} files={len(req.files)}")
117
 
118
- # premier batch pour récupérer la dimension
119
- # on prépare un mini lot
120
  warmup = []
121
  for f in req.files[:1]:
122
  warmup.append(next(_chunk_with_spans(f.text, req.chunk_size, req.overlap))[2])
@@ -126,15 +131,8 @@ def run_index_job(job_id: str, req: IndexRequest):
126
  _ensure_collection(col, dim)
127
  _append_log(job_id, f"Collection ready: {col} (dim={dim})")
128
 
129
- points_buffer: List[PointStruct] = []
130
  point_id = 0
131
 
132
- def flush_points():
133
- nonlocal points_buffer
134
- if points_buffer:
135
- qdr.upsert(collection_name=col, points=points_buffer)
136
- points_buffer = []
137
-
138
  # boucle fichiers
139
  for fi, f in enumerate(req.files, 1):
140
  chunks, metas = [], []
@@ -167,7 +165,6 @@ def run_index_job(job_id: str, req: IndexRequest):
167
  total_chunks += len(chunks)
168
  _append_log(job_id, f"file {fi}/{len(req.files)}: +{len(chunks)} chunks (total={total_chunks}) ~{sz/1024:.1f}KiB")
169
 
170
- flush_points()
171
  _append_log(job_id, f"Done. chunks={total_chunks}")
172
  _set_status(job_id, "done")
173
  LOG.info(f"[{job_id}] Index finished. chunks={total_chunks}")
@@ -179,6 +176,10 @@ def run_index_job(job_id: str, req: IndexRequest):
179
  # ---------- API ----------
180
  app = FastAPI()
181
 
 
 
 
 
182
  @app.get("/health")
183
  def health():
184
  return {"ok": True}
@@ -217,7 +218,6 @@ def query(req: QueryRequest, x_auth_token: Optional[str] = Header(default=None))
217
  for p in res:
218
  pl = p.payload or {}
219
  txt = pl.get("text")
220
- # hard cap snippet size
221
  if txt and len(txt) > 800:
222
  txt = txt[:800] + "..."
223
  out.append({"path": pl.get("path"), "chunk": pl.get("chunk"), "start": pl.get("start"), "end": pl.get("end"), "text": txt})
@@ -232,3 +232,10 @@ def wipe_collection(project_id: str, x_auth_token: Optional[str] = Header(defaul
232
  return {"ok": True}
233
  except Exception as e:
234
  raise HTTPException(400, f"wipe failed: {e}")
 
 
 
 
 
 
 
 
2
  from __future__ import annotations
3
  import os, time, uuid, logging
4
  from typing import List, Optional, Dict, Any, Tuple
5
+
6
  import numpy as np
7
+ import requests
8
  from fastapi import FastAPI, BackgroundTasks, Header, HTTPException
9
  from pydantic import BaseModel, Field
10
  from qdrant_client import QdrantClient
11
  from qdrant_client.http.models import VectorParams, Distance, PointStruct
12
 
13
+ # ---------- logging ----------
14
+ logging.basicConfig(
15
+ level=logging.INFO,
16
+ format="%(levelname)s:%(name)s:%(message)s"
17
+ )
18
  LOG = logging.getLogger("remote_indexer")
19
 
20
  # ---------- ENV ----------
 
30
  LOG.warning("HF_API_TOKEN manquant — le service refusera /index et /query.")
31
 
32
  # ---------- Clients ----------
33
+ try:
34
+ qdr = QdrantClient(url=QDRANT_URL, api_key=QDRANT_API if QDRANT_API else None)
35
+ except Exception as e:
36
+ LOG.warning(f"Qdrant client init: {e}")
37
 
38
  # ---------- Pydantic ----------
39
  class FileIn(BaseModel):
 
53
  query: str
54
  top_k: int = 6
55
 
56
+ # ---------- Jobs store (mémoire) ----------
57
  JOBS: Dict[str, Dict[str, Any]] = {} # {job_id: {"status": "...", "logs": [...], "created": ts}}
58
 
59
  # ---------- Utils ----------
 
70
  r.raise_for_status()
71
  data = r.json()
72
  arr = np.array(data, dtype=np.float32)
73
+ # [batch, dim] (sentence-transformers) ou [batch, tokens, dim] -> mean-pooling
 
74
  if arr.ndim == 3:
75
  arr = arr.mean(axis=1)
76
  if arr.ndim != 2:
77
  raise RuntimeError(f"Unexpected embeddings shape: {arr.shape}")
 
78
  norms = np.linalg.norm(arr, axis=1, keepdims=True) + 1e-12
79
  arr = arr / norms
80
  return arr.astype(np.float32), size
 
121
  LOG.info(f"[{job_id}] Index start project={req.project_id} files={len(req.files)}")
122
  _append_log(job_id, f"Start project={req.project_id} files={len(req.files)}")
123
 
124
+ # warmup pour dimension
 
125
  warmup = []
126
  for f in req.files[:1]:
127
  warmup.append(next(_chunk_with_spans(f.text, req.chunk_size, req.overlap))[2])
 
131
  _ensure_collection(col, dim)
132
  _append_log(job_id, f"Collection ready: {col} (dim={dim})")
133
 
 
134
  point_id = 0
135
 
 
 
 
 
 
 
136
  # boucle fichiers
137
  for fi, f in enumerate(req.files, 1):
138
  chunks, metas = [], []
 
165
  total_chunks += len(chunks)
166
  _append_log(job_id, f"file {fi}/{len(req.files)}: +{len(chunks)} chunks (total={total_chunks}) ~{sz/1024:.1f}KiB")
167
 
 
168
  _append_log(job_id, f"Done. chunks={total_chunks}")
169
  _set_status(job_id, "done")
170
  LOG.info(f"[{job_id}] Index finished. chunks={total_chunks}")
 
176
  # ---------- API ----------
177
  app = FastAPI()
178
 
179
+ @app.get("/")
180
+ def root():
181
+ return {"ok": True, "service": "remote-indexer", "docs": "/health, /index, /status/{job_id}, /query, /wipe"}
182
+
183
  @app.get("/health")
184
  def health():
185
  return {"ok": True}
 
218
  for p in res:
219
  pl = p.payload or {}
220
  txt = pl.get("text")
 
221
  if txt and len(txt) > 800:
222
  txt = txt[:800] + "..."
223
  out.append({"path": pl.get("path"), "chunk": pl.get("chunk"), "start": pl.get("start"), "end": pl.get("end"), "text": txt})
 
232
  return {"ok": True}
233
  except Exception as e:
234
  raise HTTPException(400, f"wipe failed: {e}")
235
+
236
+ # ---------- Entrypoint (respecte $PORT des Spaces) ----------
237
+ if __name__ == "__main__":
238
+ import uvicorn
239
+ port = int(os.getenv("PORT", "7860"))
240
+ LOG.info(f"===== Application Startup on PORT {port} =====")
241
+ uvicorn.run(app, host="0.0.0.0", port=port)