Spaces:
Build error
Build error
| from model.base_model import SummModel | |
| from model.single_doc import TextRankModel | |
| from typing import List, Union | |
| from nltk import sent_tokenize, word_tokenize | |
| from nltk.corpus import stopwords | |
| from nltk.stem import PorterStemmer | |
| class QueryBasedSummModel(SummModel): | |
| is_query_based = True | |
| def __init__( | |
| self, | |
| trained_domain: str = None, | |
| max_input_length: int = None, | |
| max_output_length: int = None, | |
| model_backend: SummModel = TextRankModel, | |
| retrieval_ratio: float = 0.5, | |
| preprocess: bool = True, | |
| **kwargs, | |
| ): | |
| super(QueryBasedSummModel, self).__init__( | |
| trained_domain=trained_domain, | |
| max_input_length=max_input_length, | |
| max_output_length=max_output_length, | |
| ) | |
| self.model = model_backend(**kwargs) | |
| self.retrieval_ratio = retrieval_ratio | |
| self.preprocess = preprocess | |
| def _retrieve(self, instance: List[str], query: List[str], n_best) -> List[str]: | |
| raise NotImplementedError() | |
| def summarize( | |
| self, | |
| corpus: Union[List[str], List[List[str]]], | |
| queries: List[str] = None, | |
| ) -> List[str]: | |
| self.assert_summ_input_type(corpus, queries) | |
| retrieval_output = [] # List[str] | |
| for instance, query in zip(corpus, queries): | |
| if isinstance(instance, str): | |
| is_dialogue = False | |
| instance = sent_tokenize(instance) | |
| else: | |
| is_dialogue = True | |
| query = [query] | |
| # instance & query now are List[str] for sure | |
| if self.preprocess: | |
| preprocessor = Preprocessor() | |
| instance = preprocessor.preprocess(instance) | |
| query = preprocessor.preprocess(query) | |
| n_best = max(int(len(instance) * self.retrieval_ratio), 1) | |
| top_n_sent = self._retrieve(instance, query, n_best) | |
| if not is_dialogue: | |
| top_n_sent = " ".join(top_n_sent) # str | |
| retrieval_output.append(top_n_sent) | |
| summaries = self.model.summarize( | |
| retrieval_output | |
| ) # List[str] or List[List[str]] | |
| return summaries | |
| def generate_specific_description(self): | |
| is_neural = self.model.is_neural & self.is_neural | |
| is_extractive = self.model.is_extractive | self.is_extractive | |
| model_name = "Pipeline with retriever: {}, summarizer: {}".format( | |
| self.model_name, self.model.model_name | |
| ) | |
| extractive_abstractive = "extractive" if is_extractive else "abstractive" | |
| neural = "neural" if is_neural else "non-neural" | |
| basic_description = ( | |
| f"{model_name} is a " | |
| f"{'query-based' if self.is_query_based else ''} " | |
| f"{extractive_abstractive}, {neural} model for summarization." | |
| ) | |
| return basic_description | |
| def assert_summ_input_type(cls, corpus, query): | |
| if query is None: | |
| raise TypeError( | |
| "Query-based summarization models summarize instances of query-text pairs, however, query is missing." | |
| ) | |
| if not isinstance(query, list): | |
| raise TypeError( | |
| "Query-based single-document summarization requires query of `List[str]`." | |
| ) | |
| if not all([isinstance(q, str) for q in query]): | |
| raise TypeError( | |
| "Query-based single-document summarization requires query of `List[str]`." | |
| ) | |
| def generate_basic_description(cls) -> str: | |
| basic_description = ( | |
| "QueryBasedSummModel performs query-based summarization. Given a query-text pair," | |
| "the model will first extract the most relevant sentences in articles or turns in " | |
| "dialogues, then use the single document summarization model to generate the summary" | |
| ) | |
| return basic_description | |
| def show_capability(cls): | |
| basic_description = cls.generate_basic_description() | |
| more_details = ( | |
| "A query-based summarization model." | |
| " Allows for custom model backend selection at initialization." | |
| " Retrieve relevant turns and then summarize the retrieved turns\n" | |
| "Strengths: \n - Allows for control of backend model.\n" | |
| "Weaknesses: \n - Heavily depends on the performance of both retriever and summarizer.\n" | |
| ) | |
| print(f"{basic_description}\n{'#' * 20}\n{more_details}") | |
| class Preprocessor: | |
| def __init__(self, remove_stopwords=True, lower_case=True, stem=False): | |
| self.sw = stopwords.words("english") | |
| self.stemmer = PorterStemmer() | |
| self.remove_stopwords = remove_stopwords | |
| self.lower_case = lower_case | |
| self.stem = stem | |
| def preprocess(self, corpus: List[str]) -> List[str]: | |
| if self.lower_case: | |
| corpus = [sent.lower() for sent in corpus] | |
| tokenized_corpus = [word_tokenize(sent) for sent in corpus] | |
| if self.remove_stopwords: | |
| tokenized_corpus = [ | |
| [word for word in sent if word not in self.sw] | |
| for sent in tokenized_corpus | |
| ] | |
| if self.stem: | |
| tokenized_corpus = [ | |
| [self.stemmer.stem(word) for word in sent] for sent in tokenized_corpus | |
| ] | |
| return [" ".join(sent) for sent in tokenized_corpus] | |