Spaces:
Runtime error
Runtime error
| # Copyright (C) 2021 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # This work is made available under the Nvidia Source Code License-NC. | |
| # To view a copy of this license, check out LICENSE.md | |
| import os | |
| import imageio | |
| import numpy as np | |
| import torch | |
| from tqdm import tqdm | |
| from imaginaire.model_utils.fs_vid2vid import (concat_frames, get_fg_mask, | |
| pre_process_densepose, | |
| random_roll) | |
| from imaginaire.model_utils.pix2pixHD import get_optimizer_with_params | |
| from imaginaire.trainers.vid2vid import Trainer as vid2vidTrainer | |
| from imaginaire.utils.distributed import is_master | |
| from imaginaire.utils.distributed import master_only_print as print | |
| from imaginaire.utils.misc import to_cuda | |
| from imaginaire.utils.visualization import tensor2flow, tensor2im | |
| class Trainer(vid2vidTrainer): | |
| r"""Initialize vid2vid trainer. | |
| Args: | |
| cfg (obj): Global configuration. | |
| net_G (obj): Generator network. | |
| net_D (obj): Discriminator network. | |
| opt_G (obj): Optimizer for the generator network. | |
| opt_D (obj): Optimizer for the discriminator network. | |
| sch_G (obj): Scheduler for the generator optimizer. | |
| sch_D (obj): Scheduler for the discriminator optimizer. | |
| train_data_loader (obj): Train data loader. | |
| val_data_loader (obj): Validation data loader. | |
| """ | |
| def __init__(self, cfg, net_G, net_D, opt_G, opt_D, sch_G, sch_D, | |
| train_data_loader, val_data_loader): | |
| super(Trainer, self).__init__(cfg, net_G, net_D, opt_G, | |
| opt_D, sch_G, sch_D, | |
| train_data_loader, val_data_loader) | |
| def _start_of_iteration(self, data, current_iteration): | |
| r"""Things to do before an iteration. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| current_iteration (int): Current number of iteration. | |
| """ | |
| data = self.pre_process(data) | |
| return to_cuda(data) | |
| def pre_process(self, data): | |
| r"""Do any data pre-processing here. | |
| Args: | |
| data (dict): Data used for the current iteration. | |
| """ | |
| data_cfg = self.cfg.data | |
| if hasattr(data_cfg, 'for_pose_dataset') and \ | |
| ('pose_maps-densepose' in data_cfg.input_labels): | |
| pose_cfg = data_cfg.for_pose_dataset | |
| data['label'] = pre_process_densepose(pose_cfg, data['label'], | |
| self.is_inference) | |
| data['few_shot_label'] = pre_process_densepose( | |
| pose_cfg, data['few_shot_label'], self.is_inference) | |
| return data | |
| def get_test_output_images(self, data): | |
| r"""Get the visualization output of test function. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| """ | |
| vis_images = [ | |
| tensor2im(data['few_shot_images'][:, 0]), | |
| self.visualize_label(data['label'][:, -1]), | |
| tensor2im(data['images'][:, -1]), | |
| tensor2im(self.net_G_output['fake_images']), | |
| ] | |
| return vis_images | |
| def get_data_t(self, data, net_G_output, data_prev, t): | |
| r"""Get data at current time frame given the sequence of data. | |
| Args: | |
| data (dict): Training data for current iteration. | |
| net_G_output (dict): Output of the generator (for previous frame). | |
| data_prev (dict): Data for previous frame. | |
| t (int): Current time. | |
| """ | |
| label = data['label'][:, t] if 'label' in data else None | |
| image = data['images'][:, t] | |
| if data_prev is not None: | |
| nG = self.cfg.data.num_frames_G | |
| prev_labels = concat_frames(data_prev['prev_labels'], | |
| data_prev['label'], nG - 1) | |
| prev_images = concat_frames( | |
| data_prev['prev_images'], | |
| net_G_output['fake_images'].detach(), nG - 1) | |
| else: | |
| prev_labels = prev_images = None | |
| data_t = dict() | |
| data_t['label'] = label | |
| data_t['image'] = image | |
| data_t['ref_labels'] = data['few_shot_label'] if 'few_shot_label' \ | |
| in data else None | |
| data_t['ref_images'] = data['few_shot_images'] | |
| data_t['prev_labels'] = prev_labels | |
| data_t['prev_images'] = prev_images | |
| data_t['real_prev_image'] = data['images'][:, t - 1] if t > 0 else None | |
| # if 'landmarks_xy' in data: | |
| # data_t['landmarks_xy'] = data['landmarks_xy'][:, t] | |
| # data_t['ref_landmarks_xy'] = data['few_shot_landmarks_xy'] | |
| return data_t | |
| def post_process(self, data, net_G_output): | |
| r"""Do any postprocessing of the data / output here. | |
| Args: | |
| data (dict): Training data at the current iteration. | |
| net_G_output (dict): Output of the generator. | |
| """ | |
| if self.has_fg: | |
| fg_mask = get_fg_mask(data['label'], self.has_fg) | |
| if net_G_output['fake_raw_images'] is not None: | |
| net_G_output['fake_raw_images'] = \ | |
| net_G_output['fake_raw_images'] * fg_mask | |
| return data, net_G_output | |
| def test(self, test_data_loader, root_output_dir, inference_args): | |
| r"""Run inference on the specified sequence. | |
| Args: | |
| test_data_loader (object): Test data loader. | |
| root_output_dir (str): Location to dump outputs. | |
| inference_args (optional): Optional args. | |
| """ | |
| self.reset() | |
| test_data_loader.dataset.set_sequence_length(0) | |
| # Set the inference sequences. | |
| test_data_loader.dataset.set_inference_sequence_idx( | |
| inference_args.driving_seq_index, | |
| inference_args.few_shot_seq_index, | |
| inference_args.few_shot_frame_index) | |
| video = [] | |
| for idx, data in enumerate(tqdm(test_data_loader)): | |
| key = data['key']['images'][0][0] | |
| filename = key.split('/')[-1] | |
| # Create output dir for this sequence. | |
| if idx == 0: | |
| seq_name = '%03d' % inference_args.driving_seq_index | |
| output_dir = os.path.join(root_output_dir, seq_name) | |
| os.makedirs(output_dir, exist_ok=True) | |
| video_path = output_dir | |
| # Get output and save images. | |
| data['img_name'] = filename | |
| data = self.start_of_iteration(data, current_iteration=-1) | |
| output = self.test_single(data, output_dir, inference_args) | |
| video.append(output) | |
| # Save output as mp4. | |
| imageio.mimsave(video_path + '.mp4', video, fps=15) | |
| def save_image(self, path, data): | |
| r"""Save the output images to path. | |
| Note when the generate_raw_output is FALSE. Then, | |
| first_net_G_output['fake_raw_images'] is None and will not be displayed. | |
| In model average mode, we will plot the flow visualization twice. | |
| Args: | |
| path (str): Save path. | |
| data (dict): Training data for current iteration. | |
| """ | |
| self.net_G.eval() | |
| if self.cfg.trainer.model_average_config.enabled: | |
| self.net_G.module.averaged_model.eval() | |
| self.net_G_output = None | |
| with torch.no_grad(): | |
| first_net_G_output, last_net_G_output, _ = self.gen_frames(data) | |
| if self.cfg.trainer.model_average_config.enabled: | |
| first_net_G_output_avg, last_net_G_output_avg, _ = \ | |
| self.gen_frames(data, use_model_average=True) | |
| def get_images(data, net_G_output, return_first_frame=True, | |
| for_model_average=False): | |
| r"""Get the ourput images to save. | |
| Args: | |
| data (dict): Training data for current iteration. | |
| net_G_output (dict): Generator output. | |
| return_first_frame (bool): Return output for first frame in the | |
| sequence. | |
| for_model_average (bool): For model average output. | |
| Return: | |
| vis_images (list of numpy arrays): Visualization images. | |
| """ | |
| frame_idx = 0 if return_first_frame else -1 | |
| warped_idx = 0 if return_first_frame else 1 | |
| vis_images = [] | |
| if not for_model_average: | |
| vis_images += [ | |
| tensor2im(data['few_shot_images'][:, frame_idx]), | |
| self.visualize_label(data['label'][:, frame_idx]), | |
| tensor2im(data['images'][:, frame_idx]) | |
| ] | |
| vis_images += [ | |
| tensor2im(net_G_output['fake_images']), | |
| tensor2im(net_G_output['fake_raw_images'])] | |
| if not for_model_average: | |
| vis_images += [ | |
| tensor2im(net_G_output['warped_images'][warped_idx]), | |
| tensor2flow(net_G_output['fake_flow_maps'][warped_idx]), | |
| tensor2im(net_G_output['fake_occlusion_masks'][warped_idx], | |
| normalize=False) | |
| ] | |
| return vis_images | |
| if is_master(): | |
| vis_images_first = get_images(data, first_net_G_output) | |
| if self.cfg.trainer.model_average_config.enabled: | |
| vis_images_first += get_images(data, first_net_G_output_avg, | |
| for_model_average=True) | |
| if self.sequence_length > 1: | |
| vis_images_last = get_images(data, last_net_G_output, | |
| return_first_frame=False) | |
| if self.cfg.trainer.model_average_config.enabled: | |
| vis_images_last += get_images(data, last_net_G_output_avg, | |
| return_first_frame=False, | |
| for_model_average=True) | |
| # If generating a video, the first row of each batch will be | |
| # the first generated frame and the flow/mask for warping the | |
| # reference image, and the second row will be the last | |
| # generated frame and the flow/mask for warping the previous | |
| # frame. If using model average, the frames generated by model | |
| # average will be at the rightmost columns. | |
| vis_images = [[np.vstack((im_first, im_last)) | |
| for im_first, im_last in | |
| zip(imgs_first, imgs_last)] | |
| for imgs_first, imgs_last in zip(vis_images_first, | |
| vis_images_last) | |
| if imgs_first is not None] | |
| else: | |
| vis_images = vis_images_first | |
| image_grid = np.hstack([np.vstack(im) for im in vis_images | |
| if im is not None]) | |
| print('Save output images to {}'.format(path)) | |
| os.makedirs(os.path.dirname(path), exist_ok=True) | |
| imageio.imwrite(path, image_grid) | |
| def finetune(self, data, inference_args): | |
| r"""Finetune the model for a few iterations on the inference data.""" | |
| # Get the list of params to finetune. | |
| self.net_G, self.net_D, self.opt_G, self.opt_D = \ | |
| get_optimizer_with_params(self.cfg, self.net_G, self.net_D, | |
| param_names_start_with=[ | |
| 'weight_generator.fc', 'conv_img', | |
| 'up']) | |
| data_finetune = {k: v for k, v in data.items()} | |
| ref_labels = data_finetune['few_shot_label'] | |
| ref_images = data_finetune['few_shot_images'] | |
| # Number of iterations to finetune. | |
| iterations = getattr(inference_args, 'finetune_iter', 100) | |
| for it in range(1, iterations + 1): | |
| # Randomly set one of the reference images as target. | |
| idx = np.random.randint(ref_labels.size(1)) | |
| tgt_label, tgt_image = ref_labels[:, idx], ref_images[:, idx] | |
| # Randomly shift and flip the target image. | |
| tgt_label, tgt_image = random_roll([tgt_label, tgt_image]) | |
| data_finetune['label'] = tgt_label.unsqueeze(1) | |
| data_finetune['images'] = tgt_image.unsqueeze(1) | |
| self.gen_update(data_finetune) | |
| self.dis_update(data_finetune) | |
| if (it % (iterations // 10)) == 0: | |
| print(it) | |
| self.has_finetuned = True | |