π [Fix] some bugs, fit the create_model, device
Browse files- yolo/lazy.py +1 -1
- yolo/model/yolo.py +3 -3
- yolo/tools/solver.py +1 -2
yolo/lazy.py
CHANGED
|
@@ -22,7 +22,7 @@ def main(cfg: Config):
|
|
| 22 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
| 23 |
device = torch.device(cfg.device)
|
| 24 |
if getattr(cfg.task, "fast_inference", False):
|
| 25 |
-
model = FastModelLoader(cfg).load_model()
|
| 26 |
device = torch.device(cfg.device)
|
| 27 |
else:
|
| 28 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
|
|
|
|
| 22 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
| 23 |
device = torch.device(cfg.device)
|
| 24 |
if getattr(cfg.task, "fast_inference", False):
|
| 25 |
+
model = FastModelLoader(cfg, device).load_model()
|
| 26 |
device = torch.device(cfg.device)
|
| 27 |
else:
|
| 28 |
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight, device=device)
|
yolo/model/yolo.py
CHANGED
|
@@ -43,7 +43,7 @@ class YOLO(nn.Module):
|
|
| 43 |
source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
|
| 44 |
|
| 45 |
# Find in channels
|
| 46 |
-
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "CBLinear"]):
|
| 47 |
layer_args["in_channels"] = output_dim[source]
|
| 48 |
if "Detection" in layer_type:
|
| 49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
|
@@ -81,7 +81,7 @@ class YOLO(nn.Module):
|
|
| 81 |
return output
|
| 82 |
|
| 83 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
| 84 |
-
if any(module in layer_type for module in ["Conv", "ELAN", "ADown"]):
|
| 85 |
return layer_args["out_channels"]
|
| 86 |
if layer_type == "CBFuse":
|
| 87 |
return output_dim[source[-1]]
|
|
@@ -134,7 +134,7 @@ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], device: dev
|
|
| 134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
| 135 |
prepare_weight(weight_path=weight_path)
|
| 136 |
if os.path.exists(weight_path):
|
| 137 |
-
model.model.load_state_dict(torch.load(weight_path, map_location=device))
|
| 138 |
logger.info("β
Success load model weight")
|
| 139 |
|
| 140 |
log_model_structure(model.model)
|
|
|
|
| 43 |
source = self.get_source_idx(layer_info.get("source", -1), layer_idx)
|
| 44 |
|
| 45 |
# Find in channels
|
| 46 |
+
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv", "CBLinear"]):
|
| 47 |
layer_args["in_channels"] = output_dim[source]
|
| 48 |
if "Detection" in layer_type:
|
| 49 |
layer_args["in_channels"] = [output_dim[idx] for idx in source]
|
|
|
|
| 81 |
return output
|
| 82 |
|
| 83 |
def get_out_channels(self, layer_type: str, layer_args: dict, output_dim: list, source: Union[int, list]):
|
| 84 |
+
if any(module in layer_type for module in ["Conv", "ELAN", "ADown", "AConv"]):
|
| 85 |
return layer_args["out_channels"]
|
| 86 |
if layer_type == "CBFuse":
|
| 87 |
return output_dim[source[-1]]
|
|
|
|
| 134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
| 135 |
prepare_weight(weight_path=weight_path)
|
| 136 |
if os.path.exists(weight_path):
|
| 137 |
+
model.model.load_state_dict(torch.load(weight_path, map_location=device), strict=False)
|
| 138 |
logger.info("β
Success load model weight")
|
| 139 |
|
| 140 |
log_model_structure(model.model)
|
yolo/tools/solver.py
CHANGED
|
@@ -143,8 +143,7 @@ class ModelTester:
|
|
| 143 |
break
|
| 144 |
if not self.save_predict:
|
| 145 |
continue
|
| 146 |
-
|
| 147 |
-
if self.save_predict == False:
|
| 148 |
save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
|
| 149 |
img.save(save_image_path)
|
| 150 |
logger.info(f"πΎ Saved visualize image at {save_image_path}")
|
|
|
|
| 143 |
break
|
| 144 |
if not self.save_predict:
|
| 145 |
continue
|
| 146 |
+
if self.save_predict != False:
|
|
|
|
| 147 |
save_image_path = os.path.join(self.save_path, f"frame{idx:03d}.png")
|
| 148 |
img.save(save_image_path)
|
| 149 |
logger.info(f"πΎ Saved visualize image at {save_image_path}")
|