essprasad's picture
Upload 9 files
b816136 verified
raw
history blame
1.96 kB
import os
import json
import re
import math
from collections import defaultdict, Counter
class BM25:
def __init__(self, corpus):
self.corpus = corpus
self.tokenized_corpus = [self._tokenize(doc['text']) for doc in corpus]
self.doc_lens = [len(doc) for doc in self.tokenized_corpus]
self.avgdl = sum(self.doc_lens) / len(self.doc_lens)
self.doc_freqs = self._calc_doc_freqs()
self.k1 = 1.5
self.b = 0.75
def _tokenize(self, text):
return re.findall(r"\w+", text.lower())
def _calc_doc_freqs(self):
freqs = defaultdict(int)
for doc in self.tokenized_corpus:
for word in set(doc):
freqs[word] += 1
return freqs
def _idf(self, term):
N = len(self.tokenized_corpus)
df = self.doc_freqs.get(term, 0)
return math.log(1 + (N - df + 0.5) / (df + 0.5))
def get_scores(self, query_tokens):
scores = [0.0] * len(self.tokenized_corpus)
for idx, doc in enumerate(self.tokenized_corpus):
freqs = Counter(doc)
dl = self.doc_lens[idx]
for term in query_tokens:
idf = self._idf(term)
tf = freqs[term]
denom = tf + self.k1 * (1 - self.b + self.b * dl / self.avgdl)
score = idf * ((tf * (self.k1 + 1)) / denom) if denom != 0 else 0
scores[idx] += score
return scores
def search_bm25(query, top_n=10):
from core.vector_store import load_all_text_chunks
if docs is None:
docs = load_all_text_chunks()
bm25 = BM25(docs)
query_tokens = re.findall(r"\w+", query.lower())
scores = bm25.get_scores(query_tokens)
top_indices = sorted(range(len(scores)), key=lambda i: scores[i], reverse=True)[:top_n]
results = []
for i in top_indices:
doc = docs[i].copy()
doc['score'] = scores[i]
results.append(doc)
return results