Spaces:
Sleeping
Sleeping
| import torch | |
| from loguru import logger | |
| from torch import Tensor | |
| # TODO: We may can't use CUDA? | |
| from torch.cuda.amp import GradScaler, autocast | |
| from yolo.config.config import Config, TrainConfig | |
| from yolo.model.yolo import YOLO | |
| from yolo.tools.data_loader import StreamDataLoader | |
| from yolo.tools.drawer import draw_bboxes | |
| from yolo.tools.loss_functions import get_loss_function | |
| from yolo.utils.bounding_box_utils import AnchorBoxConverter, bbox_nms | |
| from yolo.utils.logging_utils import ProgressTracker | |
| from yolo.utils.model_utils import ( | |
| ExponentialMovingAverage, | |
| create_optimizer, | |
| create_scheduler, | |
| ) | |
| class ModelTrainer: | |
| def __init__(self, cfg: Config, model: YOLO, save_path: str, device): | |
| train_cfg: TrainConfig = cfg.task | |
| self.model = model | |
| self.device = device | |
| self.optimizer = create_optimizer(model, train_cfg.optimizer) | |
| self.scheduler = create_scheduler(self.optimizer, train_cfg.scheduler) | |
| self.loss_fn = get_loss_function(cfg) | |
| self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb) | |
| self.num_epochs = cfg.task.epoch | |
| if getattr(train_cfg.ema, "enabled", False): | |
| self.ema = ExponentialMovingAverage(model, decay=train_cfg.ema.decay) | |
| else: | |
| self.ema = None | |
| self.scaler = GradScaler() | |
| def train_one_batch(self, data: Tensor, targets: Tensor): | |
| data, targets = data.to(self.device), targets.to(self.device) | |
| self.optimizer.zero_grad() | |
| with autocast(): | |
| outputs = self.model(data) | |
| loss, loss_item = self.loss_fn(outputs, targets) | |
| self.scaler.scale(loss).backward() | |
| self.scaler.step(self.optimizer) | |
| self.scaler.update() | |
| return loss.item(), loss_item | |
| def train_one_epoch(self, dataloader): | |
| self.model.train() | |
| total_loss = 0 | |
| for data, targets in dataloader: | |
| loss, loss_each = self.train_one_batch(data, targets) | |
| total_loss += loss | |
| self.progress.one_batch(loss_each) | |
| if self.scheduler: | |
| self.scheduler.step() | |
| return total_loss / len(dataloader) | |
| def save_checkpoint(self, epoch: int, filename="checkpoint.pt"): | |
| checkpoint = { | |
| "epoch": epoch, | |
| "model_state_dict": self.model.state_dict(), | |
| "optimizer_state_dict": self.optimizer.state_dict(), | |
| } | |
| if self.ema: | |
| self.ema.apply_shadow() | |
| checkpoint["model_state_dict_ema"] = self.model.state_dict() | |
| self.ema.restore() | |
| torch.save(checkpoint, filename) | |
| def solve(self, dataloader): | |
| logger.info("🚄 Start Training!") | |
| num_epochs = self.num_epochs | |
| with self.progress.progress: | |
| self.progress.start_train(num_epochs) | |
| for epoch in range(num_epochs): | |
| self.progress.start_one_epoch(len(dataloader), self.optimizer, epoch) | |
| epoch_loss = self.train_one_epoch(dataloader) | |
| self.progress.finish_one_epoch() | |
| logger.info(f"Epoch {epoch+1}/{num_epochs}, Loss: {epoch_loss:.4f}") | |
| if (epoch + 1) % 5 == 0: | |
| self.save_checkpoint(epoch, f"checkpoint_epoch_{epoch+1}.pth") | |
| class ModelTester: | |
| def __init__(self, cfg: Config, model: YOLO, save_path: str, device): | |
| self.model = model | |
| self.device = device | |
| self.progress = ProgressTracker(cfg, save_path, cfg.use_wandb) | |
| self.anchor2box = AnchorBoxConverter(cfg, device) | |
| self.nms = cfg.task.nms | |
| self.save_path = save_path | |
| def solve(self, dataloader: StreamDataLoader): | |
| logger.info("👀 Start Inference!") | |
| try: | |
| for idx, images in enumerate(dataloader): | |
| images = images.to(self.device) | |
| with torch.no_grad(): | |
| raw_output = self.model(images) | |
| predict, _ = self.anchor2box(raw_output[0][3:], with_logits=True) | |
| nms_out = bbox_nms(predict, self.nms) | |
| draw_bboxes( | |
| images[0], nms_out[0], scaled_bbox=False, save_path=self.save_path, save_name=f"frame{idx:03d}.png" | |
| ) | |
| except KeyboardInterrupt: | |
| logger.error("Interrupted by user") | |
| dataloader.stop_event.set() | |
| dataloader.stop() | |
| except Exception as e: | |
| logger.error(e) | |
| dataloader.stop_event.set() | |
| dataloader.stop() | |
| raise e | |
| dataloader.stop() | |