chouchouvs commited on
Commit
6cb5d1b
·
verified ·
1 Parent(s): 7fb6049

Update main.py

Browse files
Files changed (1) hide show
  1. main.py +58 -27
main.py CHANGED
@@ -33,26 +33,59 @@ LOG.setLevel(logging.INFO)
33
  # CONFIG (via ENV)
34
  # =============================================================================
35
  PORT = int(os.getenv("PORT", "7860"))
36
- DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data") # persistant dans le conteneur Space
37
  os.makedirs(DATA_ROOT, exist_ok=True)
38
 
39
  # Provider d'embeddings:
40
  # - "dummy" : vecteurs aléatoires déterministes (très rapide)
41
- # - "st" : Sentence-Transformers (CPU-friendly, simple)
42
- # - "hf" : Transformers (AutoModel/AutoTokenizer, pooling manuel)
43
  EMB_PROVIDER = os.getenv("EMB_PROVIDER", "dummy").strip().lower()
44
-
45
- # Modèle embeddings (utilisé si provider != "dummy")
46
- # Reco rapide et multilingue (FR ok) : paraphrase-multilingual-MiniLM-L12-v2 (dim=384)
47
  EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2").strip()
48
-
49
- # Batch d'encodage
50
  EMB_BATCH = int(os.getenv("EMB_BATCH", "32"))
 
51
 
52
- # Dimension par défaut (dummy) — pour st/hf on lit depuis le modèle
53
- EMB_DIM = int(os.getenv("EMB_DIM", "128"))
54
-
55
- # Cache global lazy
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
  _ST_MODEL = None
57
  _HF_TOKENIZER = None
58
  _HF_MODEL = None
@@ -135,8 +168,8 @@ def _get_st_model():
135
  global _ST_MODEL
136
  if _ST_MODEL is None:
137
  from sentence_transformers import SentenceTransformer
138
- _ST_MODEL = SentenceTransformer(EMB_MODEL)
139
- LOG.info(f"[st] modèle chargé: {EMB_MODEL}")
140
  return _ST_MODEL
141
 
142
  def _emb_st(texts: List[str]) -> np.ndarray:
@@ -155,7 +188,6 @@ def _st_dim() -> int:
155
  try:
156
  return int(model.get_sentence_embedding_dimension())
157
  except Exception:
158
- # fallback : encode une phrase et lit la shape
159
  v = model.encode(["dimension probe"], convert_to_numpy=True)
160
  return int(v.shape[1])
161
 
@@ -164,17 +196,16 @@ def _get_hf_model():
164
  global _HF_TOKENIZER, _HF_MODEL
165
  if _HF_MODEL is None or _HF_TOKENIZER is None:
166
  from transformers import AutoTokenizer, AutoModel
167
- _HF_TOKENIZER = AutoTokenizer.from_pretrained(EMB_MODEL)
168
- _HF_MODEL = AutoModel.from_pretrained(EMB_MODEL)
169
  _HF_MODEL.eval()
170
- LOG.info(f"[hf] modèle chargé: {EMB_MODEL}")
171
  return _HF_TOKENIZER, _HF_MODEL
172
 
173
  def _mean_pool(last_hidden_state: "np.ndarray", attention_mask: "np.ndarray") -> "np.ndarray":
174
- # mean pooling masquée
175
- mask = attention_mask[..., None].astype(last_hidden_state.dtype) # (b, t, 1)
176
- summed = (last_hidden_state * mask).sum(axis=1) # (b, h)
177
- counts = mask.sum(axis=1).clip(min=1e-9) # (b, 1)
178
  return summed / counts
179
 
180
  def _emb_hf(texts: List[str]) -> np.ndarray:
@@ -194,7 +225,6 @@ def _emb_hf(texts: List[str]) -> np.ndarray:
194
  return _l2_normalize(vecs)
195
 
196
  def _hf_dim() -> int:
197
- # essaie de lire hidden_size
198
  try:
199
  _, mod = _get_hf_model()
200
  return int(getattr(mod.config, "hidden_size", 768))
@@ -228,7 +258,7 @@ def _load_dataset(ds_dir: str) -> List[Dict[str, Any]]:
228
  def _save_faiss(fx_dir: str, xb: np.ndarray, meta: Dict[str, Any]):
229
  os.makedirs(fx_dir, exist_ok=True)
230
  idx_path = os.path.join(fx_dir, "emb.faiss")
231
- index = faiss.IndexFlatIP(xb.shape[1]) # cosine ~ inner product si normalisé
232
  index.add(xb)
