ClinicalTrialBasics / core /vector_search.py
essprasad's picture
Upload 9 files
b816136 verified
raw
history blame
3.13 kB
"""
core/vector_search.py
-----------------------------------------------------
Performs FAISS semantic search for hybrid retrieval.
Includes:
- SentenceTransformer embedding for query
- FAISS similarity search
- Metadata + citation extraction
- Robust fallback if index missing
"""
import os
import json
import numpy as np
import faiss
from sentence_transformers import SentenceTransformer
# Paths (shared with vector_store/vector_sync)
FAISS_INDEX = "persistent/faiss.index"
FAISS_META = "persistent/faiss.index.meta.json"
_model = None
_index = None
_meta = []
# ----------------------------
# πŸ”Ή Loaders
# ----------------------------
def _load_model():
"""Lazy-load embedding model."""
global _model
if _model is None:
print("πŸ“₯ Loading embedding model for retrieval...")
_model = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
print("βœ… Model loaded.")
return _model
def _load_faiss():
"""Load FAISS index + metadata, prefer local persistent copy."""
global _index, _meta
if _index is not None:
return _index, _meta
local_index = "/home/user/app/persistent/faiss.index"
local_meta = "/home/user/app/persistent/faiss.index.meta.json"
if os.path.exists(local_index) and os.path.exists(local_meta):
print("πŸ“‚ [vector_search] Using local FAISS index.")
_index = faiss.read_index(local_index)
with open(local_meta, "r", encoding="utf-8") as f:
_meta = json.load(f)
print(f"βœ… Loaded local FAISS index ({len(_meta)} entries).")
return _index, _meta
print("☁️ [vector_search] Local FAISS missing, using fallback remote index.")
return _index, _meta
# ----------------------------
# πŸ”Ή Core Query Function
# ----------------------------
def query_faiss(query: str, top_k: int = 5):
"""
Perform FAISS semantic similarity search.
Returns:
results: list of matched text chunks
meta: list of metadata dicts (with citations)
"""
index, meta = _load_faiss()
if index is None or len(meta) == 0:
return [], []
model = _load_model()
q_emb = np.array(model.encode([query]), dtype=np.float32)
D, I = index.search(q_emb, top_k)
results, citations = [], []
for idx in I[0]:
if 0 <= idx < len(meta):
doc = meta[idx]
text = clean_text(doc.get("text", ""))
src = doc.get("source", "Unknown Source")
citation = f"πŸ“„ <b>Source:</b> {os.path.basename(src)}"
results.append(text)
citations.append(citation)
return results, citations
# ----------------------------
# πŸ”Ή Utilities
# ----------------------------
def clean_text(text: str, max_len: int = 800):
"""
Truncate and clean text for readability.
"""
text = text.replace("\n", " ").replace(" ", " ").strip()
if len(text) > max_len:
text = text[:max_len].rsplit(" ", 1)[0] + "..."
return text
def has_index():
"""Check if FAISS index is available."""
return os.path.exists(FAISS_INDEX) and os.path.exists(FAISS_META)