Spaces:
Runtime error
Runtime error
| """SAMPLING ONLY.""" | |
| import torch | |
| import numpy as np | |
| from tqdm import tqdm | |
| from functools import partial | |
| from .diffusion_utils import make_ddim_sampling_parameters, make_ddim_timesteps, noise_like | |
| def append_dims(x, target_dims): | |
| dims_to_append = target_dims - x.ndim | |
| if dims_to_append < 0: | |
| raise ValueError(f'input has {x.ndim} dims but target_dims is {target_dims}, which is less') | |
| return x[(...,) + (None,) * dims_to_append] | |
| def default_noise_sampler(x): | |
| return lambda sigma, sigma_next: torch.randn_like(x) | |
| def get_ancestral_step(sigma_from, sigma_to, eta=1.): | |
| if not eta: | |
| return sigma_to, 0. | |
| sigma_up = min(sigma_to, eta * (sigma_to ** 2 * (sigma_from ** 2 - sigma_to ** 2) / sigma_from ** 2) ** 0.5) | |
| sigma_down = (sigma_to ** 2 - sigma_up ** 2) ** 0.5 | |
| return sigma_down, sigma_up | |
| def to_d(x, sigma, denoised): | |
| return (x - denoised) / append_dims(sigma, x.ndim) | |
| class Sampler(object): | |
| def __init__(self, net, type="ddim", steps=50, output_dim=[512, 512], n_samples=4, scale=7.5): | |
| super().__init__() | |
| self.net = net | |
| self.type = type | |
| self.steps = steps | |
| self.output_dim = output_dim | |
| self.n_samples = n_samples | |
| self.scale = scale | |
| self.sigmas = ((1 - net.alphas_cumprod) / net.alphas_cumprod) ** 0.5 | |
| self.log_sigmas = self.sigmas.log() | |
| def t_to_sigma(self, t): | |
| t = t.float() | |
| low_idx, high_idx, w = t.floor().long(), t.ceil().long(), t.frac() | |
| log_sigma = (1 - w) * self.log_sigmas[low_idx] + w * self.log_sigmas[high_idx] | |
| return log_sigma.exp() | |
| def get_sigmas(self, n=None): | |
| def append_zero(x): | |
| return torch.cat([x, x.new_zeros([1])]) | |
| if n is None: | |
| return append_zero(self.sigmas.flip(0)) | |
| t_max = len(self.sigmas) - 1 | |
| t = torch.linspace(t_max, 0, n, device=self.sigmas.device) | |
| return append_zero(self.t_to_sigma(t)) | |
| def sample(self, x_info, c_info): | |
| h, w = self.output_dim | |
| shape = [self.n_samples, 4, h//8, w//8] | |
| device, dtype = self.net.get_device(), self.net.get_dtype() | |
| if ('xt' in x_info) and (x_info['xt'] is not None): | |
| xt = x_info['xt'].astype(dtype).to(device) | |
| x_info['x'] = xt | |
| elif ('x0' in x_info) and (x_info['x0'] is not None): | |
| x0 = x_info['x0'].type(dtype).to(device) | |
| ts = timesteps[x_info['x0_forward_timesteps']].repeat(self.n_samples) | |
| ts = torch.Tensor(ts).long().to(device) | |
| timesteps = timesteps[:x_info['x0_forward_timesteps']] | |
| x0_nz = self.model.q_sample(x0, ts) | |
| x_info['x'] = x0_nz | |
| else: | |
| x_info['x'] = torch.randn(shape, device=device, dtype=dtype) | |
| sigmas = self.get_sigmas(n=self.steps) | |
| if self.type == 'eular_a': | |
| rv = self.sample_euler_ancestral( | |
| x_info=x_info, | |
| c_info=c_info, | |
| sigmas = sigmas) | |
| return rv | |
| def sample_euler_ancestral( | |
| self, x_info, c_info, sigmas, eta=1., s_noise=1.,): | |
| x = x_info['x'] | |
| x = x * sigmas[0] | |
| noise_sampler = default_noise_sampler(x) | |
| s_in = x.new_ones([x.shape[0]]) | |
| for i in range(len(sigmas)-1): | |
| denoised = self.net.apply_model(x, sigmas[i] * s_in, ) | |
| sigma_down, sigma_up = get_ancestral_step(sigmas[i], sigmas[i + 1], eta=eta) | |
| d = to_d(x, sigmas[i], denoised) | |
| # Euler method | |
| dt = sigma_down - sigmas[i] | |
| x = x + d * dt | |
| if sigmas[i + 1] > 0: | |
| x = x + noise_sampler(sigmas[i], sigmas[i + 1]) * s_noise * sigma_up | |
| return x | |