Spaces:
Running
Running
| """ | |
| core/vector_store.py | |
| ------------------------------------------------------------ | |
| Unified FAISS + BM25 storage utility for Clinical-Trial Chatbot. | |
| ✅ Works with glossary.json or FAISS metadata | |
| ✅ Returns normalized dicts for hybrid_retriever | |
| ✅ Adds load_all_text_chunks() for BM25 fallback | |
| ✅ Safe against missing files | |
| """ | |
| import os | |
| import re | |
| import json | |
| import faiss | |
| from sentence_transformers import SentenceTransformer | |
| # Globals used by retriever | |
| _index = None | |
| _model = None | |
| _meta = None | |
| # -------------------------------------------------------------------- | |
| # 1️⃣ Utility: load FAISS index + metadata (MVP version) | |
| # -------------------------------------------------------------------- | |
| def _ensure_faiss_index(): | |
| """Load FAISS index and metadata — prefer local persistent files, fallback to Hugging Face dataset.""" | |
| global _index, _model, _meta | |
| if _index is not None and _meta is not None: | |
| return True | |
| import json | |
| from huggingface_hub import hf_hub_download | |
| local_dir = "/home/user/app/persistent" | |
| local_index = os.path.join(local_dir, "faiss.index") | |
| local_meta = os.path.join(local_dir, "faiss.index.meta.json") | |
| # 1️⃣ Prefer local FAISS (rebuilt and includes URL + Excel) | |
| if os.path.exists(local_index) and os.path.exists(local_meta): | |
| print("📂 Using local FAISS index (includes Excel + Web sources).") | |
| _index = faiss.read_index(local_index) | |
| with open(local_meta, "r", encoding="utf-8") as f: | |
| _meta = json.load(f) | |
| _model = SentenceTransformer("all-MiniLM-L6-v2") | |
| print(f"✅ [vector_store] Loaded local FAISS ({len(_meta)} vectors).") | |
| return True | |
| # 2️⃣ Fallback: remote dataset | |
| print("☁️ Local FAISS missing — downloading from Hugging Face dataset...") | |
| repo_id = "essprasad/CT-Chat-Index" | |
| repo_type = "dataset" | |
| runtime_dir = "/home/user/app/runtime_faiss" | |
| os.makedirs(runtime_dir, exist_ok=True) | |
| index_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="persistent/faiss.index", | |
| repo_type=repo_type, | |
| local_dir=runtime_dir, | |
| cache_dir=runtime_dir, | |
| force_download=True, | |
| ) | |
| meta_path = hf_hub_download( | |
| repo_id=repo_id, | |
| filename="persistent/faiss.index.meta.json", | |
| repo_type=repo_type, | |
| local_dir=runtime_dir, | |
| cache_dir=runtime_dir, | |
| force_download=True, | |
| ) | |
| print(f"🧠 [vector_store] Loading FAISS index + metadata from {runtime_dir} ...") | |
| _index = faiss.read_index(index_path) | |
| with open(meta_path, "r", encoding="utf-8") as f: | |
| _meta = json.load(f) | |
| _model = SentenceTransformer("all-MiniLM-L6-v2") | |
| print(f"✅ [vector_store] Loaded remote FAISS ({len(_meta)} vectors).") | |
| return True | |
| # -------------------------------------------------------------------- | |
| # 2️⃣ Helper: Load all text chunks (for BM25 fallback) | |
| # -------------------------------------------------------------------- | |
| def load_all_text_chunks(): | |
| """ | |
| Return list of dicts for BM25 fallback and inspection. | |
| Each dict: {'text', 'file', 'source', 'term', '_meta'} | |
| """ | |
| meta_path = os.path.join("persistent", "faiss.index.meta.json") | |
| gloss_path = os.path.join("persistent", "glossary.json") | |
| docs = [] | |
| # Prefer FAISS meta (vector_sync output) | |
| if os.path.exists(meta_path): | |
| try: | |
| with open(meta_path, "r", encoding="utf-8") as f: | |
| meta = json.load(f) | |
| for m in meta: | |
| text = m.get("definition") or m.get("text") or m.get("chunk") or "" | |
| sources = m.get("sources") or m.get("source") or m.get("file") or [] | |
| if isinstance(sources, list) and sources: | |
| src = sources[0] | |
| elif isinstance(sources, str) and sources: | |
| src = sources | |
| else: | |
| src = m.get("file") or m.get("source") or "unknown" | |
| docs.append({ | |
| "text": text, | |
| "file": src, | |
| "source": src, | |
| "term": m.get("term") or m.get("normalized") or "", | |
| "_meta": m | |
| }) | |
| return docs | |
| except Exception as e: | |
| print(f"⚠️ [vector_store] Failed to read meta.json: {e}") | |
| # fallback: glossary.json | |
| if os.path.exists(gloss_path): | |
| try: | |
| with open(gloss_path, "r", encoding="utf-8") as f: | |
| gloss = json.load(f) | |
| for k, v in gloss.items(): | |
| term = v.get("term", k) | |
| definition = v.get("definition", "") | |
| srcs = v.get("sources", []) | |
| src = srcs[0] if isinstance(srcs, list) and srcs else (srcs if isinstance(srcs, str) else "glossary") | |
| docs.append({ | |
| "text": definition, | |
| "file": src, | |
| "source": src, | |
| "term": term, | |
| "_meta": {"glossary_key": k} | |
| }) | |
| return docs | |
| except Exception as e: | |
| print(f"⚠️ [vector_store] Failed to read glossary.json: {e}") | |
| return docs | |
| # -------------------------------------------------------------------- | |
| # 3️⃣ FAISS Search | |
| # -------------------------------------------------------------------- | |
| def search_index(query, top_k=10): | |
| """ | |
| Search FAISS and return a list of dict hits for hybrid_retriever. | |
| Each hit: {'text','file','source','term','_score','_meta'} | |
| """ | |
| global _index, _model, _meta | |
| if not _ensure_faiss_index(): | |
| return [] | |
| q_emb = _model.encode([query], convert_to_numpy=True).astype("float32") | |
| faiss.normalize_L2(q_emb) | |
| D, I = _index.search(q_emb, top_k) | |
| results = [] | |
| for score, idx in zip(D[0].tolist(), I[0].tolist()): | |
| if idx < 0 or idx >= len(_meta): | |
| continue | |
| m = _meta[idx] if isinstance(_meta[idx], dict) else {"raw": str(_meta[idx])} | |
| text = m.get("definition") or m.get("text") or m.get("chunk") or "" | |
| srcs = m.get("sources") or m.get("source") or m.get("file") or [] | |
| if isinstance(srcs, list) and srcs: | |
| src = srcs[0] | |
| elif isinstance(srcs, str) and srcs: | |
| src = srcs | |
| else: | |
| src = m.get("file") or m.get("source") or "unknown" | |
| results.append({ | |
| "text": text, | |
| "file": src, | |
| "source": src, | |
| "term": m.get("term") or m.get("normalized") or "", | |
| "_score": float(score), | |
| "_meta": m | |
| }) | |
| return results | |