Spaces:
Runtime error
Runtime error
| import json | |
| from pathlib import Path | |
| import torch | |
| import torchvision | |
| from .modulated_coco import ConvertCocoPolysToMask, ModulatedDataset | |
| class GQADataset(ModulatedDataset): | |
| pass | |
| class GQAQuestionAnswering(torchvision.datasets.CocoDetection): | |
| def __init__(self, img_folder, ann_file, transforms, return_masks, return_tokens, tokenizer, ann_folder): | |
| super(GQAQuestionAnswering, self).__init__(img_folder, ann_file) | |
| self._transforms = transforms | |
| self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer) | |
| with open(ann_folder / "gqa_answer2id.json", "r") as f: | |
| self.answer2id = json.load(f) | |
| with open(ann_folder / "gqa_answer2id_by_type.json", "r") as f: | |
| self.answer2id_by_type = json.load(f) | |
| self.type2id = {"obj": 0, "attr": 1, "rel": 2, "global": 3, "cat": 4} | |
| def __getitem__(self, idx): | |
| img, target = super(GQAQuestionAnswering, self).__getitem__(idx) | |
| image_id = self.ids[idx] | |
| coco_img = self.coco.loadImgs(image_id)[0] | |
| caption = coco_img["caption"] | |
| dataset_name = coco_img["dataset_name"] | |
| questionId = coco_img["questionId"] | |
| target = {"image_id": image_id, "annotations": target, "caption": caption} | |
| img, target = self.prepare(img, target) | |
| if self._transforms is not None: | |
| img, target = self._transforms(img, target) | |
| target["dataset_name"] = dataset_name | |
| target["questionId"] = questionId | |
| if coco_img["answer"] not in self.answer2id: | |
| answer = "unknown" | |
| else: | |
| answer = coco_img["answer"] | |
| target["answer"] = torch.as_tensor(self.answer2id[answer], dtype=torch.long) | |
| target["answer_type"] = torch.as_tensor(self.type2id[coco_img["question_type"]], dtype=torch.long) | |
| if coco_img["answer"] not in self.answer2id_by_type["answer_attr"]: | |
| answer = "unknown" | |
| else: | |
| answer = coco_img["answer"] | |
| target["answer_attr"] = torch.as_tensor( | |
| self.answer2id_by_type["answer_attr"][answer] if coco_img["question_type"] == "attr" else -100, | |
| dtype=torch.long, | |
| ) | |
| if coco_img["answer"] not in self.answer2id_by_type["answer_global"]: | |
| answer = "unknown" | |
| else: | |
| answer = coco_img["answer"] | |
| target["answer_global"] = torch.as_tensor( | |
| self.answer2id_by_type["answer_global"][answer] if coco_img["question_type"] == "global" else -100, | |
| dtype=torch.long, | |
| ) | |
| if coco_img["answer"] not in self.answer2id_by_type["answer_rel"]: | |
| answer = "unknown" | |
| else: | |
| answer = coco_img["answer"] | |
| target["answer_rel"] = torch.as_tensor( | |
| self.answer2id_by_type["answer_rel"][answer] if coco_img["question_type"] == "rel" else -100, | |
| dtype=torch.long, | |
| ) | |
| if coco_img["answer"] not in self.answer2id_by_type["answer_cat"]: | |
| answer = "unknown" | |
| else: | |
| answer = coco_img["answer"] | |
| target["answer_cat"] = torch.as_tensor( | |
| self.answer2id_by_type["answer_cat"][answer] if coco_img["question_type"] == "cat" else -100, | |
| dtype=torch.long, | |
| ) | |
| if coco_img["answer"] not in self.answer2id_by_type["answer_obj"]: | |
| answer = "unknown" | |
| else: | |
| answer = coco_img["answer"] | |
| target["answer_obj"] = torch.as_tensor( | |
| self.answer2id_by_type["answer_obj"][answer] if coco_img["question_type"] == "obj" else -100, | |
| dtype=torch.long, | |
| ) | |
| return img, target | |