Spaces:
Running
Running
Update main.py
Browse files
main.py
CHANGED
|
@@ -33,26 +33,59 @@ LOG.setLevel(logging.INFO)
|
|
| 33 |
# CONFIG (via ENV)
|
| 34 |
# =============================================================================
|
| 35 |
PORT = int(os.getenv("PORT", "7860"))
|
| 36 |
-
DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data") #
|
| 37 |
os.makedirs(DATA_ROOT, exist_ok=True)
|
| 38 |
|
| 39 |
# Provider d'embeddings:
|
| 40 |
# - "dummy" : vecteurs aléatoires déterministes (très rapide)
|
| 41 |
-
# - "st" : Sentence-Transformers (CPU-friendly
|
| 42 |
-
# - "hf" : Transformers (AutoModel/AutoTokenizer
|
| 43 |
EMB_PROVIDER = os.getenv("EMB_PROVIDER", "dummy").strip().lower()
|
| 44 |
-
|
| 45 |
-
# Modèle embeddings (utilisé si provider != "dummy")
|
| 46 |
-
# Reco rapide et multilingue (FR ok) : paraphrase-multilingual-MiniLM-L12-v2 (dim=384)
|
| 47 |
EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2").strip()
|
| 48 |
-
|
| 49 |
-
# Batch d'encodage
|
| 50 |
EMB_BATCH = int(os.getenv("EMB_BATCH", "32"))
|
|
|
|
| 51 |
|
| 52 |
-
#
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
_ST_MODEL = None
|
| 57 |
_HF_TOKENIZER = None
|
| 58 |
_HF_MODEL = None
|
|
@@ -135,8 +168,8 @@ def _get_st_model():
|
|
| 135 |
global _ST_MODEL
|
| 136 |
if _ST_MODEL is None:
|
| 137 |
from sentence_transformers import SentenceTransformer
|
| 138 |
-
_ST_MODEL = SentenceTransformer(EMB_MODEL)
|
| 139 |
-
LOG.info(
|
| 140 |
return _ST_MODEL
|
| 141 |
|
| 142 |
def _emb_st(texts: List[str]) -> np.ndarray:
|
|
@@ -155,7 +188,6 @@ def _st_dim() -> int:
|
|
| 155 |
try:
|
| 156 |
return int(model.get_sentence_embedding_dimension())
|
| 157 |
except Exception:
|
| 158 |
-
# fallback : encode une phrase et lit la shape
|
| 159 |
v = model.encode(["dimension probe"], convert_to_numpy=True)
|
| 160 |
return int(v.shape[1])
|
| 161 |
|
|
@@ -164,17 +196,16 @@ def _get_hf_model():
|
|
| 164 |
global _HF_TOKENIZER, _HF_MODEL
|
| 165 |
if _HF_MODEL is None or _HF_TOKENIZER is None:
|
| 166 |
from transformers import AutoTokenizer, AutoModel
|
| 167 |
-
_HF_TOKENIZER = AutoTokenizer.from_pretrained(EMB_MODEL)
|
| 168 |
-
_HF_MODEL = AutoModel.from_pretrained(EMB_MODEL)
|
| 169 |
_HF_MODEL.eval()
|
| 170 |
-
LOG.info(
|
| 171 |
return _HF_TOKENIZER, _HF_MODEL
|
| 172 |
|
| 173 |
def _mean_pool(last_hidden_state: "np.ndarray", attention_mask: "np.ndarray") -> "np.ndarray":
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
|
| 177 |
-
counts = mask.sum(axis=1).clip(min=1e-9) # (b, 1)
|
| 178 |
return summed / counts
|
| 179 |
|
| 180 |
def _emb_hf(texts: List[str]) -> np.ndarray:
|
|
@@ -194,7 +225,6 @@ def _emb_hf(texts: List[str]) -> np.ndarray:
|
|
| 194 |
return _l2_normalize(vecs)
|
| 195 |
|
| 196 |
def _hf_dim() -> int:
|
| 197 |
-
# essaie de lire hidden_size
|
| 198 |
try:
|
| 199 |
_, mod = _get_hf_model()
|
| 200 |
return int(getattr(mod.config, "hidden_size", 768))
|
|
@@ -228,7 +258,7 @@ def _load_dataset(ds_dir: str) -> List[Dict[str, Any]]:
|
|
| 228 |
def _save_faiss(fx_dir: str, xb: np.ndarray, meta: Dict[str, Any]):
|
| 229 |
os.makedirs(fx_dir, exist_ok=True)
|
| 230 |
idx_path = os.path.join(fx_dir, "emb.faiss")
|
| 231 |
-
index = faiss.IndexFlatIP(xb.shape[1]) # cosine ~ inner product si
|
| 232 |
index.add(xb)
|
| 233 |
faiss.write_index(index, idx_path)
|
| 234 |
with open(os.path.join(fx_dir, "meta.json"), "w", encoding="utf-8") as f:
|
|
@@ -250,7 +280,7 @@ def _tar_dir_to_bytes(dir_path: str) -> bytes:
|
|
| 250 |
# =============================================================================
|
| 251 |
# FASTAPI
|
| 252 |
# =============================================================================
|
| 253 |
-
fastapi_app = FastAPI(title="remote-indexer", version="2.
|
| 254 |
fastapi_app.add_middleware(
|
| 255 |
CORSMiddleware,
|
| 256 |
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
|
|
@@ -274,7 +304,8 @@ def health():
|
|
| 274 |
"ok": True,
|
| 275 |
"service": "remote-indexer",
|
| 276 |
"provider": EMB_PROVIDER,
|
| 277 |
-
"model": EMB_MODEL if EMB_PROVIDER != "dummy" else None
|
|
|
|
| 278 |
}
|
| 279 |
return info
|
| 280 |
|
|
@@ -436,7 +467,7 @@ def _ui_search(project_id: str, query: str, k: int):
|
|
| 436 |
|
| 437 |
with gr.Blocks(title="Remote Indexer (FAISS)", analytics_enabled=False) as ui:
|
| 438 |
gr.Markdown("## Remote Indexer — demo UI (API: `/index`, `/status/{job}`, `/search`, `/artifacts/...`).")
|
| 439 |
-
gr.Markdown(f"**Provider**: `{EMB_PROVIDER}` — **Model**: `{EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}'")
|
| 440 |
with gr.Tab("Index"):
|
| 441 |
pid = gr.Textbox(label="Project ID", value="DEEPWEB")
|
| 442 |
sample = gr.Textbox(label="Texte d’exemple", value="Alpha bravo charlie delta echo foxtrot.", lines=4)
|
|
|
|
| 33 |
# CONFIG (via ENV)
|
| 34 |
# =============================================================================
|
| 35 |
PORT = int(os.getenv("PORT", "7860"))
|
| 36 |
+
DATA_ROOT = os.getenv("DATA_ROOT", "/tmp/data") # stockage interne du Space
|
| 37 |
os.makedirs(DATA_ROOT, exist_ok=True)
|
| 38 |
|
| 39 |
# Provider d'embeddings:
|
| 40 |
# - "dummy" : vecteurs aléatoires déterministes (très rapide)
|
| 41 |
+
# - "st" : Sentence-Transformers (CPU-friendly)
|
| 42 |
+
# - "hf" : Transformers pur (AutoModel/AutoTokenizer)
|
| 43 |
EMB_PROVIDER = os.getenv("EMB_PROVIDER", "dummy").strip().lower()
|
|
|
|
|
|
|
|
|
|
| 44 |
EMB_MODEL = os.getenv("EMB_MODEL", "sentence-transformers/paraphrase-multilingual-MiniLM-L12-v2").strip()
|
|
|
|
|
|
|
| 45 |
EMB_BATCH = int(os.getenv("EMB_BATCH", "32"))
|
| 46 |
+
EMB_DIM = int(os.getenv("EMB_DIM", "128")) # utilisé pour dummy
|
| 47 |
|
| 48 |
+
# =============================================================================
|
| 49 |
+
# CACHE DIRECTORIES (crucial pour éviter PermissionError: '/.cache')
|
| 50 |
+
# =============================================================================
|
| 51 |
+
def _setup_cache_dirs() -> Dict[str, str]:
|
| 52 |
+
# HOME peut être vide -> expanduser('~') => '/' -> '/.cache' -> Permission denied
|
| 53 |
+
os.environ.setdefault("HOME", "/home/user")
|
| 54 |
+
|
| 55 |
+
CACHE_ROOT = os.getenv("CACHE_ROOT", "/tmp/.cache").rstrip("/")
|
| 56 |
+
paths = {
|
| 57 |
+
"root": CACHE_ROOT,
|
| 58 |
+
"hf_home": f"{CACHE_ROOT}/huggingface",
|
| 59 |
+
"hf_hub": f"{CACHE_ROOT}/huggingface/hub",
|
| 60 |
+
"hf_tf": f"{CACHE_ROOT}/huggingface/transformers",
|
| 61 |
+
"torch": f"{CACHE_ROOT}/torch",
|
| 62 |
+
"st": f"{CACHE_ROOT}/sentence-transformers",
|
| 63 |
+
"mpl": f"{CACHE_ROOT}/matplotlib",
|
| 64 |
+
}
|
| 65 |
+
for p in paths.values():
|
| 66 |
+
try:
|
| 67 |
+
os.makedirs(p, exist_ok=True)
|
| 68 |
+
except Exception as e:
|
| 69 |
+
LOG.warning("Impossible de créer %s : %s", p, e)
|
| 70 |
+
|
| 71 |
+
# Variables standard HF/Transformers/Torch/ST
|
| 72 |
+
os.environ["HF_HOME"] = paths["hf_home"]
|
| 73 |
+
os.environ["HF_HUB_CACHE"] = paths["hf_hub"]
|
| 74 |
+
os.environ["TRANSFORMERS_CACHE"] = paths["hf_tf"]
|
| 75 |
+
os.environ["TORCH_HOME"] = paths["torch"]
|
| 76 |
+
os.environ["SENTENCE_TRANSFORMERS_HOME"] = paths["st"]
|
| 77 |
+
os.environ["MPLCONFIGDIR"] = paths["mpl"] # évite les warnings matplotlib
|
| 78 |
+
|
| 79 |
+
# Qualité de vie
|
| 80 |
+
os.environ.setdefault("HF_HUB_DISABLE_SYMLINKS_WARNING", "1")
|
| 81 |
+
os.environ.setdefault("TOKENIZERS_PARALLELISM", "false")
|
| 82 |
+
|
| 83 |
+
LOG.info("Caches configurés: %s", json.dumps(paths, indent=2))
|
| 84 |
+
return paths
|
| 85 |
+
|
| 86 |
+
CACHE_PATHS = _setup_cache_dirs()
|
| 87 |
+
|
| 88 |
+
# Cache global lazy (pour les modèles)
|
| 89 |
_ST_MODEL = None
|
| 90 |
_HF_TOKENIZER = None
|
| 91 |
_HF_MODEL = None
|
|
|
|
| 168 |
global _ST_MODEL
|
| 169 |
if _ST_MODEL is None:
|
| 170 |
from sentence_transformers import SentenceTransformer
|
| 171 |
+
_ST_MODEL = SentenceTransformer(EMB_MODEL, cache_folder=CACHE_PATHS["st"])
|
| 172 |
+
LOG.info("[st] modèle chargé: %s (cache=%s)", EMB_MODEL, CACHE_PATHS["st"])
|
| 173 |
return _ST_MODEL
|
| 174 |
|
| 175 |
def _emb_st(texts: List[str]) -> np.ndarray:
|
|
|
|
| 188 |
try:
|
| 189 |
return int(model.get_sentence_embedding_dimension())
|
| 190 |
except Exception:
|
|
|
|
| 191 |
v = model.encode(["dimension probe"], convert_to_numpy=True)
|
| 192 |
return int(v.shape[1])
|
| 193 |
|
|
|
|
| 196 |
global _HF_TOKENIZER, _HF_MODEL
|
| 197 |
if _HF_MODEL is None or _HF_TOKENIZER is None:
|
| 198 |
from transformers import AutoTokenizer, AutoModel
|
| 199 |
+
_HF_TOKENIZER = AutoTokenizer.from_pretrained(EMB_MODEL, cache_dir=CACHE_PATHS["hf_tf"])
|
| 200 |
+
_HF_MODEL = AutoModel.from_pretrained(EMB_MODEL, cache_dir=CACHE_PATHS["hf_tf"])
|
| 201 |
_HF_MODEL.eval()
|
| 202 |
+
LOG.info("[hf] modèle chargé: %s (cache=%s)", EMB_MODEL, CACHE_PATHS["hf_tf"])
|
| 203 |
return _HF_TOKENIZER, _HF_MODEL
|
| 204 |
|
| 205 |
def _mean_pool(last_hidden_state: "np.ndarray", attention_mask: "np.ndarray") -> "np.ndarray":
|
| 206 |
+
mask = attention_mask[..., None].astype(last_hidden_state.dtype)
|
| 207 |
+
summed = (last_hidden_state * mask).sum(axis=1)
|
| 208 |
+
counts = mask.sum(axis=1).clip(min=1e-9)
|
|
|
|
| 209 |
return summed / counts
|
| 210 |
|
| 211 |
def _emb_hf(texts: List[str]) -> np.ndarray:
|
|
|
|
| 225 |
return _l2_normalize(vecs)
|
| 226 |
|
| 227 |
def _hf_dim() -> int:
|
|
|
|
| 228 |
try:
|
| 229 |
_, mod = _get_hf_model()
|
| 230 |
return int(getattr(mod.config, "hidden_size", 768))
|
|
|
|
| 258 |
def _save_faiss(fx_dir: str, xb: np.ndarray, meta: Dict[str, Any]):
|
| 259 |
os.makedirs(fx_dir, exist_ok=True)
|
| 260 |
idx_path = os.path.join(fx_dir, "emb.faiss")
|
| 261 |
+
index = faiss.IndexFlatIP(xb.shape[1]) # cosine ~ inner product si embeddings normalisés
|
| 262 |
index.add(xb)
|
| 263 |
faiss.write_index(index, idx_path)
|
| 264 |
with open(os.path.join(fx_dir, "meta.json"), "w", encoding="utf-8") as f:
|
|
|
|
| 280 |
# =============================================================================
|
| 281 |
# FASTAPI
|
| 282 |
# =============================================================================
|
| 283 |
+
fastapi_app = FastAPI(title="remote-indexer", version="2.1.0")
|
| 284 |
fastapi_app.add_middleware(
|
| 285 |
CORSMiddleware,
|
| 286 |
allow_origins=["*"], allow_credentials=True, allow_methods=["*"], allow_headers=["*"],
|
|
|
|
| 304 |
"ok": True,
|
| 305 |
"service": "remote-indexer",
|
| 306 |
"provider": EMB_PROVIDER,
|
| 307 |
+
"model": EMB_MODEL if EMB_PROVIDER != "dummy" else None,
|
| 308 |
+
"cache_root": os.getenv("CACHE_ROOT", "/tmp/.cache"),
|
| 309 |
}
|
| 310 |
return info
|
| 311 |
|
|
|
|
| 467 |
|
| 468 |
with gr.Blocks(title="Remote Indexer (FAISS)", analytics_enabled=False) as ui:
|
| 469 |
gr.Markdown("## Remote Indexer — demo UI (API: `/index`, `/status/{job}`, `/search`, `/artifacts/...`).")
|
| 470 |
+
gr.Markdown(f"**Provider**: `{EMB_PROVIDER}` — **Model**: `{EMB_MODEL if EMB_PROVIDER!='dummy' else '-'}` — **Cache**: `{os.getenv('CACHE_ROOT', '/tmp/.cache')}`")
|
| 471 |
with gr.Tab("Index"):
|
| 472 |
pid = gr.Textbox(label="Project ID", value="DEEPWEB")
|
| 473 |
sample = gr.Textbox(label="Texte d’exemple", value="Alpha bravo charlie delta echo foxtrot.", lines=4)
|