Spaces:
Sleeping
Sleeping
| # Copyright (c) 2023 Amphion. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Based on: https://github.com/crowsonkb/k-diffusion | |
| """ | |
| import random | |
| import numpy as np | |
| import torch as th | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| # from piq import LPIPS | |
| from utils.ssim import SSIM | |
| from modules.diffusion.karras.random_utils import get_generator | |
| def mean_flat(tensor): | |
| """ | |
| Take the mean over all non-batch dimensions. | |
| """ | |
| return tensor.mean(dim=list(range(1, len(tensor.shape)))) | |
| def append_dims(x, target_dims): | |
| """Appends dimensions to the end of a tensor until it has target_dims dimensions.""" | |
| dims_to_append = target_dims - x.ndim | |
| if dims_to_append < 0: | |
| raise ValueError( | |
| f"input has {x.ndim} dims but target_dims is {target_dims}, which is less" | |
| ) | |
| return x[(...,) + (None,) * dims_to_append] | |
| def append_zero(x): | |
| return th.cat([x, x.new_zeros([1])]) | |
| def get_weightings(weight_schedule, snrs, sigma_data): | |
| if weight_schedule == "snr": | |
| weightings = snrs | |
| elif weight_schedule == "snr+1": | |
| weightings = snrs + 1 | |
| elif weight_schedule == "karras": | |
| weightings = snrs + 1.0 / sigma_data**2 | |
| elif weight_schedule == "truncated-snr": | |
| weightings = th.clamp(snrs, min=1.0) | |
| elif weight_schedule == "uniform": | |
| weightings = th.ones_like(snrs) | |
| else: | |
| raise NotImplementedError() | |
| return weightings | |
| class KarrasDenoiser: | |
| def __init__( | |
| self, | |
| sigma_data: float = 0.5, | |
| sigma_max=80.0, | |
| sigma_min=0.002, | |
| rho=7.0, | |
| weight_schedule="karras", | |
| distillation=False, | |
| loss_norm="l2", | |
| ): | |
| self.sigma_data = sigma_data | |
| self.sigma_max = sigma_max | |
| self.sigma_min = sigma_min | |
| self.weight_schedule = weight_schedule | |
| self.distillation = distillation | |
| self.loss_norm = loss_norm | |
| # if loss_norm == "lpips": | |
| # self.lpips_loss = LPIPS(replace_pooling=True, reduction="none") | |
| if loss_norm == "ssim": | |
| self.ssim_loss = SSIM() | |
| self.rho = rho | |
| self.num_timesteps = 40 | |
| def get_snr(self, sigmas): | |
| return sigmas**-2 | |
| def get_sigmas(self, sigmas): | |
| return sigmas | |
| def get_scalings(self, sigma): | |
| c_skip = self.sigma_data**2 / (sigma**2 + self.sigma_data**2) | |
| c_out = sigma * self.sigma_data / (sigma**2 + self.sigma_data**2) ** 0.5 | |
| c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 | |
| return c_skip, c_out, c_in | |
| def get_scalings_for_boundary_condition(self, sigma): | |
| c_skip = self.sigma_data**2 / ( | |
| (sigma - self.sigma_min) ** 2 + self.sigma_data**2 | |
| ) | |
| c_out = ( | |
| (sigma - self.sigma_min) | |
| * self.sigma_data | |
| / (sigma**2 + self.sigma_data**2) ** 0.5 | |
| ) | |
| c_in = 1 / (sigma**2 + self.sigma_data**2) ** 0.5 | |
| return c_skip, c_out, c_in | |
| def training_losses(self, model, x_start, sigmas, condition=None, noise=None): | |
| if noise is None: | |
| noise = th.randn_like(x_start) | |
| terms = {} | |
| dims = x_start.ndim | |
| x_t = x_start + noise * append_dims(sigmas, dims) | |
| model_output, denoised = self.denoise(model, x_t, sigmas, condition) | |
| snrs = self.get_snr(sigmas) | |
| weights = append_dims( | |
| get_weightings(self.weight_schedule, snrs, self.sigma_data), dims | |
| ) | |
| # terms["xs_mse"] = mean_flat((denoised - x_start) ** 2) | |
| terms["mse"] = mean_flat(weights * (denoised - x_start) ** 2) | |
| # terms["mae"] = mean_flat(weights * th.abs(denoised - x_start)) | |
| # terms["mse"] = nn.MSELoss(reduction="none")(denoised, x_start) | |
| # if "vb" in terms: | |
| # terms["loss"] = terms["mse"] + terms["vb"] | |
| # else: | |
| terms["loss"] = terms["mse"] | |
| return terms | |
| def consistency_losses( | |
| self, | |
| model, | |
| x_start, | |
| num_scales, | |
| # model_kwargs=None, | |
| condition=None, | |
| target_model=None, | |
| teacher_model=None, | |
| teacher_diffusion=None, | |
| noise=None, | |
| ): | |
| if noise is None: | |
| noise = th.randn_like(x_start) | |
| dims = x_start.ndim | |
| def denoise_fn(x, t): | |
| return self.denoise(model, x, t, condition)[1] | |
| if target_model: | |
| def target_denoise_fn(x, t): | |
| return self.denoise(target_model, x, t, condition)[1] | |
| else: | |
| raise NotImplementedError("Must have a target model") | |
| if teacher_model: | |
| def teacher_denoise_fn(x, t): | |
| return teacher_diffusion.denoise(teacher_model, x, t, condition)[1] | |
| def heun_solver(samples, t, next_t, x0): | |
| x = samples | |
| if teacher_model is None: | |
| denoiser = x0 | |
| else: | |
| denoiser = teacher_denoise_fn(x, t) | |
| d = (x - denoiser) / append_dims(t, dims) | |
| samples = x + d * append_dims(next_t - t, dims) | |
| if teacher_model is None: | |
| denoiser = x0 | |
| else: | |
| denoiser = teacher_denoise_fn(samples, next_t) | |
| next_d = (samples - denoiser) / append_dims(next_t, dims) | |
| samples = x + (d + next_d) * append_dims((next_t - t) / 2, dims) | |
| return samples | |
| def euler_solver(samples, t, next_t, x0): | |
| x = samples | |
| if teacher_model is None: | |
| denoiser = x0 | |
| else: | |
| denoiser = teacher_denoise_fn(x, t) | |
| d = (x - denoiser) / append_dims(t, dims) | |
| samples = x + d * append_dims(next_t - t, dims) | |
| return samples | |
| indices = th.randint( | |
| 0, num_scales - 1, (x_start.shape[0],), device=x_start.device | |
| ) | |
| t = self.sigma_max ** (1 / self.rho) + indices / (num_scales - 1) * ( | |
| self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho) | |
| ) | |
| t = t**self.rho | |
| t2 = self.sigma_max ** (1 / self.rho) + (indices + 1) / (num_scales - 1) * ( | |
| self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho) | |
| ) | |
| t2 = t2**self.rho | |
| x_t = x_start + noise * append_dims(t, dims) | |
| dropout_state = th.get_rng_state() | |
| distiller = denoise_fn(x_t, t) | |
| if teacher_model is None: | |
| x_t2 = euler_solver(x_t, t, t2, x_start).detach() | |
| else: | |
| x_t2 = heun_solver(x_t, t, t2, x_start).detach() | |
| th.set_rng_state(dropout_state) | |
| distiller_target = target_denoise_fn(x_t2, t2) | |
| distiller_target = distiller_target.detach() | |
| snrs = self.get_snr(t) | |
| weights = get_weightings(self.weight_schedule, snrs, self.sigma_data) | |
| if self.loss_norm == "l1": | |
| diffs = th.abs(distiller - distiller_target) | |
| loss = mean_flat(diffs) * weights | |
| elif self.loss_norm == "l2": | |
| # diffs = (distiller - distiller_target) ** 2 | |
| loss = F.mse_loss(distiller, distiller_target) | |
| # loss = mean_flat(diffs) * weights | |
| elif self.loss_norm == "ssim": | |
| loss = self.ssim_loss(distiller, distiller_target) * weights | |
| # elif self.loss_norm == "l2-32": | |
| # distiller = F.interpolate(distiller, size=32, mode="bilinear") | |
| # distiller_target = F.interpolate( | |
| # distiller_target, | |
| # size=32, | |
| # mode="bilinear", | |
| # ) | |
| # diffs = (distiller - distiller_target) ** 2 | |
| # loss = mean_flat(diffs) * weights | |
| # elif self.loss_norm == "lpips": | |
| # if x_start.shape[-1] < 256: | |
| # distiller = F.interpolate(distiller, size=224, mode="bilinear") | |
| # distiller_target = F.interpolate( | |
| # distiller_target, size=224, mode="bilinear" | |
| # ) | |
| # loss = ( | |
| # self.lpips_loss( | |
| # (distiller + 1) / 2.0, | |
| # (distiller_target + 1) / 2.0, | |
| # ) | |
| # * weights | |
| # ) | |
| else: | |
| raise ValueError(f"Unknown loss norm {self.loss_norm}") | |
| terms = {} | |
| terms["loss"] = loss | |
| return terms | |
| # def progdist_losses( | |
| # self, | |
| # model, | |
| # x_start, | |
| # num_scales, | |
| # model_kwargs=None, | |
| # teacher_model=None, | |
| # teacher_diffusion=None, | |
| # noise=None, | |
| # ): | |
| # if model_kwargs is None: | |
| # model_kwargs = {} | |
| # if noise is None: | |
| # noise = th.randn_like(x_start) | |
| # dims = x_start.ndim | |
| # def denoise_fn(x, t): | |
| # return self.denoise(model, x, t, **model_kwargs)[1] | |
| # @th.no_grad() | |
| # def teacher_denoise_fn(x, t): | |
| # return teacher_diffusion.denoise(teacher_model, x, t, **model_kwargs)[1] | |
| # @th.no_grad() | |
| # def euler_solver(samples, t, next_t): | |
| # x = samples | |
| # denoiser = teacher_denoise_fn(x, t) | |
| # d = (x - denoiser) / append_dims(t, dims) | |
| # samples = x + d * append_dims(next_t - t, dims) | |
| # return samples | |
| # @th.no_grad() | |
| # def euler_to_denoiser(x_t, t, x_next_t, next_t): | |
| # denoiser = x_t - append_dims(t, dims) * (x_next_t - x_t) / append_dims( | |
| # next_t - t, dims | |
| # ) | |
| # return denoiser | |
| # indices = th.randint(0, num_scales, (x_start.shape[0],), device=x_start.device) | |
| # t = self.sigma_max ** (1 / self.rho) + indices / num_scales * ( | |
| # self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho) | |
| # ) | |
| # t = t**self.rho | |
| # t2 = self.sigma_max ** (1 / self.rho) + (indices + 0.5) / num_scales * ( | |
| # self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho) | |
| # ) | |
| # t2 = t2**self.rho | |
| # t3 = self.sigma_max ** (1 / self.rho) + (indices + 1) / num_scales * ( | |
| # self.sigma_min ** (1 / self.rho) - self.sigma_max ** (1 / self.rho) | |
| # ) | |
| # t3 = t3**self.rho | |
| # x_t = x_start + noise * append_dims(t, dims) | |
| # denoised_x = denoise_fn(x_t, t) | |
| # x_t2 = euler_solver(x_t, t, t2).detach() | |
| # x_t3 = euler_solver(x_t2, t2, t3).detach() | |
| # target_x = euler_to_denoiser(x_t, t, x_t3, t3).detach() | |
| # snrs = self.get_snr(t) | |
| # weights = get_weightings(self.weight_schedule, snrs, self.sigma_data) | |
| # if self.loss_norm == "l1": | |
| # diffs = th.abs(denoised_x - target_x) | |
| # loss = mean_flat(diffs) * weights | |
| # elif self.loss_norm == "l2": | |
| # diffs = (denoised_x - target_x) ** 2 | |
| # loss = mean_flat(diffs) * weights | |
| # elif self.loss_norm == "lpips": | |
| # if x_start.shape[-1] < 256: | |
| # denoised_x = F.interpolate(denoised_x, size=224, mode="bilinear") | |
| # target_x = F.interpolate(target_x, size=224, mode="bilinear") | |
| # loss = ( | |
| # self.lpips_loss( | |
| # (denoised_x + 1) / 2.0, | |
| # (target_x + 1) / 2.0, | |
| # ) | |
| # * weights | |
| # ) | |
| # else: | |
| # raise ValueError(f"Unknown loss norm {self.loss_norm}") | |
| # terms = {} | |
| # terms["loss"] = loss | |
| # return terms | |
| def denoise(self, model, x_t, sigmas, condition): | |
| if not self.distillation: | |
| c_skip, c_out, c_in = [ | |
| append_dims(x, x_t.ndim) for x in self.get_scalings(sigmas) | |
| ] | |
| else: | |
| c_skip, c_out, c_in = [ | |
| append_dims(x, x_t.ndim) | |
| for x in self.get_scalings_for_boundary_condition(sigmas) | |
| ] | |
| rescaled_t = 1000 * 0.25 * th.log(sigmas + 1e-44) | |
| # rescaled_t = rescaled_t[:, None] | |
| model_output = model(c_in * x_t, rescaled_t, condition) | |
| denoised = c_out * model_output + c_skip * x_t | |
| return model_output, denoised | |
| def karras_sample( | |
| diffusion, | |
| model, | |
| shape, | |
| steps, | |
| clip_denoised=True, | |
| progress=True, | |
| callback=None, | |
| # model_kwargs=None, | |
| condition=None, | |
| device=None, | |
| sigma_min=0.002, | |
| sigma_max=80, # higher for highres? | |
| rho=7.0, | |
| sampler="heun", | |
| s_churn=0.0, | |
| s_tmin=0.0, | |
| s_tmax=float("inf"), | |
| s_noise=1.0, | |
| generator=None, | |
| ts=None, | |
| ): | |
| if generator is None: | |
| generator = get_generator("dummy") | |
| if sampler == "progdist": | |
| sigmas = get_sigmas_karras(steps + 1, sigma_min, sigma_max, rho, device=device) | |
| else: | |
| sigmas = get_sigmas_karras(steps, sigma_min, sigma_max, rho, device=device) | |
| th.manual_seed(42) | |
| x_T = generator.randn(*shape, device=device) * sigma_max | |
| sigmas = sigmas.unsqueeze(-1) | |
| sample_fn = { | |
| "heun": sample_heun, | |
| "dpm": sample_dpm, | |
| "ancestral": sample_euler_ancestral, | |
| "onestep": sample_onestep, | |
| "progdist": sample_progdist, | |
| "euler": sample_euler, | |
| "multistep": stochastic_iterative_sampler, | |
| }[sampler] | |
| if sampler in ["heun", "dpm"]: | |
| sampler_args = dict( | |
| s_churn=s_churn, s_tmin=s_tmin, s_tmax=s_tmax, s_noise=s_noise | |
| ) | |
| elif sampler == "multistep": | |
| sampler_args = dict( | |
| ts=ts, t_min=sigma_min, t_max=sigma_max, rho=diffusion.rho, steps=steps | |
| ) | |
| else: | |
| sampler_args = {} | |
| def denoiser(x_t, sigma): | |
| _, denoised = diffusion.denoise(model, x_t, sigma, condition) | |
| if clip_denoised: | |
| denoised = denoised.clamp(-1, 1) | |
| return denoised | |
| x_0 = sample_fn( | |
| denoiser, | |
| x_T, | |
| sigmas, | |
| generator, | |
| progress=progress, | |
| callback=callback, | |
| **sampler_args, | |
| ) | |
| return x_0.clamp(-1, 1) | |
| def get_sigmas_karras(n, sigma_min, sigma_max, rho=7.0, device="cpu"): | |
| """Constructs the noise schedule of Karras et al. (2022).""" | |
| ramp = th.linspace(0, 1, n) | |
| min_inv_rho = sigma_min ** (1 / rho) | |
| max_inv_rho = sigma_max ** (1 / rho) | |
| sigmas = (max_inv_rho + ramp * (min_inv_rho - max_inv_rho)) ** rho | |
| return append_zero(sigmas).to(device) | |
| def to_d(x, sigma, denoised): | |
| """Converts a denoiser output to a Karras ODE derivative.""" | |
| return (x - denoised) / append_dims(sigma, x.ndim) | |
| def get_ancestral_step(sigma_from, sigma_to): | |
| """Calculates the noise level (sigma_down) to step down to and the amount | |
| of noise to add (sigma_up) when doing an ancestral sampling step.""" | |
| sigma_up = ( | |
| sigma_to**2 * (sigma_from**2 - sigma_to**2) / sigma_from**2 | |
| ) ** 0.5 | |
| sigma_down = (sigma_to**2 - sigma_up**2) ** 0.5 | |
| return sigma_down, sigma_up | |
| def sample_euler_ancestral(model, x, sigmas, generator, progress=False, callback=None): | |
| """Ancestral sampling with Euler method steps.""" | |
| s_in = x.new_ones([x.shape[0]]) | |
| indices = range(len(sigmas) - 1) | |
| if progress: | |
| from tqdm.auto import tqdm | |
| indices = tqdm(indices) | |
| for i in indices: | |
| denoised = model(x, sigmas[i] * s_in) | |
| sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1]) | |
| if callback is not None: | |
| callback( | |
| { | |
| "x": x, | |
| "i": i, | |
| "sigma": sigmas[i], | |
| "sigma_hat": sigmas[i], | |
| "denoised": denoised, | |
| } | |
| ) | |
| d = to_d(x, sigmas[i], denoised) | |
| # Euler method | |
| dt = sigma_down - sigmas[i] | |
| x = x + d * dt | |
| x = x + generator.randn_like(x) * sigma_up | |
| return x | |
| def sample_midpoint_ancestral(model, x, ts, generator, progress=False, callback=None): | |
| """Ancestral sampling with midpoint method steps.""" | |
| s_in = x.new_ones([x.shape[0]]) | |
| step_size = 1 / len(ts) | |
| if progress: | |
| from tqdm.auto import tqdm | |
| ts = tqdm(ts) | |
| for tn in ts: | |
| dn = model(x, tn * s_in) | |
| dn_2 = model(x + (step_size / 2) * dn, (tn + step_size / 2) * s_in) | |
| x = x + step_size * dn_2 | |
| if callback is not None: | |
| callback({"x": x, "tn": tn, "dn": dn, "dn_2": dn_2}) | |
| return x | |
| def sample_heun( | |
| denoiser, | |
| x, | |
| sigmas, | |
| generator, | |
| progress=False, | |
| callback=None, | |
| s_churn=0.0, | |
| s_tmin=0.0, | |
| s_tmax=float("inf"), | |
| s_noise=1.0, | |
| ): | |
| """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" | |
| s_in = x.new_ones([x.shape[0]]) | |
| indices = range(len(sigmas) - 1) | |
| if progress: | |
| from tqdm.auto import tqdm | |
| indices = tqdm(indices) | |
| for i in indices: | |
| gamma = ( | |
| min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) | |
| if s_tmin <= sigmas[i] <= s_tmax | |
| else 0.0 | |
| ) | |
| eps = generator.randn_like(x) * s_noise | |
| sigma_hat = sigmas[i] * (gamma + 1) | |
| if gamma > 0: | |
| x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 | |
| denoised = denoiser(x, sigma_hat * s_in) | |
| d = to_d(x, sigma_hat, denoised) | |
| if callback is not None: | |
| callback( | |
| { | |
| "x": x, | |
| "i": i, | |
| "sigma": sigmas[i], | |
| "sigma_hat": sigma_hat, | |
| "denoised": denoised, | |
| } | |
| ) | |
| dt = sigmas[i + 1] - sigma_hat | |
| if sigmas[i + 1] == 0: | |
| # Euler method | |
| x = x + d * dt | |
| else: | |
| # Heun's method | |
| x_2 = x + d * dt | |
| denoised_2 = denoiser(x_2, sigmas[i + 1] * s_in) | |
| d_2 = to_d(x_2, sigmas[i + 1], denoised_2) | |
| d_prime = (d + d_2) / 2 | |
| x = x + d_prime * dt | |
| return x | |
| def sample_euler( | |
| denoiser, | |
| x, | |
| sigmas, | |
| generator, | |
| progress=False, | |
| callback=None, | |
| ): | |
| """Implements Algorithm 2 (Heun steps) from Karras et al. (2022).""" | |
| s_in = x.new_ones([x.shape[0]]) | |
| indices = range(len(sigmas) - 1) | |
| if progress: | |
| from tqdm.auto import tqdm | |
| indices = tqdm(indices) | |
| for i in indices: | |
| sigma = sigmas[i] | |
| denoised = denoiser(x, sigma * s_in) | |
| d = to_d(x, sigma, denoised) | |
| if callback is not None: | |
| callback( | |
| { | |
| "x": x, | |
| "i": i, | |
| "sigma": sigmas[i], | |
| "denoised": denoised, | |
| } | |
| ) | |
| dt = sigmas[i + 1] - sigma | |
| x = x + d * dt | |
| return x | |
| def sample_dpm( | |
| denoiser, | |
| x, | |
| sigmas, | |
| generator, | |
| progress=False, | |
| callback=None, | |
| s_churn=0.0, | |
| s_tmin=0.0, | |
| s_tmax=float("inf"), | |
| s_noise=1.0, | |
| ): | |
| """A sampler inspired by DPM-Solver-2 and Algorithm 2 from Karras et al. (2022).""" | |
| s_in = x.new_ones([x.shape[0]]) | |
| indices = range(len(sigmas) - 1) | |
| if progress: | |
| from tqdm.auto import tqdm | |
| indices = tqdm(indices) | |
| for i in indices: | |
| gamma = ( | |
| min(s_churn / (len(sigmas) - 1), 2**0.5 - 1) | |
| if s_tmin <= sigmas[i] <= s_tmax | |
| else 0.0 | |
| ) | |
| eps = generator.randn_like(x) * s_noise | |
| sigma_hat = sigmas[i] * (gamma + 1) | |
| if gamma > 0: | |
| x = x + eps * (sigma_hat**2 - sigmas[i] ** 2) ** 0.5 | |
| denoised = denoiser(x, sigma_hat * s_in) | |
| d = to_d(x, sigma_hat, denoised) | |
| if callback is not None: | |
| callback( | |
| { | |
| "x": x, | |
| "i": i, | |
| "sigma": sigmas[i], | |
| "sigma_hat": sigma_hat, | |
| "denoised": denoised, | |
| } | |
| ) | |
| # Midpoint method, where the midpoint is chosen according to a rho=3 Karras schedule | |
| sigma_mid = ((sigma_hat ** (1 / 3) + sigmas[i + 1] ** (1 / 3)) / 2) ** 3 | |
| dt_1 = sigma_mid - sigma_hat | |
| dt_2 = sigmas[i + 1] - sigma_hat | |
| x_2 = x + d * dt_1 | |
| denoised_2 = denoiser(x_2, sigma_mid * s_in) | |
| d_2 = to_d(x_2, sigma_mid, denoised_2) | |
| x = x + d_2 * dt_2 | |
| return x | |
| def sample_onestep( | |
| distiller, | |
| x, | |
| sigmas, | |
| generator=None, | |
| progress=False, | |
| callback=None, | |
| ): | |
| """Single-step generation from a distilled model.""" | |
| s_in = x.new_ones([x.shape[0]]) | |
| return distiller(x, sigmas[0] * s_in) | |
| def stochastic_iterative_sampler( | |
| distiller, | |
| x, | |
| sigmas, | |
| generator, | |
| ts, | |
| progress=False, | |
| callback=None, | |
| t_min=0.002, | |
| t_max=80.0, | |
| rho=7.0, | |
| steps=40, | |
| ): | |
| t_max_rho = t_max ** (1 / rho) | |
| t_min_rho = t_min ** (1 / rho) | |
| s_in = x.new_ones([x.shape[0]]) | |
| for i in range(len(ts) - 1): | |
| t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
| x0 = distiller(x, t * s_in) | |
| next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
| next_t = np.clip(next_t, t_min, t_max) | |
| x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2) | |
| return x | |
| def sample_progdist( | |
| denoiser, | |
| x, | |
| sigmas, | |
| generator=None, | |
| progress=False, | |
| callback=None, | |
| ): | |
| s_in = x.new_ones([x.shape[0]]) | |
| sigmas = sigmas[:-1] # skip the zero sigma | |
| indices = range(len(sigmas) - 1) | |
| if progress: | |
| from tqdm.auto import tqdm | |
| indices = tqdm(indices) | |
| for i in indices: | |
| sigma = sigmas[i] | |
| denoised = denoiser(x, sigma * s_in) | |
| d = to_d(x, sigma, denoised) | |
| if callback is not None: | |
| callback( | |
| { | |
| "x": x, | |
| "i": i, | |
| "sigma": sigma, | |
| "denoised": denoised, | |
| } | |
| ) | |
| dt = sigmas[i + 1] - sigma | |
| x = x + d * dt | |
| return x | |
| # @th.no_grad() | |
| # def iterative_colorization( | |
| # distiller, | |
| # images, | |
| # x, | |
| # ts, | |
| # t_min=0.002, | |
| # t_max=80.0, | |
| # rho=7.0, | |
| # steps=40, | |
| # generator=None, | |
| # ): | |
| # def obtain_orthogonal_matrix(): | |
| # vector = np.asarray([0.2989, 0.5870, 0.1140]) | |
| # vector = vector / np.linalg.norm(vector) | |
| # matrix = np.eye(3) | |
| # matrix[:, 0] = vector | |
| # matrix = np.linalg.qr(matrix)[0] | |
| # if np.sum(matrix[:, 0]) < 0: | |
| # matrix = -matrix | |
| # return matrix | |
| # Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32) | |
| # mask = th.zeros(*x.shape[1:], device=dist_util.dev()) | |
| # mask[0, ...] = 1.0 | |
| # def replacement(x0, x1): | |
| # x0 = th.einsum("bchw,cd->bdhw", x0, Q) | |
| # x1 = th.einsum("bchw,cd->bdhw", x1, Q) | |
| # x_mix = x0 * mask + x1 * (1.0 - mask) | |
| # x_mix = th.einsum("bdhw,cd->bchw", x_mix, Q) | |
| # return x_mix | |
| # t_max_rho = t_max ** (1 / rho) | |
| # t_min_rho = t_min ** (1 / rho) | |
| # s_in = x.new_ones([x.shape[0]]) | |
| # images = replacement(images, th.zeros_like(images)) | |
| # for i in range(len(ts) - 1): | |
| # t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
| # x0 = distiller(x, t * s_in) | |
| # x0 = th.clamp(x0, -1.0, 1.0) | |
| # x0 = replacement(images, x0) | |
| # next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
| # next_t = np.clip(next_t, t_min, t_max) | |
| # x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2) | |
| # return x, images | |
| # @th.no_grad() | |
| # def iterative_inpainting( | |
| # distiller, | |
| # images, | |
| # x, | |
| # ts, | |
| # t_min=0.002, | |
| # t_max=80.0, | |
| # rho=7.0, | |
| # steps=40, | |
| # generator=None, | |
| # ): | |
| # from PIL import Image, ImageDraw, ImageFont | |
| # image_size = x.shape[-1] | |
| # # create a blank image with a white background | |
| # img = Image.new("RGB", (image_size, image_size), color="white") | |
| # # get a drawing context for the image | |
| # draw = ImageDraw.Draw(img) | |
| # # load a font | |
| # font = ImageFont.truetype("arial.ttf", 250) | |
| # # draw the letter "C" in black | |
| # draw.text((50, 0), "S", font=font, fill=(0, 0, 0)) | |
| # # convert the image to a numpy array | |
| # img_np = np.array(img) | |
| # img_np = img_np.transpose(2, 0, 1) | |
| # img_th = th.from_numpy(img_np).to(dist_util.dev()) | |
| # mask = th.zeros(*x.shape, device=dist_util.dev()) | |
| # mask = mask.reshape(-1, 7, 3, image_size, image_size) | |
| # mask[::2, :, img_th > 0.5] = 1.0 | |
| # mask[1::2, :, img_th < 0.5] = 1.0 | |
| # mask = mask.reshape(-1, 3, image_size, image_size) | |
| # def replacement(x0, x1): | |
| # x_mix = x0 * mask + x1 * (1 - mask) | |
| # return x_mix | |
| # t_max_rho = t_max ** (1 / rho) | |
| # t_min_rho = t_min ** (1 / rho) | |
| # s_in = x.new_ones([x.shape[0]]) | |
| # images = replacement(images, -th.ones_like(images)) | |
| # for i in range(len(ts) - 1): | |
| # t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
| # x0 = distiller(x, t * s_in) | |
| # x0 = th.clamp(x0, -1.0, 1.0) | |
| # x0 = replacement(images, x0) | |
| # next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
| # next_t = np.clip(next_t, t_min, t_max) | |
| # x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2) | |
| # return x, images | |
| # @th.no_grad() | |
| # def iterative_superres( | |
| # distiller, | |
| # images, | |
| # x, | |
| # ts, | |
| # t_min=0.002, | |
| # t_max=80.0, | |
| # rho=7.0, | |
| # steps=40, | |
| # generator=None, | |
| # ): | |
| # patch_size = 8 | |
| # def obtain_orthogonal_matrix(): | |
| # vector = np.asarray([1] * patch_size**2) | |
| # vector = vector / np.linalg.norm(vector) | |
| # matrix = np.eye(patch_size**2) | |
| # matrix[:, 0] = vector | |
| # matrix = np.linalg.qr(matrix)[0] | |
| # if np.sum(matrix[:, 0]) < 0: | |
| # matrix = -matrix | |
| # return matrix | |
| # Q = th.from_numpy(obtain_orthogonal_matrix()).to(dist_util.dev()).to(th.float32) | |
| # image_size = x.shape[-1] | |
| # def replacement(x0, x1): | |
| # x0_flatten = ( | |
| # x0.reshape(-1, 3, image_size, image_size) | |
| # .reshape( | |
| # -1, | |
| # 3, | |
| # image_size // patch_size, | |
| # patch_size, | |
| # image_size // patch_size, | |
| # patch_size, | |
| # ) | |
| # .permute(0, 1, 2, 4, 3, 5) | |
| # .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2) | |
| # ) | |
| # x1_flatten = ( | |
| # x1.reshape(-1, 3, image_size, image_size) | |
| # .reshape( | |
| # -1, | |
| # 3, | |
| # image_size // patch_size, | |
| # patch_size, | |
| # image_size // patch_size, | |
| # patch_size, | |
| # ) | |
| # .permute(0, 1, 2, 4, 3, 5) | |
| # .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2) | |
| # ) | |
| # x0 = th.einsum("bcnd,de->bcne", x0_flatten, Q) | |
| # x1 = th.einsum("bcnd,de->bcne", x1_flatten, Q) | |
| # x_mix = x0.new_zeros(x0.shape) | |
| # x_mix[..., 0] = x0[..., 0] | |
| # x_mix[..., 1:] = x1[..., 1:] | |
| # x_mix = th.einsum("bcne,de->bcnd", x_mix, Q) | |
| # x_mix = ( | |
| # x_mix.reshape( | |
| # -1, | |
| # 3, | |
| # image_size // patch_size, | |
| # image_size // patch_size, | |
| # patch_size, | |
| # patch_size, | |
| # ) | |
| # .permute(0, 1, 2, 4, 3, 5) | |
| # .reshape(-1, 3, image_size, image_size) | |
| # ) | |
| # return x_mix | |
| # def average_image_patches(x): | |
| # x_flatten = ( | |
| # x.reshape(-1, 3, image_size, image_size) | |
| # .reshape( | |
| # -1, | |
| # 3, | |
| # image_size // patch_size, | |
| # patch_size, | |
| # image_size // patch_size, | |
| # patch_size, | |
| # ) | |
| # .permute(0, 1, 2, 4, 3, 5) | |
| # .reshape(-1, 3, image_size**2 // patch_size**2, patch_size**2) | |
| # ) | |
| # x_flatten[..., :] = x_flatten.mean(dim=-1, keepdim=True) | |
| # return ( | |
| # x_flatten.reshape( | |
| # -1, | |
| # 3, | |
| # image_size // patch_size, | |
| # image_size // patch_size, | |
| # patch_size, | |
| # patch_size, | |
| # ) | |
| # .permute(0, 1, 2, 4, 3, 5) | |
| # .reshape(-1, 3, image_size, image_size) | |
| # ) | |
| # t_max_rho = t_max ** (1 / rho) | |
| # t_min_rho = t_min ** (1 / rho) | |
| # s_in = x.new_ones([x.shape[0]]) | |
| # images = average_image_patches(images) | |
| # for i in range(len(ts) - 1): | |
| # t = (t_max_rho + ts[i] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
| # x0 = distiller(x, t * s_in) | |
| # x0 = th.clamp(x0, -1.0, 1.0) | |
| # x0 = replacement(images, x0) | |
| # next_t = (t_max_rho + ts[i + 1] / (steps - 1) * (t_min_rho - t_max_rho)) ** rho | |
| # next_t = np.clip(next_t, t_min, t_max) | |
| # x = x0 + generator.randn_like(x) * np.sqrt(next_t**2 - t_min**2) | |
| # return x, images | |