import json
import os
from sentence_transformers import SentenceTransformer, util
import torch
FAQ_PATHS = ["data/faq_data.json", "data/clinical_faq.json"]
_FAQ_CACHE = None
_FAQ_EMBEDDINGS = None
_MODEL = None
def _get_model():
"""Load and cache the embedding model (shared with main app if possible)."""
global _MODEL
if _MODEL is None:
print("📦 [faq] Loading embedding model: all-MiniLM-L6-v2 ...")
_MODEL = SentenceTransformer("sentence-transformers/all-MiniLM-L6-v2")
return _MODEL
def load_faqs():
"""Load FAQ data from JSON files and cache them."""
global _FAQ_CACHE
if _FAQ_CACHE is not None:
return _FAQ_CACHE
all_faqs = []
for path in FAQ_PATHS:
if os.path.exists(path):
try:
with open(path, "r", encoding="utf-8") as f:
data = json.load(f)
if isinstance(data, list):
all_faqs.extend(data)
elif isinstance(data, dict):
for k, v in data.items():
all_faqs.append({"question": k, "answer": v})
except Exception as e:
print(f"⚠️ Failed to load FAQ file {path}: {e}")
_FAQ_CACHE = all_faqs
print(f"✅ [faq] Loaded {len(_FAQ_CACHE)} FAQ entries.")
return _FAQ_CACHE
def _build_embeddings():
"""Precompute embeddings for all FAQ questions."""
global _FAQ_EMBEDDINGS
faqs = load_faqs()
if not faqs:
_FAQ_EMBEDDINGS = torch.empty(0)
return _FAQ_EMBEDDINGS
model = _get_model()
questions = [f["question"] for f in faqs if f.get("question")]
_FAQ_EMBEDDINGS = model.encode(questions, convert_to_tensor=True, show_progress_bar=False)
print(f"✅ [faq] Encoded {len(_FAQ_EMBEDDINGS)} FAQ embeddings.")
return _FAQ_EMBEDDINGS
def get_faq_answer(query: str, top_k: int = 1) -> str:
"""
Return the most semantically similar FAQ answer to the query.
Uses MiniLM embeddings and cosine similarity.
"""
faqs = load_faqs()
if not faqs:
return ""
if _FAQ_EMBEDDINGS is None:
_build_embeddings()
model = _get_model()
query_emb = model.encode(query, convert_to_tensor=True)
sims = util.cos_sim(query_emb, _FAQ_EMBEDDINGS)[0]
top_idx = int(torch.argmax(sims))
best_score = float(sims[top_idx])
best_item = faqs[top_idx]
if best_score < 0.45: # threshold to avoid weak matches
return ""
answer = best_item.get("answer", "")
print(f"💡 [faq] Best match: \"{best_item.get('question')}\" (score={best_score:.2f})")
return answer
def lookup_faq(query: str, top_k: int = 3) -> str:
"""
Return HTML-formatted list of top-k semantically similar FAQ matches.
Useful for admin or verbose display.
"""
faqs = load_faqs()
if not faqs:
return "No FAQ data loaded."
if _FAQ_EMBEDDINGS is None:
_build_embeddings()
model = _get_model()
query_emb = model.encode(query, convert_to_tensor=True)
sims = util.cos_sim(query_emb, _FAQ_EMBEDDINGS)[0]
top_indices = torch.topk(sims, k=min(top_k, len(faqs))).indices.tolist()
html = []
for idx in top_indices:
score = float(sims[idx])
item = faqs[idx]
html.append(f"{item['question']}
{item['answer']}
(score={score:.2f})")
return "
".join(html)