Spaces:
				
			
			
	
			
			
		Runtime error
		
	
	
	
			
			
	
	
	
	
		
		
		Runtime error
		
	| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # NVIDIA CORPORATION and its licensors retain all intellectual property | |
| # and proprietary rights in and to this software, related documentation | |
| # and any modifications thereto. Any use, reproduction, disclosure or | |
| # distribution of this software and related documentation without an express | |
| # license agreement from NVIDIA CORPORATION is strictly prohibited. | |
| """Project given image to the latent space of pretrained network pickle.""" | |
| import copy | |
| import os | |
| from time import perf_counter | |
| import click | |
| import imageio | |
| import numpy as np | |
| import PIL.Image | |
| import torch | |
| import torch.nn.functional as F | |
| import dnnlib | |
| import legacy | |
| _MODELS = { | |
| "RN50": "https://openaipublic.azureedge.net/clip/models/afeb0e10f9e5a86da6080e35cf09123aca3b358a0c3e3b6c78a7b63bc04b6762/RN50.pt", | |
| "RN101": "https://openaipublic.azureedge.net/clip/models/8fa8567bab74a42d41c5915025a8e4538c3bdbe8804a470a72f30b0d94fab599/RN101.pt", | |
| "RN50x4": "https://openaipublic.azureedge.net/clip/models/7e526bd135e493cef0776de27d5f42653e6b4c8bf9e0f653bb11773263205fdd/RN50x4.pt", | |
| "RN50x16": "https://openaipublic.azureedge.net/clip/models/52378b407f34354e150460fe41077663dd5b39c54cd0bfd2b27167a4a06ec9aa/RN50x16.pt", | |
| "RN50x64": "https://openaipublic.azureedge.net/clip/models/be1cfb55d75a9666199fb2206c106743da0f6468c9d327f3e0d0a543a9919d9c/RN50x64.pt", | |
| "ViT-B/32": "https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt", | |
| "ViT-B/16": "https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt", | |
| "ViT-L/14": "https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt", | |
| "ViT-L/14@336px": "https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt", | |
| } | |
| def project( | |
| G, | |
| target: torch.Tensor, # [C,H,W] and dynamic range [0,255], W & H must match G output resolution | |
| *, | |
| num_steps = 1000, | |
| w_avg_samples = 10000, | |
| initial_learning_rate = 0.1, | |
| initial_noise_factor = 0.05, | |
| lr_rampdown_length = 0.25, | |
| lr_rampup_length = 0.05, | |
| noise_ramp_length = 0.75, | |
| regularize_noise_weight = 1e5, | |
| verbose = False, | |
| model_name='vgg16', | |
| loss_type='l2', | |
| normalize_for_clip=True, | |
| device: torch.device | |
| ): | |
| assert target.shape == (G.img_channels, G.img_resolution, G.img_resolution) | |
| def logprint(*args): | |
| if verbose: | |
| print(*args) | |
| G = copy.deepcopy(G).eval().requires_grad_(False).to(device) # type: ignore | |
| # Compute w stats. | |
| logprint(f'Computing W midpoint and stddev using {w_avg_samples} samples...') | |
| z_samples = np.random.RandomState(123).randn(w_avg_samples, G.z_dim) | |
| w_samples = G.mapping(torch.from_numpy(z_samples).to(device), None) # [N, L, C] | |
| w_samples = w_samples[:, :1, :].cpu().numpy().astype(np.float32) # [N, 1, C] | |
| w_avg = np.mean(w_samples, axis=0, keepdims=True) # [1, 1, C] | |
| w_std = (np.sum((w_samples - w_avg) ** 2) / w_avg_samples) ** 0.5 | |
| # Setup noise inputs. | |
| noise_bufs = { name: buf for (name, buf) in G.synthesis.named_buffers() if 'noise_const' in name } | |
| USE_CLIP = model_name != 'vgg16' | |
| # Load VGG16 feature detector. | |
| url = 'https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/metrics/vgg16.pt' | |
| if USE_CLIP: | |
| # url = 'https://openaipublic.azureedge.net/clip/models/40d365715913c9da98579312b702a82c18be219cc2a73407c4526f58eba950af/ViT-B-32.pt' | |
| # url = 'https://openaipublic.azureedge.net/clip/models/5806e77cd80f8b59890b7e101eabd078d9fb84e6937f9e85e4ecb61988df416f/ViT-B-16.pt' | |
| # url = 'https://openaipublic.azureedge.net/clip/models/b8cca3fd41ae0c99ba7e8951adf17d267cdb84cd88be6f7c2e0eca1737a03836/ViT-L-14.pt' | |
| # url = 'https://openaipublic.azureedge.net/clip/models/3035c92b350959924f9f00213499208652fc7ea050643e8b385c2dac08641f02/ViT-L-14-336px.pt' | |
| url = _MODELS[model_name] | |
| with dnnlib.util.open_url(url) as f: | |
| vgg16 = torch.jit.load(f).eval().to(device) | |
| # Features for target image. | |
| target_images = target.unsqueeze(0).to(device).to(torch.float32) | |
| if USE_CLIP: | |
| image_mean = torch.tensor([0.48145466, 0.4578275, 0.40821073]).to(device)[:, None, None] | |
| image_std = torch.tensor([0.26862954, 0.26130258, 0.27577711]).to(device)[:, None, None] | |
| # target_images = F.interpolate(target_images, size=(224, 224), mode='area') | |
| target_images = F.interpolate(target_images, size=(vgg16.input_resolution.item(), vgg16.input_resolution.item()), mode='area') | |
| print("target_images.shape:", target_images.shape) | |
| def _encode_image(image): | |
| image = image / 255. | |
| # image = torch.sigmoid(image) | |
| if normalize_for_clip: | |
| image = (image - image_mean) / image_std | |
| return vgg16.encode_image(image) | |
| target_features = _encode_image(target_images.clamp(0, 255)) | |
| target_features = target_features.detach() | |
| else: | |
| if target_images.shape[2] > 256: | |
| target_images = F.interpolate(target_images, size=(256, 256), mode='area') | |
| target_features = vgg16(target_images, resize_images=False, return_lpips=True) | |
| w_opt = torch.tensor(w_avg, dtype=torch.float32, device=device, requires_grad=True) # pylint: disable=not-callable | |
| w_out = torch.zeros([num_steps] + list(w_opt.shape[1:]), dtype=torch.float32, device=device) | |
| optimizer = torch.optim.Adam([w_opt] + list(noise_bufs.values()), betas=(0.9, 0.999), lr=initial_learning_rate) | |
| # Init noise. | |
| for buf in noise_bufs.values(): | |
| buf[:] = torch.randn_like(buf) | |
| buf.requires_grad = True | |
| for step in range(num_steps): | |
| # Learning rate schedule. | |
| t = step / num_steps | |
| w_noise_scale = w_std * initial_noise_factor * max(0.0, 1.0 - t / noise_ramp_length) ** 2 | |
| lr_ramp = min(1.0, (1.0 - t) / lr_rampdown_length) | |
| lr_ramp = 0.5 - 0.5 * np.cos(lr_ramp * np.pi) | |
| lr_ramp = lr_ramp * min(1.0, t / lr_rampup_length) | |
| lr = initial_learning_rate * lr_ramp | |
| for param_group in optimizer.param_groups: | |
| param_group['lr'] = lr | |
| # Synth images from opt_w. | |
| w_noise = torch.randn_like(w_opt) * w_noise_scale | |
| ws = (w_opt + w_noise).repeat([1, G.mapping.num_ws, 1]) | |
| synth_images = G.synthesis(ws, noise_mode='const') | |
| # Downsample image to 256x256 if it's larger than that. VGG was built for 224x224 images. | |
| synth_images = (synth_images + 1) * (255/2) | |
| if synth_images.shape[2] > 256: | |
| synth_images = F.interpolate(synth_images, size=(256, 256), mode='area') | |
| # Features for synth images. | |
| if USE_CLIP: | |
| synth_images = F.interpolate(synth_images, size=(vgg16.input_resolution.item(), vgg16.input_resolution.item()), mode='area') | |
| synth_features = _encode_image(synth_images) | |
| if loss_type == 'cosine': | |
| target_features_normalized = target_features / target_features.norm(dim=-1, keepdim=True).detach() | |
| synth_features_normalized = synth_features / synth_features.norm(dim=-1, keepdim=True).detach() | |
| dist = 1.0 - torch.sum(synth_features_normalized * target_features_normalized) | |
| elif loss_type == 'l1': | |
| dist = (target_features - synth_features).abs().sum() | |
| else: | |
| dist = (target_features - synth_features).square().sum() | |
| else: | |
| synth_features = vgg16(synth_images, resize_images=False, return_lpips=True) | |
| dist = (target_features - synth_features).square().sum() | |
| # Noise regularization. | |
| reg_loss = 0.0 | |
| for v in noise_bufs.values(): | |
| noise = v[None,None,:,:] # must be [1,1,H,W] for F.avg_pool2d() | |
| while True: | |
| reg_loss += (noise*torch.roll(noise, shifts=1, dims=3)).mean()**2 | |
| reg_loss += (noise*torch.roll(noise, shifts=1, dims=2)).mean()**2 | |
| if noise.shape[2] <= 8: | |
| break | |
| noise = F.avg_pool2d(noise, kernel_size=2) | |
| loss = dist + reg_loss * regularize_noise_weight | |
| # Step | |
| optimizer.zero_grad(set_to_none=True) | |
| loss.backward() | |
| optimizer.step() | |
| logprint(f'step {step+1:>4d}/{num_steps}: dist {dist:<4.2f} loss {float(loss):<5.2f}') | |
| # Save projected W for each optimization step. | |
| w_out[step] = w_opt.detach()[0] | |
| # Normalize noise. | |
| with torch.no_grad(): | |
| for buf in noise_bufs.values(): | |
| buf -= buf.mean() | |
| buf *= buf.square().mean().rsqrt() | |
| return w_out.repeat([1, G.mapping.num_ws, 1]) | |
| #---------------------------------------------------------------------------- | |
| def run_projection( | |
| network_pkl: str, | |
| target_fname: str, | |
| outdir: str, | |
| save_video: bool, | |
| seed: int, | |
| num_steps: int | |
| ): | |
| """Project given image to the latent space of pretrained network pickle. | |
| Examples: | |
| \b | |
| python projector.py --outdir=out --target=~/mytargetimg.png \\ | |
| --network=https://nvlabs-fi-cdn.nvidia.com/stylegan2-ada-pytorch/pretrained/ffhq.pkl | |
| """ | |
| np.random.seed(seed) | |
| torch.manual_seed(seed) | |
| # Load networks. | |
| print('Loading networks from "%s"...' % network_pkl) | |
| device = torch.device('cuda') | |
| with dnnlib.util.open_url(network_pkl) as fp: | |
| G = legacy.load_network_pkl(fp)['G_ema'].requires_grad_(False).to(device) # type: ignore | |
| # Load target image. | |
| target_pil = PIL.Image.open(target_fname).convert('RGB') | |
| w, h = target_pil.size | |
| s = min(w, h) | |
| target_pil = target_pil.crop(((w - s) // 2, (h - s) // 2, (w + s) // 2, (h + s) // 2)) | |
| target_pil = target_pil.resize((G.img_resolution, G.img_resolution), PIL.Image.LANCZOS) | |
| target_uint8 = np.array(target_pil, dtype=np.uint8) | |
| # Optimize projection. | |
| start_time = perf_counter() | |
| projected_w_steps = project( | |
| G, | |
| target=torch.tensor(target_uint8.transpose([2, 0, 1]), device=device), # pylint: disable=not-callable | |
| num_steps=num_steps, | |
| device=device, | |
| verbose=True | |
| ) | |
| print (f'Elapsed: {(perf_counter()-start_time):.1f} s') | |
| # Render debug output: optional video and projected image and W vector. | |
| os.makedirs(outdir, exist_ok=True) | |
| if save_video: | |
| video = imageio.get_writer(f'{outdir}/proj.mp4', mode='I', fps=10, codec='libx264', bitrate='16M') | |
| print (f'Saving optimization progress video "{outdir}/proj.mp4"') | |
| for projected_w in projected_w_steps: | |
| synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') | |
| synth_image = (synth_image + 1) * (255/2) | |
| synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | |
| video.append_data(np.concatenate([target_uint8, synth_image], axis=1)) | |
| video.close() | |
| # Save final projected frame and W vector. | |
| target_pil.save(f'{outdir}/target.png') | |
| projected_w = projected_w_steps[-1] | |
| synth_image = G.synthesis(projected_w.unsqueeze(0), noise_mode='const') | |
| synth_image = (synth_image + 1) * (255/2) | |
| synth_image = synth_image.permute(0, 2, 3, 1).clamp(0, 255).to(torch.uint8)[0].cpu().numpy() | |
| PIL.Image.fromarray(synth_image, 'RGB').save(f'{outdir}/proj.png') | |
| np.savez(f'{outdir}/projected_w.npz', w=projected_w.unsqueeze(0).cpu().numpy()) | |
| #---------------------------------------------------------------------------- | |
| if __name__ == "__main__": | |
| run_projection() # pylint: disable=no-value-for-parameter | |
| #---------------------------------------------------------------------------- | |
