Spaces:
Build error
Build error
| from pathlib import Path | |
| from math import sin, pi, sqrt | |
| from functools import partial | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from easydict import EasyDict | |
| from guided_diffusion.script_util import ( | |
| create_model_and_diffusion, | |
| model_and_diffusion_defaults, | |
| NUM_CLASSES, | |
| create_classifier, | |
| classifier_defaults, | |
| sr_create_model_and_diffusion, | |
| sr_model_and_diffusion_defaults, | |
| ) | |
| from adapt import ScoreAdapter | |
| from my.registry import Registry | |
| PRETRAINED_REGISTRY = Registry("pretrained") | |
| device = torch.device("cuda") | |
| def load_ckpt(path, **kwargs): | |
| # with bf.BlobFile(path, "rb") as f: | |
| # data = f.read() | |
| return torch.load(path, **kwargs) | |
| def pick_out_cfgs(src, target_ks): | |
| return {k: src[k] for k in target_ks} | |
| def m_imgnet_64(): | |
| return dict( | |
| attention_resolutions="32,16,8", | |
| class_cond=True, | |
| diffusion_steps=1000, | |
| dropout=0.1, | |
| image_size=64, | |
| learn_sigma=True, | |
| noise_schedule="cosine", | |
| num_channels=192, | |
| num_head_channels=64, | |
| num_res_blocks=3, | |
| resblock_updown=True, | |
| use_new_attention_order=True, | |
| use_fp16=True, | |
| use_scale_shift_norm=True, | |
| classifier_depth=4, | |
| classifier_scale=1.0, | |
| model_path="models/64x64_diffusion.pt", | |
| classifier_path="models/64x64_classifier.pt", | |
| ) | |
| def m_imgnet_128(): | |
| return dict( | |
| attention_resolutions="32,16,8", | |
| class_cond=True, | |
| diffusion_steps=1000, | |
| image_size=128, | |
| learn_sigma=True, | |
| noise_schedule="linear", | |
| num_channels=256, | |
| num_heads=4, | |
| num_res_blocks=2, | |
| resblock_updown=True, | |
| use_fp16=True, | |
| use_scale_shift_norm=True, | |
| classifier_scale=0.5, | |
| model_path="models/128x128_diffusion.pt", | |
| classifier_path="models/128x128_classifier.pt", | |
| ) | |
| def m_imgnet_256(): | |
| return dict( | |
| attention_resolutions="32,16,8", | |
| class_cond=True, | |
| diffusion_steps=1000, | |
| image_size=256, | |
| learn_sigma=True, | |
| noise_schedule="linear", | |
| num_channels=256, | |
| num_head_channels=64, | |
| num_res_blocks=2, | |
| resblock_updown=True, | |
| use_fp16=True, | |
| use_scale_shift_norm=True, | |
| classifier_scale=1.0, | |
| model_path="models/256x256_diffusion.pt", | |
| classifier_path="models/256x256_classifier.pt" | |
| ) | |
| def m_imgnet_256_uncond(): | |
| return dict( | |
| attention_resolutions="32,16,8", | |
| class_cond=False, | |
| diffusion_steps=1000, | |
| image_size=256, | |
| learn_sigma=True, | |
| noise_schedule="linear", | |
| num_channels=256, | |
| num_head_channels=64, | |
| num_res_blocks=2, | |
| resblock_updown=True, | |
| use_fp16=True, | |
| use_scale_shift_norm=True, | |
| classifier_scale=10.0, | |
| model_path="models/256x256_diffusion_uncond.pt", | |
| classifier_path="models/256x256_classifier.pt", | |
| ) | |
| def m_imgnet_512(): | |
| return dict( | |
| attention_resolutions="32,16,8", | |
| class_cond=True, | |
| diffusion_steps=1000, | |
| image_size=512, | |
| learn_sigma=True, | |
| noise_schedule="linear", | |
| num_channels=256, | |
| num_head_channels=64, | |
| num_res_blocks=2, | |
| resblock_updown=True, | |
| use_fp16=False, | |
| use_scale_shift_norm=True, | |
| classifier_scale=4.0, | |
| model_path="models/512x512_diffusion.pt", | |
| classifier_path="models/512x512_classifier.pt" | |
| ) | |
| def m_imgnet_64_256(base_samples="64_samples.npz"): | |
| return dict( | |
| attention_resolutions="32,16,8", | |
| class_cond=True, | |
| diffusion_steps=1000, | |
| large_size=256, | |
| small_size=64, | |
| learn_sigma=True, | |
| noise_schedule="linear", | |
| num_channels=192, | |
| num_heads=4, | |
| num_res_blocks=2, | |
| resblock_updown=True, | |
| use_fp16=True, | |
| use_scale_shift_norm=True, | |
| model_path="models/64_256_upsampler.pt", | |
| base_samples=base_samples, | |
| ) | |
| def m_imgnet_128_512(base_samples="128_samples.npz",): | |
| return dict( | |
| attention_resolutions="32,16", | |
| class_cond=True, | |
| diffusion_steps=1000, | |
| large_size=512, | |
| small_size=128, | |
| learn_sigma=True, | |
| noise_schedule="linear", | |
| num_channels=192, | |
| num_head_channels=64, | |
| num_res_blocks=2, | |
| resblock_updown=True, | |
| use_fp16=True, | |
| use_scale_shift_norm=True, | |
| model_path="models/128_512_upsampler.pt", | |
| base_samples=base_samples, | |
| ) | |
| def m_lsun_256(category="bedroom"): | |
| return dict( | |
| attention_resolutions="32,16,8", | |
| class_cond=False, | |
| diffusion_steps=1000, | |
| dropout=0.1, | |
| image_size=256, | |
| learn_sigma=True, | |
| noise_schedule="linear", | |
| num_channels=256, | |
| num_head_channels=64, | |
| num_res_blocks=2, | |
| resblock_updown=True, | |
| use_fp16=True, | |
| use_scale_shift_norm=True, | |
| model_path=f"models/lsun_{category}.pt" | |
| ) | |
| def img_gen(specific_cfgs, num_samples=16, batch_size=16, load_only=False, ckpt_root=Path("")): | |
| cfgs = EasyDict( | |
| clip_denoised=True, | |
| num_samples=num_samples, | |
| batch_size=batch_size, | |
| use_ddim=False, | |
| model_path="", | |
| classifier_path="", | |
| classifier_scale=1.0, | |
| ) | |
| cfgs.update(model_and_diffusion_defaults()) | |
| cfgs.update(classifier_defaults()) | |
| cfgs.update(specific_cfgs) | |
| use_classifier_guidance = bool(cfgs.classifier_path) | |
| class_aware = cfgs.class_cond or use_classifier_guidance | |
| model, diffusion = create_model_and_diffusion( | |
| **pick_out_cfgs(cfgs, model_and_diffusion_defaults().keys()) | |
| ) | |
| model.load_state_dict( | |
| load_ckpt(str(ckpt_root / cfgs.model_path), map_location="cpu") | |
| ) | |
| model.to(device) | |
| if cfgs.use_fp16: | |
| model.convert_to_fp16() | |
| model.eval() | |
| def model_fn(x, t, y=None): | |
| return model(x, t, y if cfgs.class_cond else None) | |
| classifier = None | |
| cond_fn = None | |
| if use_classifier_guidance: | |
| classifier = create_classifier( | |
| **pick_out_cfgs(cfgs, classifier_defaults().keys()) | |
| ) | |
| classifier.load_state_dict( | |
| load_ckpt(str(ckpt_root / cfgs.classifier_path), map_location="cpu") | |
| ) | |
| classifier.to(device) | |
| if cfgs.classifier_use_fp16: | |
| classifier.convert_to_fp16() | |
| classifier.eval() | |
| def cond_fn(x, t, y=None): | |
| assert y is not None | |
| with torch.enable_grad(): | |
| x_in = x.detach().requires_grad_(True) | |
| logits = classifier(x_in, t) | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| selected = log_probs[range(len(logits)), y.view(-1)] | |
| return torch.autograd.grad(selected.sum(), x_in)[0] * cfgs.classifier_scale | |
| if load_only: | |
| return model, classifier | |
| all_images = [] | |
| all_labels = [] | |
| while len(all_images) * cfgs.batch_size < cfgs.num_samples: | |
| model_kwargs = {} | |
| if class_aware: | |
| classes = torch.randint( | |
| low=0, high=NUM_CLASSES, size=(cfgs.batch_size,), device=device | |
| ) | |
| model_kwargs["y"] = classes | |
| sample_fn = ( | |
| diffusion.p_sample_loop if not cfgs.use_ddim else diffusion.ddim_sample_loop | |
| ) | |
| sample = sample_fn( | |
| model_fn, | |
| (cfgs.batch_size, 3, cfgs.image_size, cfgs.image_size), | |
| clip_denoised=cfgs.clip_denoised, | |
| model_kwargs=model_kwargs, | |
| cond_fn=cond_fn, | |
| device=device, | |
| progress=True | |
| ) | |
| sample = ((sample + 1) * 127.5).clamp(0, 255).to(torch.uint8) | |
| sample = sample.permute(0, 2, 3, 1) | |
| sample = sample.contiguous() | |
| all_images.append(sample.cpu().numpy()) | |
| if class_aware: | |
| all_labels.append(classes.cpu().numpy()) | |
| arr = np.concatenate(all_images, axis=0) | |
| arr = arr[:cfgs.num_samples] | |
| if class_aware: | |
| all_labels = np.concatenate(all_labels, axis=0) | |
| all_labels = all_labels[:cfgs.num_samples] | |
| shape_str = "x".join([str(x) for x in arr.shape]) | |
| out_path = Path("./out") / f"samples_{shape_str}.npz" | |
| np.savez(out_path, arr, all_labels) | |
| def img_upsamp(specific_cfgs, num_samples=16, batch_size=16, load_only=False): | |
| """note that here the ckpt root is not configured properly; will break but easy fix""" | |
| cfgs = EasyDict( | |
| clip_denoised=True, | |
| num_samples=num_samples, | |
| batch_size=batch_size, | |
| use_ddim=False, | |
| base_samples="", | |
| model_path="", | |
| ) | |
| cfgs.update(sr_model_and_diffusion_defaults()) | |
| cfgs.update(specific_cfgs) | |
| model, diffusion = sr_create_model_and_diffusion( | |
| **pick_out_cfgs(cfgs, sr_model_and_diffusion_defaults().keys()) | |
| ) | |
| model.load_state_dict(load_ckpt(cfgs.model_path, map_location="cpu")) | |
| model.to(device) | |
| if cfgs.use_fp16: | |
| model.convert_to_fp16() | |
| model.eval() | |
| if load_only: | |
| return model | |
| data = load_low_res_samples( | |
| cfgs.base_samples, cfgs.batch_size, cfgs.class_cond | |
| ) | |
| all_images = [] | |
| while len(all_images) * cfgs.batch_size < cfgs.num_samples: | |
| model_kwargs = next(data) | |
| model_kwargs = {k: v.to(device) for k, v in model_kwargs.items()} | |
| samples = diffusion.p_sample_loop( | |
| model, | |
| (cfgs.batch_size, 3, cfgs.large_size, cfgs.large_size), | |
| clip_denoised=cfgs.clip_denoised, | |
| model_kwargs=model_kwargs, | |
| progress=True | |
| ) | |
| samples = ((samples + 1) * 127.5).clamp(0, 255).to(torch.uint8) | |
| samples = samples.permute(0, 2, 3, 1) | |
| samples = samples.contiguous() | |
| all_images.append(samples.cpu().numpy()) | |
| arr = np.concatenate(all_images, axis=0) | |
| arr = arr[: cfgs.num_samples] | |
| shape_str = "x".join([str(x) for x in arr.shape]) | |
| out_path = Path("./out") / f"samples_{shape_str}.npz" | |
| np.savez(out_path, arr) | |
| def load_low_res_samples(base_samples, batch_size, class_cond): | |
| obj = np.load(base_samples) | |
| image_arr = obj["arr_0"] | |
| if class_cond: | |
| label_arr = obj["arr_1"] | |
| buffer = [] | |
| label_buffer = [] | |
| while True: | |
| for i in range(len(image_arr)): | |
| buffer.append(image_arr[i]) | |
| if class_cond: | |
| label_buffer.append(label_arr[i]) | |
| if len(buffer) == batch_size: | |
| batch = torch.from_numpy(np.stack(buffer)).float() | |
| batch = batch / 127.5 - 1.0 | |
| batch = batch.permute(0, 3, 1, 2) | |
| res = {} | |
| res["low_res"] = batch | |
| if class_cond: | |
| res["y"] = torch.from_numpy(np.stack(label_buffer)) | |
| yield res | |
| buffer, label_buffer = [], [] | |
| def class_cond_info(imgnet_cat): | |
| def rand_cond_fn(batch_size): | |
| cats = torch.randint( | |
| low=0, high=NUM_CLASSES, size=(batch_size,), device=device | |
| ) | |
| return {"y": cats} | |
| def class_specific_cond(batch_size): | |
| cats = torch.tensor([imgnet_cat, ] * batch_size, device=device) | |
| return {"y": cats} | |
| if imgnet_cat == -1: | |
| return rand_cond_fn | |
| else: | |
| return class_specific_cond | |
| def _sqrt(x): | |
| if isinstance(x, float): | |
| return sqrt(x) | |
| else: | |
| assert isinstance(x, torch.Tensor) | |
| return torch.sqrt(x) | |
| class GuidedDDPM(ScoreAdapter): | |
| def __init__(self, model, lsun_cat, imgnet_cat): | |
| print(PRETRAINED_REGISTRY) | |
| cfgs = PRETRAINED_REGISTRY.get(model)( | |
| **({"category": lsun_cat} if model.startswith("m_lsun") else {}) | |
| ) | |
| self.unet, self.classifier = img_gen( | |
| cfgs, load_only=True, ckpt_root=self.checkpoint_root() / "guided_ddpm" | |
| ) | |
| H, W = cfgs['image_size'], cfgs['image_size'] | |
| self._data_shape = (3, H, W) | |
| if cfgs['class_cond'] or (self.classifier is not None): | |
| cond_func = class_cond_info(imgnet_cat) | |
| else: | |
| cond_func = lambda *args, **kwargs: {} | |
| self.cond_func = cond_func | |
| self._unet_is_cond = bool(cfgs['class_cond']) | |
| noise_schedule = cfgs['noise_schedule'] | |
| assert noise_schedule in ("linear", "cosine") | |
| self.M = 1000 | |
| if noise_schedule == "linear": | |
| self.us = self.linear_us(self.M) | |
| self._σ_min = 0.01 | |
| else: | |
| self.us = self.cosine_us(self.M) | |
| self._σ_min = 0.0064 | |
| self.noise_schedule = noise_schedule | |
| self._device = next(self.unet.parameters()).device | |
| def data_shape(self): | |
| return self._data_shape | |
| def σ_max(self): | |
| return self.us[0] | |
| def σ_min(self): | |
| return self.us[-1] | |
| def denoise(self, xs, σ, **model_kwargs): | |
| N = xs.shape[0] | |
| cond_t, σ = self.time_cond_vec(N, σ) | |
| output = self.unet( | |
| xs / _sqrt(1 + σ**2), cond_t, **model_kwargs | |
| ) | |
| # not using the var pred | |
| n_hat = torch.split(output, xs.shape[1], dim=1)[0] | |
| Ds = xs - σ * n_hat | |
| return Ds | |
| def cond_info(self, batch_size): | |
| return self.cond_func(batch_size) | |
| def unet_is_cond(self): | |
| return self._unet_is_cond | |
| def use_cls_guidance(self): | |
| return (self.classifier is not None) | |
| def classifier_grad(self, xs, σ, ys): | |
| N = xs.shape[0] | |
| cond_t, σ = self.time_cond_vec(N, σ) | |
| with torch.enable_grad(): | |
| x_in = xs.detach().requires_grad_(True) | |
| logits = self.classifier(x_in, cond_t) | |
| log_probs = F.log_softmax(logits, dim=-1) | |
| selected = log_probs[range(len(logits)), ys.view(-1)] | |
| grad = torch.autograd.grad(selected.sum(), x_in)[0] | |
| grad = grad * (1 / sqrt(1 + σ**2)) | |
| return grad | |
| def snap_t_to_nearest_tick(self, t): | |
| j = np.abs(t - self.us).argmin() | |
| return self.us[j], j | |
| def time_cond_vec(self, N, σ): | |
| if isinstance(σ, float): | |
| σ, j = self.snap_t_to_nearest_tick(σ) # σ might change due to snapping | |
| cond_t = (self.M - 1) - j | |
| cond_t = torch.tensor([cond_t] * N, device=self.device) | |
| return cond_t, σ | |
| else: | |
| assert isinstance(σ, torch.Tensor) | |
| σ = σ.reshape(-1).cpu().numpy() | |
| σs = [] | |
| js = [] | |
| for elem in σ: | |
| _σ, _j = self.snap_t_to_nearest_tick(elem) | |
| σs.append(_σ) | |
| js.append((self.M - 1) - _j) | |
| cond_t = torch.tensor(js, device=self.device) | |
| σs = torch.tensor(σs, device=self.device, dtype=torch.float32).reshape(-1, 1, 1, 1) | |
| return cond_t, σs | |
| def cosine_us(M=1000): | |
| assert M == 1000 | |
| def α_bar(j): | |
| return sin(pi / 2 * j / (M * (0.008 + 1))) ** 2 | |
| us = [0, ] | |
| for j in reversed(range(0, M)): # [M-1, 0], inclusive | |
| u_j = sqrt(((us[-1] ** 2) + 1) / (max(α_bar(j) / α_bar(j+1), 0.001)) - 1) | |
| us.append(u_j) | |
| us = np.array(us) | |
| us = us[1:] | |
| us = us[::-1] | |
| return us | |
| def linear_us(M=1000): | |
| assert M == 1000 | |
| β_start = 0.0001 | |
| β_end = 0.02 | |
| βs = np.linspace(β_start, β_end, M, dtype=np.float64) | |
| αs = np.cumprod(1 - βs) | |
| us = np.sqrt((1 - αs) / αs) | |
| us = us[::-1] | |
| return us | |