π [Merge] branch 'SETUP' into MODELv2
Browse files- yolo/lazy.py +4 -5
- yolo/tools/solver.py +9 -11
- yolo/utils/logging_utils.py +6 -3
yolo/lazy.py
CHANGED
|
@@ -13,13 +13,12 @@ from yolo.tools.data_loader import create_dataloader
|
|
| 13 |
from yolo.tools.solver import ModelTester, ModelTrainer
|
| 14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
| 15 |
from yolo.utils.deploy_utils import FastModelLoader
|
| 16 |
-
from yolo.utils.logging_utils import
|
| 17 |
|
| 18 |
|
| 19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
| 20 |
def main(cfg: Config):
|
| 21 |
-
|
| 22 |
-
save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
| 23 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
| 24 |
device = torch.device(cfg.device)
|
| 25 |
if getattr(cfg.task, "fast_inference", False):
|
|
@@ -31,11 +30,11 @@ def main(cfg: Config):
|
|
| 31 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
| 32 |
|
| 33 |
if cfg.task.task == "train":
|
| 34 |
-
trainer = ModelTrainer(cfg, model, vec2box,
|
| 35 |
trainer.solve(dataloader)
|
| 36 |
|
| 37 |
if cfg.task.task == "inference":
|
| 38 |
-
tester = ModelTester(cfg, model, vec2box,
|
| 39 |
tester.solve(dataloader)
|
| 40 |
|
| 41 |
|
|
|
|
| 13 |
from yolo.tools.solver import ModelTester, ModelTrainer
|
| 14 |
from yolo.utils.bounding_box_utils import Vec2Box
|
| 15 |
from yolo.utils.deploy_utils import FastModelLoader
|
| 16 |
+
from yolo.utils.logging_utils import ProgressLogger
|
| 17 |
|
| 18 |
|
| 19 |
@hydra.main(config_path="config", config_name="config", version_base=None)
|
| 20 |
def main(cfg: Config):
|
| 21 |
+
progress = ProgressLogger(cfg, exp_name=cfg.name)
|
|
|
|
| 22 |
dataloader = create_dataloader(cfg.task.data, cfg.dataset, cfg.task.task)
|
| 23 |
device = torch.device(cfg.device)
|
| 24 |
if getattr(cfg.task, "fast_inference", False):
|
|
|
|
| 30 |
vec2box = Vec2Box(model, cfg.image_size, device)
|
| 31 |
|
| 32 |
if cfg.task.task == "train":
|
| 33 |
+
trainer = ModelTrainer(cfg, model, vec2box, progress, device)
|
| 34 |
trainer.solve(dataloader)
|
| 35 |
|
| 36 |
if cfg.task.task == "inference":
|
| 37 |
+
tester = ModelTester(cfg, model, vec2box, progress, device)
|
| 38 |
tester.solve(dataloader)
|
| 39 |
|
| 40 |
|
yolo/tools/solver.py
CHANGED
|
@@ -14,7 +14,7 @@ from yolo.tools.data_loader import StreamDataLoader, create_dataloader
|
|
| 14 |
from yolo.tools.drawer import draw_bboxes
|
| 15 |
from yolo.tools.loss_functions import create_loss_function
|
| 16 |
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
|
| 17 |
-
from yolo.utils.logging_utils import
|
| 18 |
from yolo.utils.model_utils import (
|
| 19 |
ExponentialMovingAverage,
|
| 20 |
create_optimizer,
|
|
@@ -23,7 +23,7 @@ from yolo.utils.model_utils import (
|
|
| 23 |
|
| 24 |
|
| 25 |
class ModelTrainer:
|
| 26 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box,
|
| 27 |
train_cfg: TrainConfig = cfg.task
|
| 28 |
self.model = model
|
| 29 |
self.vec2box = vec2box
|
|
@@ -31,11 +31,11 @@ class ModelTrainer:
|
|
| 31 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
| 32 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
| 33 |
self.loss_fn = create_loss_function(cfg, vec2box)
|
| 34 |
-
self.progress =
|
| 35 |
self.num_epochs = cfg.task.epoch
|
| 36 |
|
| 37 |
self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
| 38 |
-
self.validator = ModelValidator(cfg.task.validation, model, vec2box,
|
| 39 |
|
| 40 |
if getattr(train_cfg.ema, "enabled", False):
|
| 41 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
|
@@ -102,14 +102,15 @@ class ModelTrainer:
|
|
| 102 |
|
| 103 |
|
| 104 |
class ModelTester:
|
| 105 |
-
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box,
|
| 106 |
self.model = model
|
| 107 |
self.device = device
|
| 108 |
self.vec2box = vec2box
|
| 109 |
-
self.progress =
|
| 110 |
|
| 111 |
self.nms = cfg.task.nms
|
| 112 |
-
self.save_path = save_path
|
|
|
|
| 113 |
self.save_predict = getattr(cfg.task, "save_predict", None)
|
| 114 |
self.idx2label = cfg.class_list
|
| 115 |
|
|
@@ -164,16 +165,13 @@ class ModelValidator:
|
|
| 164 |
validation_cfg: ValidationConfig,
|
| 165 |
model: YOLO,
|
| 166 |
vec2box: Vec2Box,
|
| 167 |
-
save_path: str,
|
| 168 |
device,
|
| 169 |
-
|
| 170 |
-
progress: ProgressTracker,
|
| 171 |
):
|
| 172 |
self.model = model
|
| 173 |
self.vec2box = vec2box
|
| 174 |
self.device = device
|
| 175 |
self.progress = progress
|
| 176 |
-
self.save_path = save_path
|
| 177 |
|
| 178 |
self.nms = validation_cfg.nms
|
| 179 |
|
|
|
|
| 14 |
from yolo.tools.drawer import draw_bboxes
|
| 15 |
from yolo.tools.loss_functions import create_loss_function
|
| 16 |
from yolo.utils.bounding_box_utils import Vec2Box, bbox_nms, calculate_map
|
| 17 |
+
from yolo.utils.logging_utils import ProgressLogger
|
| 18 |
from yolo.utils.model_utils import (
|
| 19 |
ExponentialMovingAverage,
|
| 20 |
create_optimizer,
|
|
|
|
| 23 |
|
| 24 |
|
| 25 |
class ModelTrainer:
|
| 26 |
+
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
| 27 |
train_cfg: TrainConfig = cfg.task
|
| 28 |
self.model = model
|
| 29 |
self.vec2box = vec2box
|
|
|
|
| 31 |
self.optimizer = create_optimizer(model, train_cfg.optimizer)
|
| 32 |
self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler)
|
| 33 |
self.loss_fn = create_loss_function(cfg, vec2box)
|
| 34 |
+
self.progress = progress
|
| 35 |
self.num_epochs = cfg.task.epoch
|
| 36 |
|
| 37 |
self.validation_dataloader = create_dataloader(cfg.task.validation.data, cfg.dataset, cfg.task.validation.task)
|
| 38 |
+
self.validator = ModelValidator(cfg.task.validation, model, vec2box, progress, device, self.progress)
|
| 39 |
|
| 40 |
if getattr(train_cfg.ema, "enabled", False):
|
| 41 |
self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay)
|
|
|
|
| 102 |
|
| 103 |
|
| 104 |
class ModelTester:
|
| 105 |
+
def __init__(self, cfg: Config, model: YOLO, vec2box: Vec2Box, progress: ProgressLogger, device):
|
| 106 |
self.model = model
|
| 107 |
self.device = device
|
| 108 |
self.vec2box = vec2box
|
| 109 |
+
self.progress = progress
|
| 110 |
|
| 111 |
self.nms = cfg.task.nms
|
| 112 |
+
self.save_path = os.path.join(progress.save_path, "images")
|
| 113 |
+
os.makedirs(self.save_path, exist_ok=True)
|
| 114 |
self.save_predict = getattr(cfg.task, "save_predict", None)
|
| 115 |
self.idx2label = cfg.class_list
|
| 116 |
|
|
|
|
| 165 |
validation_cfg: ValidationConfig,
|
| 166 |
model: YOLO,
|
| 167 |
vec2box: Vec2Box,
|
|
|
|
| 168 |
device,
|
| 169 |
+
progress: ProgressLogger,
|
|
|
|
| 170 |
):
|
| 171 |
self.model = model
|
| 172 |
self.vec2box = vec2box
|
| 173 |
self.device = device
|
| 174 |
self.progress = progress
|
|
|
|
| 175 |
|
| 176 |
self.nms = validation_cfg.nms
|
| 177 |
|
yolo/utils/logging_utils.py
CHANGED
|
@@ -38,15 +38,18 @@ def custom_logger(quite: bool = False):
|
|
| 38 |
)
|
| 39 |
|
| 40 |
|
| 41 |
-
class
|
| 42 |
-
def __init__(self,
|
|
|
|
|
|
|
|
|
|
| 43 |
self.progress = Progress(
|
| 44 |
TextColumn("[progress.description]{task.description}"),
|
| 45 |
BarColumn(bar_width=None),
|
| 46 |
TextColumn("{task.completed:.0f}/{task.total:.0f}"),
|
| 47 |
TimeRemainingColumn(),
|
| 48 |
)
|
| 49 |
-
self.use_wandb = use_wandb
|
| 50 |
if self.use_wandb:
|
| 51 |
wandb.errors.term._log = custom_wandb_log
|
| 52 |
self.wandb = wandb.init(
|
|
|
|
| 38 |
)
|
| 39 |
|
| 40 |
|
| 41 |
+
class ProgressLogger:
|
| 42 |
+
def __init__(self, cfg: Config, exp_name: str):
|
| 43 |
+
custom_logger(getattr(cfg, "quite", False))
|
| 44 |
+
self.save_path = validate_log_directory(cfg, exp_name=cfg.name)
|
| 45 |
+
|
| 46 |
self.progress = Progress(
|
| 47 |
TextColumn("[progress.description]{task.description}"),
|
| 48 |
BarColumn(bar_width=None),
|
| 49 |
TextColumn("{task.completed:.0f}/{task.total:.0f}"),
|
| 50 |
TimeRemainingColumn(),
|
| 51 |
)
|
| 52 |
+
self.use_wandb = cfg.use_wandb
|
| 53 |
if self.use_wandb:
|
| 54 |
wandb.errors.term._log = custom_wandb_log
|
| 55 |
self.wandb = wandb.init(
|