Spaces:
Sleeping
Sleeping
✨ [Add] model logger to print model as table
Browse files- yolo/model/yolo.py +5 -0
- yolo/tools/log_helper.py +26 -0
yolo/model/yolo.py
CHANGED
|
@@ -6,6 +6,7 @@ from omegaconf import ListConfig, OmegaConf
|
|
| 6 |
|
| 7 |
from yolo.config.config import Config, Model, YOLOLayer
|
| 8 |
from yolo.tools.layer_helper import get_layer_map
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class YOLO(nn.Module):
|
|
@@ -23,6 +24,7 @@ class YOLO(nn.Module):
|
|
| 23 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
| 24 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
| 25 |
self.build_model(model_cfg.model)
|
|
|
|
| 26 |
|
| 27 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
| 28 |
self.layer_index = {}
|
|
@@ -55,6 +57,7 @@ class YOLO(nn.Module):
|
|
| 55 |
|
| 56 |
out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
|
| 57 |
output_dim.append(out_channels)
|
|
|
|
| 58 |
layer_idx += 1
|
| 59 |
|
| 60 |
def forward(self, x):
|
|
@@ -98,7 +101,9 @@ class YOLO(nn.Module):
|
|
| 98 |
def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
|
| 99 |
if layer_type in self.layer_map:
|
| 100 |
layer = self.layer_map[layer_type](**kwargs)
|
|
|
|
| 101 |
setattr(layer, "source", source)
|
|
|
|
| 102 |
setattr(layer, "output", layer_info.get("output", False))
|
| 103 |
setattr(layer, "tags", layer_info.get("tags", None))
|
| 104 |
return layer
|
|
|
|
| 6 |
|
| 7 |
from yolo.config.config import Config, Model, YOLOLayer
|
| 8 |
from yolo.tools.layer_helper import get_layer_map
|
| 9 |
+
from yolo.tools.log_helper import log_model
|
| 10 |
|
| 11 |
|
| 12 |
class YOLO(nn.Module):
|
|
|
|
| 24 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
| 25 |
self.model: List[YOLOLayer] = nn.ModuleList()
|
| 26 |
self.build_model(model_cfg.model)
|
| 27 |
+
log_model(self.model)
|
| 28 |
|
| 29 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
| 30 |
self.layer_index = {}
|
|
|
|
| 57 |
|
| 58 |
out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source)
|
| 59 |
output_dim.append(out_channels)
|
| 60 |
+
setattr(layer, "out_c", out_channels)
|
| 61 |
layer_idx += 1
|
| 62 |
|
| 63 |
def forward(self, x):
|
|
|
|
| 101 |
def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer:
|
| 102 |
if layer_type in self.layer_map:
|
| 103 |
layer = self.layer_map[layer_type](**kwargs)
|
| 104 |
+
setattr(layer, "layer_type", layer_type)
|
| 105 |
setattr(layer, "source", source)
|
| 106 |
+
setattr(layer, "in_c", kwargs.get("in_channels", None))
|
| 107 |
setattr(layer, "output", layer_info.get("output", False))
|
| 108 |
setattr(layer, "tags", layer_info.get("tags", None))
|
| 109 |
return layer
|
yolo/tools/log_helper.py
CHANGED
|
@@ -12,8 +12,13 @@ Example:
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import sys
|
|
|
|
| 15 |
|
| 16 |
from loguru import logger
|
|
|
|
|
|
|
|
|
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
def custom_logger():
|
|
@@ -22,3 +27,24 @@ def custom_logger():
|
|
| 22 |
sys.stderr,
|
| 23 |
format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
|
| 24 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import sys
|
| 15 |
+
from typing import List
|
| 16 |
|
| 17 |
from loguru import logger
|
| 18 |
+
from rich.console import Console
|
| 19 |
+
from rich.table import Table
|
| 20 |
+
|
| 21 |
+
from yolo.config.config import YOLOLayer
|
| 22 |
|
| 23 |
|
| 24 |
def custom_logger():
|
|
|
|
| 27 |
sys.stderr,
|
| 28 |
format="<green>{time:MM-DD HH:mm:ss}</green> | <level>{level: <8}</level> | <level>{message}</level>",
|
| 29 |
)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def log_model(model: List[YOLOLayer]):
|
| 33 |
+
console = Console()
|
| 34 |
+
table = Table(title="Model Layers")
|
| 35 |
+
|
| 36 |
+
table.add_column("Index", justify="center")
|
| 37 |
+
table.add_column("Layer Type", justify="center")
|
| 38 |
+
table.add_column("Tags", justify="center")
|
| 39 |
+
table.add_column("Params", justify="right")
|
| 40 |
+
table.add_column("Channels (IN->OUT)", justify="center")
|
| 41 |
+
|
| 42 |
+
for idx, layer in enumerate(model, start=1):
|
| 43 |
+
layer_param = sum(x.numel() for x in layer.parameters()) # number parameters
|
| 44 |
+
in_channels, out_channels = getattr(layer, "in_c", None), getattr(layer, "out_c", None)
|
| 45 |
+
if in_channels and out_channels:
|
| 46 |
+
channels = f"{in_channels:4} -> {out_channels:4}"
|
| 47 |
+
else:
|
| 48 |
+
channels = "-"
|
| 49 |
+
table.add_row(str(idx), layer.layer_type, layer.tags, f"{layer_param:,}", channels)
|
| 50 |
+
console.print(table)
|