Spaces:
Sleeping
Sleeping
💬 [Add] Progress class, handle progress bar
Browse files- yolo/tools/log_helper.py +29 -0
- yolo/tools/trainer.py +30 -20
yolo/tools/log_helper.py
CHANGED
|
@@ -16,6 +16,7 @@ from typing import List
|
|
| 16 |
|
| 17 |
from loguru import logger
|
| 18 |
from rich.console import Console
|
|
|
|
| 19 |
from rich.table import Table
|
| 20 |
|
| 21 |
from yolo.config.config import YOLOLayer
|
|
@@ -29,6 +30,34 @@ def custom_logger():
|
|
| 29 |
)
|
| 30 |
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def log_model(model: List[YOLOLayer]):
|
| 33 |
console = Console()
|
| 34 |
table = Table(title="Model Layers")
|
|
|
|
| 16 |
|
| 17 |
from loguru import logger
|
| 18 |
from rich.console import Console
|
| 19 |
+
from rich.progress import BarColumn, Progress, TextColumn, TimeRemainingColumn
|
| 20 |
from rich.table import Table
|
| 21 |
|
| 22 |
from yolo.config.config import YOLOLayer
|
|
|
|
| 30 |
)
|
| 31 |
|
| 32 |
|
| 33 |
+
class CustomProgress:
|
| 34 |
+
def __init__(self):
|
| 35 |
+
self.progress = Progress(
|
| 36 |
+
TextColumn("[progress.description]{task.description}"),
|
| 37 |
+
BarColumn(bar_width=None),
|
| 38 |
+
TextColumn("{task.completed}/{task.total}"),
|
| 39 |
+
TimeRemainingColumn(),
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
def start_train(self, num_epochs: int):
|
| 43 |
+
self.task_epoch = self.progress.add_task("[cyan]Epochs", total=num_epochs)
|
| 44 |
+
|
| 45 |
+
def one_epoch(self):
|
| 46 |
+
self.progress.update(self.task_epoch, advance=1)
|
| 47 |
+
|
| 48 |
+
def start_batch(self, num_batches):
|
| 49 |
+
self.batch_task = self.progress.add_task("[green]Batches", total=num_batches)
|
| 50 |
+
|
| 51 |
+
def one_batch(self, loss_each):
|
| 52 |
+
loss_iou, loss_dfl, loss_cls = loss_each
|
| 53 |
+
# TODO: make it flexible? if need add more loss
|
| 54 |
+
loss_str = f"Loss IoU: {loss_iou:.3f}, DFL: {loss_dfl:.3f}, CLS: {loss_cls:.3f}"
|
| 55 |
+
self.progress.update(self.batch_task, advance=1, description=f"[green]Batches {loss_str}")
|
| 56 |
+
|
| 57 |
+
def finish_batch(self):
|
| 58 |
+
self.progress.remove_task(self.batch_task)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
def log_model(model: List[YOLOLayer]):
|
| 62 |
console = Console()
|
| 63 |
table = Table(title="Model Layers")
|
yolo/tools/trainer.py
CHANGED
|
@@ -1,11 +1,13 @@
|
|
| 1 |
import torch
|
| 2 |
from loguru import logger
|
| 3 |
from torch import Tensor
|
|
|
|
|
|
|
| 4 |
from torch.cuda.amp import GradScaler, autocast
|
| 5 |
-
from tqdm import tqdm
|
| 6 |
|
| 7 |
from yolo.config.config import Config, TrainConfig
|
| 8 |
from yolo.model.yolo import YOLO
|
|
|
|
| 9 |
from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
|
| 10 |
from yolo.utils.loss import get_loss_function
|
| 11 |
|
|
@@ -26,16 +28,13 @@ class Trainer:
|
|
| 26 |
self.ema = None
|
| 27 |
self.scaler = GradScaler()
|
| 28 |
|
| 29 |
-
def train_one_batch(self, data: Tensor, targets: Tensor
|
| 30 |
data, targets = data.to(self.device), targets.to(self.device)
|
| 31 |
self.optimizer.zero_grad()
|
| 32 |
|
| 33 |
with autocast():
|
| 34 |
outputs = self.model(data)
|
| 35 |
loss, loss_item = self.loss_fn(outputs, targets)
|
| 36 |
-
loss_iou, loss_dfl, loss_cls = loss_item
|
| 37 |
-
|
| 38 |
-
progress.set_description(f"Loss IoU: {loss_iou:.5f}, DFL: {loss_dfl:.5f}, CLS: {loss_cls:.5f}")
|
| 39 |
|
| 40 |
self.scaler.scale(loss).backward()
|
| 41 |
self.scaler.step(self.optimizer)
|
|
@@ -43,17 +42,21 @@ class Trainer:
|
|
| 43 |
|
| 44 |
return loss.item(), loss_item
|
| 45 |
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
def train_one_epoch(self, dataloader):
|
| 49 |
self.model.train()
|
| 50 |
total_loss = 0
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
return total_loss / len(dataloader)
|
| 58 |
|
| 59 |
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
|
|
@@ -69,9 +72,16 @@ class Trainer:
|
|
| 69 |
torch.save(checkpoint, filename)
|
| 70 |
|
| 71 |
def train(self, dataloader, num_epochs):
|
| 72 |
-
logger.info("
|
| 73 |
-
|
| 74 |
-
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import torch
|
| 2 |
from loguru import logger
|
| 3 |
from torch import Tensor
|
| 4 |
+
|
| 5 |
+
# TODO: We may can't use CUDA?
|
| 6 |
from torch.cuda.amp import GradScaler, autocast
|
|
|
|
| 7 |
|
| 8 |
from yolo.config.config import Config, TrainConfig
|
| 9 |
from yolo.model.yolo import YOLO
|
| 10 |
+
from yolo.tools.log_helper import CustomProgress
|
| 11 |
from yolo.tools.model_helper import EMA, get_optimizer, get_scheduler
|
| 12 |
from yolo.utils.loss import get_loss_function
|
| 13 |
|
|
|
|
| 28 |
self.ema = None
|
| 29 |
self.scaler = GradScaler()
|
| 30 |
|
| 31 |
+
def train_one_batch(self, data: Tensor, targets: Tensor):
|
| 32 |
data, targets = data.to(self.device), targets.to(self.device)
|
| 33 |
self.optimizer.zero_grad()
|
| 34 |
|
| 35 |
with autocast():
|
| 36 |
outputs = self.model(data)
|
| 37 |
loss, loss_item = self.loss_fn(outputs, targets)
|
|
|
|
|
|
|
|
|
|
| 38 |
|
| 39 |
self.scaler.scale(loss).backward()
|
| 40 |
self.scaler.step(self.optimizer)
|
|
|
|
| 42 |
|
| 43 |
return loss.item(), loss_item
|
| 44 |
|
| 45 |
+
def train_one_epoch(self, dataloader, progress: CustomProgress):
|
|
|
|
|
|
|
| 46 |
self.model.train()
|
| 47 |
total_loss = 0
|
| 48 |
+
progress.start_batch(len(dataloader))
|
| 49 |
+
|
| 50 |
+
for data, targets in dataloader:
|
| 51 |
+
loss, loss_each = self.train_one_batch(data, targets)
|
| 52 |
+
|
| 53 |
+
total_loss += loss
|
| 54 |
+
progress.one_batch(loss_each)
|
| 55 |
+
|
| 56 |
+
if self.scheduler:
|
| 57 |
+
self.scheduler.step()
|
| 58 |
+
|
| 59 |
+
progress.finish_batch()
|
| 60 |
return total_loss / len(dataloader)
|
| 61 |
|
| 62 |
def save_checkpoint(self, epoch: int, filename="checkpoint.pt"):
|
|
|
|
| 72 |
torch.save(checkpoint, filename)
|
| 73 |
|
| 74 |
def train(self, dataloader, num_epochs):
|
| 75 |
+
logger.info("🚄 Start Training!")
|
| 76 |
+
progress = CustomProgress()
|
| 77 |
+
|
| 78 |
+
with progress.progress:
|
| 79 |
+
progress.start_train(num_epochs)
|
| 80 |
+
for epoch in range(num_epochs):
|
| 81 |
+
|
| 82 |
+
epoch_loss = self.train_one_epoch(dataloader, progress)
|
| 83 |
+
progress.one_epoch()
|
| 84 |
+
|
| 85 |
+
logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}")
|
| 86 |
+
if (epoch + 1) % 5 == 0:
|
| 87 |
+
self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth")
|