Spaces:
Runtime error
Runtime error
| import json | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as transforms | |
| from pycocotools import mask as mask_utils | |
| from skimage import io | |
| from torch.utils.data import DataLoader, Dataset | |
| from torch.utils.data.distributed import DistributedSampler | |
| from efficientvit.apps.data_provider import DataProvider | |
| from efficientvit.samcore.data_provider.utils import ( | |
| Normalize_and_Pad, | |
| RandomHFlip, | |
| ResizeLongestSide, | |
| SAMDistributedSampler, | |
| ) | |
| __all__ = ["SAMDataProvider"] | |
| class OnlineDataset(Dataset): | |
| def __init__(self, root, train=True, num_masks=64, transform=None): | |
| self.root = root | |
| self.train = train | |
| self.num_masks = num_masks | |
| self.transform = transform | |
| self.data = open(f"{self.root}/sa_images_ids.txt", "r").read().splitlines() | |
| if self.train: | |
| self.data = self.data[: int(len(self.data) * 0.99)] | |
| else: | |
| self.data = self.data[int(len(self.data) * 0.99) :] | |
| def __len__(self): | |
| return len(self.data) | |
| def __getitem__(self, idx): | |
| """ | |
| Note: We provide the simplest data organization here. You can modify the code according to your data organization. | |
| """ | |
| index = int(self.data[idx]) | |
| image_path = f"{self.root}/images/sa_{index}.jpg" | |
| image = io.imread(image_path) | |
| json_path = f"{self.root}/masks/sa_{index}.json" | |
| annotations = json.load(open(json_path))["annotations"] | |
| if self.train: | |
| if len(annotations) > self.num_masks: | |
| r = np.random.choice(len(annotations), size=self.num_masks, replace=False) | |
| else: | |
| repeat, residue = self.num_masks // len(annotations), self.num_masks % len(annotations) | |
| r = np.random.choice(len(annotations), size=residue, replace=False) | |
| r = np.concatenate([np.arange(len(annotations)) for _ in range(repeat)] + [r], axis=0) | |
| else: | |
| if len(annotations) > self.num_masks: | |
| r = np.arange(self.num_masks) | |
| else: | |
| repeat, residue = self.num_masks // len(annotations), self.num_masks % len(annotations) | |
| r = np.arange(residue) | |
| r = np.concatenate([np.arange(len(annotations)) for _ in range(repeat)] + [r], axis=0) | |
| masks = np.stack([mask_utils.decode(annotations[i]["segmentation"]) for i in r]) | |
| points = np.stack([annotations[i]["point_coords"][0] for i in r]) | |
| bboxs = np.stack([annotations[i]["bbox"] for i in r]) | |
| image = torch.tensor(image, dtype=torch.float32) | |
| image = torch.transpose(torch.transpose(image, 1, 2), 0, 1) | |
| masks = torch.tensor(masks, dtype=torch.float32) | |
| points = torch.tensor(points, dtype=torch.float32) | |
| bboxs = torch.tensor(bboxs, dtype=torch.float32) | |
| sample = { | |
| "image": image, | |
| "masks": masks, | |
| "points": points, | |
| "bboxs": bboxs, | |
| "shape": torch.tensor(image.shape[-2:]), | |
| } | |
| if self.transform: | |
| sample = self.transform(sample) | |
| return sample | |
| class SAMDataProvider(DataProvider): | |
| name = "sam" | |
| def __init__( | |
| self, | |
| root: str, | |
| sub_epochs_per_epoch: int, | |
| num_masks: int, | |
| train_batch_size: int, | |
| test_batch_size: int, | |
| valid_size: int or float or None = None, | |
| n_worker=8, | |
| image_size: int = 1024, | |
| num_replicas: int or None = None, | |
| rank: int or None = None, | |
| train_ratio: float or None = None, | |
| drop_last: bool = False, | |
| ): | |
| self.root = root | |
| self.num_masks = num_masks | |
| self.sub_epochs_per_epoch = sub_epochs_per_epoch | |
| super().__init__( | |
| train_batch_size, | |
| test_batch_size, | |
| valid_size, | |
| n_worker, | |
| image_size, | |
| num_replicas, | |
| rank, | |
| train_ratio, | |
| drop_last, | |
| ) | |
| def build_train_transform(self): | |
| train_transforms = [ | |
| RandomHFlip(), | |
| ResizeLongestSide(target_length=self.image_size[0]), | |
| Normalize_and_Pad(target_length=self.image_size[0]), | |
| ] | |
| return transforms.Compose(train_transforms) | |
| def build_valid_transform(self): | |
| valid_transforms = [ | |
| ResizeLongestSide(target_length=self.image_size[0]), | |
| Normalize_and_Pad(target_length=self.image_size[0]), | |
| ] | |
| return transforms.Compose(valid_transforms) | |
| def build_datasets(self) -> tuple[any, any, any]: | |
| train_transform = self.build_train_transform() | |
| valid_transform = self.build_valid_transform() | |
| train_dataset = OnlineDataset(root=self.root, train=True, num_masks=self.num_masks, transform=train_transform) | |
| val_dataset = OnlineDataset(root=self.root, train=False, num_masks=2, transform=valid_transform) | |
| test_dataset = None | |
| return train_dataset, val_dataset, test_dataset | |
| def build_dataloader(self, dataset: any or None, batch_size: int, n_worker: int, drop_last: bool, train: bool): | |
| if dataset is None: | |
| return None | |
| if train: | |
| sampler = SAMDistributedSampler(dataset, sub_epochs_per_epoch=self.sub_epochs_per_epoch) | |
| dataloader = DataLoader(dataset, batch_size, sampler=sampler, drop_last=True, num_workers=n_worker) | |
| return dataloader | |
| else: | |
| sampler = DistributedSampler(dataset, shuffle=False) | |
| dataloader = DataLoader(dataset, batch_size, sampler=sampler, drop_last=False, num_workers=n_worker) | |
| return dataloader | |
| def set_epoch_and_sub_epoch(self, epoch: int, sub_epoch: int) -> None: | |
| if isinstance(self.train.sampler, SAMDistributedSampler): | |
| self.train.sampler.set_epoch_and_sub_epoch(epoch, sub_epoch) | |