⚡️ [Update] using anchor if is given, or autoanchor
Browse files- yolo/config/model/v9-c.yaml +1 -0
- yolo/tools/solver.py +1 -1
- yolo/utils/bounding_box_utils.py +11 -8
yolo/config/model/v9-c.yaml
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
anchor:
|
| 2 |
reg_max: 16
|
|
|
|
| 3 |
|
| 4 |
model:
|
| 5 |
backbone:
|
|
|
|
| 1 |
anchor:
|
| 2 |
reg_max: 16
|
| 3 |
+
anchors: [8, 16, 32]
|
| 4 |
|
| 5 |
model:
|
| 6 |
backbone:
|
yolo/tools/solver.py
CHANGED
|
@@ -35,7 +35,7 @@ class ModelTrainer:
|
|
| 35 |
self.num_epochs = cfg.task.epoch
|
| 36 |
|
| 37 |
self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
| 38 |
-
self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device
|
| 39 |
|
| 40 |
if getattr(train_cfg.ema, "enabled", False):
|
| 41 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
|
|
|
| 35 |
self.num_epochs = cfg.task.epoch
|
| 36 |
|
| 37 |
self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
| 38 |
+
self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device)
|
| 39 |
|
| 40 |
if getattr(train_cfg.ema, "enabled", False):
|
| 41 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
yolo/utils/bounding_box_utils.py
CHANGED
|
@@ -264,14 +264,17 @@ class BoxMatcher:
|
|
| 264 |
|
| 265 |
|
| 266 |
class Vec2Box:
|
| 267 |
-
def __init__(self, model, image_size, device):
|
| 268 |
-
|
| 269 |
-
|
| 270 |
-
|
| 271 |
-
|
| 272 |
-
|
| 273 |
-
|
| 274 |
-
|
|
|
|
|
|
|
|
|
|
| 275 |
anchor_grid, scaler = generate_anchors(image_size, anchors_num)
|
| 276 |
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
|
| 277 |
self.anchor_norm = (anchor_grid / scaler[:, None])[None].to(device)
|
|
|
|
| 264 |
|
| 265 |
|
| 266 |
class Vec2Box:
|
| 267 |
+
def __init__(self, model, image_size, device, anchors: list = None):
|
| 268 |
+
if anchors is None:
|
| 269 |
+
logger.info("🧸 Found no anchor, Make a dummy test for auto-anchor size")
|
| 270 |
+
dummy_input = torch.zeros(1, 3, *image_size).to(device)
|
| 271 |
+
dummy_output = model(dummy_input)
|
| 272 |
+
anchors_num = []
|
| 273 |
+
for predict_head in dummy_output["Main"]:
|
| 274 |
+
_, _, *anchor_num = predict_head[2].shape
|
| 275 |
+
anchors_num.append(anchor_num)
|
| 276 |
+
else:
|
| 277 |
+
anchors_num = [[image_size[0] / anchor, image_size[0] / anchor] for anchor in anchors]
|
| 278 |
anchor_grid, scaler = generate_anchors(image_size, anchors_num)
|
| 279 |
self.anchor_grid, self.scaler = anchor_grid.to(device), scaler.to(device)
|
| 280 |
self.anchor_norm = (anchor_grid / scaler[:, None])[None].to(device)
|