233
  faiss.write_index(index, idx_path)
234
  with open(os.path.join(fx_dir, "meta.json"), "w", encoding="utf-8") as f:
@@ -250,7 +280,7 @@ def _tar_dir_to_bytes(dir_path: str) -> bytes:
250
  # =============================================================================
251
  # FASTAPI
252
  # =============================================================================
253
- fastapi_app = FastAPI(title="remote-indexer", version="2.0.0")
254
  fastapi_app.add_middleware(
255
  CORSMiddleware,
256
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
@@ -274,7 +304,8 @@ def health():
274
  "ok": True,
275
  "service": "remote-indexer",
276
  "provider": EMB_PROVIDER,
277
- "model": EMB_MODEL if EMB_PROVIDER != "dummy" else None
 
278
  }
279
  return info
280
 
@@ -436,7 +467,7 @@ def _ui_search(project_id: str, query: str, k: int):
436
 
437
  with gr.Blocks(title="Remote Indexer (FAISS)", analytics_enabled=False) as ui:
438
  gr.Markdown("## Remote Indexer — demo UI (API: `/index`, `/status/{job}`, `/search`, `/artifacts/...`).")
439
- gr.Markdown(f"**Provider**: `{EMB_PROVIDER}` — **Model**: `{EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}'")
440
  with gr.Tab("Index"):
441
  pid = gr.Textbox(label="Project ID", value="DEEPWEB")
442
  sample = gr.Textbox(label="Texte d’exemple", value="Alpha bravo charlie delta echo foxtrot.", lines=4)
 
33
  # CONFIG (via ENV)
34
  # =============================================================================
35
  PORT = int(os.getenv("PORT", "7860"))
36
+ DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data") # stockage interne du Space
37
  os.makedirs(DATA_ROOT, exist_ok=True)
38
 
39
  # Provider d'embeddings:
40
  # - "dummy" : vecteurs aléatoires déterministes (très rapide)
41
+ # - "st" : Sentence-Transformers (CPU-friendly)
42
+ # - "hf" : Transformers pur (AutoModel/AutoTokenizer)
43
  EMB_PROVIDER = os.getenv("EMB_PROVIDER", "dummy").strip().lower()
 
 
 
44
  EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2").strip()
 
 
45
  EMB_BATCH = int(os.getenv("EMB_BATCH", "32"))
46
+ EMB_DIM = int(os.getenv("EMB_DIM", "128")) # utilisé pour dummy
47
 
48
+ # =============================================================================
49
+ # CACHE DIRECTORIES (crucial pour éviter PermissionError: '/.cache')
50
+ # =============================================================================
51
+ def _setup_cache_dirs() -> Dict[str, str]:
52
+ # HOME peut être vide -> expanduser('~') => '/' -> '/.cache' -> Permission denied
53
+ os.environ.setdefault("HOME", "/home/user")
54
+
55
+ CACHE_ROOT = os.getenv("CACHE_ROOT", "/tmp/.cache").rstrip("/")
56
+ paths = {
57
+ "root": CACHE_ROOT,
58
+ "hf_home": f"{CACHE_ROOT}/huggingface",
59
+ "hf_hub": f"{CACHE_ROOT}/huggingface/hub",
60
+ "hf_tf": f"{CACHE_ROOT}/huggingface/transformers",
61
+ "torch": f"{CACHE_ROOT}/torch",
62
+ "st": f"{CACHE_ROOT}/sentence-transformers",
63
+ "mpl": f"{CACHE_ROOT}/matplotlib",
64
+ }
65
+ for p in paths.values():
66
+ try:
67
+ os.makedirs(p, exist_ok=True)
68
+ except Exception as e:
69
+ LOG.warning("Impossible de créer %s : %s", p, e)
70
+
71
+ # Variables standard HF/Transformers/Torch/ST
72
+ os.environ["HF_HOME"] = paths["hf_home"]
73
+ os.environ["HF_HUB_CACHE"] = paths["hf_hub"]
74
+ os.environ["TRANSFORMERS_CACHE"] = paths["hf_tf"]
75
+ os.environ["TORCH_HOME"] = paths["torch"]
76
+ os.environ["SENTENCE_TRANSFORMERS_HOME"] = paths["st"]
77
+ os.environ["MPLCONFIGDIR"] = paths["mpl"] # évite les warnings matplotlib
78
+
79
+ # Qualité de vie
80
+ os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
81
+ os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
82
+
83
+ LOG.info("Caches configurés: %s", json.dumps(paths, indent=2))
84
+ return paths
85
+
86
+ CACHE_PATHS = _setup_cache_dirs()
87
+
88
+ # Cache global lazy (pour les modèles)
89
  _ST_MODEL = None
