Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import pandas as pd | |
| ROOT_DIR = os.path.join(os.path.dirname(__file__), "../../") | |
| class Question_Image_Match_Classifier(object): | |
| """result are from a T5-3b model finetuned on train set of MMQA.""" | |
| def __init__(self): | |
| self.whether_retrieve_image = None | |
| self.qi_pairs_should_retrieve = None | |
| self.load_retrieve_info() | |
| self.caption_info = None | |
| with open(os.path.join(ROOT_DIR, "utils", "mmqa", "mmqa_captions.json"), "r") as f: | |
| self.caption_info = json.load(f) | |
| def load_retrieve_info(self): | |
| df_qc = pd.read_csv(os.path.join(ROOT_DIR, "utils", "mmqa", "qc_mmqa_dev.csv")) | |
| whether_retrieve_image = {} | |
| for index, row in df_qc.iterrows(): | |
| _id = row['id'] | |
| prediction = row['prediction'] | |
| whether_retrieve_image[_id] = True if prediction == "['yes']" else False | |
| self.whether_retrieve_image = whether_retrieve_image | |
| df_qimc = pd.read_csv(os.path.join(ROOT_DIR, "utils", "mmqa", "qimc_mmqa_dev.csv")) | |
| qi_pairs_should_retrieve = {} | |
| for index, row in df_qimc.iterrows(): | |
| qa = row['question'].lower() | |
| prediction = row['prediction'] | |
| qi_pairs_should_retrieve[qa] = True if prediction == "['yes']" else False | |
| self.qi_pairs_should_retrieve = qi_pairs_should_retrieve | |
| def judge_match(self, _id, question, pic): | |
| # fixme: hardcode since it is done in pipeline, change that in the future | |
| if not self.whether_retrieve_image[_id]: | |
| return False | |
| image_caption = self.caption_info[os.path.split(pic)[-1].split(".")[0]] | |
| return self.qi_pairs_should_retrieve['qa: {} \n{}'.format(question.lower(), image_caption.lower())] |