Spaces:
Runtime error
Runtime error
| import os | |
| import os.path | |
| from pathlib import Path | |
| from typing import Any, Callable, Optional, Tuple | |
| import torch | |
| from maskrcnn_benchmark.structures.bounding_box import BoxList | |
| from PIL import Image, ImageDraw | |
| from torchvision.datasets.vision import VisionDataset | |
| from .modulated_coco import ConvertCocoPolysToMask, has_valid_annotation | |
| class CustomCocoDetection(VisionDataset): | |
| """Coco-style dataset imported from TorchVision. | |
| It is modified to handle several image sources | |
| Args: | |
| root_coco (string): Path to the coco images | |
| root_vg (string): Path to the vg images | |
| annFile (string): Path to json annotation file. | |
| transform (callable, optional): A function/transform that takes in an PIL image | |
| and returns a transformed version. E.g, ``transforms.ToTensor`` | |
| target_transform (callable, optional): A function/transform that takes in the | |
| target and transforms it. | |
| transforms (callable, optional): A function/transform that takes input sample and its target as entry | |
| and returns a transformed version. | |
| """ | |
| def __init__( | |
| self, | |
| root_coco: str, | |
| root_vg: str, | |
| annFile: str, | |
| transform: Optional[Callable] = None, | |
| target_transform: Optional[Callable] = None, | |
| transforms: Optional[Callable] = None, | |
| ) -> None: | |
| super(CustomCocoDetection, self).__init__(root_coco, transforms, transform, target_transform) | |
| from pycocotools.coco import COCO | |
| self.coco = COCO(annFile) | |
| self.ids = list(sorted(self.coco.imgs.keys())) | |
| ids = [] | |
| for img_id in self.ids: | |
| if isinstance(img_id, str): | |
| ann_ids = self.coco.getAnnIds(imgIds=[img_id], iscrowd=None) | |
| else: | |
| ann_ids = self.coco.getAnnIds(imgIds=img_id, iscrowd=None) | |
| anno = self.coco.loadAnns(ann_ids) | |
| if has_valid_annotation(anno): | |
| ids.append(img_id) | |
| self.ids = ids | |
| self.root_coco = root_coco | |
| self.root_vg = root_vg | |
| def __getitem__(self, index): | |
| """ | |
| Args: | |
| index (int): Index | |
| Returns: | |
| tuple: Tuple (image, target). target is the object returned by ``coco.loadAnns``. | |
| """ | |
| coco = self.coco | |
| img_id = self.ids[index] | |
| ann_ids = coco.getAnnIds(imgIds=img_id) | |
| target = coco.loadAnns(ann_ids) | |
| img_info = coco.loadImgs(img_id)[0] | |
| path = img_info["file_name"] | |
| dataset = img_info["data_source"] | |
| cur_root = self.root_coco if dataset == "coco" else self.root_vg | |
| img = Image.open(os.path.join(cur_root, path)).convert("RGB") | |
| if self.transforms is not None: | |
| img, target = self.transforms(img, target) | |
| return img, target | |
| def __len__(self): | |
| return len(self.ids) | |
| class MixedDataset(CustomCocoDetection): | |
| """Same as the modulated detection dataset, except with multiple img sources""" | |
| def __init__(self, | |
| img_folder_coco, | |
| img_folder_vg, | |
| ann_file, | |
| transforms, | |
| return_masks, | |
| return_tokens, | |
| tokenizer=None, | |
| disable_clip_to_image=False, | |
| no_mask_for_gold=False, | |
| max_query_len=256, | |
| **kwargs): | |
| super(MixedDataset, self).__init__(img_folder_coco, img_folder_vg, ann_file) | |
| self._transforms = transforms | |
| self.max_query_len = max_query_len | |
| self.prepare = ConvertCocoPolysToMask(return_masks, return_tokens, tokenizer=tokenizer, max_query_len=max_query_len) | |
| self.id_to_img_map = {k: v for k, v in enumerate(self.ids)} | |
| self.disable_clip_to_image = disable_clip_to_image | |
| self.no_mask_for_gold = no_mask_for_gold | |
| def __getitem__(self, idx): | |
| img, target = super(MixedDataset, self).__getitem__(idx) | |
| image_id = self.ids[idx] | |
| caption = self.coco.loadImgs(image_id)[0]["caption"] | |
| anno = {"image_id": image_id, "annotations": target, "caption": caption} | |
| anno["greenlight_span_for_masked_lm_objective"] = [(0, len(caption))] | |
| if self.no_mask_for_gold: | |
| anno["greenlight_span_for_masked_lm_objective"].append((-1, -1, -1)) | |
| img, anno = self.prepare(img, anno) | |
| # convert to BoxList (bboxes, labels) | |
| boxes = torch.as_tensor(anno["boxes"]).reshape(-1, 4) # guard against no boxes | |
| target = BoxList(boxes, img.size, mode="xyxy") | |
| classes = anno["labels"] | |
| target.add_field("labels", classes) | |
| if not self.disable_clip_to_image: | |
| num_boxes = len(boxes) | |
| target = target.clip_to_image(remove_empty=True) | |
| assert len(target.bbox) == num_boxes, "Box removed in MixedDataset!!!" | |
| if self._transforms is not None: | |
| img, target = self._transforms(img, target) | |
| # add additional property | |
| for ann in anno: | |
| target.add_field(ann, anno[ann]) | |
| return img, target, idx | |
| def get_img_info(self, index): | |
| img_id = self.id_to_img_map[index] | |
| img_data = self.coco.imgs[img_id] | |
| return img_data | |