🔧 [Add] TypeHint for yolo.model
Browse files- yolo/config/config.py +13 -0
- yolo/model/yolo.py +2 -1
yolo/config/config.py
CHANGED
|
@@ -1,6 +1,8 @@
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import Dict, List, Union
|
| 3 |
|
|
|
|
|
|
|
| 4 |
|
| 5 |
@dataclass
|
| 6 |
class AnchorConfig:
|
|
@@ -100,6 +102,17 @@ class Download:
|
|
| 100 |
datasets: Datasets
|
| 101 |
|
| 102 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 103 |
@dataclass
|
| 104 |
class Config:
|
| 105 |
model: Model
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
from typing import Dict, List, Union
|
| 3 |
|
| 4 |
+
from torch import nn
|
| 5 |
+
|
| 6 |
|
| 7 |
@dataclass
|
| 8 |
class AnchorConfig:
|
|
|
|
| 102 |
datasets: Datasets
|
| 103 |
|
| 104 |
|
| 105 |
+
@dataclass
|
| 106 |
+
class YOLOLayer(nn.Module):
|
| 107 |
+
source: Union[int, str, List[int]]
|
| 108 |
+
output: bool
|
| 109 |
+
tags: str
|
| 110 |
+
layer_type: str
|
| 111 |
+
|
| 112 |
+
def __post_init__(self):
|
| 113 |
+
super().__init__()
|
| 114 |
+
|
| 115 |
+
|
| 116 |
@dataclass
|
| 117 |
class Config:
|
| 118 |
model: Model
|
yolo/model/yolo.py
CHANGED
|
@@ -4,7 +4,7 @@ import torch.nn as nn
|
|
| 4 |
from loguru import logger
|
| 5 |
from omegaconf import ListConfig, OmegaConf
|
| 6 |
|
| 7 |
-
from yolo.config.config import Config, Model
|
| 8 |
from yolo.tools.layer_helper import get_layer_map
|
| 9 |
|
| 10 |
|
|
@@ -21,6 +21,7 @@ class YOLO(nn.Module):
|
|
| 21 |
super(YOLO, self).__init__()
|
| 22 |
self.num_classes = num_classes
|
| 23 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
|
|
|
| 24 |
self.build_model(model_cfg.model)
|
| 25 |
|
| 26 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|
|
|
|
| 4 |
from loguru import logger
|
| 5 |
from omegaconf import ListConfig, OmegaConf
|
| 6 |
|
| 7 |
+
from yolo.config.config import Config, Model, YOLOLayer
|
| 8 |
from yolo.tools.layer_helper import get_layer_map
|
| 9 |
|
| 10 |
|
|
|
|
| 21 |
super(YOLO, self).__init__()
|
| 22 |
self.num_classes = num_classes
|
| 23 |
self.layer_map = get_layer_map() # Get the map Dict[str: Module]
|
| 24 |
+
self.model: List[YOLOLayer] = nn.ModuleList()
|
| 25 |
self.build_model(model_cfg.model)
|
| 26 |
|
| 27 |
def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]):
|