Spaces:
Runtime error
Runtime error
| """ | |
| Retriever to retrieve relevant examples from annotations. | |
| """ | |
| import copy | |
| from typing import Dict, List, Tuple, Any | |
| import nltk | |
| from nltk.stem import SnowballStemmer | |
| from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction | |
| from utils.normalizer import normalize | |
| from retrieval.retrieve_pool import OpenAIQARetrievePool, QAItem | |
| class OpenAIQARetriever(object): | |
| def __init__(self, retrieve_pool: OpenAIQARetrievePool): | |
| self.retrieve_pool = retrieve_pool | |
| def _string_bleu(q1: str, q2: str, stop_words=None, stemmer=None): | |
| """ | |
| BLEU score. | |
| """ | |
| q1, q2 = normalize(q1), normalize(q2) | |
| reference = [[tk for tk in nltk.word_tokenize(q1)]] | |
| candidate = [tk for tk in nltk.word_tokenize(q2)] | |
| if stemmer is not None: | |
| reference = [[stemmer.stem(tk) for tk in reference[0]]] | |
| candidate = [stemmer.stem(tk) for tk in candidate] | |
| chencherry_smooth = SmoothingFunction() # bleu smooth to avoid hard behaviour when no ngram overlaps | |
| bleu_score = sentence_bleu( | |
| reference, | |
| candidate, | |
| weights=(0.25, 0.3, 0.3, 0.15), | |
| smoothing_function=chencherry_smooth.method1 | |
| ) | |
| return bleu_score | |
| def _qh2qh_similarity( | |
| self, | |
| item: QAItem, | |
| num_retrieve_samples: int, | |
| score_func: str, | |
| qa_type: str, | |
| weight_h: float = 0.2, | |
| verbose: bool = False | |
| ): | |
| """ | |
| Retrieve top K nsqls based on query&header to query&header similarities. | |
| """ | |
| q = item.qa_question | |
| header_wo_row_id = copy.copy(item.table['header']) | |
| header_wo_row_id.remove('row_id') | |
| h = ' '.join(header_wo_row_id) | |
| stemmer = SnowballStemmer('english') | |
| if score_func == 'bleu': | |
| retrieve_q_list = [(d, self._string_bleu(q, d.qa_question.split('@')[1], stemmer=stemmer)) | |
| for d in self.retrieve_pool if d.qa_question.split('@')[0] == qa_type] | |
| retrieve_h_list = [(d, self._string_bleu(h, ' '.join(d.table['header']), stemmer=stemmer)) | |
| for d in self.retrieve_pool if d.qa_question.split('@')[0] == qa_type] | |
| retrieve_list = [(retrieve_q_list[idx][0], retrieve_q_list[idx][1] + weight_h * retrieve_h_list[idx][1]) | |
| for idx in range(len(retrieve_q_list))] | |
| else: | |
| raise ValueError | |
| retrieve_list = sorted(retrieve_list, key=lambda x: x[1], reverse=True) | |
| retrieve_list = list(map(lambda x: x[0], retrieve_list))[:num_retrieve_samples] | |
| if verbose: | |
| print(retrieve_list) | |
| return retrieve_list | |
| def retrieve( | |
| self, | |
| item: QAItem, | |
| num_shots: int, | |
| method: str = 'qh2qh_bleu', | |
| qa_type: str = 'map', | |
| verbose: bool = False | |
| ) -> List[QAItem]: | |
| """ | |
| Retrieve a list of relevant QA samples. | |
| """ | |
| if method == 'qh2qh_bleu': | |
| retrieved_items = self._qh2qh_similarity( | |
| item=item, | |
| num_retrieve_samples=num_shots, | |
| score_func='bleu', | |
| qa_type=qa_type, | |
| verbose=verbose | |
| ) | |
| return retrieved_items | |
| else: | |
| raise ValueError(f'Retrieve method {method} is not supported.') | |