Spaces:
Build error
Build error
| # Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved | |
| # Copyright (c) 2021, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| import sys | |
| import copy | |
| import traceback | |
| import numpy as np | |
| import torch | |
| import torch.fft | |
| import torch.nn | |
| import matplotlib.cm | |
| import dnnlib | |
| import torch.nn.functional as F | |
| from torch_utils import misc | |
| from torch_utils.ops import upfirdn2d | |
| from training.networks import Generator | |
| import legacy # pylint: disable=import-error | |
| #---------------------------------------------------------------------------- | |
| class CapturedException(Exception): | |
| def __init__(self, msg=None): | |
| if msg is None: | |
| _type, value, _traceback = sys.exc_info() | |
| assert value is not None | |
| if isinstance(value, CapturedException): | |
| msg = str(value) | |
| else: | |
| msg = traceback.format_exc() | |
| assert isinstance(msg, str) | |
| super().__init__(msg) | |
| #---------------------------------------------------------------------------- | |
| class CaptureSuccess(Exception): | |
| def __init__(self, out): | |
| super().__init__() | |
| self.out = out | |
| #---------------------------------------------------------------------------- | |
| def _sinc(x): | |
| y = (x * np.pi).abs() | |
| z = torch.sin(y) / y.clamp(1e-30, float('inf')) | |
| return torch.where(y < 1e-30, torch.ones_like(x), z) | |
| def _lanczos_window(x, a): | |
| x = x.abs() / a | |
| return torch.where(x < 1, _sinc(x), torch.zeros_like(x)) | |
| #---------------------------------------------------------------------------- | |
| def _construct_affine_bandlimit_filter(mat, a=3, amax=16, aflt=64, up=4, cutoff_in=1, cutoff_out=1): | |
| assert a <= amax < aflt | |
| mat = torch.as_tensor(mat).to(torch.float32) | |
| # Construct 2D filter taps in input & output coordinate spaces. | |
| taps = ((torch.arange(aflt * up * 2 - 1, device=mat.device) + 1) / up - aflt).roll(1 - aflt * up) | |
| yi, xi = torch.meshgrid(taps, taps) | |
| xo, yo = (torch.stack([xi, yi], dim=2) @ mat[:2, :2].t()).unbind(2) | |
| # Convolution of two oriented 2D sinc filters. | |
| fi = _sinc(xi * cutoff_in) * _sinc(yi * cutoff_in) | |
| fo = _sinc(xo * cutoff_out) * _sinc(yo * cutoff_out) | |
| f = torch.fft.ifftn(torch.fft.fftn(fi) * torch.fft.fftn(fo)).real | |
| # Convolution of two oriented 2D Lanczos windows. | |
| wi = _lanczos_window(xi, a) * _lanczos_window(yi, a) | |
| wo = _lanczos_window(xo, a) * _lanczos_window(yo, a) | |
| w = torch.fft.ifftn(torch.fft.fftn(wi) * torch.fft.fftn(wo)).real | |
| # Construct windowed FIR filter. | |
| f = f * w | |
| # Finalize. | |
| c = (aflt - amax) * up | |
| f = f.roll([aflt * up - 1] * 2, dims=[0,1])[c:-c, c:-c] | |
| f = torch.nn.functional.pad(f, [0, 1, 0, 1]).reshape(amax * 2, up, amax * 2, up) | |
| f = f / f.sum([0,2], keepdim=True) / (up ** 2) | |
| f = f.reshape(amax * 2 * up, amax * 2 * up)[:-1, :-1] | |
| return f | |
| #---------------------------------------------------------------------------- | |
| def _apply_affine_transformation(x, mat, up=4, **filter_kwargs): | |
| _N, _C, H, W = x.shape | |
| mat = torch.as_tensor(mat).to(dtype=torch.float32, device=x.device) | |
| # Construct filter. | |
| f = _construct_affine_bandlimit_filter(mat, up=up, **filter_kwargs) | |
| assert f.ndim == 2 and f.shape[0] == f.shape[1] and f.shape[0] % 2 == 1 | |
| p = f.shape[0] // 2 | |
| # Construct sampling grid. | |
| theta = mat.inverse() | |
| theta[:2, 2] *= 2 | |
| theta[0, 2] += 1 / up / W | |
| theta[1, 2] += 1 / up / H | |
| theta[0, :] *= W / (W + p / up * 2) | |
| theta[1, :] *= H / (H + p / up * 2) | |
| theta = theta[:2, :3].unsqueeze(0).repeat([x.shape[0], 1, 1]) | |
| g = torch.nn.functional.affine_grid(theta, x.shape, align_corners=False) | |
| # Resample image. | |
| y = upfirdn2d.upsample2d(x=x, f=f, up=up, padding=p) | |
| z = torch.nn.functional.grid_sample(y, g, mode='bilinear', padding_mode='zeros', align_corners=False) | |
| # Form mask. | |
| m = torch.zeros_like(y) | |
| c = p * 2 + 1 | |
| m[:, :, c:-c, c:-c] = 1 | |
| m = torch.nn.functional.grid_sample(m, g, mode='nearest', padding_mode='zeros', align_corners=False) | |
| return z, m | |
| #---------------------------------------------------------------------------- | |
| def set_random_seed(seed): | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| class Renderer: | |
| def __init__(self): | |
| self._device = torch.device('cuda') | |
| self._pkl_data = dict() # {pkl: dict | CapturedException, ...} | |
| self._networks = dict() # {cache_key: torch.nn.Module, ...} | |
| self._pinned_bufs = dict() # {(shape, dtype): torch.Tensor, ...} | |
| self._cmaps = dict() # {name: torch.Tensor, ...} | |
| self._is_timing = False | |
| self._start_event = torch.cuda.Event(enable_timing=True) | |
| self._end_event = torch.cuda.Event(enable_timing=True) | |
| self._net_layers = dict() # {cache_key: [dnnlib.EasyDict, ...], ...} | |
| def render(self, **args): | |
| self._is_timing = True | |
| self._start_event.record(torch.cuda.current_stream(self._device)) | |
| res = dnnlib.EasyDict() | |
| try: | |
| self._render_impl(res, **args) | |
| except: | |
| res.error = CapturedException() | |
| self._end_event.record(torch.cuda.current_stream(self._device)) | |
| if 'error' in res: | |
| res.error = str(res.error) | |
| if self._is_timing: | |
| self._end_event.synchronize() | |
| res.render_time = self._start_event.elapsed_time(self._end_event) * 1e-3 | |
| self._is_timing = False | |
| return res | |
| def get_network(self, pkl, key, **tweak_kwargs): | |
| data = self._pkl_data.get(pkl, None) | |
| if data is None: | |
| print(f'Loading "{pkl}"... ', end='', flush=True) | |
| try: | |
| with dnnlib.util.open_url(pkl, verbose=False) as f: | |
| data = legacy.load_network_pkl(f) | |
| print('Done.') | |
| except: | |
| data = CapturedException() | |
| print('Failed!') | |
| self._pkl_data[pkl] = data | |
| self._ignore_timing() | |
| if isinstance(data, CapturedException): | |
| raise data | |
| orig_net = data[key] | |
| cache_key = (orig_net, self._device, tuple(sorted(tweak_kwargs.items()))) | |
| net = self._networks.get(cache_key, None) | |
| if net is None: | |
| try: | |
| net = copy.deepcopy(orig_net) | |
| net = self._tweak_network(net, **tweak_kwargs) | |
| net.to(self._device) | |
| except: | |
| net = CapturedException() | |
| self._networks[cache_key] = net | |
| self._ignore_timing() | |
| if isinstance(net, CapturedException): | |
| raise net | |
| return net | |
| def get_camera_traj(self, gen, pitch, yaw, fov=12, batch_size=1, model_name='FFHQ512'): | |
| range_u, range_v = gen.C.range_u, gen.C.range_v | |
| if not (('car' in model_name) or ('Car' in model_name)): # TODO: hack, better option? | |
| yaw, pitch = 0.5 * yaw, 0.3 * pitch | |
| pitch = pitch + np.pi/2 | |
| u = (yaw - range_u[0]) / (range_u[1] - range_u[0]) | |
| v = (pitch - range_v[0]) / (range_v[1] - range_v[0]) | |
| else: | |
| u = (yaw + 1) / 2 | |
| v = (pitch + 1) / 2 | |
| cam = gen.get_camera(batch_size=batch_size, mode=[u, v, 0.5], device=self._device, fov=fov) | |
| return cam | |
| def _tweak_network(self, net): | |
| # Print diagnostics. | |
| #for name, value in misc.named_params_and_buffers(net): | |
| # if name.endswith('.magnitude_ema'): | |
| # value = value.rsqrt().numpy() | |
| # print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}') | |
| # if name.endswith('.weight') and value.ndim == 4: | |
| # value = value.square().mean([1,2,3]).sqrt().numpy() | |
| # print(f'{name:<50s}{np.min(value):<16g}{np.max(value):g}') | |
| return net | |
| def _get_pinned_buf(self, ref): | |
| key = (tuple(ref.shape), ref.dtype) | |
| buf = self._pinned_bufs.get(key, None) | |
| if buf is None: | |
| buf = torch.empty(ref.shape, dtype=ref.dtype).pin_memory() | |
| self._pinned_bufs[key] = buf | |
| return buf | |
| def to_device(self, buf): | |
| return self._get_pinned_buf(buf).copy_(buf).to(self._device) | |
| def to_cpu(self, buf): | |
| return self._get_pinned_buf(buf).copy_(buf).clone() | |
| def _ignore_timing(self): | |
| self._is_timing = False | |
| def _apply_cmap(self, x, name='viridis'): | |
| cmap = self._cmaps.get(name, None) | |
| if cmap is None: | |
| cmap = matplotlib.cm.get_cmap(name) | |
| cmap = cmap(np.linspace(0, 1, num=1024), bytes=True)[:, :3] | |
| cmap = self.to_device(torch.from_numpy(cmap)) | |
| self._cmaps[name] = cmap | |
| hi = cmap.shape[0] - 1 | |
| x = (x * hi + 0.5).clamp(0, hi).to(torch.int64) | |
| x = torch.nn.functional.embedding(x, cmap) | |
| return x | |
| def _render_impl(self, res, | |
| pkl = None, | |
| w0_seeds = [[0, 1]], | |
| stylemix_idx = [], | |
| stylemix_seed = 0, | |
| trunc_psi = 1, | |
| trunc_cutoff = 0, | |
| random_seed = 0, | |
| noise_mode = 'const', | |
| force_fp32 = False, | |
| layer_name = None, | |
| sel_channels = 3, | |
| base_channel = 0, | |
| img_scale_db = 0, | |
| img_normalize = False, | |
| fft_show = False, | |
| fft_all = True, | |
| fft_range_db = 50, | |
| fft_beta = 8, | |
| input_transform = None, | |
| untransform = False, | |
| camera = None, | |
| output_lowres = False, | |
| **unused, | |
| ): | |
| # Dig up network details. | |
| _G = self.get_network(pkl, 'G_ema') | |
| try: | |
| G = Generator(*_G.init_args, **_G.init_kwargs).to(self._device) | |
| misc.copy_params_and_buffers(_G, G, require_all=False) | |
| except Exception: | |
| G = _G | |
| G.eval() | |
| res.img_resolution = G.img_resolution | |
| res.num_ws = G.num_ws | |
| res.has_noise = any('noise_const' in name for name, _buf in G.synthesis.named_buffers()) | |
| res.has_input_transform = (hasattr(G.synthesis, 'input') and hasattr(G.synthesis.input, 'transform')) | |
| # Set input transform. | |
| if res.has_input_transform: | |
| m = np.eye(3) | |
| try: | |
| if input_transform is not None: | |
| m = np.linalg.inv(np.asarray(input_transform)) | |
| except np.linalg.LinAlgError: | |
| res.error = CapturedException() | |
| G.synthesis.input.transform.copy_(torch.from_numpy(m)) | |
| # Generate random latents. | |
| all_seeds = [seed for seed, _weight in w0_seeds] + [stylemix_seed] | |
| all_seeds = list(set(all_seeds)) | |
| all_zs = np.zeros([len(all_seeds), G.z_dim], dtype=np.float32) | |
| all_cs = np.zeros([len(all_seeds), G.c_dim], dtype=np.float32) | |
| for idx, seed in enumerate(all_seeds): | |
| rnd = np.random.RandomState(seed) | |
| all_zs[idx] = rnd.randn(G.z_dim) | |
| if G.c_dim > 0: | |
| all_cs[idx, rnd.randint(G.c_dim)] = 1 | |
| # Run mapping network. | |
| w_avg = G.mapping.w_avg | |
| all_zs = self.to_device(torch.from_numpy(all_zs)) | |
| all_cs = self.to_device(torch.from_numpy(all_cs)) | |
| all_ws = G.mapping(z=all_zs, c=all_cs, truncation_psi=trunc_psi, truncation_cutoff=trunc_cutoff) - w_avg | |
| all_ws = dict(zip(all_seeds, all_ws)) | |
| # Calculate final W. | |
| w = torch.stack([all_ws[seed] * weight for seed, weight in w0_seeds]).sum(dim=0, keepdim=True) | |
| stylemix_idx = [idx for idx in stylemix_idx if 0 <= idx < G.num_ws] | |
| if len(stylemix_idx) > 0: | |
| w[:, stylemix_idx] = all_ws[stylemix_seed][np.newaxis, stylemix_idx] | |
| w += w_avg | |
| # Run synthesis network. | |
| synthesis_kwargs = dnnlib.EasyDict(noise_mode=noise_mode, force_fp32=force_fp32) | |
| set_random_seed(random_seed) | |
| if hasattr(G.synthesis, 'C'): | |
| synthesis_kwargs.update({'camera_matrices': camera}) | |
| out, out_lowres, layers = self.run_synthesis_net(G.synthesis, w, capture_layer=layer_name, **synthesis_kwargs) | |
| # Update layer list. | |
| cache_key = (G.synthesis, tuple(sorted(synthesis_kwargs.items()))) | |
| if cache_key not in self._net_layers: | |
| self._net_layers = dict() | |
| if layer_name is not None: | |
| torch.manual_seed(random_seed) | |
| _out, _out2, layers = self.run_synthesis_net(G.synthesis, w, **synthesis_kwargs) | |
| self._net_layers[cache_key] = layers | |
| res.layers = self._net_layers[cache_key] | |
| # Untransform. | |
| if untransform and res.has_input_transform: | |
| out, _mask = _apply_affine_transformation(out.to(torch.float32), G.synthesis.input.transform, amax=6) # Override amax to hit the fast path in upfirdn2d. | |
| # Select channels and compute statistics. | |
| if output_lowres and out_lowres is not None: | |
| out = torch.cat([out, F.interpolate(out_lowres, out.size(-1), mode='nearest')], -1) | |
| out = out[0].to(torch.float32) | |
| if sel_channels > out.shape[0]: | |
| sel_channels = 1 | |
| base_channel = max(min(base_channel, out.shape[0] - sel_channels), 0) | |
| sel = out[base_channel : base_channel + sel_channels] | |
| res.stats = torch.stack([ | |
| out.mean(), sel.mean(), | |
| out.std(), sel.std(), | |
| out.norm(float('inf')), sel.norm(float('inf')), | |
| ]) | |
| res.stats = self.to_cpu(res.stats).numpy() # move to cpu | |
| # Scale and convert to uint8. | |
| img = sel | |
| if img_normalize: | |
| img = img / img.norm(float('inf'), dim=[1,2], keepdim=True).clip(1e-8, 1e8) | |
| img = img * (10 ** (img_scale_db / 20)) | |
| img = (img * 127.5 + 128).clamp(0, 255).to(torch.uint8).permute(1, 2, 0) | |
| res.image = img | |
| # FFT. | |
| if fft_show: | |
| sig = out if fft_all else sel | |
| sig = sig.to(torch.float32) | |
| sig = sig - sig.mean(dim=[1,2], keepdim=True) | |
| sig = sig * torch.kaiser_window(sig.shape[1], periodic=False, beta=fft_beta, device=self._device)[None, :, None] | |
| sig = sig * torch.kaiser_window(sig.shape[2], periodic=False, beta=fft_beta, device=self._device)[None, None, :] | |
| fft = torch.fft.fftn(sig, dim=[1,2]).abs().square().sum(dim=0) | |
| fft = fft.roll(shifts=[fft.shape[0] // 2, fft.shape[1] // 2], dims=[0,1]) | |
| fft = (fft / fft.mean()).log10() * 10 # dB | |
| fft = self._apply_cmap((fft / fft_range_db + 1) / 2) | |
| res.image = torch.cat([img.expand_as(fft), fft], dim=1) | |
| res.image = self.to_cpu(res.image).numpy() # move to cpu | |
| def run_synthesis_net(self, net, *args, capture_layer=None, **kwargs): # => out, layers | |
| submodule_names = {mod: name for name, mod in net.named_modules()} | |
| unique_names = set() | |
| layers = [] | |
| def module_hook(module, _inputs, outputs): | |
| outputs = list(outputs) if isinstance(outputs, (tuple, list)) else [outputs] | |
| outputs = [out for out in outputs if isinstance(out, torch.Tensor) and out.ndim in [4, 5]] | |
| for idx, out in enumerate(outputs): | |
| if out.ndim == 5: # G-CNN => remove group dimension. | |
| out = out.mean(2) | |
| name = submodule_names[module] | |
| if name == '': | |
| name = 'output' | |
| if len(outputs) > 1: | |
| name += f':{idx}' | |
| if name in unique_names: | |
| suffix = 2 | |
| while f'{name}_{suffix}' in unique_names: | |
| suffix += 1 | |
| name += f'_{suffix}' | |
| unique_names.add(name) | |
| shape = [int(x) for x in out.shape] | |
| dtype = str(out.dtype).split('.')[-1] | |
| layers.append(dnnlib.EasyDict(name=name, shape=shape, dtype=dtype)) | |
| if name == capture_layer: | |
| raise CaptureSuccess(out) | |
| hooks = [] | |
| hooks = [module.register_forward_hook(module_hook) for module in net.modules()] | |
| try: | |
| if 'camera_matrices' in kwargs: | |
| kwargs['camera_matrices'] = self.get_camera_traj(net, *kwargs['camera_matrices']) | |
| out = net(*args, **kwargs) | |
| out_lowres = None | |
| if isinstance(out, dict): | |
| if 'img_nerf' in out: | |
| out_lowres = out['img_nerf'] | |
| out = out['img'] | |
| except CaptureSuccess as e: | |
| out = e.out | |
| out_lowres = None | |
| for hook in hooks: | |
| hook.remove() | |
| return out, out_lowres, layers | |
| #---------------------------------------------------------------------------- | |