Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from torchvision.datasets import ImageFolder | |
| from open_flamingo.eval.imagenet_utils import IMAGENET_1K_CLASS_ID_TO_LABEL | |
| class COCOFlickrDataset(Dataset): | |
| def __init__( | |
| self, | |
| image_dir_path, | |
| annotations_path, | |
| is_flickr=False, | |
| ): | |
| self.image_dir_path = image_dir_path | |
| self.annotations = json.load(open(annotations_path))["annotations"] | |
| self.is_flickr = is_flickr | |
| def __len__(self): | |
| return len(self.annotations) | |
| def get_img_path(self, idx): | |
| if self.is_flickr: | |
| return f"{self.image_dir_path}/{self.annotations[idx]['image_id']}.jpg" | |
| else: | |
| return f"{self.image_dir_path}/{self.annotations[idx]['image_id']:012d}.jpg" | |
| def __getitem__(self, idx): | |
| image = Image.open(self.get_img_path(idx)) | |
| caption = self.annotations[idx]["caption"] | |
| return { | |
| "image": image, | |
| "caption": caption, | |
| "image_id": self.annotations[idx]["image_id"], | |
| } | |
| class VQADataset(Dataset): | |
| def __init__( | |
| self, | |
| image_dir_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/train2014/", | |
| question_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_OpenEnded_mscoco_train2014_questions.json", | |
| annotations_path="/mmfs1/gscratch/efml/anasa2/data/vqav2/v2_mscoco_train2014_annotations.json", | |
| vqa_dataset="vqa", | |
| ): | |
| self.questions = json.load(open(question_path, "r"))["questions"] | |
| self.answers = json.load(open(annotations_path, "r"))["annotations"] | |
| self.image_dir_path = image_dir_path | |
| self.vqa_dataset = vqa_dataset | |
| def __len__(self): | |
| return len(self.questions) | |
| def get_img_path(self, question): | |
| if self.vqa_dataset == "vqa": | |
| return os.path.join( | |
| self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg" | |
| ) | |
| elif self.vqa_dataset == "ok_vqa": | |
| return os.path.join( | |
| self.image_dir_path, f"COCO_val2014_{question['image_id']:012d}.jpg" | |
| ) | |
| else: | |
| raise Exception(f"Unknown VQA dataset {self.vqa_dataset}") | |
| def __getitem__(self, idx): | |
| question = self.questions[idx] | |
| answers = self.answers[idx] | |
| img_path = self.get_img_path(question) | |
| image = Image.open(img_path) | |
| return { | |
| "image": image, | |
| "question": question["question"], | |
| "answers": [a["answer"] for a in answers["answers"]], | |
| "question_id": question["question_id"], | |
| } | |
| class ImageNetDataset(ImageFolder): | |
| """Class to represent the ImageNet1k dataset.""" | |
| def __init__(self, root, **kwargs): | |
| super().__init__(root=root, **kwargs) | |
| def __getitem__(self, idx): | |
| sample, target = super().__getitem__(idx) | |
| target_label = IMAGENET_1K_CLASS_ID_TO_LABEL[target] | |
| return { | |
| "image": sample, | |
| "class_id": target, # numeric ID of the ImageNet class | |
| "class_name": target_label, # human-readable name of ImageNet class | |
| } | |
| if __name__ == "__main__": | |
| gqa_dataset = GQADataset() | |
| for sample in gqa_dataset: | |
| print(sample) | |