Spaces:
Running
on
Zero
Running
on
Zero
| import torch | |
| import time | |
| import torch.optim as optim | |
| from collections import OrderedDict | |
| from utils.utils import print_current_loss | |
| from os.path import join as pjoin | |
| from diffusers import DDPMScheduler | |
| from torch.utils.tensorboard import SummaryWriter | |
| import time | |
| import pdb | |
| import sys | |
| import os | |
| from torch.optim.lr_scheduler import ExponentialLR | |
| class DDPMTrainer(object): | |
| def __init__(self, args, model, accelerator, model_ema=None): | |
| self.opt = args | |
| self.accelerator = accelerator | |
| self.device = self.accelerator.device | |
| self.model = model | |
| self.diffusion_steps = args.diffusion_steps | |
| self.noise_scheduler = DDPMScheduler( | |
| num_train_timesteps=self.diffusion_steps, | |
| beta_schedule=args.beta_schedule, | |
| variance_type="fixed_small", | |
| prediction_type=args.prediction_type, | |
| clip_sample=False, | |
| ) | |
| self.model_ema = model_ema | |
| if args.is_train: | |
| self.mse_criterion = torch.nn.MSELoss(reduction="none") | |
| accelerator.print("Diffusion_config:\n", self.noise_scheduler.config) | |
| if self.accelerator.is_main_process: | |
| starttime = time.strftime("%Y-%m-%d_%H:%M:%S") | |
| print("Start experiment:", starttime) | |
| self.writer = SummaryWriter( | |
| log_dir=pjoin(args.save_root, "logs_") + starttime[:16], | |
| comment=starttime[:16], | |
| flush_secs=60, | |
| ) | |
| self.accelerator.wait_for_everyone() | |
| self.optimizer = optim.AdamW( | |
| self.model.parameters(), lr=self.opt.lr, weight_decay=self.opt.weight_decay | |
| ) | |
| self.scheduler = ( | |
| ExponentialLR(self.optimizer, gamma=args.decay_rate) | |
| if args.decay_rate > 0 | |
| else None | |
| ) | |
| def zero_grad(opt_list): | |
| for opt in opt_list: | |
| opt.zero_grad() | |
| def clip_norm(self, network_list): | |
| for network in network_list: | |
| self.accelerator.clip_grad_norm_( | |
| network.parameters(), self.opt.clip_grad_norm | |
| ) # 0.5 -> 1 | |
| def step(opt_list): | |
| for opt in opt_list: | |
| opt.step() | |
| def forward(self, batch_data): | |
| caption, motions, m_lens = batch_data | |
| motions = motions.detach().float() | |
| x_start = motions | |
| B, T = x_start.shape[:2] | |
| cur_len = torch.LongTensor([min(T, m_len) for m_len in m_lens]).to(self.device) | |
| self.src_mask = self.generate_src_mask(T, cur_len).to(x_start.device) | |
| # 1. Sample noise that we'll add to the motion | |
| real_noise = torch.randn_like(x_start) | |
| # 2. Sample a random timestep for each motion | |
| t = torch.randint(0, self.diffusion_steps, (B,), device=self.device) | |
| self.timesteps = t | |
| # 3. Add noise to the motion according to the noise magnitude at each timestep | |
| # (this is the forward diffusion process) | |
| x_t = self.noise_scheduler.add_noise(x_start, real_noise, t) | |
| # 4. network prediction | |
| self.prediction = self.model(x_t, t, text=caption) | |
| if self.opt.prediction_type == "sample": | |
| self.target = x_start | |
| elif self.opt.prediction_type == "epsilon": | |
| self.target = real_noise | |
| elif self.opt.prediction_type == "v_prediction": | |
| self.target = self.noise_scheduler.get_velocity(x_start, real_noise, t) | |
| def masked_l2(self, a, b, mask, weights): | |
| loss = self.mse_criterion(a, b).mean(dim=-1) # (bath_size, motion_length) | |
| loss = (loss * mask).sum(-1) / mask.sum(-1) # (batch_size, ) | |
| loss = (loss * weights).mean() | |
| return loss | |
| def backward_G(self): | |
| loss_logs = OrderedDict({}) | |
| mse_loss_weights = torch.ones_like(self.timesteps) | |
| loss_logs["loss_mot_rec"] = self.masked_l2( | |
| self.prediction, self.target, self.src_mask, mse_loss_weights | |
| ) | |
| self.loss = loss_logs["loss_mot_rec"] | |
| return loss_logs | |
| def update(self): | |
| self.zero_grad([self.optimizer]) | |
| loss_logs = self.backward_G() | |
| self.accelerator.backward(self.loss) | |
| self.clip_norm([self.model]) | |
| self.step([self.optimizer]) | |
| return loss_logs | |
| def generate_src_mask(self, T, length): | |
| B = len(length) | |
| src_mask = torch.ones(B, T) | |
| for i in range(B): | |
| for j in range(length[i], T): | |
| src_mask[i, j] = 0 | |
| return src_mask | |
| def train_mode(self): | |
| self.model.train() | |
| if self.model_ema: | |
| self.model_ema.train() | |
| def eval_mode(self): | |
| self.model.eval() | |
| if self.model_ema: | |
| self.model_ema.eval() | |
| def save(self, file_name, total_it): | |
| state = { | |
| "opt_encoder": self.optimizer.state_dict(), | |
| "total_it": total_it, | |
| "encoder": self.accelerator.unwrap_model(self.model).state_dict(), | |
| } | |
| if self.model_ema: | |
| state["model_ema"] = self.accelerator.unwrap_model( | |
| self.model_ema | |
| ).module.state_dict() | |
| torch.save(state, file_name) | |
| return | |
| def load(self, model_dir): | |
| checkpoint = torch.load(model_dir, map_location=self.device) | |
| self.optimizer.load_state_dict(checkpoint["opt_encoder"]) | |
| if self.model_ema: | |
| self.model_ema.load_state_dict(checkpoint["model_ema"], strict=True) | |
| self.model.load_state_dict(checkpoint["encoder"], strict=True) | |
| return checkpoint.get("total_it", 0) | |
| def train(self, train_loader): | |
| it = 0 | |
| if self.opt.is_continue: | |
| model_path = pjoin(self.opt.model_dir, self.opt.continue_ckpt) | |
| it = self.load(model_path) | |
| self.accelerator.print(f"continue train from {it} iters in {model_path}") | |
| start_time = time.time() | |
| logs = OrderedDict() | |
| self.dataset = train_loader.dataset | |
| self.model, self.mse_criterion, self.optimizer, train_loader, self.model_ema = ( | |
| self.accelerator.prepare( | |
| self.model, | |
| self.mse_criterion, | |
| self.optimizer, | |
| train_loader, | |
| self.model_ema, | |
| ) | |
| ) | |
| num_epochs = (self.opt.num_train_steps - it) // len(train_loader) + 1 | |
| self.accelerator.print(f"need to train for {num_epochs} epochs....") | |
| for epoch in range(0, num_epochs): | |
| self.train_mode() | |
| for i, batch_data in enumerate(train_loader): | |
| self.forward(batch_data) | |
| log_dict = self.update() | |
| it += 1 | |
| if self.model_ema and it % self.opt.model_ema_steps == 0: | |
| self.accelerator.unwrap_model(self.model_ema).update_parameters( | |
| self.model | |
| ) | |
| # update logger | |
| for k, v in log_dict.items(): | |
| if k not in logs: | |
| logs[k] = v | |
| else: | |
| logs[k] += v | |
| if it % self.opt.log_every == 0: | |
| mean_loss = OrderedDict({}) | |
| for tag, value in logs.items(): | |
| mean_loss[tag] = value / self.opt.log_every | |
| logs = OrderedDict() | |
| print_current_loss( | |
| self.accelerator, start_time, it, mean_loss, epoch, inner_iter=i | |
| ) | |
| if self.accelerator.is_main_process: | |
| self.writer.add_scalar("loss", mean_loss["loss_mot_rec"], it) | |
| self.accelerator.wait_for_everyone() | |
| if ( | |
| it % self.opt.save_interval == 0 | |
| and self.accelerator.is_main_process | |
| ): # Save model | |
| self.save(pjoin(self.opt.model_dir, "latest.tar").format(it), it) | |
| self.accelerator.wait_for_everyone() | |
| if (self.scheduler is not None) and ( | |
| it % self.opt.update_lr_steps == 0 | |
| ): | |
| self.scheduler.step() | |
| # Save the last checkpoint if it wasn't already saved. | |
| if it % self.opt.save_interval != 0 and self.accelerator.is_main_process: | |
| self.save(pjoin(self.opt.model_dir, "latest.tar"), it) | |
| self.accelerator.wait_for_everyone() | |
| self.accelerator.print("FINISH") | |