ClinicalTrialBasics / utils /nlp_helpers.py
essprasad's picture
Upload 4 files
9788b7f verified
"""
utils/nlp_helpers.py — Enhanced NLP Utilities for Clinical Research Chatbot
----------------------------------------------------------------------------
✅ Domain-aware abbreviation normalization (ICH-GCP, CDISC, FDA)
✅ Glossary-synonym expansion with prioritization
✅ Improved VAN (Verb–Adjective–Noun) normalization
✅ Compatible with Hugging Face Spaces (persistent NLTK path)
"""
import os
import re
import json
import nltk
from nltk.corpus import stopwords
from nltk.stem import WordNetLemmatizer
# --------------------------------------------------------------------
# 🧠 NLTK Setup (force consistent path for Hugging Face Spaces)
# --------------------------------------------------------------------
NLTK_PATH = "/usr/local/share/nltk_data"
os.environ["NLTK_DATA"] = NLTK_PATH
nltk.data.path.clear()
nltk.data.path.append(NLTK_PATH)
required_pkgs = [
"punkt",
"punkt_tab",
"averaged_perceptron_tagger",
"averaged_perceptron_tagger_eng",
"stopwords",
"wordnet",
]
for pkg in required_pkgs:
try:
nltk.data.find(pkg)
except LookupError:
nltk.download(pkg, download_dir=NLTK_PATH, quiet=True)
STOPWORDS = set(stopwords.words("english"))
lemmatizer = WordNetLemmatizer()
# --------------------------------------------------------------------
# ⚕️ Clinical Abbreviation & Synonym Normalization
# --------------------------------------------------------------------
NORMALIZATION_MAP = {
# Core trial terms
r"\be[-_ ]?crf(s)?\b": "electronic case report form",
r"\bedc(s)?\b": "electronic data capture",
r"\bctms\b": "clinical trial management system",
r"\bcsr(s)?\b": "clinical study report",
r"\bcrf\b": "case report form",
# Data standards
r"\bsdtm(s)?\b": "study data tabulation model",
r"\badam(s)?\b": "analysis data model",
r"\bdefine[-_ ]?xml\b": "define xml metadata",
# Compliance / Ethics
r"\bgcp\b": "good clinical practice",
r"\biec\b": "independent ethics committee",
r"\birb\b": "institutional review board",
r"\bpi\b": "principal investigator",
r"\bsub[-_ ]?inv(es)?tigators?\b": "sub investigator",
r"\bsae(s)?\b": "serious adverse event",
r"\bae(s)?\b": "adverse event",
r"\bsusar(s)?\b": "suspected unexpected serious adverse reaction",
# Misc
r"\bsdv\b": "source data verification",
r"\bsop(s)?\b": "standard operating procedure",
r"\bqms\b": "quality management system",
r"\bicf\b": "informed consent form",
r"\bregulatory\b": "regulatory compliance",
}
DOMAIN_SYNONYMS = {
"edc": ["data entry system", "data management platform"],
"ecrf": ["electronic data entry form", "study data form"],
"gcp": ["good clinical practice", "ich e6", "regulatory compliance"],
"sdtm": ["data tabulation model", "cdisc standard"],
"adam": ["analysis dataset model", "statistical dataset"],
"ae": ["adverse event", "side effect"],
"sae": ["serious adverse event", "life threatening event"],
"susar": ["unexpected serious adverse reaction", "drug safety event"],
"ctms": ["trial management tool", "site tracking system"],
"pi": ["principal investigator", "study doctor"],
"csr": ["clinical study report", "final study document"],
"qms": ["quality management framework", "audit system"],
"sop": ["standard operating procedure", "company process document"],
}
GLOSSARY_PATH = "data/glossary.json"
# --------------------------------------------------------------------
# 🧹 Text Normalization
# --------------------------------------------------------------------
def normalize_query_text(text: str) -> str:
"""Lowercase, remove punctuation, and expand known abbreviations."""
text = text.strip().lower()
text = re.sub(r"[^\w\s\-]", " ", text)
text = re.sub(r"\s+", " ", text)
for pattern, repl in NORMALIZATION_MAP.items():
text = re.sub(pattern, repl, text)
return text.strip()
# --------------------------------------------------------------------
# ⚙️ VAN (Verb–Adjective–Noun) Extraction — IMPROVED
# --------------------------------------------------------------------
def extract_van_tokens(text: str):
"""
Extract and normalize core content words using VAN logic.
- Lowercases and expands abbreviations
- Removes stopwords and determiners ('a', 'an', 'the')
- Keeps only Verbs, Adjectives, and Nouns
- Lemmatizes words to singular or base form
- Deduplicates tokens
"""
text = normalize_query_text(text)
if not text:
return []
try:
tokens = nltk.word_tokenize(text)
pos_tags = nltk.pos_tag(tokens)
except LookupError:
for pkg in ["punkt", "punkt_tab", "averaged_perceptron_tagger"]:
nltk.download(pkg, download_dir=NLTK_PATH, quiet=True)
pos_tags = nltk.pos_tag(nltk.word_tokenize(text))
filtered = []
for w, t in pos_tags:
if not w.isalpha():
continue
# Remove determiners and common auxiliaries
if w in {"a", "an", "the", "is", "are", "was", "were", "be", "been", "being"}:
continue
if w in STOPWORDS:
continue
if len(w) <= 2:
continue
# Keep only N, V, J
if t.startswith(("N", "V", "J")):
pos = (
"v" if t.startswith("V")
else "a" if t.startswith("J")
else "n"
)
lemma = lemmatizer.lemmatize(w, pos)
filtered.append(lemma)
# Deduplicate while preserving order
seen, unique = set(), []
for w in filtered:
if w not in seen:
seen.add(w)
unique.append(w)
return unique
# --------------------------------------------------------------------
# 📘 Glossary-based Synonym Expansion
# --------------------------------------------------------------------
def expand_with_glossary(tokens: list):
"""Expand tokens using both internal DOMAIN_SYNONYMS and glossary.json."""
expanded = list(tokens)
# Add domain synonym expansion
for token in tokens:
key = token.lower()
if key in DOMAIN_SYNONYMS:
expanded.extend(DOMAIN_SYNONYMS[key])
# Glossary-driven enrichment
if os.path.exists(GLOSSARY_PATH):
try:
with open(GLOSSARY_PATH, "r", encoding="utf-8") as f:
glossary = json.load(f)
except Exception:
glossary = {}
for token in tokens:
t_norm = re.sub(r"[^a-z0-9]", "", token.lower())
for term, definition in glossary.items():
term_norm = re.sub(r"[^a-z0-9]", "", term.lower())
if t_norm in term_norm or term_norm in t_norm:
defs = [
w for w in re.findall(r"[a-z]+", str(definition).lower())
if w not in STOPWORDS and len(w) > 3
]
expanded.extend(defs[:3])
# Deduplicate
seen, out = set(), []
for w in expanded:
if w not in seen:
seen.add(w)
out.append(w)
return out
# --------------------------------------------------------------------
# 🔍 Unified Token Extraction
# --------------------------------------------------------------------
def extract_content_words(query: str):
"""Normalize, extract VAN tokens, and expand using domain synonyms & glossary."""
print(f"🔎 [NLP] Extracting VANs from query: {query}")
tokens = extract_van_tokens(query)
expanded = expand_with_glossary(tokens)
print(f"🔎 [NLP] VAN tokens → {expanded}")
return expanded
# --------------------------------------------------------------------
# 🧪 Self-test
# --------------------------------------------------------------------
if __name__ == "__main__":
sample = "Explain how EDC and eCRF relate to GCP compliance in a clinical trial?"
print(extract_content_words(sample))