ClinicalTrialBasics / core /vector_store.py
essprasad's picture
Upload 9 files
b816136 verified
raw
history blame
6.65 kB
"""
core/vector_store.py
------------------------------------------------------------
Unified FAISS + BM25 storage utility for Clinical-Trial Chatbot.
✅ Works with glossary.json or FAISS metadata
✅ Returns normalized dicts for hybrid_retriever
✅ Adds load_all_text_chunks() for BM25 fallback
✅ Safe against missing files
"""
import os
import re
import json
import faiss
from sentence_transformers import SentenceTransformer
# Globals used by retriever
_index = None
_model = None
_meta = None
# --------------------------------------------------------------------
# 1️⃣ Utility: load FAISS index + metadata (MVP version)
# --------------------------------------------------------------------
def _ensure_faiss_index():
"""Load FAISS index and metadata — prefer local persistent files, fallback to Hugging Face dataset."""
global _index, _model, _meta
if _index is not None and _meta is not None:
return True
import json
from huggingface_hub import hf_hub_download
local_dir = "/home/user/app/persistent"
local_index = os.path.join(local_dir, "faiss.index")
local_meta = os.path.join(local_dir, "faiss.index.meta.json")
# 1️⃣ Prefer local FAISS (rebuilt and includes URL + Excel)
if os.path.exists(local_index) and os.path.exists(local_meta):
print("📂 Using local FAISS index (includes Excel + Web sources).")
_index = faiss.read_index(local_index)
with open(local_meta, "r", encoding="utf-8") as f:
_meta = json.load(f)
_model = SentenceTransformer("all-MiniLM-L6-v2")
print(f"✅ [vector_store] Loaded local FAISS ({len(_meta)} vectors).")
return True
# 2️⃣ Fallback: remote dataset
print("☁️ Local FAISS missing — downloading from Hugging Face dataset...")
repo_id = "essprasad/CT-Chat-Index"
repo_type = "dataset"
runtime_dir = "/home/user/app/runtime_faiss"
os.makedirs(runtime_dir, exist_ok=True)
index_path = hf_hub_download(
repo_id=repo_id,
filename="persistent/faiss.index",
repo_type=repo_type,
local_dir=runtime_dir,
cache_dir=runtime_dir,
force_download=True,
)
meta_path = hf_hub_download(
repo_id=repo_id,
filename="persistent/faiss.index.meta.json",
repo_type=repo_type,
local_dir=runtime_dir,
cache_dir=runtime_dir,
force_download=True,
)
print(f"🧠 [vector_store] Loading FAISS index + metadata from {runtime_dir} ...")
_index = faiss.read_index(index_path)
with open(meta_path, "r", encoding="utf-8") as f:
_meta = json.load(f)
_model = SentenceTransformer("all-MiniLM-L6-v2")
print(f"✅ [vector_store] Loaded remote FAISS ({len(_meta)} vectors).")
return True
# --------------------------------------------------------------------
# 2️⃣ Helper: Load all text chunks (for BM25 fallback)
# --------------------------------------------------------------------
def load_all_text_chunks():
"""
Return list of dicts for BM25 fallback and inspection.
Each dict: {'text', 'file', 'source', 'term', '_meta'}
"""
meta_path = os.path.join("persistent", "faiss.index.meta.json")
gloss_path = os.path.join("persistent", "glossary.json")
docs = []
# Prefer FAISS meta (vector_sync output)
if os.path.exists(meta_path):
try:
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f)
for m in meta:
text = m.get("definition") or m.get("text") or m.get("chunk") or ""
sources = m.get("sources") or m.get("source") or m.get("file") or []
if isinstance(sources, list) and sources:
src = sources[0]
elif isinstance(sources, str) and sources:
src = sources
else:
src = m.get("file") or m.get("source") or "unknown"
docs.append({
"text": text,
"file": src,
"source": src,
"term": m.get("term") or m.get("normalized") or "",
"_meta": m
})
return docs
except Exception as e:
print(f"⚠️ [vector_store] Failed to read meta.json: {e}")
# fallback: glossary.json
if os.path.exists(gloss_path):
try:
with open(gloss_path, "r", encoding="utf-8") as f:
gloss = json.load(f)
for k, v in gloss.items():
term = v.get("term", k)
definition = v.get("definition", "")
srcs = v.get("sources", [])
src = srcs[0] if isinstance(srcs, list) and srcs else (srcs if isinstance(srcs, str) else "glossary")
docs.append({
"text": definition,
"file": src,
"source": src,
"term": term,
"_meta": {"glossary_key": k}
})
return docs
except Exception as e:
print(f"⚠️ [vector_store] Failed to read glossary.json: {e}")
return docs
# --------------------------------------------------------------------
# 3️⃣ FAISS Search
# --------------------------------------------------------------------
def search_index(query, top_k=10):
"""
Search FAISS and return a list of dict hits for hybrid_retriever.
Each hit: {'text','file','source','term','_score','_meta'}
"""
global _index, _model, _meta
if not _ensure_faiss_index():
return []
q_emb = _model.encode([query], convert_to_numpy=True).astype("float32")
faiss.normalize_L2(q_emb)
D, I = _index.search(q_emb, top_k)
results = []
for score, idx in zip(D[0].tolist(), I[0].tolist()):
if idx < 0 or idx >= len(_meta):
continue
m = _meta[idx] if isinstance(_meta[idx], dict) else {"raw": str(_meta[idx])}
text = m.get("definition") or m.get("text") or m.get("chunk") or ""
srcs = m.get("sources") or m.get("source") or m.get("file") or []
if isinstance(srcs, list) and srcs:
src = srcs[0]
elif isinstance(srcs, str) and srcs:
src = srcs
else:
src = m.get("file") or m.get("source") or "unknown"
results.append({
"text": text,
"file": src,
"source": src,
"term": m.get("term") or m.get("normalized") or "",
"_score": float(score),
"_meta": m
})
return results