Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| """Generate images using pretrained network pickle.""" | |
| import os | |
| import re | |
| import time | |
| import glob | |
| from typing import List, Optional | |
| import click | |
| import dnnlib | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| import imageio | |
| import legacy | |
| from renderer import Renderer | |
| #---------------------------------------------------------------------------- | |
| def num_range(s: str) -> List[int]: | |
| '''Accept either a comma separated list of numbers 'a,b,c' or a range 'a-c' and return as a list of ints.''' | |
| range_re = re.compile(r'^(\d+)-(\d+)$') | |
| m = range_re.match(s) | |
| if m: | |
| return list(range(int(m.group(1)), int(m.group(2))+1)) | |
| vals = s.split(',') | |
| return [int(x) for x in vals] | |
| #---------------------------------------------------------------------------- | |
| os.environ['PYOPENGL_PLATFORM'] = 'egl' | |
| def generate_images( | |
| ctx: click.Context, | |
| network_pkl: str, | |
| seeds: Optional[List[int]], | |
| truncation_psi: float, | |
| noise_mode: str, | |
| outdir: str, | |
| class_idx: Optional[int], | |
| projected_w: Optional[str], | |
| render_program=None, | |
| render_option=None, | |
| n_steps=8, | |
| no_video=False, | |
| relative_range_u_scale=1.0 | |
| ): | |
| device = torch.device('cuda') | |
| if os.path.isdir(network_pkl): | |
| network_pkl = sorted(glob.glob(network_pkl + '/*.pkl'))[-1] | |
| print('Loading networks from "%s"...' % network_pkl) | |
| with dnnlib.util.open_url(network_pkl) as f: | |
| network = legacy.load_network_pkl(f) | |
| G = network['G_ema'].to(device) # type: ignore | |
| D = network['D'].to(device) | |
| # from fairseq import pdb;pdb.set_trace() | |
| os.makedirs(outdir, exist_ok=True) | |
| # Labels. | |
| label = torch.zeros([1, G.c_dim], device=device) | |
| if G.c_dim != 0: | |
| if class_idx is None: | |
| ctx.fail('Must specify class label with --class when using a conditional network') | |
| label[:, class_idx] = 1 | |
| else: | |
| if class_idx is not None: | |
| print ('warn: --class=lbl ignored when running on an unconditional network') | |
| # avoid persistent classes... | |
| from training.networks import Generator | |
| # from training.stylenerf import Discriminator | |
| from torch_utils import misc | |
| with torch.no_grad(): | |
| G2 = Generator(*G.init_args, **G.init_kwargs).to(device) | |
| misc.copy_params_and_buffers(G, G2, require_all=False) | |
| # D2 = Discriminator(*D.init_args, **D.init_kwargs).to(device) | |
| # misc.copy_params_and_buffers(D, D2, require_all=False) | |
| G2 = Renderer(G2, D, program=render_program) | |
| # Generate images. | |
| all_imgs = [] | |
| def stack_imgs(imgs): | |
| img = torch.stack(imgs, dim=2) | |
| return img.reshape(img.size(0) * img.size(1), img.size(2) * img.size(3), 3) | |
| def proc_img(img): | |
| return (img.permute(0, 2, 3, 1) * 127.5 + 128).clamp(0, 255).to(torch.uint8).cpu() | |
| if projected_w is not None: | |
| ws = np.load(projected_w) | |
| ws = torch.tensor(ws, device=device) # pylint: disable=not-callable | |
| img = G2(styles=ws, truncation_psi=truncation_psi, noise_mode=noise_mode, render_option=render_option) | |
| assert isinstance(img, List) | |
| imgs = [proc_img(i) for i in img] | |
| all_imgs += [imgs] | |
| else: | |
| for seed_idx, seed in enumerate(seeds): | |
| print('Generating image for seed %d (%d/%d) ...' % (seed, seed_idx, len(seeds))) | |
| G2.set_random_seed(seed) | |
| z = torch.from_numpy(np.random.RandomState(seed).randn(2, G.z_dim)).to(device) | |
| relative_range_u = [0.5 - 0.5 * relative_range_u_scale, 0.5 + 0.5 * relative_range_u_scale] | |
| outputs = G2( | |
| z=z, | |
| c=label, | |
| truncation_psi=truncation_psi, | |
| noise_mode=noise_mode, | |
| render_option=render_option, | |
| n_steps=n_steps, | |
| relative_range_u=relative_range_u, | |
| return_cameras=True) | |
| if isinstance(outputs, tuple): | |
| img, cameras = outputs | |
| else: | |
| img = outputs | |
| if isinstance(img, List): | |
| imgs = [proc_img(i) for i in img] | |
| if not no_video: | |
| all_imgs += [imgs] | |
| curr_out_dir = os.path.join(outdir, 'seed_{:0>6d}'.format(seed)) | |
| os.makedirs(curr_out_dir, exist_ok=True) | |
| if (render_option is not None) and ("gen_ibrnet_metadata" in render_option): | |
| intrinsics = [] | |
| poses = [] | |
| _, H, W, _ = imgs[0].shape | |
| for i, camera in enumerate(cameras): | |
| intri, pose, _, _ = camera | |
| focal = (H - 1) * 0.5 / intri[0, 0, 0].item() | |
| intri = np.diag([focal, focal, 1.0, 1.0]).astype(np.float32) | |
| intri[0, 2], intri[1, 2] = (W - 1) * 0.5, (H - 1) * 0.5 | |
| pose = pose.squeeze().detach().cpu().numpy() @ np.diag([1, -1, -1, 1]).astype(np.float32) | |
| intrinsics.append(intri) | |
| poses.append(pose) | |
| intrinsics = np.stack(intrinsics, axis=0) | |
| poses = np.stack(poses, axis=0) | |
| np.savez(os.path.join(curr_out_dir, 'cameras.npz'), intrinsics=intrinsics, poses=poses) | |
| with open(os.path.join(curr_out_dir, 'meta.conf'), 'w') as f: | |
| f.write('depth_range = {}\ntest_hold_out = {}\nheight = {}\nwidth = {}'. | |
| format(G2.generator.synthesis.depth_range, 2, H, W)) | |
| img_dir = os.path.join(curr_out_dir, 'images_raw') | |
| os.makedirs(img_dir, exist_ok=True) | |
| for step, img in enumerate(imgs): | |
| PIL.Image.fromarray(img[0].detach().cpu().numpy(), 'RGB').save(f'{img_dir}/{step:03d}.png') | |
| else: | |
| img = proc_img(img)[0] | |
| PIL.Image.fromarray(img.numpy(), 'RGB').save(f'{outdir}/seed_{seed:0>6d}.png') | |
| if len(all_imgs) > 0 and (not no_video): | |
| # write to video | |
| timestamp = time.strftime('%Y%m%d.%H%M%S',time.localtime(time.time())) | |
| seeds = ','.join([str(s) for s in seeds]) if seeds is not None else 'projected' | |
| network_pkl = network_pkl.split('/')[-1].split('.')[0] | |
| all_imgs = [stack_imgs([a[k] for a in all_imgs]).numpy() for k in range(len(all_imgs[0]))] | |
| imageio.mimwrite(f'{outdir}/{network_pkl}_{timestamp}_{seeds}.mp4', all_imgs, fps=30, quality=8) | |
| outdir = f'{outdir}/{network_pkl}_{timestamp}_{seeds}' | |
| os.makedirs(outdir, exist_ok=True) | |
| for step, img in enumerate(all_imgs): | |
| PIL.Image.fromarray(img, 'RGB').save(f'{outdir}/{step:04d}.png') | |
| #---------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| generate_images() # pylint: disable=no-value-for-parameter | |
| #---------------------------------------------------------------------------- | |