90
  _HF_TOKENIZER = None
91
  _HF_MODEL = None
 
168
  global _ST_MODEL
169
  if _ST_MODEL is None:
170
  from sentence_transformers import SentenceTransformer
171
+ _ST_MODEL = SentenceTransformer(EMB_MODEL, cache_folder=CACHE_PATHS["st"])
172
+ LOG.info("[st] modèle chargé: %s (cache=%s)", EMB_MODEL, CACHE_PATHS["st"])
173
  return _ST_MODEL
174
 
175
  def _emb_st(texts: List[str]) -> np.ndarray:
 
188
  try:
189
  return int(model.get_sentence_embedding_dimension())
190
  except Exception:
 
191
  v = model.encode(["dimension probe"], convert_to_numpy=True)
192
  return int(v.shape[1])
193
 
 
196
  global _HF_TOKENIZER, _HF_MODEL
197
  if _HF_MODEL is None or _HF_TOKENIZER is None:
198
  from transformers import AutoTokenizer, AutoModel
199
+ _HF_TOKENIZER = AutoTokenizer.from_pretrained(EMB_MODEL, cache_dir=CACHE_PATHS["hf_tf"])
200
+ _HF_MODEL = AutoModel.from_pretrained(EMB_MODEL, cache_dir=CACHE_PATHS["hf_tf"])
201
  _HF_MODEL.eval()
202
+ LOG.info("[hf] modèle chargé: %s (cache=%s)", EMB_MODEL, CACHE_PATHS["hf_tf"])
203
  return _HF_TOKENIZER, _HF_MODEL
204
 
205
  def _mean_pool(last_hidden_state: "np.ndarray", attention_mask: "np.ndarray") -> "np.ndarray":
206
+ mask = attention_mask[..., None].astype(last_hidden_state.dtype)
207
+ summed = (last_hidden_state * mask).sum(axis=1)
208
+ counts = mask.sum(axis=1).clip(min=1e-9)
 
209
  return summed / counts
210
 
211
  def _emb_hf(texts: List[str]) -> np.ndarray:
 
225
  return _l2_normalize(vecs)
226
 
227
  def _hf_dim() -> int:
 
228
  try:
229
  _, mod = _get_hf_model()
230
  return int(getattr(mod.config, "hidden_size", 768))
 
258
  def _save_faiss(fx_dir: str, xb: np.ndarray, meta: Dict[str, Any]):
259
  os.makedirs(fx_dir, exist_ok=True)
260
  idx_path = os.path.join(fx_dir, "emb.faiss")
261
+ index = faiss.IndexFlatIP(xb.shape[1]) # cosine ~ inner product si embeddings normalisés
262
  index.add(xb)
263
  faiss.write_index(index, idx_path)
264
  with open(os.path.join(fx_dir, "meta.json"), "w", encoding="utf-8") as f:
 
280
  # =============================================================================
281
  # FASTAPI
282
  # =============================================================================
283
+ fastapi_app = FastAPI(title="remote-indexer", version="2.1.0")
284
  fastapi_app.add_middleware(
285
  CORSMiddleware,
286
  allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
 
304
  "ok": True,
305
  "service": "remote-indexer",
306
  "provider": EMB_PROVIDER,
307
+ "model": EMB_MODEL if EMB_PROVIDER != "dummy" else None,
308
+ "cache_root": os.getenv("CACHE_ROOT", "/tmp/.cache"),
309
  }
310
  return info
311
 
 
467
 
468
  with gr.Blocks(title="Remote Indexer (FAISS)", analytics_enabled=False) as ui:
469
  gr.Markdown("## Remote Indexer — demo UI (API: `/index`, `/status/{job}`, `/search`, `/artifacts/...`).")
470
+ gr.Markdown(f"**Provider**: `{EMB_PROVIDER}` — **Model**: `{EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}` — **Cache**: `{os.getenv('CACHE_ROOT', '/tmp/.cache')}`")
471
  with gr.Tab("Index"):
472
  pid = gr.Textbox(label="Project ID", value="DEEPWEB")
473
  sample = gr.Textbox(label="Texte d’exemple", value="Alpha bravo charlie delta echo foxtrot.", lines=4)