ClinicalTrialBasics / core /hybrid_retriever.py
essprasad's picture
Upload 10 files
e61e934 verified
"""
Final Hybrid Retriever for Clinical Research Chatbot
---------------------------------------------------
Updated: Prioritize sources: GCDMP_Glossary.pdf > MRCT excel > ICH docs > other PDFs > web
Includes: GCDMP glossary-style extraction, acronym handling, MRCT extra field, glossary.json fallback
"""
import os
import re
import time
from urllib.parse import urlparse
from difflib import SequenceMatcher
from core.glossary import _normalize_term
from core.vector_store import _ensure_faiss_index, search_index, load_all_text_chunks
from core.bm25 import search_bm25
from utils.nlp_helpers import extract_van_tokens, normalize_query_text
# ----------------------------
# CONFIG
# ----------------------------
DENSE_TOP_K = 10
FUZZY_THRESHOLD = 0.15
TOP_RESULTS_LIMIT = 5
GCDMP_FILENAME = "GCDMP_Glossary.pdf" # exact filename in your HF space/persistent store
# ----------------------------
# UTILITIES
# ----------------------------
def fuzzy_ratio(a: str, b: str) -> float:
return SequenceMatcher(None, a or "", b or "").ratio()
def strip_question_phrases(text: str) -> str:
text = (text or "").lower().strip()
prefixes = [
"what", "how", "when", "why", "define", "definition", "meaning", "explain",
"describe", "expand", "abbreviate", "compare", "identify", "classify",
"determine", "differentiate", "do", "did", "does", "done", "can", "shall",
"will", "where", "which", "who", "whose", "have", "might", "could", "would",
"kindly", "please", "may", "you", "i", "we", "they", "there", "here",
"what's", "i'll", "where's", "how's", "there's", "who's", "didn't", "doesn't",
"give", "provide", "mention", "state", "arrange", "asking", "tell", "explain me",
"can you", "could you", "would you", "please explain", "let me know",
"say something about", "give details of", "show me", "find", "list", "expand on"
]
prefix_pattern = r"^(" + "|".join(re.escape(p) for p in prefixes) + r")(\s+|['’])"
while re.match(prefix_pattern, text):
text = re.sub(prefix_pattern, "", text).strip()
text = re.sub(r"[?.!]+$", "", text)
text = re.sub(r"\s{2,}", " ", text)
return text.strip()
def add_links_to_text(text: str) -> str:
return re.sub(r"(https?://[^\s<]+)", r'<a href="\1" target="_blank" rel="noopener noreferrer">\1</a>', text)
def get_source_rank(src: str, src_type: str) -> int:
s = (src or "").lower()
# 1. GCDMP glossary PDF β†’ highest priority
if GCDMP_FILENAME.lower() in s:
return 1
# 2. MRCT Excel or MRCT filename
if src_type == "excel" or "mrct" in s:
return 2
# 3. ICH documents (E6, E3, E2A, E9, E1) - try a few patterns
if any(x in s for x in ["ich_e6", "ich-e6", "ich e6", "ich_e3", "ich-e3", "ich e3", "ich_e2", "ich-e2", "ich e2", "ich_e9", "ich-e9", "ich e9", "ich_e1", "ich-e1", "ich e1"]):
return 3
# 4. Other PDFs
if src_type == "pdf":
return 4
# 5. Web sources
if src_type == "web":
return 5
return 6
# ----------------------------
# MAIN RETRIEVER
# ----------------------------
def summarize_combined(query: str, mode: str = "short") -> str:
start = time.time()
if not query or not query.strip():
return "<i>No query provided.</i>"
# Normalize user query
cleaned = strip_question_phrases(query)
expanded = normalize_query_text(cleaned)
van_tokens = extract_van_tokens(expanded)
normalized = " ".join(van_tokens).strip() or cleaned
nq = normalized.lower().strip()
print(f"πŸ” summarize_combined() | cleaned='{cleaned}' normalized='{nq}'")
# Acronym expansion map (preserve/extend)
acronym_map = {
"ae": "adverse event", "adr": "adverse drug reaction",
"crf": "case report form", "ecrf": "electronic case report form",
"cro": "contract research organization", "csr": "clinical study report",
"ctms": "clinical trial management system", "edc": "electronic data capture",
"ehr": "electronic health record", "emr": "electronic medical record",
"gcp": "good clinical practice", "irb": "institutional review board",
"iec": "independent ethics committee", "ind": "investigational new drug application",
"mrct": "multi-regional clinical trials", "qa": "quality assurance",
"qc": "quality control", "sae": "serious adverse event", "sap": "statistical analysis plan",
"siv": "site initiation visit", "sop": "standard operating procedure",
"ssu": "study start-up", "uat": "user acceptance testing",
"whodrug": "world health organization drug dictionary",
}
glossary_key = _normalize_term(nq)
if glossary_key in acronym_map:
expanded_term = acronym_map[glossary_key]
nq = _normalize_term(expanded_term)
print(f"πŸ” Acronym expanded β†’ {expanded_term}")
# ----------------------------
# FAISS + BM25 retrieval
# ----------------------------
dense_hits, bm25_hits = [], []
try:
if _ensure_faiss_index():
dense_hits = search_index(normalized, top_k=DENSE_TOP_K) or []
print(f"βœ… FAISS hits: {len(dense_hits)}")
except Exception as e:
print(f"⚠️ FAISS search failed: {e}")
try:
docs = load_all_text_chunks()
if docs:
bm25_hits = search_bm25(normalized, docs, top_n=8) or []
print(f"βœ… BM25 hits: {len(bm25_hits)}")
except Exception as e:
print(f"⚠️ BM25 fallback failed: {e}")
hits = (dense_hits or []) + (bm25_hits or [])
if not hits:
return "<i>No relevant information found.</i>"
# ----------------------------
# Group by original resolved source (prefer real source over glossary.json)
# ----------------------------
grouped = {}
glossary_fallbacks = []
for h in hits:
raw_src = h.get("file") or h.get("source") or h.get("source_file") or "unknown"
meta_sources = h.get("sources") or h.get("source_list") or []
# prefer a non-glossary meta source if available
src = raw_src
if isinstance(meta_sources, (list, tuple)) and meta_sources:
chosen = None
for s in meta_sources:
if isinstance(s, str) and not s.lower().endswith("glossary.json"):
chosen = s
break
if chosen:
src = chosen
else:
src = meta_sources[0]
src_type = (h.get("type") or "").lower()
term = (h.get("term") or "").strip()
term_lower = term.lower()
txt = (h.get("definition") or h.get("text") or h.get("content") or h.get("full_text") or "").strip()
if not txt:
continue
# If original stored file was glossary.json, keep as fallback only
if str(raw_src).lower().endswith("glossary.json"):
glossary_fallbacks.append({"hit": h, "text": txt, "src": src})
# Save resolved sources for provenance
h["_resolved_sources"] = meta_sources if meta_sources else [raw_src]
# Group key based on resolved original source + type + term
key = f"{os.path.basename(src).lower()}__{src_type}__{term_lower[:200]}"
# Prefer glossary PDF entries (GCDMP/ 'glossary' in filename) when colliding with long chunks
prefer_glossary = (GCDMP_FILENAME.lower() in str(src).lower()) or ("glossary" in str(src).lower())
if key not in grouped:
grouped[key] = {"hit": h, "text": txt, "src": src, "src_type": src_type, "term": term}
else:
existing_src = grouped[key]["src"]
existing_is_glossary = (GCDMP_FILENAME.lower() in str(existing_src).lower()) or ("glossary" in str(existing_src).lower())
if prefer_glossary and not existing_is_glossary:
grouped[key] = {"hit": h, "text": txt, "src": src, "src_type": src_type, "term": term}
else:
# otherwise prefer longer chunk unless this new is a glossary and existing is not
if not prefer_glossary and len(txt) > len(grouped[key]["text"]):
grouped[key] = {"hit": h, "text": txt, "src": src, "src_type": src_type, "term": term}
# ----------------------------
# Format answers: one per original source
# ----------------------------
answers = []
src_counts = {"excel": 0, "pdf": 0, "web": 0, "other": 0}
for entry in grouped.values():
h = entry["hit"]
txt = entry["text"]
src = entry["src"]
src_type = entry.get("src_type") or (h.get("type") or "").lower()
term = entry.get("term") or (h.get("term") or "").strip()
term_lower = (term or "").lower()
# Skip entries resolved to glossary.json here (we'll use them only as fallback)
if str(src).lower().endswith("glossary.json"):
continue
# Skip noisy PDF sections unless they look like short glossary entries
txt_lower = txt.lower()
if src_type == "pdf" and any(k in txt_lower[:300] for k in ["table of contents", "appendix", "index", "section"]):
if not (len(txt.split()) < 80 and term_lower and term_lower in txt_lower[:120]):
# treat as noise
continue
# Determine icon and counts
if src_type == "excel":
icon, cat = "πŸ“˜", "excel"
elif src_type == "pdf":
icon, cat = "πŸ“„", "pdf"
elif src_type == "web":
icon, cat = "🌐", "web"
else:
icon, cat = "πŸ“", "other"
src_counts[cat] += 1
# SAFE acronym handling:
# If user query is a short single-token (<=4 chars) treat as acronym query and accept matches.
is_acronym_query = (len(nq) > 0 and " " not in nq and len(nq) <= 4)
# Soft subset/superset filter β€” allow acronyms and glossary terms
if term_lower and term_lower != nq and not is_acronym_query:
if (term_lower in nq or nq in term_lower) and fuzzy_ratio(term_lower, nq) < 0.5:
# reject only if long and very dissimilar
continue
# Extract excerpt (PDF / web special handling for glossary-style)
excerpt = ""
if src_type in ("pdf", "web"):
paragraphs = re.split(r"\n\s*\n", txt)
paragraphs = [p.strip() for p in paragraphs if p.strip()]
# 1) If acronym query and first paragraph equals acronym -> next paragraph is definition
if paragraphs and is_acronym_query:
heading = paragraphs[0].strip().lower()
if heading == nq:
excerpt = paragraphs[1].strip() if len(paragraphs) > 1 else paragraphs[0].strip()
# 2) If full term matches heading (e.g., "electronic health record")
if not excerpt and paragraphs and term_lower:
heading = paragraphs[0].strip().lower()
if heading == term_lower or (term_lower in heading):
excerpt = paragraphs[1].strip() if len(paragraphs) > 1 else paragraphs[0].strip()
# 3) If not yet found, try exact normalized query inside paragraphs
if not excerpt:
found = None
for p in paragraphs:
if nq and nq in p.lower():
found = p.strip()
break
# 4) Fuzzy match with paragraph starts
if not found and term_lower:
for p in paragraphs:
if fuzzy_ratio(term_lower, p.lower()[:100]) > 0.75:
found = p.strip()
break
# 5) Paragraph following a heading that contains the term
if not found and term_lower:
for i, p in enumerate(paragraphs[:-1]):
if term_lower in p.lower():
found = paragraphs[i + 1].strip()
break
excerpt = (found or (paragraphs[0] if paragraphs else txt)).strip()
excerpt = excerpt[:2000] + ("..." if len(excerpt) > 2000 else "")
excerpt = add_links_to_text(excerpt)
elif src_type == "excel":
# Capture MRCT Excel fields including the "Other Info..." column
fields = {
"Glossary Definition": re.search(r"Glossary Definition:\s*(.+?)(?=\n[A-Z]|$)", txt, re.S),
"Use in Context": re.search(r"Use in Context:\s*(.+?)(?=\n[A-Z]|$)", txt, re.S),
"More Info": re.search(r"More Info:\s*(.+?)(?=\n[A-Z]|$)", txt, re.S),
"Other Info to Think About When Joining a Study": re.search(
r"Other Info to Think About When Joining a Study:\s*(.+?)(?=\n[A-Z]|$)",
txt, re.S
),
"Related Terms": re.search(r"Related Terms:\s*(.+?)(?=\n[A-Z]|$)", txt, re.S),
"Term URL": re.search(r"Term URL:\s*(https?://[^\s]+)", txt),
}
lines = []
for label, match in fields.items():
if match:
val = match.group(1).strip()
if "http" in val:
val = f'<a href="{val}" target="_blank">{val}</a>'
lines.append(f"<b>{label}:</b> {val}")
excerpt = "<br>".join(lines) or txt
else:
excerpt = txt
# Prepare heading and display sources (exclude internal glossary.json from display)
heading_term = term.strip() or os.path.splitext(os.path.basename(src))[0]
heading_html = f"<h4>{icon} {heading_term}</h4>"
resolved_sources = h.get("_resolved_sources") or []
display_sources = [os.path.basename(s) for s in resolved_sources if isinstance(s, str) and not s.lower().endswith("glossary.json")]
if not display_sources:
display_sources = [os.path.basename(src)]
sources_line = f"<p>πŸ”— <i>Source:</i> " + " Β· ".join(dict.fromkeys(display_sources)) + "</p>"
answers.append({
"rank": get_source_rank(src, src_type),
"type": cat,
"term": term,
"html": f"{heading_html}{sources_line}<blockquote>{excerpt}</blockquote>"
})
# ----------------------------
# Fallback: only use glossary.json definitions if no other original sources matched
# ----------------------------
if not answers and glossary_fallbacks:
print("βš™οΈ Using glossary.json fallback definitions (no original sources found)")
for item in glossary_fallbacks:
h = item["hit"]
txt = item["text"]
src = item.get("src") or (h.get("file") or h.get("source") or "glossary.json")
term = (h.get("term") or "").strip() or "Definition"
heading_html = f"<h4>πŸ“„ {term}</h4>"
excerpt = txt.strip()
answers.append({
"rank": 10,
"type": "pdf",
"term": term,
"html": f"{heading_html}<p>πŸ”— <i>Source:</i> {os.path.basename(src)}</p><blockquote>{excerpt}</blockquote>"
})
# ----------------------------
# Final sort & output
# ----------------------------
if not answers:
return "<i>No relevant results found.</i>"
answers = sorted(answers, key=lambda a: a["rank"])
final_html_parts = [a["html"] for a in answers[:TOP_RESULTS_LIMIT]]
summary_counts = " | ".join(f"{k.capitalize()}: {v}" for k, v in src_counts.items() if v > 0)
elapsed = time.time() - start
print(f"βœ… Answers from {len(answers)} sources in {elapsed:.2f}s")
return (
f"<h3>🧠 Answers (one per source):</h3>"
f"<p><i>Sources β†’ {summary_counts}</i></p>"
+ "<br>".join(final_html_parts)
)