🔧 [Add] model name into modelcfg, and move reg_max
Browse files- yolo/config/config.py +1 -0
- yolo/config/model/v9-c.yaml +2 -4
- yolo/config/model/v9-m.yaml +2 -4
- yolo/config/model/v9-s.yaml +2 -4
- yolo/model/yolo.py +3 -1
yolo/config/config.py
CHANGED
|
@@ -24,6 +24,7 @@ class BlockConfig:
|
|
| 24 |
|
| 25 |
@dataclass
|
| 26 |
class ModelConfig:
|
|
|
|
| 27 |
anchor: AnchorConfig
|
| 28 |
model: Dict[str, BlockConfig]
|
| 29 |
|
|
|
|
| 24 |
|
| 25 |
@dataclass
|
| 26 |
class ModelConfig:
|
| 27 |
+
name: Optional[str]
|
| 28 |
anchor: AnchorConfig
|
| 29 |
model: Dict[str, BlockConfig]
|
| 30 |
|
yolo/config/model/v9-c.yaml
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
anchor:
|
| 2 |
reg_max: 16
|
| 3 |
strides: [8, 16, 32]
|
|
@@ -73,8 +75,6 @@ model:
|
|
| 73 |
- MultiheadDetection:
|
| 74 |
source: [P3, P4, P5]
|
| 75 |
tags: Main
|
| 76 |
-
args:
|
| 77 |
-
reg_max: ${model.anchor.reg_max}
|
| 78 |
output: True
|
| 79 |
|
| 80 |
auxiliary:
|
|
@@ -129,6 +129,4 @@ model:
|
|
| 129 |
- MultiheadDetection:
|
| 130 |
source: [A3, A4, A5]
|
| 131 |
tags: AUX
|
| 132 |
-
args:
|
| 133 |
-
reg_max: ${model.anchor.reg_max}
|
| 134 |
output: True
|
|
|
|
| 1 |
+
name: v9-c
|
| 2 |
+
|
| 3 |
anchor:
|
| 4 |
reg_max: 16
|
| 5 |
strides: [8, 16, 32]
|
|
|
|
| 75 |
- MultiheadDetection:
|
| 76 |
source: [P3, P4, P5]
|
| 77 |
tags: Main
|
|
|
|
|
|
|
| 78 |
output: True
|
| 79 |
|
| 80 |
auxiliary:
|
|
|
|
| 129 |
- MultiheadDetection:
|
| 130 |
source: [A3, A4, A5]
|
| 131 |
tags: AUX
|
|
|
|
|
|
|
| 132 |
output: True
|
yolo/config/model/v9-m.yaml
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
anchor:
|
| 2 |
reg_max: 16
|
| 3 |
|
|
@@ -72,8 +74,6 @@ model:
|
|
| 72 |
- MultiheadDetection:
|
| 73 |
source: [P3, P4, P5]
|
| 74 |
tags: Main
|
| 75 |
-
args:
|
| 76 |
-
reg_max: ${model.anchor.reg_max}
|
| 77 |
output: True
|
| 78 |
|
| 79 |
auxiliary:
|
|
@@ -128,6 +128,4 @@ model:
|
|
| 128 |
- MultiheadDetection:
|
| 129 |
source: [A3, A4, A5]
|
| 130 |
tags: AUX
|
| 131 |
-
args:
|
| 132 |
-
reg_max: ${model.anchor.reg_max}
|
| 133 |
output: True
|
|
|
|
| 1 |
+
name: v9-m
|
| 2 |
+
|
| 3 |
anchor:
|
| 4 |
reg_max: 16
|
| 5 |
|
|
|
|
| 74 |
- MultiheadDetection:
|
| 75 |
source: [P3, P4, P5]
|
| 76 |
tags: Main
|
|
|
|
|
|
|
| 77 |
output: True
|
| 78 |
|
| 79 |
auxiliary:
|
|
|
|
| 128 |
- MultiheadDetection:
|
| 129 |
source: [A3, A4, A5]
|
| 130 |
tags: AUX
|
|
|
|
|
|
|
| 131 |
output: True
|
yolo/config/model/v9-s.yaml
CHANGED
|
@@ -1,3 +1,5 @@
|
|
|
|
|
|
|
|
| 1 |
anchor:
|
| 2 |
reg_max: 16
|
| 3 |
|
|
@@ -92,8 +94,6 @@ model:
|
|
| 92 |
- MultiheadDetection:
|
| 93 |
source: [P3, P4, P5]
|
| 94 |
tags: Main
|
| 95 |
-
args:
|
| 96 |
-
reg_max: ${model.anchor.reg_max}
|
| 97 |
output: True
|
| 98 |
|
| 99 |
auxiliary:
|
|
@@ -129,6 +129,4 @@ model:
|
|
| 129 |
- MultiheadDetection:
|
| 130 |
source: [A3, A4, A5]
|
| 131 |
tags: AUX
|
| 132 |
-
args:
|
| 133 |
-
reg_max: ${model.anchor.reg_max}
|
| 134 |
output: True
|
|
|
|
| 1 |
+
name: v9-s
|
| 2 |
+
|
| 3 |
anchor:
|
| 4 |
reg_max: 16
|
| 5 |
|
|
|
|
| 94 |
- MultiheadDetection:
|
| 95 |
source: [P3, P4, P5]
|
| 96 |
tags: Main
|
|
|
|
|
|
|
| 97 |
output: True
|
| 98 |
|
| 99 |
auxiliary:
|
|
|
|
| 129 |
- MultiheadDetection:
|
| 130 |
source: [A3, A4, A5]
|
| 131 |
tags: AUX
|
|
|
|
|
|
|
| 132 |
output: True
|
yolo/model/yolo.py
CHANGED
|
@@ -25,8 +25,9 @@ class YOLO(nn.Module):
|
|
| 25 |
self.num_classes = class_num
|
| 26 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
| 27 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
| 28 |
-
self.
|
| 29 |
self.strides = getattr(model_cfg.anchor, "strides", None)
|
|
|
|
| 30 |
|
| 31 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
| 32 |
self.layer_index = {}
|
|
@@ -48,6 +49,7 @@ class YOLO(nn.Module):
|
|
| 48 |
if "Detection" in layer_type:
|
| 49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
| 50 |
layer_args["num_classes"] = self.num_classes
|
|
|
|
| 51 |
|
| 52 |
# create layers
|
| 53 |
layer = self.create_layer(layer_type, source, layer_info, **layer_args)
|
|
|
|
| 25 |
self.num_classes = class_num
|
| 26 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
| 27 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
| 28 |
+
self.reg_max = getattr(model_cfg.anchor, "reg_max", 16)
|
| 29 |
self.strides = getattr(model_cfg.anchor, "strides", None)
|
| 30 |
+
self.build_model(model_cfg.model)
|
| 31 |
|
| 32 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
| 33 |
self.layer_index = {}
|
|
|
|
| 49 |
if "Detection" in layer_type:
|
| 50 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
| 51 |
layer_args["num_classes"] = self.num_classes
|
| 52 |
+
layer_args["reg_max"] = self.reg_max
|
| 53 |
|
| 54 |
# create layers
|
| 55 |
layer = self.create_layer(layer_type, source, layer_info, **layer_args)
|