🔧 [Update] Trainer input config
Browse files- examples/example_train.py +1 -1
- yolo/tools/trainer.py +3 -1
examples/example_train.py
CHANGED
|
@@ -28,7 +28,7 @@ def main(cfg: Config):
|
|
| 28 |
# TODO: get_device or rank, for DDP mode
|
| 29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
|
| 31 |
-
trainer = Trainer(model, cfg
|
| 32 |
trainer.train(dataloader, 10)
|
| 33 |
|
| 34 |
|
|
|
|
| 28 |
# TODO: get_device or rank, for DDP mode
|
| 29 |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 30 |
|
| 31 |
+
trainer = Trainer(model, cfg, device)
|
| 32 |
trainer.train(dataloader, 10)
|
| 33 |
|
| 34 |
|
yolo/tools/trainer.py
CHANGED
|
@@ -9,7 +9,9 @@ from yolo.utils.loss import get_loss_function
|
|
| 9 |
|
| 10 |
|
| 11 |
class Trainer:
|
| 12 |
-
def __init__(self, model: YOLO,
|
|
|
|
|
|
|
| 13 |
self.model = model.to(device)
|
| 14 |
self.device = device
|
| 15 |
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class Trainer:
|
| 12 |
+
def __init__(self, model: YOLO, cfg: Config, device):
|
| 13 |
+
train_cfg: TrainConfig = cfg.hyper.train
|
| 14 |
+
|
| 15 |
self.model = model.to(device)
|
| 16 |
self.device = device
|
| 17 |
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|