Spaces:
Runtime error
Runtime error
| """ | |
| Retrieval pool of candidates | |
| """ | |
| from dataclasses import dataclass | |
| from typing import List, Dict | |
| import json | |
| class OpenAIQARetrievePool(object): | |
| def __init__(self, data_path): | |
| with open(data_path, 'r') as f: | |
| data = json.load(f) | |
| self.data = [] | |
| for d in data: | |
| if isinstance(d['qa_column'], List): | |
| d['qa_column'] = '|'.join(d['qa_column']) | |
| qa_item = QAItem( | |
| id=d['id'], | |
| qa_question=d['qa_question'], | |
| qa_column=d['qa_column'], | |
| qa_answer=d['qa_answer'], | |
| table=d['table'], | |
| title=d['title'] | |
| ) | |
| self.data.append(qa_item) | |
| self.pointer = 0 | |
| def __iter__(self): | |
| return self | |
| def __next__(self): | |
| pointer = self.pointer | |
| if pointer < len(self): | |
| self.pointer += 1 | |
| return self.data[pointer] | |
| else: | |
| self.pointer = 0 | |
| raise StopIteration | |
| def __getitem__(self, item): | |
| return self.data[item] | |
| def __len__(self): | |
| return len(self.data) | |
| class QAItem(object): | |
| id: int = None | |
| qa_question: str = None | |
| qa_column: str = None | |
| qa_answer: str = None | |
| table: Dict = None | |
| title: str = None |