🔧 [Move] model class num config out of modelyaml
Browse files- yolo/config/config.py +0 -1
- yolo/config/model/v9-c.yaml +0 -2
- yolo/lazy.py +1 -1
- yolo/model/module.py +3 -3
- yolo/model/yolo.py +5 -4
- yolo/tools/loss_functions.py +1 -1
- yolo/utils/deploy_utils.py +3 -3
yolo/config/config.py
CHANGED
|
@@ -25,7 +25,6 @@ class BlockConfig:
|
|
| 25 |
@dataclass
|
| 26 |
class ModelConfig:
|
| 27 |
anchor: AnchorConfig
|
| 28 |
-
class_num: int
|
| 29 |
model: Dict[str, BlockConfig]
|
| 30 |
|
| 31 |
|
|
|
|
| 25 |
@dataclass
|
| 26 |
class ModelConfig:
|
| 27 |
anchor: AnchorConfig
|
|
|
|
| 28 |
model: Dict[str, BlockConfig]
|
| 29 |
|
| 30 |
|
yolo/config/model/v9-c.yaml
CHANGED
|
@@ -2,8 +2,6 @@ anchor:
|
|
| 2 |
reg_max: 16
|
| 3 |
strides: [8, 16, 32]
|
| 4 |
|
| 5 |
-
class_num: ${class_num}
|
| 6 |
-
|
| 7 |
model:
|
| 8 |
backbone:
|
| 9 |
- Conv:
|
|
|
|
| 2 |
reg_max: 16
|
| 3 |
strides: [8, 16, 32]
|
| 4 |
|
|
|
|
|
|
|
| 5 |
model:
|
| 6 |
backbone:
|
| 7 |
- Conv:
|
yolo/lazy.py
CHANGED
|
@@ -25,7 +25,7 @@ def main(cfg: Config):
|
|
| 25 |
model = FastModelLoader(cfg).load_model()
|
| 26 |
device = torch.device(cfg.device)
|
| 27 |
else:
|
| 28 |
-
model = create_model(cfg.model, cfg.weight).to(device)
|
| 29 |
|
| 30 |
if cfg.task.task == "train":
|
| 31 |
trainer = ModelTrainer(cfg, model, save_path, device)
|
|
|
|
| 25 |
model = FastModelLoader(cfg).load_model()
|
| 26 |
device = torch.device(cfg.device)
|
| 27 |
else:
|
| 28 |
+
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight).to(device)
|
| 29 |
|
| 30 |
if cfg.task.task == "train":
|
| 31 |
trainer = ModelTrainer(cfg, model, save_path, device)
|
yolo/model/module.py
CHANGED
|
@@ -93,13 +93,13 @@ class MultiheadDetection(nn.Module):
|
|
| 93 |
|
| 94 |
|
| 95 |
class Anchor2Box(nn.Module):
|
| 96 |
-
def __init__(self, reg_max, strides) -> None:
|
| 97 |
super().__init__()
|
| 98 |
self.reg_max = reg_max
|
| 99 |
self.strides = strides
|
| 100 |
# TODO: read by cfg!
|
| 101 |
image_size = [640, 640]
|
| 102 |
-
self.
|
| 103 |
self.anchors, self.scaler = generate_anchors(image_size, self.strides)
|
| 104 |
reverse_reg = torch.arange(self.reg_max, dtype=torch.float32)
|
| 105 |
self.reverse_reg = nn.Parameter(reverse_reg, requires_grad=False)
|
|
@@ -117,7 +117,7 @@ class Anchor2Box(nn.Module):
|
|
| 117 |
for pred in predicts:
|
| 118 |
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
| 119 |
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
| 120 |
-
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.
|
| 121 |
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
| 122 |
|
| 123 |
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
|
|
|
| 93 |
|
| 94 |
|
| 95 |
class Anchor2Box(nn.Module):
|
| 96 |
+
def __init__(self, reg_max, strides, num_classes: int) -> None:
|
| 97 |
super().__init__()
|
| 98 |
self.reg_max = reg_max
|
| 99 |
self.strides = strides
|
| 100 |
# TODO: read by cfg!
|
| 101 |
image_size = [640, 640]
|
| 102 |
+
self.num_classes = num_classes
|
| 103 |
self.anchors, self.scaler = generate_anchors(image_size, self.strides)
|
| 104 |
reverse_reg = torch.arange(self.reg_max, dtype=torch.float32)
|
| 105 |
self.reverse_reg = nn.Parameter(reverse_reg, requires_grad=False)
|
|
|
|
| 117 |
for pred in predicts:
|
| 118 |
preds.append(rearrange(pred, "B AC h w -> B (h w) AC")) # B x AC x h x w-> B x hw x AC
|
| 119 |
preds = torch.concat(preds, dim=1) # -> B x (H W) x AC
|
| 120 |
+
preds_anc, preds_cls = torch.split(preds, (self.reg_max * 4, self.num_classes), dim=-1)
|
| 121 |
preds_anc = rearrange(preds_anc, "B hw (P R)-> B hw P R", P=4)
|
| 122 |
|
| 123 |
pred_LTRB = preds_anc.softmax(dim=-1) @ self.reverse_reg * self.scaler.view(1, -1, 1)
|
yolo/model/yolo.py
CHANGED
|
@@ -22,9 +22,9 @@ class YOLO(nn.Module):
|
|
| 22 |
parameters, and any other relevant configuration details.
|
| 23 |
"""
|
| 24 |
|
| 25 |
-
def __init__(self, model_cfg: ModelConfig):
|
| 26 |
super(YOLO, self).__init__()
|
| 27 |
-
self.num_classes =
|
| 28 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
| 29 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
| 30 |
self.build_model(model_cfg.model)
|
|
@@ -47,6 +47,7 @@ class YOLO(nn.Module):
|
|
| 47 |
layer_args["in_channels"] = output_dim[source]
|
| 48 |
if "Detection" in layer_type:
|
| 49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
|
|
|
| 50 |
layer_args["num_classes"] = self.num_classes
|
| 51 |
|
| 52 |
# create layers
|
|
@@ -116,7 +117,7 @@ class YOLO(nn.Module):
|
|
| 116 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
| 117 |
|
| 118 |
|
| 119 |
-
def create_model(model_cfg: ModelConfig, weight_path: str) -> YOLO:
|
| 120 |
"""Constructs and returns a model from a Dictionary configuration file.
|
| 121 |
|
| 122 |
Args:
|
|
@@ -126,7 +127,7 @@ def create_model(model_cfg: ModelConfig, weight_path: str) -> YOLO:
|
|
| 126 |
YOLO: An instance of the model defined by the given configuration.
|
| 127 |
"""
|
| 128 |
OmegaConf.set_struct(model_cfg, False)
|
| 129 |
-
model = YOLO(model_cfg)
|
| 130 |
logger.info("✅ Success load model")
|
| 131 |
if weight_path:
|
| 132 |
if os.path.exists(weight_path):
|
|
|
|
| 22 |
parameters, and any other relevant configuration details.
|
| 23 |
"""
|
| 24 |
|
| 25 |
+
def __init__(self, model_cfg: ModelConfig, class_num: int = 80):
|
| 26 |
super(YOLO, self).__init__()
|
| 27 |
+
self.num_classes = class_num
|
| 28 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
| 29 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
| 30 |
self.build_model(model_cfg.model)
|
|
|
|
| 47 |
layer_args["in_channels"] = output_dim[source]
|
| 48 |
if "Detection" in layer_type:
|
| 49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
| 50 |
+
if "Detection" in layer_type or "Anchor2Box" in layer_type:
|
| 51 |
layer_args["num_classes"] = self.num_classes
|
| 52 |
|
| 53 |
# create layers
|
|
|
|
| 117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
| 118 |
|
| 119 |
|
| 120 |
+
def create_model(model_cfg: ModelConfig, class_num: int = 80, weight_path: str = "weights/v9-c.pt") -> YOLO:
|
| 121 |
"""Constructs and returns a model from a Dictionary configuration file.
|
| 122 |
|
| 123 |
Args:
|
|
|
|
| 127 |
YOLO: An instance of the model defined by the given configuration.
|
| 128 |
"""
|
| 129 |
OmegaConf.set_struct(model_cfg, False)
|
| 130 |
+
model = YOLO(model_cfg, class_num)
|
| 131 |
logger.info("✅ Success load model")
|
| 132 |
if weight_path:
|
| 133 |
if os.path.exists(weight_path):
|
yolo/tools/loss_functions.py
CHANGED
|
@@ -70,7 +70,7 @@ class DFLoss(nn.Module):
|
|
| 70 |
class YOLOLoss:
|
| 71 |
def __init__(self, cfg: Config) -> None:
|
| 72 |
self.reg_max = cfg.model.anchor.reg_max
|
| 73 |
-
self.class_num = cfg.
|
| 74 |
self.image_size = list(cfg.image_size)
|
| 75 |
self.strides = cfg.model.anchor.strides
|
| 76 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
|
|
|
| 70 |
class YOLOLoss:
|
| 71 |
def __init__(self, cfg: Config) -> None:
|
| 72 |
self.reg_max = cfg.model.anchor.reg_max
|
| 73 |
+
self.class_num = cfg.class_num
|
| 74 |
self.image_size = list(cfg.image_size)
|
| 75 |
self.strides = cfg.model.anchor.strides
|
| 76 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
yolo/utils/deploy_utils.py
CHANGED
|
@@ -28,7 +28,7 @@ class FastModelLoader:
|
|
| 28 |
return self._load_onnx_model()
|
| 29 |
elif self.compiler == "trt":
|
| 30 |
return self._load_trt_model()
|
| 31 |
-
return create_model(self.cfg)
|
| 32 |
|
| 33 |
def _load_onnx_model(self):
|
| 34 |
from onnxruntime import InferenceSession
|
|
@@ -53,7 +53,7 @@ class FastModelLoader:
|
|
| 53 |
from onnxruntime import InferenceSession
|
| 54 |
from torch.onnx import export
|
| 55 |
|
| 56 |
-
model = create_model(self.cfg).eval()
|
| 57 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
| 58 |
export(
|
| 59 |
model,
|
|
@@ -81,7 +81,7 @@ class FastModelLoader:
|
|
| 81 |
def _create_trt_model(self):
|
| 82 |
from torch2trt import torch2trt
|
| 83 |
|
| 84 |
-
model = create_model(self.cfg).eval()
|
| 85 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
| 86 |
logger.info(f"♻️ Creating TensorRT model")
|
| 87 |
model_trt = torch2trt(model, [dummy_input])
|
|
|
|
| 28 |
return self._load_onnx_model()
|
| 29 |
elif self.compiler == "trt":
|
| 30 |
return self._load_trt_model()
|
| 31 |
+
return create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight)
|
| 32 |
|
| 33 |
def _load_onnx_model(self):
|
| 34 |
from onnxruntime import InferenceSession
|
|
|
|
| 53 |
from onnxruntime import InferenceSession
|
| 54 |
from torch.onnx import export
|
| 55 |
|
| 56 |
+
model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
|
| 57 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
| 58 |
export(
|
| 59 |
model,
|
|
|
|
| 81 |
def _create_trt_model(self):
|
| 82 |
from torch2trt import torch2trt
|
| 83 |
|
| 84 |
+
model = create_model(self.cfg.model, class_num=self.cfg.class_num, weight_path=self.cfg.weight).eval()
|
| 85 |
dummy_input = torch.ones((1, 3, *self.cfg.image_size))
|
| 86 |
logger.info(f"♻️ Creating TensorRT model")
|
| 87 |
model_trt = torch2trt(model, [dummy_input])
|