Spaces:
Running
Running
| from collections import OrderedDict | |
| from pathlib import Path | |
| from typing import Dict, List, Union | |
| import torch | |
| from omegaconf import ListConfig, OmegaConf | |
| from torch import nn | |
| from yolo.config.config import ModelConfig, YOLOLayer | |
| from yolo.tools.dataset_preparation import prepare_weight | |
| from yolo.utils.logger import logger | |
| from yolo.utils.module_utils import get_layer_map | |
| class YOLO(nn.Module): | |
| """ | |
| A preliminary YOLO (You Only Look Once) model class still under development. | |
| Parameters: | |
| model_cfg: Configuration for the YOLO model. Expected to define the layers, | |
| parameters, and any other relevant configuration details. | |
| """ | |
| def __init__(self, model_cfg: ModelConfig, class_num: int = 80): | |
| super(YOLO, self).__init__() | |
| self.num_classes = class_num | |
| self.layer_map = get_layer_map() # Get the map Dict[str: Module] | |
| self.model: List[YOLOLayer] = nn.ModuleList() | |
| self.reg_max = getattr(model_cfg.anchor, "reg_max", 16) | |
| self.build_model(model_cfg.model) | |
| def build_model(self, model_arch: Dict[str, List[Dict[str, Dict[str, Dict]]]]): | |
| self.layer_index = {} | |
| output_dim, layer_idx = [3], 1 | |
| logger.info(f":tractor: Building YOLO") | |
| for arch_name in model_arch: | |
| if model_arch[arch_name]: | |
| logger.info(f" :building_construction: Building {arch_name}") | |
| for layer_idx, layer_spec in enumerate(model_arch[arch_name], start=layer_idx): | |
| layer_type, layer_info = next(iter(layer_spec.items())) | |
| layer_args = layer_info.get("args", {}) | |
| # Get input source | |
| source = self.get_source_idx(layer_info.get("source", -1), layer_idx) | |
| # Find in channels | |
| if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]): | |
| layer_args["in_channels"] = output_dim[source] | |
| if any(module in layer_type for module in ["Detection", "Segmentation", "Classification"]): | |
| if isinstance(source, list): | |
| layer_args["in_channels"] = [output_dim[idx] for idx in source] | |
| else: | |
| layer_args["in_channel"] = output_dim[source] | |
| layer_args["num_classes"] = self.num_classes | |
| layer_args["reg_max"] = self.reg_max | |
| # create layers | |
| layer = self.create_layer(layer_type, source, layer_info, **layer_args) | |
| self.model.append(layer) | |
| if layer.tags: | |
| if layer.tags in self.layer_index: | |
| raise ValueError(f"Duplicate tag '{layer_info['tags']}' found.") | |
| self.layer_index[layer.tags] = layer_idx | |
| out_channels = self.get_out_channels(layer_type, layer_args, output_dim, source) | |
| output_dim.append(out_channels) | |
| setattr(layer, "out_c", out_channels) | |
| layer_idx += 1 | |
| def forward(self, x): | |
| y = {0: x} | |
| output = dict() | |
| for index, layer in enumerate(self.model, start=1): | |
| if isinstance(layer.source, list): | |
| model_input = [y[idx] for idx in layer.source] | |
| else: | |
| model_input = y[layer.source] | |
| x = layer(model_input) | |
| y[-1] = x | |
| if layer.usable: | |
| y[index] = x | |
| if layer.output: | |
| output[layer.tags] = x | |
| return output | |
| def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]): | |
| if hasattr(layer_args, "out_channels"): | |
| return layer_args["out_channels"] | |
| if layer_type == "CBFuse": | |
| return output_dim[source[-1]] | |
| if isinstance(source, int): | |
| return output_dim[source] | |
| if isinstance(source, list): | |
| return sum(output_dim[idx] for idx in source) | |
| def get_source_idx(self, source: Union[ListConfig, str, int], layer_idx: int): | |
| if isinstance(source, ListConfig): | |
| return [self.get_source_idx(index, layer_idx) for index in source] | |
| if isinstance(source, str): | |
| source = self.layer_index[source] | |
| if source < -1: | |
| source += layer_idx | |
| if source > 0: # Using Previous Layer's Output | |
| self.model[source - 1].usable = True | |
| return source | |
| def create_layer(self, layer_type: str, source: Union[int, list], layer_info: Dict, **kwargs) -> YOLOLayer: | |
| if layer_type in self.layer_map: | |
| layer = self.layer_map[layer_type](**kwargs) | |
| setattr(layer, "layer_type", layer_type) | |
| setattr(layer, "source", source) | |
| setattr(layer, "in_c", kwargs.get("in_channels", None)) | |
| setattr(layer, "output", layer_info.get("output", False)) | |
| setattr(layer, "tags", layer_info.get("tags", None)) | |
| setattr(layer, "usable", 0) | |
| return layer | |
| else: | |
| raise ValueError(f"Unsupported layer type: {layer_type}") | |
| def save_load_weights(self, weights: Union[Path, OrderedDict]): | |
| """ | |
| Update the model's weights with the provided weights. | |
| args: | |
| weights: A OrderedDict containing the new weights. | |
| """ | |
| if isinstance(weights, Path): | |
| weights = torch.load(weights, map_location=torch.device("cpu"), weights_only=False) | |
| if "model_state_dict" in weights: | |
| weights = weights["model_state_dict"] | |
| model_state_dict = self.model.state_dict() | |
| # TODO1: autoload old version weight | |
| # TODO2: weight transform if num_class difference | |
| error_dict = {"Mismatch": set(), "Not Found": set()} | |
| for model_key, model_weight in model_state_dict.items(): | |
| if model_key not in weights: | |
| error_dict["Not Found"].add(tuple(model_key.split(".")[:-2])) | |
| continue | |
| if model_weight.shape != weights[model_key].shape: | |
| error_dict["Mismatch"].add(tuple(model_key.split(".")[:-2])) | |
| continue | |
| model_state_dict[model_key] = weights[model_key] | |
| for error_name, error_set in error_dict.items(): | |
| for weight_name in error_set: | |
| logger.warning(f":warning: Weight {error_name} for key: {'.'.join(weight_name)}") | |
| self.model.load_state_dict(model_state_dict) | |
| def create_model(model_cfg: ModelConfig, weight_path: Union[bool, Path] = True, class_num: int = 80) -> YOLO: | |
| """Constructs and returns a model from a Dictionary configuration file. | |
| Args: | |
| config_file (dict): The configuration file of the model. | |
| Returns: | |
| YOLO: An instance of the model defined by the given configuration. | |
| """ | |
| OmegaConf.set_struct(model_cfg, False) | |
| model = YOLO(model_cfg, class_num) | |
| if weight_path: | |
| if weight_path == True: | |
| weight_path = Path("weights") / f"{model_cfg.name}.pt" | |
| elif isinstance(weight_path, str): | |
| weight_path = Path(weight_path) | |
| if not weight_path.exists(): | |
| logger.info(f"๐ Weight {weight_path} not found, try downloading") | |
| prepare_weight(weight_path=weight_path) | |
| if weight_path.exists(): | |
| model.save_load_weights(weight_path) | |
| logger.info(":white_check_mark: Success load model & weight") | |
| else: | |
| logger.info(":white_check_mark: Success load model") | |
| return model | |