Spaces:
Runtime error
Runtime error
| from typing import List, Tuple | |
| from scipy import interpolate | |
| import numpy as np | |
| import torch | |
| import matplotlib.pyplot as plt | |
| from IPython.display import clear_output | |
| import abc | |
| class GuideModel(torch.nn.Module, abc.ABC): | |
| def __init__(self) -> None: | |
| super().__init__() | |
| def preprocess(self, x_img): | |
| pass | |
| def compute_loss(self, inp): | |
| pass | |
| class Guider(torch.nn.Module): | |
| def __init__(self, sampler, guide_model, scale=1.0, verbose=False): | |
| """Apply classifier guidance | |
| Specify a guidance scale as either a scalar | |
| Or a schedule as a list of tuples t = 0->1 and scale, e.g. | |
| [(0, 10), (0.5, 20), (1, 50)] | |
| """ | |
| super().__init__() | |
| self.sampler = sampler | |
| self.index = 0 | |
| self.show = verbose | |
| self.guide_model = guide_model | |
| self.history = [] | |
| if isinstance(scale, (Tuple, List)): | |
| times = np.array([x[0] for x in scale]) | |
| values = np.array([x[1] for x in scale]) | |
| self.scale_schedule = {"times": times, "values": values} | |
| else: | |
| self.scale_schedule = float(scale) | |
| self.ddim_timesteps = sampler.ddim_timesteps | |
| self.ddpm_num_timesteps = sampler.ddpm_num_timesteps | |
| def get_scales(self): | |
| if isinstance(self.scale_schedule, float): | |
| return len(self.ddim_timesteps)*[self.scale_schedule] | |
| interpolater = interpolate.interp1d(self.scale_schedule["times"], self.scale_schedule["values"]) | |
| fractional_steps = np.array(self.ddim_timesteps)/self.ddpm_num_timesteps | |
| return interpolater(fractional_steps) | |
| def modify_score(self, model, e_t, x, t, c): | |
| # TODO look up index by t | |
| scale = self.get_scales()[self.index] | |
| if (scale == 0): | |
| return e_t | |
| sqrt_1ma = self.sampler.ddim_sqrt_one_minus_alphas[self.index].to(x.device) | |
| with torch.enable_grad(): | |
| x_in = x.detach().requires_grad_(True) | |
| pred_x0 = model.predict_start_from_noise(x_in, t=t, noise=e_t) | |
| x_img = model.first_stage_model.decode((1/0.18215)*pred_x0) | |
| inp = self.guide_model.preprocess(x_img) | |
| loss = self.guide_model.compute_loss(inp) | |
| grads = torch.autograd.grad(loss.sum(), x_in)[0] | |
| correction = grads * scale | |
| if self.show: | |
| clear_output(wait=True) | |
| print(loss.item(), scale, correction.abs().max().item(), e_t.abs().max().item()) | |
| self.history.append([loss.item(), scale, correction.min().item(), correction.max().item()]) | |
| plt.imshow((inp[0].detach().permute(1,2,0).clamp(-1,1).cpu()+1)/2) | |
| plt.axis('off') | |
| plt.show() | |
| plt.imshow(correction[0][0].detach().cpu()) | |
| plt.axis('off') | |
| plt.show() | |
| e_t_mod = e_t - sqrt_1ma*correction | |
| if self.show: | |
| fig, axs = plt.subplots(1, 3) | |
| axs[0].imshow(e_t[0][0].detach().cpu(), vmin=-2, vmax=+2) | |
| axs[1].imshow(e_t_mod[0][0].detach().cpu(), vmin=-2, vmax=+2) | |
| axs[2].imshow(correction[0][0].detach().cpu(), vmin=-2, vmax=+2) | |
| plt.show() | |
| self.index += 1 | |
| return e_t_mod |