Spaces:
Build error
Build error
| from .base_query_based_model import QueryBasedSummModel | |
| from model.base_model import SummModel | |
| from model.single_doc import TextRankModel | |
| from typing import List | |
| from gensim.summarization.bm25 import BM25 | |
| from nltk import word_tokenize | |
| class BM25SummModel(QueryBasedSummModel): | |
| # static variables | |
| model_name = "BM25" | |
| is_extractive = True # only represents the retrieval part | |
| is_neural = False # only represents the retrieval part | |
| is_query_based = True | |
| def __init__( | |
| self, | |
| trained_domain: str = None, | |
| max_input_length: int = None, | |
| max_output_length: int = None, | |
| model_backend: SummModel = TextRankModel, | |
| retrieval_ratio: float = 0.5, | |
| preprocess: bool = True, | |
| **kwargs | |
| ): | |
| super(BM25SummModel, self).__init__( | |
| trained_domain=trained_domain, | |
| max_input_length=max_input_length, | |
| max_output_length=max_output_length, | |
| model_backend=model_backend, | |
| retrieval_ratio=retrieval_ratio, | |
| preprocess=preprocess, | |
| **kwargs | |
| ) | |
| def _retrieve(self, instance: List[str], query: List[str], n_best): | |
| bm25 = BM25(word_tokenize(s) for s in instance) | |
| scores = bm25.get_scores(query) | |
| best_sent_ind = sorted( | |
| range(len(scores)), key=lambda i: scores[i], reverse=True | |
| )[:n_best] | |
| top_n_sent = [instance[ind] for ind in sorted(best_sent_ind)] | |
| return top_n_sent | |