Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import os | |
| import time | |
| import imageio | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from imaginaire.losses import MaskedL1Loss | |
| from imaginaire.model_utils.fs_vid2vid import concat_frames, resample | |
| from imaginaire.trainers.vid2vid import Trainer as Vid2VidTrainer | |
| from imaginaire.utils.distributed import is_master | |
| from imaginaire.utils.distributed import master_only_print as print | |
| from imaginaire.utils.misc import split_labels, to_cuda | |
| from imaginaire.utils.visualization import tensor2flow, tensor2im | |
| class Trainer(Vid2VidTrainer): | |
| r"""Initialize world consistent vid2vid trainer. | |
| Args: | |
| cfg (obj): Global configuration. | |
| net_G (obj): Generator network. | |
| net_D (obj): Discriminator network. | |
| opt_G (obj): Optimizer for the generator network. | |
| opt_D (obj): Optimizer for the discriminator network. | |
| sch_G (obj): Scheduler for the generator optimizer. | |
| sch_D (obj): Scheduler for the discriminator optimizer. | |
| train_data_loader (obj): Train data loader. | |
| val_data_loader (obj): Validation data loader. | |
| """ | |
| def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, | |
| train_data_loader, val_data_loader): | |
| super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, | |
| opt_D, sch_G, sch_D, | |
| train_data_loader, val_data_loader) | |
| self.guidance_start_after = getattr(cfg.gen.guidance, 'start_from', 0) | |
| self.train_data_loader = train_data_loader | |
| def _define_custom_losses(self): | |
| r"""All other custom losses are defined here.""" | |
| # Setup the guidance loss. | |
| self.criteria['Guidance'] = MaskedL1Loss(normalize_over_valid=True) | |
| self.weights['Guidance'] = self.cfg.trainer.loss_weight.guidance | |
| def start_of_iteration(self, data, current_iteration): | |
| r"""Things to do before an iteration. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| current_iteration (int): Current iteration number. | |
| """ | |
| self.net_G_module.reset_renderer(is_flipped_input=data['is_flipped']) | |
| # Keep unprojections on cpu to prevent unnecessary transfer. | |
| unprojections = data.pop('unprojections') | |
| data = to_cuda(data) | |
| data['unprojections'] = unprojections | |
| self.current_iteration = current_iteration | |
| if not self.is_inference: | |
| self.net_D.train() | |
| self.net_G.train() | |
| self.start_iteration_time = time.time() | |
| return data | |
| def reset(self): | |
| r"""Reset the trainer (for inference) at the beginning of a sequence.""" | |
| # Inference time. | |
| self.net_G_module.reset_renderer(is_flipped_input=False) | |
| # print('Resetting trainer.') | |
| self.net_G_output = self.data_prev = None | |
| self.t = 0 | |
| test_in_model_average_mode = getattr( | |
| self, 'test_in_model_average_mode', False) | |
| if test_in_model_average_mode: | |
| if hasattr(self.net_G.module.averaged_model, 'reset'): | |
| self.net_G.module.averaged_model.reset() | |
| else: | |
| if hasattr(self.net_G.module, 'reset'): | |
| self.net_G.module.reset() | |
| def create_sequence_output_dir(self, output_dir, key): | |
| r"""Create output subdir for this sequence. | |
| Args: | |
| output_dir (str): Root output dir. | |
| key (str): LMDB key which contains sequence name and file name. | |
| Returns: | |
| output_dir (str): Output subdir for this sequence. | |
| seq_name (str): Name of this sequence. | |
| """ | |
| seq_dir = '/'.join(key.split('/')[:-1]) | |
| output_dir = os.path.join(output_dir, seq_dir) | |
| os.makedirs(output_dir, exist_ok=True) | |
| os.makedirs(output_dir + '/all', exist_ok=True) | |
| os.makedirs(output_dir + '/fake', exist_ok=True) | |
| seq_name = seq_dir.replace('/', '-') | |
| return output_dir, seq_name | |
| def test(self, test_data_loader, root_output_dir, inference_args): | |
| r"""Run inference on all sequences. | |
| Args: | |
| test_data_loader (object): Test data loader. | |
| root_output_dir (str): Location to dump outputs. | |
| inference_args (optional): Optional args. | |
| """ | |
| # Go over all sequences. | |
| loader = test_data_loader | |
| num_inference_sequences = loader.dataset.num_inference_sequences() | |
| for sequence_idx in range(num_inference_sequences): | |
| loader.dataset.set_inference_sequence_idx(sequence_idx) | |
| print('Seq id: %d, Seq length: %d' % | |
| (sequence_idx + 1, len(loader))) | |
| # Reset model at start of new inference sequence. | |
| self.reset() | |
| self.sequence_length = len(loader) | |
| # Go over all frames of this sequence. | |
| video = [] | |
| for idx, data in enumerate(tqdm(loader)): | |
| key = data['key']['images'][0][0] | |
| filename = key.split('/')[-1] | |
| # Create output dir for this sequence. | |
| if idx == 0: | |
| output_dir, seq_name = \ | |
| self.create_sequence_output_dir(root_output_dir, key) | |
| video_path = os.path.join(output_dir, '..', seq_name) | |
| # Get output, and save all vis to all/. | |
| data['img_name'] = filename | |
| data = to_cuda(data) | |
| output = self.test_single(data, output_dir=output_dir + '/all') | |
| # Dump just the fake image here. | |
| fake = tensor2im(output['fake_images'])[0] | |
| video.append(fake) | |
| imageio.imsave(output_dir + '/fake/%s.jpg' % (filename), fake) | |
| # Save as mp4 and gif. | |
| imageio.mimsave(video_path + '.mp4', video, fps=15) | |
| def test_single(self, data, output_dir=None, save_fake_only=False): | |
| r"""The inference function. If output_dir exists, also save the | |
| output image. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| output_dir (str): Save image directory. | |
| save_fake_only (bool): Only save the fake output image. | |
| """ | |
| if self.is_inference and self.cfg.trainer.model_average_config.enabled: | |
| test_in_model_average_mode = True | |
| else: | |
| test_in_model_average_mode = getattr( | |
| self, 'test_in_model_average_mode', False) | |
| data_t = self.get_data_t(data, self.net_G_output, self.data_prev, 0) | |
| if self.sequence_length > 1: | |
| self.data_prev = data_t | |
| # Generator forward. | |
| # Reset renderer if first time step. | |
| if self.t == 0: | |
| self.net_G_module.reset_renderer( | |
| is_flipped_input=data['is_flipped']) | |
| with torch.no_grad(): | |
| if test_in_model_average_mode: | |
| net_G = self.net_G.module.averaged_model | |
| else: | |
| net_G = self.net_G | |
| self.net_G_output = net_G(data_t) | |
| if output_dir is not None: | |
| if save_fake_only: | |
| image_grid = tensor2im(self.net_G_output['fake_images'])[0] | |
| else: | |
| vis_images = self.get_test_output_images(data) | |
| image_grid = np.hstack([np.vstack(im) for im in | |
| vis_images if im is not None]) | |
| if 'img_name' in data: | |
| save_name = data['img_name'].split('.')[0] + '.jpg' | |
| else: | |
| save_name = '%04d.jpg' % self.t | |
| output_filename = os.path.join(output_dir, save_name) | |
| os.makedirs(output_dir, exist_ok=True) | |
| imageio.imwrite(output_filename, image_grid) | |
| self.t += 1 | |
| return self.net_G_output | |
| def get_test_output_images(self, data): | |
| r"""Get the visualization output of test function. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| """ | |
| # Visualize labels. | |
| label_lengths = self.val_data_loader.dataset.get_label_lengths() | |
| labels = split_labels(data['label'], label_lengths) | |
| vis_labels = [] | |
| for key, value in labels.items(): | |
| if key == 'seg_maps': | |
| vis_labels.append(self.visualize_label(value[:, -1])) | |
| else: | |
| vis_labels.append(tensor2im(value[:, -1])) | |
| # Get gt image. | |
| im = tensor2im(data['images'][:, -1]) | |
| # Get guidance image and masks. | |
| if self.net_G_output['guidance_images_and_masks'] is not None: | |
| guidance_image = tensor2im( | |
| self.net_G_output['guidance_images_and_masks'][:, :3]) | |
| guidance_mask = tensor2im( | |
| self.net_G_output['guidance_images_and_masks'][:, 3:4], | |
| normalize=False) | |
| else: | |
| guidance_image = [np.zeros_like(item) for item in im] | |
| guidance_mask = [np.zeros_like(item) for item in im] | |
| # Create output. | |
| vis_images = [ | |
| *vis_labels, | |
| im, | |
| guidance_image, guidance_mask, | |
| tensor2im(self.net_G_output['fake_images']), | |
| ] | |
| return vis_images | |
| def gen_frames(self, data, use_model_average=False): | |
| r"""Generate a sequence of frames given a sequence of data. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| use_model_average (bool): Whether to use model average | |
| for update or not. | |
| """ | |
| net_G_output = None # Previous generator output. | |
| data_prev = None # Previous data. | |
| if use_model_average: | |
| net_G = self.net_G.module.averaged_model | |
| else: | |
| net_G = self.net_G | |
| # Iterate through the length of sequence. | |
| self.net_G_module.reset_renderer(is_flipped_input=data['is_flipped']) | |
| all_info = {'inputs': [], 'outputs': []} | |
| for t in range(self.sequence_length): | |
| # Get the data at the current time frame. | |
| data_t = self.get_data_t(data, net_G_output, data_prev, t) | |
| data_prev = data_t | |
| # Generator forward. | |
| with torch.no_grad(): | |
| net_G_output = net_G(data_t) | |
| # Do any postprocessing if necessary. | |
| data_t, net_G_output = self.post_process(data_t, net_G_output) | |
| if t == 0: | |
| # Get the output at beginning of sequence for visualization. | |
| first_net_G_output = net_G_output | |
| all_info['inputs'].append(data_t) | |
| all_info['outputs'].append(net_G_output) | |
| return first_net_G_output, net_G_output, all_info | |
| def _get_custom_gen_losses(self, data_t, net_G_output, net_D_output): | |
| r"""All other custom generator losses go here. | |
| Args: | |
| data_t (dict): Training data at the current time t. | |
| net_G_output (dict): Output of the generator. | |
| net_D_output (dict): Output of the discriminator. | |
| """ | |
| # Compute guidance loss. | |
| if net_G_output['guidance_images_and_masks'] is not None: | |
| guidance_image = net_G_output['guidance_images_and_masks'][:, :3] | |
| guidance_mask = net_G_output['guidance_images_and_masks'][:, 3:] | |
| self.gen_losses['Guidance'] = self.criteria['Guidance']( | |
| net_G_output['fake_images'], guidance_image, guidance_mask) | |
| else: | |
| self.gen_losses['Guidance'] = self.Tensor(1).fill_(0) | |
| def get_data_t(self, data, net_G_output, data_prev, t): | |
| r"""Get data at current time frame given the sequence of data. | |
| Args: | |
| data (dict): Training data for current iteration. | |
| net_G_output (dict): Output of the generator (for previous frame). | |
| data_prev (dict): Data for previous frame. | |
| t (int): Current time. | |
| """ | |
| label = data['label'][:, t] | |
| image = data['images'][:, t] | |
| # Get keypoint mapping. | |
| unprojection = None | |
| if t >= self.guidance_start_after: | |
| if 'unprojections' in data: | |
| try: | |
| # Remove unwanted padding. | |
| unprojection = {} | |
| for key, value in data['unprojections'].items(): | |
| value = value[0, t].cpu().numpy() | |
| length = value[-1][0] | |
| unprojection[key] = value[:length] | |
| except: # noqa | |
| pass | |
| if data_prev is not None: | |
| # Concat previous labels/fake images to the ones before. | |
| num_frames_G = self.cfg.data.num_frames_G | |
| prev_labels = concat_frames(data_prev['prev_labels'], | |
| data_prev['label'], num_frames_G - 1) | |
| prev_images = concat_frames( | |
| data_prev['prev_images'], | |
| net_G_output['fake_images'].detach(), num_frames_G - 1) | |
| else: | |
| prev_labels = prev_images = None | |
| data_t = dict() | |
| data_t['label'] = label | |
| data_t['image'] = image | |
| data_t['prev_labels'] = prev_labels | |
| data_t['prev_images'] = prev_images | |
| data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None | |
| data_t['unprojection'] = unprojection | |
| return data_t | |
| def save_image(self, path, data): | |
| r"""Save the output images to path. | |
| Note when the generate_raw_output is FALSE. Then, | |
| first_net_G_output['fake_raw_images'] is None and will not be displayed. | |
| In model average mode, we will plot the flow visualization twice. | |
| Args: | |
| path (str): Save path. | |
| data (dict): Training data for current iteration. | |
| """ | |
| self.net_G.eval() | |
| if self.cfg.trainer.model_average_config.enabled: | |
| self.net_G.module.averaged_model.eval() | |
| self.net_G_output = None | |
| with torch.no_grad(): | |
| first_net_G_output, net_G_output, all_info = self.gen_frames(data) | |
| if self.cfg.trainer.model_average_config.enabled: | |
| first_net_G_output_avg, net_G_output_avg = self.gen_frames( | |
| data, use_model_average=True) | |
| # Visualize labels. | |
| label_lengths = self.train_data_loader.dataset.get_label_lengths() | |
| labels = split_labels(data['label'], label_lengths) | |
| vis_labels_start, vis_labels_end = [], [] | |
| for key, value in labels.items(): | |
| if 'seg_maps' in key: | |
| vis_labels_start.append(self.visualize_label(value[:, -1])) | |
| vis_labels_end.append(self.visualize_label(value[:, 0])) | |
| else: | |
| normalize = self.train_data_loader.dataset.normalize[key] | |
| vis_labels_start.append( | |
| tensor2im(value[:, -1], normalize=normalize)) | |
| vis_labels_end.append( | |
| tensor2im(value[:, 0], normalize=normalize)) | |
| if is_master(): | |
| vis_images = [ | |
| *vis_labels_start, | |
| tensor2im(data['images'][:, -1]), | |
| tensor2im(net_G_output['fake_images']), | |
| tensor2im(net_G_output['fake_raw_images'])] | |
| if self.cfg.trainer.model_average_config.enabled: | |
| vis_images += [ | |
| tensor2im(net_G_output_avg['fake_images']), | |
| tensor2im(net_G_output_avg['fake_raw_images'])] | |
| if self.sequence_length > 1: | |
| if net_G_output['guidance_images_and_masks'] is not None: | |
| guidance_image = tensor2im( | |
| net_G_output['guidance_images_and_masks'][:, :3]) | |
| guidance_mask = tensor2im( | |
| net_G_output['guidance_images_and_masks'][:, 3:4], | |
| normalize=False) | |
| else: | |
| im = tensor2im(data['images'][:, -1]) | |
| guidance_image = [np.zeros_like(item) for item in im] | |
| guidance_mask = [np.zeros_like(item) for item in im] | |
| vis_images += [guidance_image, guidance_mask] | |
| vis_images_first = [ | |
| *vis_labels_end, | |
| tensor2im(data['images'][:, 0]), | |
| tensor2im(first_net_G_output['fake_images']), | |
| tensor2im(first_net_G_output['fake_raw_images']), | |
| [np.zeros_like(item) for item in guidance_image], | |
| [np.zeros_like(item) for item in guidance_mask] | |
| ] | |
| if self.cfg.trainer.model_average_config.enabled: | |
| vis_images_first += [ | |
| tensor2im(first_net_G_output_avg['fake_images']), | |
| tensor2im(first_net_G_output_avg['fake_raw_images'])] | |
| if self.use_flow: | |
| flow_gt, conf_gt = self.criteria['Flow'].flowNet( | |
| data['images'][:, -1], data['images'][:, -2]) | |
| warped_image_gt = resample(data['images'][:, -1], flow_gt) | |
| vis_images_first += [ | |
| tensor2flow(flow_gt), | |
| tensor2im(conf_gt, normalize=False), | |
| tensor2im(warped_image_gt), | |
| ] | |
| vis_images += [ | |
| tensor2flow(net_G_output['fake_flow_maps']), | |
| tensor2im(net_G_output['fake_occlusion_masks'], | |
| normalize=False), | |
| tensor2im(net_G_output['warped_images']), | |
| ] | |
| if self.cfg.trainer.model_average_config.enabled: | |
| vis_images_first += [ | |
| tensor2flow(flow_gt), | |
| tensor2im(conf_gt, normalize=False), | |
| tensor2im(warped_image_gt), | |
| ] | |
| vis_images += [ | |
| tensor2flow(net_G_output_avg['fake_flow_maps']), | |
| tensor2im(net_G_output_avg['fake_occlusion_masks'], | |
| normalize=False), | |
| tensor2im(net_G_output_avg['warped_images'])] | |
| vis_images = [[np.vstack((im_first, im)) | |
| for im_first, im in zip(imgs_first, imgs)] | |
| for imgs_first, imgs in zip(vis_images_first, | |
| vis_images) | |
| if imgs is not None] | |
| image_grid = np.hstack([np.vstack(im) for im in | |
| vis_images if im is not None]) | |
| print('Save output images to {}'.format(path)) | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| imageio.imwrite(path, image_grid) | |
| # Gather all inputs and outputs for dumping into video. | |
| if self.sequence_length > 1: | |
| input_images, output_images, output_guidance = [], [], [] | |
| for item in all_info['inputs']: | |
| input_images.append(tensor2im(item['image'])[0]) | |
| for item in all_info['outputs']: | |
| output_images.append(tensor2im(item['fake_images'])[0]) | |
| if item['guidance_images_and_masks'] is not None: | |
| output_guidance.append(tensor2im( | |
| item['guidance_images_and_masks'][:, :3])[0]) | |
| else: | |
| output_guidance.append(np.zeros_like(output_images[-1])) | |
| imageio.mimwrite(os.path.splitext(path)[0] + '.mp4', | |
| output_images, fps=2, macro_block_size=None) | |
| imageio.mimwrite(os.path.splitext(path)[0] + '_guidance.mp4', | |
| output_guidance, fps=2, macro_block_size=None) | |
| # for idx, item in enumerate(output_guidance): | |
| # imageio.imwrite(os.path.splitext( | |
| # path)[0] + '_guidance_%d.jpg' % (idx), item) | |
| # for idx, item in enumerate(input_images): | |
| # imageio.imwrite(os.path.splitext( | |
| # path)[0] + '_input_%d.jpg' % (idx), item) | |
| self.net_G.float() | |
| def _compute_fid(self): | |
| r"""Compute fid. Ignore for faster training.""" | |
| return None | |
| def load_checkpoint(self, cfg, checkpoint_path, resume=None, load_sch=True): | |
| r"""Save network weights, optimizer parameters, scheduler parameters | |
| in the checkpoint. | |
| Args: | |
| cfg (obj): Global configuration. | |
| checkpoint_path (str): Path to the checkpoint. | |
| """ | |
| # Create the single image model. | |
| if self.train_data_loader is None: | |
| load_single_image_model_weights = False | |
| else: | |
| load_single_image_model_weights = True | |
| self.net_G.module._init_single_image_model( | |
| load_weights=load_single_image_model_weights) | |
| # Call the original super function. | |
| return super().load_checkpoint(cfg, checkpoint_path, resume, load_sch) | |