Spaces:
Build error
Build error
| import os | |
| import sys | |
| import torch | |
| from loguru import logger | |
| from configs.train_config import TrainConfig | |
| from data.dataset import TrainDatasetDataLoader | |
| from models.model import HifiFace | |
| from utils.visualizer import Visualizer | |
| use_ddp = TrainConfig().use_ddp | |
| if use_ddp: | |
| import torch.distributed as dist | |
| def setup(): | |
| # os.environ["MASTER_ADDR"] = "localhost" | |
| # os.environ["MASTER_PORT"] = "12345" | |
| dist.init_process_group("nccl") # , rank=rank, world_size=world_size) | |
| return dist.get_rank() | |
| def cleanup(): | |
| dist.destroy_process_group() | |
| def train(): | |
| rank = 0 | |
| if use_ddp: | |
| rank = setup() | |
| device = torch.device(f"cuda:{rank}") | |
| logger.info(f"use device {device}") | |
| opt = TrainConfig() | |
| dataloader = TrainDatasetDataLoader() | |
| dataset_length = len(dataloader) | |
| logger.info(f"Dataset length: {dataset_length}") | |
| model = HifiFace( | |
| opt.identity_extractor_config, is_training=True, device=device, load_checkpoint=opt.load_checkpoint | |
| ) | |
| model.train() | |
| logger.info("model initialized") | |
| visualizer = None | |
| ckpt = False | |
| if not opt.use_ddp or rank == 0: | |
| visualizer = Visualizer(opt) | |
| ckpt = True | |
| total_iter = 0 | |
| epoch = 0 | |
| while True: | |
| if opt.use_ddp: | |
| dataloader.train_sampler.set_epoch(epoch) | |
| for data in dataloader: | |
| source_image = data["source_image"].to(device) | |
| target_image = data["target_image"].to(device) | |
| targe_mask = data["target_mask"].to(device) | |
| same = data["same"].to(device) | |
| loss_dict, visual_dict = model.optimize(source_image, target_image, targe_mask, same) | |
| total_iter += 1 | |
| if total_iter % opt.visualize_interval == 0 and visualizer is not None: | |
| visualizer.display_current_results(total_iter, visual_dict) | |
| if total_iter % opt.plot_interval == 0 and visualizer is not None: | |
| visualizer.plot_current_losses(total_iter, loss_dict) | |
| logger.info(f"Iter: {total_iter}") | |
| for k, v in loss_dict.items(): | |
| logger.info(f" {k}: {v}") | |
| logger.info("=" * 20) | |
| if total_iter % opt.checkpoint_interval == 0 and ckpt: | |
| logger.info(f"Saving model at iter {total_iter}") | |
| model.save(opt.checkpoint_dir, total_iter) | |
| if total_iter > opt.max_iters: | |
| logger.info(f"Maximum iterations exceeded. Stopping training.") | |
| if ckpt: | |
| model.save(opt.checkpoint_dir, total_iter) | |
| if use_ddp: | |
| cleanup() | |
| sys.exit(0) | |
| epoch += 1 | |
| if __name__ == "__main__": | |
| if use_ddp: | |
| # CUDA_VISIBLE_DEVICES=2,3 torchrun --nnodes=1 --nproc_per_node=2 --rdzv_id=100 --rdzv_backend=c10d --rdzv_endpoint=127.0.0.1:29400 -m entry.train | |
| os.environ["OMP_NUM_THREADS"] = "1" | |
| n_gpus = torch.cuda.device_count() | |
| train() | |
| else: | |
| train() | |