Spaces:
Build error
Build error
| from .base_query_based_model import QueryBasedSummModel | |
| from model.base_model import SummModel | |
| from model.single_doc import TextRankModel | |
| from typing import List | |
| from sklearn.feature_extraction.text import TfidfVectorizer | |
| from sklearn.metrics.pairwise import cosine_similarity | |
| class TFIDFSummModel(QueryBasedSummModel): | |
| # static variables | |
| model_name = "TF-IDF" | |
| is_extractive = True | |
| is_neural = False | |
| is_query_based = True | |
| def __init__( | |
| self, | |
| trained_domain: str = None, | |
| max_input_length: int = None, | |
| max_output_length: int = None, | |
| model_backend: SummModel = TextRankModel, | |
| retrieval_ratio: float = 0.5, | |
| preprocess: bool = True, | |
| **kwargs | |
| ): | |
| super(TFIDFSummModel, self).__init__( | |
| trained_domain=trained_domain, | |
| max_input_length=max_input_length, | |
| max_output_length=max_output_length, | |
| model_backend=model_backend, | |
| retrieval_ratio=retrieval_ratio, | |
| preprocess=preprocess, | |
| **kwargs | |
| ) | |
| self.vectorizer = TfidfVectorizer() | |
| def _retrieve(self, instance: List[str], query: List[str], n_best): | |
| instance_vectors = self.vectorizer.fit_transform(instance) | |
| query_vector = self.vectorizer.transform(query) | |
| similarities = cosine_similarity(query_vector, instance_vectors).squeeze() | |
| top_n_index = similarities.argsort()[::-1][0:n_best] | |
| top_n_sent = [instance[ind] for ind in top_n_index] # List[str] | |
| return top_n_sent | |