Spaces:
Runtime error
Runtime error
| #!/usr/bin/env python3 | |
| # -*- coding:utf-8 -*- | |
| # Copyright (c) Megvii, Inc. and its affiliates. | |
| import os | |
| import torch | |
| import torch.nn as nn | |
| from yolox.exp import Exp as MyExp | |
| class Exp(MyExp): | |
| def __init__(self): | |
| super(Exp, self).__init__() | |
| self.depth = 1.0 | |
| self.width = 1.0 | |
| self.exp_name = os.path.split(os.path.realpath(__file__))[1].split(".")[0] | |
| def get_model(self, sublinear=False): | |
| def init_yolo(M): | |
| for m in M.modules(): | |
| if isinstance(m, nn.BatchNorm2d): | |
| m.eps = 1e-3 | |
| m.momentum = 0.03 | |
| if "model" not in self.__dict__: | |
| from yolox.models import YOLOX, YOLOFPN, YOLOXHead | |
| backbone = YOLOFPN() | |
| head = YOLOXHead(self.num_classes, self.width, in_channels=[128, 256, 512], act="lrelu") | |
| self.model = YOLOX(backbone, head) | |
| self.model.apply(init_yolo) | |
| self.model.head.initialize_biases(1e-2) | |
| return self.model | |
| def get_data_loader(self, batch_size, is_distributed, no_aug=False): | |
| from data.datasets.cocodataset import COCODataset | |
| from data.datasets.mosaicdetection import MosaicDetection | |
| from data.datasets.data_augment import TrainTransform | |
| from data.datasets.dataloading import YoloBatchSampler, DataLoader, InfiniteSampler | |
| import torch.distributed as dist | |
| dataset = COCODataset( | |
| data_dir='data/COCO/', | |
| json_file=self.train_ann, | |
| img_size=self.input_size, | |
| preproc=TrainTransform( | |
| rgb_means=(0.485, 0.456, 0.406), | |
| std=(0.229, 0.224, 0.225), | |
| max_labels=50 | |
| ), | |
| ) | |
| dataset = MosaicDetection( | |
| dataset, | |
| mosaic=not no_aug, | |
| img_size=self.input_size, | |
| preproc=TrainTransform( | |
| rgb_means=(0.485, 0.456, 0.406), | |
| std=(0.229, 0.224, 0.225), | |
| max_labels=120 | |
| ), | |
| degrees=self.degrees, | |
| translate=self.translate, | |
| scale=self.scale, | |
| shear=self.shear, | |
| perspective=self.perspective, | |
| ) | |
| self.dataset = dataset | |
| if is_distributed: | |
| batch_size = batch_size // dist.get_world_size() | |
| sampler = InfiniteSampler(len(self.dataset), seed=self.seed if self.seed else 0) | |
| else: | |
| sampler = torch.utils.data.RandomSampler(self.dataset) | |
| batch_sampler = YoloBatchSampler( | |
| sampler=sampler, | |
| batch_size=batch_size, | |
| drop_last=False, | |
| input_dimension=self.input_size, | |
| mosaic=not no_aug | |
| ) | |
| dataloader_kwargs = {"num_workers": self.data_num_workers, "pin_memory": True} | |
| dataloader_kwargs["batch_sampler"] = batch_sampler | |
| train_loader = DataLoader(self.dataset, **dataloader_kwargs) | |
| return train_loader | |