π§ [WIP] DDP for model training
Browse files- yolo/lazy.py +4 -5
- yolo/model/yolo.py +5 -4
- yolo/utils/deploy_utils.py +1 -2
- yolo/utils/model_utils.py +33 -1
yolo/lazy.py
CHANGED
|
@@ -14,19 +14,18 @@ from yolo.tools.solver import ModelTester, ModelTrainer
|
|
| 14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
| 15 |
from yolo.utils.deploy_utils import FastModelLoader
|
| 16 |
from yolo.utils.logging_utils import ProgressLogger
|
|
|
|
| 17 |
|
| 18 |
|
| 19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
| 20 |
def main(cfg: Config):
|
| 21 |
progress = ProgressLogger(cfg, exp_name=cfg.name)
|
| 22 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
| 23 |
-
device = torch.device(cfg.device)
|
| 24 |
if getattr(cfg.task, "fast_inference", False):
|
| 25 |
-
model = FastModelLoader(cfg
|
| 26 |
-
device = torch.device(cfg.device)
|
| 27 |
else:
|
| 28 |
-
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight
|
| 29 |
-
|
| 30 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
| 31 |
|
| 32 |
if cfg.task.task == "train":
|
|
|
|
| 14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
| 15 |
from yolo.utils.deploy_utils import FastModelLoader
|
| 16 |
from yolo.utils.logging_utils import ProgressLogger
|
| 17 |
+
from yolo.utils.model_utils import send_to_device
|
| 18 |
|
| 19 |
|
| 20 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
| 21 |
def main(cfg: Config):
|
| 22 |
progress = ProgressLogger(cfg, exp_name=cfg.name)
|
| 23 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
|
|
|
| 24 |
if getattr(cfg.task, "fast_inference", False):
|
| 25 |
+
model = FastModelLoader(cfg).load_model()
|
|
|
|
| 26 |
else:
|
| 27 |
+
model = create_model(cfg.model, class_num=cfg.class_num, weight_path=cfg.weight)
|
| 28 |
+
device, model = send_to_device(model, cfg.device)
|
| 29 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
| 30 |
|
| 31 |
if cfg.task.task == "train":
|
yolo/model/yolo.py
CHANGED
|
@@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union
|
|
| 4 |
import torch
|
| 5 |
from loguru import logger
|
| 6 |
from omegaconf import ListConfig, OmegaConf
|
| 7 |
-
from torch import
|
| 8 |
|
| 9 |
from yolo.config.config import Config, ModelConfig, YOLOLayer
|
| 10 |
from yolo.tools.dataset_preparation import prepare_weight
|
|
@@ -117,7 +117,7 @@ class YOLO(nn.Module):
|
|
| 117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
| 118 |
|
| 119 |
|
| 120 |
-
def create_model(model_cfg: ModelConfig, weight_path: Optional[str],
|
| 121 |
"""Constructs and returns a model from a Dictionary configuration file.
|
| 122 |
|
| 123 |
Args:
|
|
@@ -134,9 +134,10 @@ def create_model(model_cfg: ModelConfig, weight_path: Optional[str], device: dev
|
|
| 134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
| 135 |
prepare_weight(weight_path=weight_path)
|
| 136 |
if os.path.exists(weight_path):
|
| 137 |
-
|
|
|
|
| 138 |
logger.info("β
Success load model weight")
|
| 139 |
|
| 140 |
log_model_structure(model.model)
|
| 141 |
draw_model(model=model)
|
| 142 |
-
return model
|
|
|
|
| 4 |
import torch
|
| 5 |
from loguru import logger
|
| 6 |
from omegaconf import ListConfig, OmegaConf
|
| 7 |
+
from torch import nn
|
| 8 |
|
| 9 |
from yolo.config.config import Config, ModelConfig, YOLOLayer
|
| 10 |
from yolo.tools.dataset_preparation import prepare_weight
|
|
|
|
| 117 |
raise ValueError(f"Unsupported layer type: {layer_type}")
|
| 118 |
|
| 119 |
|
| 120 |
+
def create_model(model_cfg: ModelConfig, weight_path: Optional[str], class_num: int = 80) -> YOLO:
|
| 121 |
"""Constructs and returns a model from a Dictionary configuration file.
|
| 122 |
|
| 123 |
Args:
|
|
|
|
| 134 |
logger.info(f"π Weight {weight_path} not found, try downloading")
|
| 135 |
prepare_weight(weight_path=weight_path)
|
| 136 |
if os.path.exists(weight_path):
|
| 137 |
+
# TODO: fix map_location
|
| 138 |
+
model.model.load_state_dict(torch.load(weight_path), strict=False)
|
| 139 |
logger.info("β
Success load model weight")
|
| 140 |
|
| 141 |
log_model_structure(model.model)
|
| 142 |
draw_model(model=model)
|
| 143 |
+
return model
|
yolo/utils/deploy_utils.py
CHANGED
|
@@ -9,9 +9,8 @@ from yolo.model.yolo import create_model
|
|
| 9 |
|
| 10 |
|
| 11 |
class FastModelLoader:
|
| 12 |
-
def __init__(self, cfg: Config
|
| 13 |
self.cfg = cfg
|
| 14 |
-
self.device = device
|
| 15 |
self.compiler = cfg.task.fast_inference
|
| 16 |
self._validate_compiler()
|
| 17 |
self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class FastModelLoader:
|
| 12 |
+
def __init__(self, cfg: Config):
|
| 13 |
self.cfg = cfg
|
|
|
|
| 14 |
self.compiler = cfg.task.fast_inference
|
| 15 |
self._validate_compiler()
|
| 16 |
self.model_path = f"{os.path.splitext(cfg.weight)[0]}.{self.compiler}"
|
yolo/utils/model_utils.py
CHANGED
|
@@ -1,6 +1,10 @@
|
|
| 1 |
-
from typing import Any, Dict, Type
|
| 2 |
|
| 3 |
import torch
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
from torch.optim import Optimizer
|
| 5 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
| 6 |
|
|
@@ -67,3 +71,31 @@ def create_scheduler(optimizer: Optimizer, schedule_cfg: SchedulerConfig) -> _LR
|
|
| 67 |
warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
|
| 68 |
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
|
| 69 |
return schedule
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Type, Union
|
| 2 |
|
| 3 |
import torch
|
| 4 |
+
import torch.distributed as dist
|
| 5 |
+
from omegaconf import ListConfig
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
| 8 |
from torch.optim import Optimizer
|
| 9 |
from torch.optim.lr_scheduler import LambdaLR, SequentialLR, _LRScheduler
|
| 10 |
|
|
|
|
| 71 |
warmup_schedule = LambdaLR(optimizer, lr_lambda=[lambda1, lambda2, lambda1])
|
| 72 |
schedule = SequentialLR(optimizer, schedulers=[warmup_schedule, schedule], milestones=[2])
|
| 73 |
return schedule
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def get_device():
|
| 77 |
+
if torch.cuda.is_available():
|
| 78 |
+
return torch.device("cuda")
|
| 79 |
+
elif torch.backends.mps.is_available():
|
| 80 |
+
return torch.device("mps")
|
| 81 |
+
else:
|
| 82 |
+
return torch.device("cpu")
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def send_to_device(model: nn.Module, device: Union[str, int, List[int]]):
|
| 86 |
+
if not isinstance(device, (List, ListConfig)):
|
| 87 |
+
device = torch.device(device)
|
| 88 |
+
print("runing man")
|
| 89 |
+
return device, model.to(device)
|
| 90 |
+
|
| 91 |
+
device = torch.device("cuda")
|
| 92 |
+
world_size = dist.get_world_size()
|
| 93 |
+
print("runing man")
|
| 94 |
+
dist.init_process_group(
|
| 95 |
+
backend="gloo" if torch.cuda.is_available() else "gloo", rank=dist.get_rank(), world_size=world_size
|
| 96 |
+
)
|
| 97 |
+
print(f"Initialized process group; rank: {dist.get_rank()}, size: {world_size}")
|
| 98 |
+
|
| 99 |
+
model = model.cuda(device)
|
| 100 |
+
model = DDP(model, device_ids=[device])
|
| 101 |
+
return device, model.to(device)
|