Spaces:
Running
Running
| import hydra | |
| import torch | |
| from loguru import logger | |
| from config.config import Config | |
| from model.yolo import get_model | |
| from tools.log_helper import custom_logger | |
| from tools.trainer import Trainer | |
| from utils.dataloader import get_dataloader | |
| from utils.get_dataset import prepare_dataset | |
| def main(cfg: Config): | |
| if cfg.download.auto: | |
| prepare_dataset(cfg.download) | |
| dataloader = get_dataloader(cfg) | |
| model = get_model(cfg.model) | |
| # TODO: get_device or rank, for DDP mode | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| trainer = Trainer(model, cfg.hyper.train, device) | |
| trainer.train(dataloader, 10) | |
| if __name__ == "__main__": | |
| custom_logger() | |
| main() | |