essprasad's picture
Upload 10 files
e61e934 verified
raw
history blame
4.83 kB
import os
import json
import re
import math
from collections import defaultdict, Counter
# --- πŸ”§ NEW: Lightweight stemming and lemmatization helpers
try:
import nltk
from nltk.stem import WordNetLemmatizer, PorterStemmer
from nltk.corpus import wordnet
nltk.download("wordnet", quiet=True)
nltk.download("omw-1.4", quiet=True)
except Exception:
WordNetLemmatizer = PorterStemmer = None
# Initialize stemmer and lemmatizer
_lemmatizer = WordNetLemmatizer() if WordNetLemmatizer else None
_stemmer = PorterStemmer() if PorterStemmer else None
def _normalize_token(token: str) -> str:
"""Normalize a token by lowercasing, lemmatizing, and stemming."""
token = token.lower().strip()
if _lemmatizer:
try:
token = _lemmatizer.lemmatize(token)
except Exception:
pass
if _stemmer:
try:
token = _stemmer.stem(token)
except Exception:
pass
return token
class BM25:
def __init__(self, corpus):
# corpus = list of dicts each with 'text'
# πŸ”§ FIX: support for 'definition' or 'content' fallback if 'text' missing
self.corpus = corpus
self.tokenized_corpus = [self._tokenize(self._get_text(doc)) for doc in corpus]
self.doc_lens = [len(tokens) for tokens in self.tokenized_corpus]
self.avgdl = sum(self.doc_lens) / len(self.doc_lens) if self.doc_lens else 0.0
self.doc_freqs = self._calc_doc_freqs()
self.k1 = 1.5
self.b = 0.75
def _get_text(self, doc):
"""Safely extract text from multiple possible keys ('text', 'definition', 'content')."""
if not isinstance(doc, dict):
return ""
if "text" in doc and isinstance(doc["text"], str):
return doc["text"]
elif "definition" in doc and isinstance(doc["definition"], str):
return doc["definition"]
elif "content" in doc and isinstance(doc["content"], str):
return doc["content"]
return ""
def _tokenize(self, text):
"""Tokenize and normalize each word with stemming and lemmatization."""
raw_tokens = re.findall(r"\w+", (text or "").lower())
return [_normalize_token(t) for t in raw_tokens if t]
def _calc_doc_freqs(self):
freqs = defaultdict(int)
for doc in self.tokenized_corpus:
for word in set(doc):
freqs[word] += 1
return freqs
def _idf(self, term):
N = len(self.tokenized_corpus)
df = self.doc_freqs.get(term, 0)
# smoothed idf to avoid division issues
return math.log(1 + (N - df + 0.5) / (df + 0.5)) if N > 0 else 0.0
def get_scores(self, query_tokens):
scores = [0.0] * len(self.tokenized_corpus)
for idx, doc_tokens in enumerate(self.tokenized_corpus):
freqs = Counter(doc_tokens)
dl = self.doc_lens[idx]
for term in query_tokens:
idf = self._idf(term)
tf = freqs.get(term, 0)
denom = tf + self.k1 * (1 - self.b + self.b * dl / (self.avgdl or 1.0))
score = 0.0
if denom != 0:
score = idf * ((tf * (self.k1 + 1)) / denom)
scores[idx] += score
return scores
def search_bm25(query, docs=None, top_n=10):
"""
BM25 search helper.
- query: string
- docs: optional list of dicts (each may have 'text'/'definition'/'content');
if None, will load from vector_store.load_all_text_chunks()
- top_n: int
Returns list of doc dicts with added 'score' field.
"""
from core.vector_store import load_all_text_chunks
if docs is None:
docs = load_all_text_chunks() or []
if not docs:
return []
bm25 = BM25(docs)
# πŸ”§ Normalize query tokens with same stem/lemma logic
query_tokens = [_normalize_token(t) for t in re.findall(r"\w+", (query or "").lower()) if t]
if not query_tokens:
return []
scores = bm25.get_scores(query_tokens)
# --- 🎯 NEW: Boost Excel glossary sources (MRCT, xlsx/xls) by +15%
for i, doc in enumerate(docs):
src = (doc.get("file") or doc.get("source") or "").lower()
if any(x in src for x in [".xlsx", ".xls", "mrct", "clinical-research-glossary"]):
scores[i] *= 1.15 # Excel source boost
# --- Rank and return top_n docs
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
results = []
for i in top_indices:
doc = dict(docs[i]) # shallow copy
# πŸ”§ Ensure 'text' key exists so retriever can render it
if "text" not in doc:
doc["text"] = bm25._get_text(doc)
doc["score"] = float(scores[i])
results.append(doc)
return results