π [Merge] branch 'MODEL' into TEST
Browse files- yolo/config/config.py +13 -0
- yolo/model/module.py +12 -6
- yolo/model/yolo.py +30 -20
- yolo/tools/log_helper.py +26 -0
- yolo/tools/module_helper.py +1 -1
yolo/config/config.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import Dict, List, Union
|
| 3 |
|
|
|
|
|
|
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
class AnchorConfig:
|
|
@@ -100,6 +102,17 @@ class Download:
|
|
| 100 |
datasets: Datasets
|
| 101 |
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
@dataclass
|
| 104 |
class Config:
|
| 105 |
model: Model
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import Dict, List, Union
|
| 3 |
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
|
| 7 |
@dataclass
|
| 8 |
class AnchorConfig:
|
|
|
|
| 102 |
datasets: Datasets
|
| 103 |
|
| 104 |
|
| 105 |
+
@dataclass
|
| 106 |
+
class YOLOLayer(nn.Module):
|
| 107 |
+
source: Union[int, str, List[int]]
|
| 108 |
+
output: bool
|
| 109 |
+
tags: str
|
| 110 |
+
layer_type: str
|
| 111 |
+
|
| 112 |
+
def __post_init__(self):
|
| 113 |
+
super().__init__()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
@dataclass
|
| 117 |
class Config:
|
| 118 |
model: Model
|
yolo/model/module.py
CHANGED
|
@@ -24,7 +24,7 @@ class Conv(nn.Module):
|
|
| 24 |
):
|
| 25 |
super().__init__()
|
| 26 |
kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
|
| 27 |
-
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, **kwargs)
|
| 28 |
self.bn = nn.BatchNorm2d(out_channels)
|
| 29 |
self.act = get_activation(activation)
|
| 30 |
|
|
@@ -49,14 +49,16 @@ class Pool(nn.Module):
|
|
| 49 |
class Detection(nn.Module):
|
| 50 |
"""A single YOLO Detection head for detection models"""
|
| 51 |
|
| 52 |
-
def __init__(self, in_channels: int, num_classes: int, *, reg_max: int = 16, use_group: bool = True):
|
| 53 |
super().__init__()
|
| 54 |
|
| 55 |
groups = 4 if use_group else 1
|
| 56 |
anchor_channels = 4 * reg_max
|
|
|
|
|
|
|
| 57 |
# TODO: round up head[0] channels or each head?
|
| 58 |
-
anchor_neck = max(round_up(
|
| 59 |
-
class_neck = max(
|
| 60 |
|
| 61 |
self.anchor_conv = nn.Sequential(
|
| 62 |
Conv(in_channels, anchor_neck, 3),
|
|
@@ -78,8 +80,12 @@ class MultiheadDetection(nn.Module):
|
|
| 78 |
|
| 79 |
def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
|
| 80 |
super().__init__()
|
|
|
|
| 81 |
self.heads = nn.ModuleList(
|
| 82 |
-
[
|
|
|
|
|
|
|
|
|
|
| 83 |
)
|
| 84 |
|
| 85 |
def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
@@ -118,7 +124,7 @@ class RepNBottleneck(nn.Module):
|
|
| 118 |
*,
|
| 119 |
kernel_size: Tuple[int, int] = (3, 3),
|
| 120 |
residual: bool = True,
|
| 121 |
-
expand: float = 0
|
| 122 |
**kwargs
|
| 123 |
):
|
| 124 |
super().__init__()
|
|
|
|
| 24 |
):
|
| 25 |
super().__init__()
|
| 26 |
kwargs.setdefault("padding", auto_pad(kernel_size, **kwargs))
|
| 27 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, bias=False, **kwargs)
|
| 28 |
self.bn = nn.BatchNorm2d(out_channels)
|
| 29 |
self.act = get_activation(activation)
|
| 30 |
|
|
|
|
| 49 |
class Detection(nn.Module):
|
| 50 |
"""A single YOLO Detection head for detection models"""
|
| 51 |
|
| 52 |
+
def __init__(self, in_channels: Tuple[int], num_classes: int, *, reg_max: int = 16, use_group: bool = True):
|
| 53 |
super().__init__()
|
| 54 |
|
| 55 |
groups = 4 if use_group else 1
|
| 56 |
anchor_channels = 4 * reg_max
|
| 57 |
+
|
| 58 |
+
first_neck, in_channels = in_channels
|
| 59 |
# TODO: round up head[0] channels or each head?
|
| 60 |
+
anchor_neck = max(round_up(first_neck // 4, groups), anchor_channels, 16)
|
| 61 |
+
class_neck = max(first_neck, min(num_classes * 2, 128))
|
| 62 |
|
| 63 |
self.anchor_conv = nn.Sequential(
|
| 64 |
Conv(in_channels, anchor_neck, 3),
|
|
|
|
| 80 |
|
| 81 |
def __init__(self, in_channels: List[int], num_classes: int, **head_kwargs):
|
| 82 |
super().__init__()
|
| 83 |
+
# TODO: Refactor these parts
|
| 84 |
self.heads = nn.ModuleList(
|
| 85 |
+
[
|
| 86 |
+
Detection((in_channels[3 * (idx // 3)], in_channel), num_classes, **head_kwargs)
|
| 87 |
+
for idx, in_channel in enumerate(in_channels)
|
| 88 |
+
]
|
| 89 |
)
|
| 90 |
|
| 91 |
def forward(self, x_list: List[torch.Tensor]) -> List[torch.Tensor]:
|
|
|
|
| 124 |
*,
|
| 125 |
kernel_size: Tuple[int, int] = (3, 3),
|
| 126 |
residual: bool = True,
|
| 127 |
+
expand: float = 1.0,
|
| 128 |
**kwargs
|
| 129 |
):
|
| 130 |
super().__init__()
|
yolo/model/yolo.py
CHANGED
|
@@ -4,8 +4,9 @@ import torch.nn as nn
|
|
| 4 |
from loguru import logger
|
| 5 |
from omegaconf import ListConfig, OmegaConf
|
| 6 |
|
| 7 |
-
from yolo.config.config import Config, Model
|
| 8 |
from yolo.tools.layer_helper import get_layer_map
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class YOLO(nn.Module):
|
|
@@ -21,13 +22,13 @@ class YOLO(nn.Module):
|
|
| 21 |
super(YOLO, self).__init__()
|
| 22 |
self.num_classes = num_classes
|
| 23 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
|
|
|
| 24 |
self.build_model(model_cfg.model)
|
|
|
|
| 25 |
|
| 26 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
| 27 |
-
|
| 28 |
-
output_dim = [3]
|
| 29 |
-
layer_indices_by_tag = {}
|
| 30 |
-
layer_idx = 1
|
| 31 |
logger.info(f"π Building YOLO")
|
| 32 |
for arch_name in model_arch:
|
| 33 |
logger.info(f" ποΈ Building {arch_name}")
|
|
@@ -36,11 +37,7 @@ class YOLO(nn.Module):
|
|
| 36 |
layer_args = layer_info.get("args", {})
|
| 37 |
|
| 38 |
# Get input source
|
| 39 |
-
source = layer_info.get("source", -1)
|
| 40 |
-
if isinstance(source, str):
|
| 41 |
-
source = layer_indices_by_tag[source]
|
| 42 |
-
elif isinstance(source, ListConfig):
|
| 43 |
-
source = [layer_indices_by_tag[idx] if isinstance(idx, str) else idx for idx in source]
|
| 44 |
|
| 45 |
# Find in channels
|
| 46 |
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
|
|
@@ -51,29 +48,29 @@ class YOLO(nn.Module):
|
|
| 51 |
|
| 52 |
# create layers
|
| 53 |
layer = self.create_layer(layer_type, source, layer_info, **layer_args)
|
| 54 |
-
|
| 55 |
|
| 56 |
-
if
|
| 57 |
-
if
|
| 58 |
raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
|
| 59 |
-
|
| 60 |
|
| 61 |
out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
|
| 62 |
output_dim.append(out_channels)
|
|
|
|
| 63 |
layer_idx += 1
|
| 64 |
|
| 65 |
-
self.model = model_list
|
| 66 |
-
|
| 67 |
def forward(self, x):
|
| 68 |
-
y =
|
| 69 |
output = []
|
| 70 |
-
for layer in self.model:
|
| 71 |
if isinstance(layer.source, list):
|
| 72 |
model_input = [y[idx] for idx in layer.source]
|
| 73 |
else:
|
| 74 |
model_input = y[layer.source]
|
| 75 |
x = layer(model_input)
|
| 76 |
-
|
|
|
|
| 77 |
if layer.output:
|
| 78 |
output.append(x)
|
| 79 |
return output
|
|
@@ -90,10 +87,23 @@ class YOLO(nn.Module):
|
|
| 90 |
if layer_type == "IDetect":
|
| 91 |
return None
|
| 92 |
|
| 93 |
-
def
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 94 |
if layer_type in self.layer_map:
|
| 95 |
layer = self.layer_map[layer_type](**kwargs)
|
|
|
|
| 96 |
setattr(layer, "source", source)
|
|
|
|
| 97 |
setattr(layer, "output", layer_info.get("output", False))
|
| 98 |
setattr(layer, "tags", layer_info.get("tags", None))
|
| 99 |
return layer
|
|
|
|
| 4 |
from loguru import logger
|
| 5 |
from omegaconf import ListConfig, OmegaConf
|
| 6 |
|
| 7 |
+
from yolo.config.config import Config, Model, YOLOLayer
|
| 8 |
from yolo.tools.layer_helper import get_layer_map
|
| 9 |
+
from yolo.tools.log_helper import log_model
|
| 10 |
|
| 11 |
|
| 12 |
class YOLO(nn.Module):
|
|
|
|
| 22 |
super(YOLO, self).__init__()
|
| 23 |
self.num_classes = num_classes
|
| 24 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
| 25 |
+
self.model: List[YOLOLayer] = nn.ModuleList()
|
| 26 |
self.build_model(model_cfg.model)
|
| 27 |
+
log_model(self.model)
|
| 28 |
|
| 29 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
| 30 |
+
self.layer_index = {}
|
| 31 |
+
output_dim, layer_idx = [3], 1
|
|
|
|
|
|
|
| 32 |
logger.info(f"π Building YOLO")
|
| 33 |
for arch_name in model_arch:
|
| 34 |
logger.info(f" ποΈ Building {arch_name}")
|
|
|
|
| 37 |
layer_args = layer_info.get("args", {})
|
| 38 |
|
| 39 |
# Get input source
|
| 40 |
+
source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
# Find in channels
|
| 43 |
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
|
|
|
|
| 48 |
|
| 49 |
# create layers
|
| 50 |
layer = self.create_layer(layer_type, source, layer_info, **layer_args)
|
| 51 |
+
self.model.append(layer)
|
| 52 |
|
| 53 |
+
if layer.tags:
|
| 54 |
+
if layer.tags in self.layer_index:
|
| 55 |
raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.")
|
| 56 |
+
self.layer_index[layer.tags] = layer_idx
|
| 57 |
|
| 58 |
out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
|
| 59 |
output_dim.append(out_channels)
|
| 60 |
+
setattr(layer, "out_c", out_channels)
|
| 61 |
layer_idx += 1
|
| 62 |
|
|
|
|
|
|
|
| 63 |
def forward(self, x):
|
| 64 |
+
y = {0: x}
|
| 65 |
output = []
|
| 66 |
+
for index, layer in enumerate(self.model, start=1):
|
| 67 |
if isinstance(layer.source, list):
|
| 68 |
model_input = [y[idx] for idx in layer.source]
|
| 69 |
else:
|
| 70 |
model_input = y[layer.source]
|
| 71 |
x = layer(model_input)
|
| 72 |
+
if hasattr(layer, "save"):
|
| 73 |
+
y[index] = x
|
| 74 |
if layer.output:
|
| 75 |
output.append(x)
|
| 76 |
return output
|
|
|
|
| 87 |
if layer_type == "IDetect":
|
| 88 |
return None
|
| 89 |
|
| 90 |
+
def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int):
|
| 91 |
+
if isinstance(source, ListConfig):
|
| 92 |
+
return [self.get_source_idx(index, layer_idx) for index in source]
|
| 93 |
+
if isinstance(source, str):
|
| 94 |
+
source = self.layer_index[source]
|
| 95 |
+
if source < 0:
|
| 96 |
+
source += layer_idx
|
| 97 |
+
if source > 0:
|
| 98 |
+
setattr(self.model[source - 1], "save", True)
|
| 99 |
+
return source
|
| 100 |
+
|
| 101 |
+
def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
|
| 102 |
if layer_type in self.layer_map:
|
| 103 |
layer = self.layer_map[layer_type](**kwargs)
|
| 104 |
+
setattr(layer, "layer_type", layer_type)
|
| 105 |
setattr(layer, "source", source)
|
| 106 |
+
setattr(layer, "in_c", kwargs.get("in_channels", None))
|
| 107 |
setattr(layer, "output", layer_info.get("output", False))
|
| 108 |
setattr(layer, "tags", layer_info.get("tags", None))
|
| 109 |
return layer
|
yolo/tools/log_helper.py
CHANGED
|
@@ -12,8 +12,13 @@ Example:
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import sys
|
|
|
|
| 15 |
|
| 16 |
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def custom_logger():
|
|
@@ -22,3 +27,24 @@ def custom_logger():
|
|
| 22 |
sys.stderr,
|
| 23 |
format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
|
| 24 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import sys
|
| 15 |
+
from typing import List
|
| 16 |
|
| 17 |
from loguru import logger
|
| 18 |
+
from rich.console import Console
|
| 19 |
+
from rich.table import Table
|
| 20 |
+
|
| 21 |
+
from yolo.config.config import YOLOLayer
|
| 22 |
|
| 23 |
|
| 24 |
def custom_logger():
|
|
|
|
| 27 |
sys.stderr,
|
| 28 |
format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
|
| 29 |
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def log_model(model: List[YOLOLayer]):
|
| 33 |
+
console = Console()
|
| 34 |
+
table = Table(title="Model Layers")
|
| 35 |
+
|
| 36 |
+
table.add_column("Index", justify="center")
|
| 37 |
+
table.add_column("Layer Type", justify="center")
|
| 38 |
+
table.add_column("Tags", justify="center")
|
| 39 |
+
table.add_column("Params", justify="right")
|
| 40 |
+
table.add_column("Channels (IN->OUT)", justify="center")
|
| 41 |
+
|
| 42 |
+
for idx, layer in enumerate(model, start=1):
|
| 43 |
+
layer_param = sum(x.numel() for x in layer.parameters()) # number parameters
|
| 44 |
+
in_channels, out_channels = getattr(layer, "in_c", None), getattr(layer, "out_c", None)
|
| 45 |
+
if in_channels and out_channels:
|
| 46 |
+
channels = f"{in_channels:4} -> {out_channels:4}"
|
| 47 |
+
else:
|
| 48 |
+
channels = "-"
|
| 49 |
+
table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
|
| 50 |
+
console.print(table)
|
yolo/tools/module_helper.py
CHANGED
|
@@ -31,7 +31,7 @@ def get_activation(activation: str) -> nn.Module:
|
|
| 31 |
if isinstance(obj, type) and issubclass(obj, nn.Module)
|
| 32 |
}
|
| 33 |
if activation.lower() in activation_map:
|
| 34 |
-
return activation_map[activation.lower()]()
|
| 35 |
else:
|
| 36 |
raise ValueError(f"Activation function '{activation}' is not found in torch.nn")
|
| 37 |
|
|
|
|
| 31 |
if isinstance(obj, type) and issubclass(obj, nn.Module)
|
| 32 |
}
|
| 33 |
if activation.lower() in activation_map:
|
| 34 |
+
return activation_map[activation.lower()](inplace=True)
|
| 35 |
else:
|
| 36 |
raise ValueError(f"Activation function '{activation}' is not found in torch.nn")
|
| 37 |
|