Spaces:
Sleeping
Sleeping
✨ [New] wandb, progress class for handle proccess
Browse files- requirements.txt +3 -1
- yolo/tools/log_helper.py +31 -9
- yolo/tools/trainer.py +9 -9
- yolo/utils/loss.py +9 -9
requirements.txt
CHANGED
|
@@ -1,4 +1,5 @@
|
|
| 1 |
einops
|
|
|
|
| 2 |
hydra-core
|
| 3 |
loguru
|
| 4 |
numpy
|
|
@@ -9,4 +10,5 @@ requests
|
|
| 9 |
rich
|
| 10 |
torch
|
| 11 |
torchvision
|
| 12 |
-
tqdm
|
|
|
|
|
|
| 1 |
einops
|
| 2 |
+
graphviz
|
| 3 |
hydra-core
|
| 4 |
loguru
|
| 5 |
numpy
|
|
|
|
| 10 |
rich
|
| 11 |
torch
|
| 12 |
torchvision
|
| 13 |
+
tqdm
|
| 14 |
+
wandb
|
yolo/tools/log_helper.py
CHANGED
|
@@ -12,32 +12,39 @@ Example:
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import sys
|
| 15 |
-
from typing import List
|
| 16 |
|
|
|
|
|
|
|
| 17 |
from loguru import logger
|
| 18 |
from rich.console import Console
|
| 19 |
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
| 20 |
from rich.table import Table
|
|
|
|
| 21 |
|
| 22 |
-
from yolo.config.config import YOLOLayer
|
| 23 |
|
| 24 |
|
| 25 |
def custom_logger():
|
| 26 |
logger.remove()
|
| 27 |
logger.add(
|
| 28 |
sys.stderr,
|
| 29 |
-
format="<
|
| 30 |
)
|
| 31 |
|
| 32 |
|
| 33 |
class CustomProgress:
|
| 34 |
-
def __init__(self):
|
| 35 |
self.progress = Progress(
|
| 36 |
TextColumn("[progress.description]{task.description}"),
|
| 37 |
BarColumn(bar_width=None),
|
| 38 |
TextColumn("{task.completed}/{task.total}"),
|
| 39 |
TimeRemainingColumn(),
|
| 40 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
def start_train(self, num_epochs: int):
|
| 43 |
self.task_epoch = self.progress.add_task("[cyan]Epochs", total=num_epochs)
|
|
@@ -45,19 +52,34 @@ class CustomProgress:
|
|
| 45 |
def one_epoch(self):
|
| 46 |
self.progress.update(self.task_epoch, advance=1)
|
| 47 |
|
|
|
|
|
|
|
|
|
|
| 48 |
def start_batch(self, num_batches):
|
| 49 |
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
| 50 |
|
| 51 |
-
def one_batch(self,
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 56 |
|
| 57 |
def finish_batch(self):
|
| 58 |
self.progress.remove_task(self.batch_task)
|
| 59 |
|
| 60 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 61 |
def log_model(model: List[YOLOLayer]):
|
| 62 |
console = Console()
|
| 63 |
table = Table(title="Model Layers")
|
|
|
|
| 12 |
"""
|
| 13 |
|
| 14 |
import sys
|
| 15 |
+
from typing import Dict, List
|
| 16 |
|
| 17 |
+
import wandb
|
| 18 |
+
import wandb.errors
|
| 19 |
from loguru import logger
|
| 20 |
from rich.console import Console
|
| 21 |
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
| 22 |
from rich.table import Table
|
| 23 |
+
from torch import Tensor
|
| 24 |
|
| 25 |
+
from yolo.config.config import Config, YOLOLayer
|
| 26 |
|
| 27 |
|
| 28 |
def custom_logger():
|
| 29 |
logger.remove()
|
| 30 |
logger.add(
|
| 31 |
sys.stderr,
|
| 32 |
+
format="<fg #003385>[{time:MM/DD HH:mm:ss}]</fg #003385><level>{level: ^8}</level>| <level>{message}</level>",
|
| 33 |
)
|
| 34 |
|
| 35 |
|
| 36 |
class CustomProgress:
|
| 37 |
+
def __init__(self, cfg: Config, use_wandb: bool = False):
|
| 38 |
self.progress = Progress(
|
| 39 |
TextColumn("[progress.description]{task.description}"),
|
| 40 |
BarColumn(bar_width=None),
|
| 41 |
TextColumn("{task.completed}/{task.total}"),
|
| 42 |
TimeRemainingColumn(),
|
| 43 |
)
|
| 44 |
+
self.use_wandb = use_wandb
|
| 45 |
+
if self.use_wandb:
|
| 46 |
+
wandb.errors.term._log = custom_wandb_log
|
| 47 |
+
self.wandb = wandb.init(project="YOLO", resume="allow", mode="online", dir="runs", name=cfg.name)
|
| 48 |
|
| 49 |
def start_train(self, num_epochs: int):
|
| 50 |
self.task_epoch = self.progress.add_task("[cyan]Epochs", total=num_epochs)
|
|
|
|
| 52 |
def one_epoch(self):
|
| 53 |
self.progress.update(self.task_epoch, advance=1)
|
| 54 |
|
| 55 |
+
def finish_epoch(self):
|
| 56 |
+
self.wandb.finish()
|
| 57 |
+
|
| 58 |
def start_batch(self, num_batches):
|
| 59 |
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
| 60 |
|
| 61 |
+
def one_batch(self, loss_dict: Dict[str, Tensor]):
|
| 62 |
+
if self.use_wandb:
|
| 63 |
+
for loss_name, loss_value in loss_dict.items():
|
| 64 |
+
self.wandb.log({f"Loss/{loss_name}": loss_value})
|
| 65 |
+
|
| 66 |
+
loss_str = "Loss"
|
| 67 |
+
for loss_name, loss_val in loss_dict.items():
|
| 68 |
+
loss_str += f" {loss_name[:-4]}: {loss_val:.2f} |"
|
| 69 |
+
|
| 70 |
+
self.progress.update(self.batch_task, advance=1, description=f"[green]Batches [white]{loss_str}")
|
| 71 |
|
| 72 |
def finish_batch(self):
|
| 73 |
self.progress.remove_task(self.batch_task)
|
| 74 |
|
| 75 |
|
| 76 |
+
def custom_wandb_log(string="", level=int, newline=True, repeat=True, prefix=True, silent=False):
|
| 77 |
+
if silent:
|
| 78 |
+
return
|
| 79 |
+
for line in string.split("\n"):
|
| 80 |
+
logger.opt(raw=not newline).info("🌐 " + line)
|
| 81 |
+
|
| 82 |
+
|
| 83 |
def log_model(model: List[YOLOLayer]):
|
| 84 |
console = Console()
|
| 85 |
table = Table(title="Model Layers")
|
yolo/tools/trainer.py
CHANGED
|
@@ -21,6 +21,7 @@ class Trainer:
|
|
| 21 |
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|
| 22 |
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
|
| 23 |
self.loss_fn = get_loss_function(cfg)
|
|
|
|
| 24 |
|
| 25 |
if getattr(train_cfg.ema, "enabled", False):
|
| 26 |
self.ema = EMA(model, decay=train_cfg.ema.decay)
|
|
@@ -42,21 +43,21 @@ class Trainer:
|
|
| 42 |
|
| 43 |
return loss.item(), loss_item
|
| 44 |
|
| 45 |
-
def train_one_epoch(self, dataloader
|
| 46 |
self.model.train()
|
| 47 |
total_loss = 0
|
| 48 |
-
progress.start_batch(len(dataloader))
|
| 49 |
|
| 50 |
for data, targets in dataloader:
|
| 51 |
loss, loss_each = self.train_one_batch(data, targets)
|
| 52 |
|
| 53 |
total_loss += loss
|
| 54 |
-
progress.one_batch(loss_each)
|
| 55 |
|
| 56 |
if self.scheduler:
|
| 57 |
self.scheduler.step()
|
| 58 |
|
| 59 |
-
progress.finish_batch()
|
| 60 |
return total_loss / len(dataloader)
|
| 61 |
|
| 62 |
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
|
|
@@ -73,14 +74,13 @@ class Trainer:
|
|
| 73 |
|
| 74 |
def train(self, dataloader, num_epochs):
|
| 75 |
logger.info("🚄 Start Training!")
|
| 76 |
-
progress = CustomProgress()
|
| 77 |
|
| 78 |
-
with progress.progress:
|
| 79 |
-
progress.start_train(num_epochs)
|
| 80 |
for epoch in range(num_epochs):
|
| 81 |
|
| 82 |
-
epoch_loss = self.train_one_epoch(dataloader, progress)
|
| 83 |
-
progress.one_epoch()
|
| 84 |
|
| 85 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
| 86 |
if (epoch + 1) % 5 == 0:
|
|
|
|
| 21 |
self.optimizer = get_optimizer(model.parameters(), train_cfg.optimizer)
|
| 22 |
self.scheduler = get_scheduler(self.optimizer, train_cfg.scheduler)
|
| 23 |
self.loss_fn = get_loss_function(cfg)
|
| 24 |
+
self.progress = CustomProgress(cfg, use_wandb=True)
|
| 25 |
|
| 26 |
if getattr(train_cfg.ema, "enabled", False):
|
| 27 |
self.ema = EMA(model, decay=train_cfg.ema.decay)
|
|
|
|
| 43 |
|
| 44 |
return loss.item(), loss_item
|
| 45 |
|
| 46 |
+
def train_one_epoch(self, dataloader):
|
| 47 |
self.model.train()
|
| 48 |
total_loss = 0
|
| 49 |
+
self.progress.start_batch(len(dataloader))
|
| 50 |
|
| 51 |
for data, targets in dataloader:
|
| 52 |
loss, loss_each = self.train_one_batch(data, targets)
|
| 53 |
|
| 54 |
total_loss += loss
|
| 55 |
+
self.progress.one_batch(loss_each)
|
| 56 |
|
| 57 |
if self.scheduler:
|
| 58 |
self.scheduler.step()
|
| 59 |
|
| 60 |
+
self.progress.finish_batch()
|
| 61 |
return total_loss / len(dataloader)
|
| 62 |
|
| 63 |
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
|
|
|
|
| 74 |
|
| 75 |
def train(self, dataloader, num_epochs):
|
| 76 |
logger.info("🚄 Start Training!")
|
|
|
|
| 77 |
|
| 78 |
+
with self.progress.progress:
|
| 79 |
+
self.progress.start_train(num_epochs)
|
| 80 |
for epoch in range(num_epochs):
|
| 81 |
|
| 82 |
+
epoch_loss = self.train_one_epoch(dataloader, self.progress)
|
| 83 |
+
self.progress.one_epoch()
|
| 84 |
|
| 85 |
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
| 86 |
if (epoch + 1) % 5 == 0:
|
yolo/utils/loss.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
| 1 |
-
import
|
| 2 |
-
from typing import Any, List, Tuple
|
| 3 |
|
| 4 |
import torch
|
| 5 |
import torch.nn.functional as F
|
|
@@ -169,7 +168,7 @@ class DualLoss:
|
|
| 169 |
self.dfl_rate = cfg.hyper.train.loss.objective["DFLoss"]
|
| 170 |
self.cls_rate = cfg.hyper.train.loss.objective["BCELoss"]
|
| 171 |
|
| 172 |
-
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor,
|
| 173 |
targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up
|
| 174 |
|
| 175 |
# TODO: Need Refactor this region, make it flexible!
|
|
@@ -177,12 +176,13 @@ class DualLoss:
|
|
| 177 |
aux_iou, aux_dfl, aux_cls = self.loss(predicts[0], targets)
|
| 178 |
main_iou, main_dfl, main_cls = self.loss(predicts[1], targets)
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
|
|
|
| 186 |
|
| 187 |
|
| 188 |
def get_loss_function(cfg: Config) -> YOLOLoss:
|
|
|
|
| 1 |
+
from typing import Any, Dict, List, Tuple
|
|
|
|
| 2 |
|
| 3 |
import torch
|
| 4 |
import torch.nn.functional as F
|
|
|
|
| 168 |
self.dfl_rate = cfg.hyper.train.loss.objective["DFLoss"]
|
| 169 |
self.cls_rate = cfg.hyper.train.loss.objective["BCELoss"]
|
| 170 |
|
| 171 |
+
def __call__(self, predicts: List[Tensor], targets: Tensor) -> Tuple[Tensor, Dict[str, Tensor]]:
|
| 172 |
targets[:, :, 1:] = targets[:, :, 1:] * self.loss.scale_up
|
| 173 |
|
| 174 |
# TODO: Need Refactor this region, make it flexible!
|
|
|
|
| 176 |
aux_iou, aux_dfl, aux_cls = self.loss(predicts[0], targets)
|
| 177 |
main_iou, main_dfl, main_cls = self.loss(predicts[1], targets)
|
| 178 |
|
| 179 |
+
loss_dict = {
|
| 180 |
+
"BoxLoss": self.iou_rate * (aux_iou * self.aux_rate + main_iou),
|
| 181 |
+
"DFLoss": self.dfl_rate * (aux_dfl * self.aux_rate + main_dfl),
|
| 182 |
+
"BCELoss": self.cls_rate * (aux_cls * self.aux_rate + main_cls),
|
| 183 |
+
}
|
| 184 |
+
loss_sum = sum(list(loss_dict.values())) / len(loss_dict)
|
| 185 |
+
return loss_sum, loss_dict
|
| 186 |
|
| 187 |
|
| 188 |
def get_loss_function(cfg: Config) -> YOLOLoss:
|