Spaces:
Sleeping
Sleeping
🐛 [Add] Tensorlize @ dataloader
Browse files- yolo/tools/data_loader.py +15 -3
yolo/tools/data_loader.py
CHANGED
|
@@ -32,7 +32,19 @@ class YoloDataset(Dataset):
|
|
| 32 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
| 33 |
self.transform = AugmentationComposer(transforms, self.image_size)
|
| 34 |
self.transform.get_more_data = self.get_more_data
|
| 35 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
def load_data(self, dataset_path: Path, phase_name: str):
|
| 38 |
"""
|
|
@@ -132,7 +144,7 @@ class YoloDataset(Dataset):
|
|
| 132 |
return torch.zeros((0, 5))
|
| 133 |
|
| 134 |
def get_data(self, idx):
|
| 135 |
-
img_path, bboxes = self.
|
| 136 |
img = Image.open(img_path).convert("RGB")
|
| 137 |
return img, bboxes, img_path
|
| 138 |
|
|
@@ -146,7 +158,7 @@ class YoloDataset(Dataset):
|
|
| 146 |
return img, bboxes, rev_tensor, img_path
|
| 147 |
|
| 148 |
def __len__(self) -> int:
|
| 149 |
-
return len(self.
|
| 150 |
|
| 151 |
|
| 152 |
class YoloDataLoader(DataLoader):
|
|
|
|
| 32 |
transforms = [eval(aug)(prob) for aug, prob in augment_cfg.items()]
|
| 33 |
self.transform = AugmentationComposer(transforms, self.image_size)
|
| 34 |
self.transform.get_more_data = self.get_more_data
|
| 35 |
+
self.img_paths, self.bboxes = self.tensorlize(self.load_data(Path(dataset_cfg.path), phase_name))
|
| 36 |
+
|
| 37 |
+
def tensorlize(self, data):
|
| 38 |
+
img_paths, bboxes = zip(*data)
|
| 39 |
+
max_box = max(bbox.size(0) for bbox in bboxes)
|
| 40 |
+
padded_bbox_list = []
|
| 41 |
+
for bbox in bboxes:
|
| 42 |
+
padding = torch.full((max_box, 5), -1, dtype=torch.float32)
|
| 43 |
+
padding[: bbox.size(0)] = bbox
|
| 44 |
+
padded_bbox_list.append(padding)
|
| 45 |
+
bboxes = torch.stack(padded_bbox_list)
|
| 46 |
+
img_paths = np.array(img_paths)
|
| 47 |
+
return img_paths, bboxes
|
| 48 |
|
| 49 |
def load_data(self, dataset_path: Path, phase_name: str):
|
| 50 |
"""
|
|
|
|
| 144 |
return torch.zeros((0, 5))
|
| 145 |
|
| 146 |
def get_data(self, idx):
|
| 147 |
+
img_path, bboxes = self.img_paths[idx], self.bboxes[idx]
|
| 148 |
img = Image.open(img_path).convert("RGB")
|
| 149 |
return img, bboxes, img_path
|
| 150 |
|
|
|
|
| 158 |
return img, bboxes, rev_tensor, img_path
|
| 159 |
|
| 160 |
def __len__(self) -> int:
|
| 161 |
+
return len(self.bboxes)
|
| 162 |
|
| 163 |
|
| 164 |
class YoloDataLoader(DataLoader):
|