File size: 4,834 Bytes
e61e934
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os
import json
import re
import math
from collections import defaultdict, Counter

# --- 🔧 NEW: Lightweight stemming and lemmatization helpers
try:
    import nltk
    from nltk.stem import WordNetLemmatizer, PorterStemmer
    from nltk.corpus import wordnet
    nltk.download("wordnet", quiet=True)
    nltk.download("omw-1.4", quiet=True)
except Exception:
    WordNetLemmatizer = PorterStemmer = None

# Initialize stemmer and lemmatizer
_lemmatizer = WordNetLemmatizer() if WordNetLemmatizer else None
_stemmer = PorterStemmer() if PorterStemmer else None


def _normalize_token(token: str) -> str:
    """Normalize a token by lowercasing, lemmatizing, and stemming."""
    token = token.lower().strip()
    if _lemmatizer:
        try:
            token = _lemmatizer.lemmatize(token)
        except Exception:
            pass
    if _stemmer:
        try:
            token = _stemmer.stem(token)
        except Exception:
            pass
    return token


class BM25:
    def __init__(self, corpus):
        # corpus = list of dicts each with 'text'
        # 🔧 FIX: support for 'definition' or 'content' fallback if 'text' missing
        self.corpus = corpus
        self.tokenized_corpus = [self._tokenize(self._get_text(doc)) for doc in corpus]
        self.doc_lens = [len(tokens) for tokens in self.tokenized_corpus]
        self.avgdl = sum(self.doc_lens) / len(self.doc_lens) if self.doc_lens else 0.0
        self.doc_freqs = self._calc_doc_freqs()
        self.k1 = 1.5
        self.b = 0.75

    def _get_text(self, doc):
        """Safely extract text from multiple possible keys ('text', 'definition', 'content')."""
        if not isinstance(doc, dict):
            return ""
        if "text" in doc and isinstance(doc["text"], str):
            return doc["text"]
        elif "definition" in doc and isinstance(doc["definition"], str):
            return doc["definition"]
        elif "content" in doc and isinstance(doc["content"], str):
            return doc["content"]
        return ""

    def _tokenize(self, text):
        """Tokenize and normalize each word with stemming and lemmatization."""
        raw_tokens = re.findall(r"\w+", (text or "").lower())
        return [_normalize_token(t) for t in raw_tokens if t]

    def _calc_doc_freqs(self):
        freqs = defaultdict(int)
        for doc in self.tokenized_corpus:
            for word in set(doc):
                freqs[word] += 1
        return freqs

    def _idf(self, term):
        N = len(self.tokenized_corpus)
        df = self.doc_freqs.get(term, 0)
        # smoothed idf to avoid division issues
        return math.log(1 + (N - df + 0.5) / (df + 0.5)) if N > 0 else 0.0

    def get_scores(self, query_tokens):
        scores = [0.0] * len(self.tokenized_corpus)
        for idx, doc_tokens in enumerate(self.tokenized_corpus):
            freqs = Counter(doc_tokens)
            dl = self.doc_lens[idx]
            for term in query_tokens:
                idf = self._idf(term)
                tf = freqs.get(term, 0)
                denom = tf + self.k1 * (1 - self.b + self.b * dl / (self.avgdl or 1.0))
                score = 0.0
                if denom != 0:
                    score = idf * ((tf * (self.k1 + 1)) / denom)
                scores[idx] += score
        return scores


def search_bm25(query, docs=None, top_n=10):
    """
    BM25 search helper.
    - query: string
    - docs: optional list of dicts (each may have 'text'/'definition'/'content');
      if None, will load from vector_store.load_all_text_chunks()
    - top_n: int
    Returns list of doc dicts with added 'score' field.
    """
    from core.vector_store import load_all_text_chunks

    if docs is None:
        docs = load_all_text_chunks() or []
    if not docs:
        return []

    bm25 = BM25(docs)

    # 🔧 Normalize query tokens with same stem/lemma logic
    query_tokens = [_normalize_token(t) for t in re.findall(r"\w+", (query or "").lower()) if t]
    if not query_tokens:
        return []

    scores = bm25.get_scores(query_tokens)

    # --- 🎯 NEW: Boost Excel glossary sources (MRCT, xlsx/xls) by +15%
    for i, doc in enumerate(docs):
        src = (doc.get("file") or doc.get("source") or "").lower()
        if any(x in src for x in [".xlsx", ".xls", "mrct", "clinical-research-glossary"]):
            scores[i] *= 1.15  # Excel source boost

    # --- Rank and return top_n docs
    top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
    results = []
    for i in top_indices:
        doc = dict(docs[i])  # shallow copy
        # 🔧 Ensure 'text' key exists so retriever can render it
        if "text" not in doc:
            doc["text"] = bm25._get_text(doc)
        doc["score"] = float(scores[i])
        results.append(doc)
    return results