Spaces:
Running
on
Zero
Running
on
Zero
| import os | |
| import spaces | |
| import time | |
| from glob import glob | |
| from typing import Callable, Optional, Tuple, Union, Dict | |
| import random | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| import torch | |
| import torchvision.transforms as transforms | |
| from PIL import Image | |
| from torch.utils.data import DataLoader | |
| from torchvision.datasets import VisionDataset | |
| from tqdm import tqdm | |
| from util.img_utils import clear_color | |
| from latent_models import PipelineWrapper | |
| def set_seed(seed: int) -> None: | |
| torch.manual_seed(seed) | |
| np.random.seed(seed) | |
| random.seed(seed) | |
| torch.cuda.manual_seed_all(seed) | |
| # torch.backends.cudnn.deterministic = True | |
| # torch.backends.cudnn.benchmark = False | |
| class MinusOneToOne(torch.nn.Module): | |
| def forward(self, tensor: torch.Tensor) -> torch.Tensor: | |
| return tensor * 2 - 1 | |
| class ResizePIL(torch.nn.Module): | |
| def __init__(self, image_size: Optional[Union[int, Tuple[int, int]]] = None): | |
| super().__init__() | |
| if isinstance(image_size, int): | |
| image_size = (image_size, image_size) | |
| self.image_size = image_size | |
| def forward(self, pil_image: Image.Image) -> Image.Image: | |
| if self.image_size is not None: | |
| pil_image = pil_image.resize(self.image_size) | |
| return pil_image | |
| def get_loader(datadir: str, batch_size: int = 1, | |
| crop_to: Optional[Union[int, Tuple[int, int]]] = None, | |
| include_path: bool = False) -> DataLoader: | |
| transform = transforms.Compose([ | |
| ResizePIL(crop_to), | |
| transforms.ToTensor(), | |
| MinusOneToOne(), | |
| ]) | |
| loader = DataLoader(FoldersDataset(datadir, transform, include_path=include_path), | |
| batch_size=batch_size, | |
| shuffle=True, num_workers=0, drop_last=False) | |
| return loader | |
| class FoldersDataset(VisionDataset): | |
| def __init__(self, root: str, transforms: Optional[Callable] = None, | |
| include_path: bool = False) -> None: | |
| super().__init__(root, transforms) | |
| self.include_path = include_path | |
| self.root = root | |
| if os.path.isdir(root): | |
| self.fpaths = glob(os.path.join(root, '**', '*.png'), recursive=True) | |
| self.fpaths += glob(os.path.join(root, '**', '*.JPEG'), recursive=True) | |
| self.fpaths += glob(os.path.join(root, '**', '*.jpg'), recursive=True) | |
| self.fpaths = sorted(self.fpaths) | |
| assert len(self.fpaths) > 0, "File list is empty. Check the root." | |
| elif os.path.exists(root): | |
| self.fpaths = [root] | |
| else: | |
| raise FileNotFoundError(f"File not found: {root}") | |
| def __len__(self): | |
| return len(self.fpaths) | |
| def __getitem__(self, index: int) -> Tuple[torch.Tensor, str]: | |
| fpath = self.fpaths[index] | |
| img = Image.open(fpath).convert('RGB') | |
| if self.transforms is not None: | |
| img = self.transforms(img) | |
| path = "" | |
| if self.include_path: | |
| dirname = os.path.dirname(fpath) | |
| # remove root from dirname | |
| path = dirname[len(self.root) + 1:] | |
| return img, os.path.basename(fpath).split(os.extsep)[0], path | |
| def compress(model: PipelineWrapper, | |
| img_to_compress: torch.Tensor, | |
| num_noises: int, | |
| loaded_indices, | |
| device, | |
| ): | |
| # model.set_timesteps(model.num_timesteps, device=device) | |
| dtype = model.dtype | |
| prompt_embeds = model.encode_prompt("", None) | |
| set_seed(88888888) | |
| if img_to_compress is None: | |
| img_to_compress = torch.zeros(1, 3, model.get_image_size(), model.get_image_size(), device=device) | |
| enc_im = model.encode_image(img_to_compress.to(dtype)) | |
| kwargs = model.get_pre_kwargs(height=img_to_compress.shape[-2], width=img_to_compress.shape[-1], | |
| prompt_embeds=prompt_embeds) | |
| set_seed(100000) | |
| xt = torch.randn(1, *enc_im.shape[1:], device=device, dtype=dtype) | |
| result_noise_indices = [] | |
| pbar = tqdm(model.timesteps) | |
| for idx, t in enumerate(pbar): | |
| set_seed(idx) | |
| noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype) | |
| _, epst, _ = model.get_epst(xt, t, prompt_embeds, 0.0, **kwargs) | |
| x_0_hat = model.get_x_0_hat(xt, epst, t) | |
| if loaded_indices is None: | |
| if t >= 1: | |
| dot_prod = torch.matmul(noise.view(noise.shape[0], -1), | |
| (enc_im - x_0_hat).view(enc_im.shape[0], -1).transpose(0, 1)) | |
| best_idx = torch.argmax(dot_prod) | |
| best_noise = noise[best_idx] | |
| else: | |
| best_noise = noise[0] | |
| else: | |
| if t >= 1: | |
| best_idx = loaded_indices[idx] | |
| best_noise = noise[best_idx] | |
| else: | |
| best_noise = noise[0] | |
| if t >= 1: | |
| result_noise_indices.append(best_idx) | |
| xt = model.finish_step(xt, x_0_hat, epst, t, best_noise.unsqueeze(0), eta=None) | |
| try: | |
| img = model.decode_image(xt) | |
| except torch.OutOfMemoryError: | |
| img = model.decode_image(xt.to('cpu')) | |
| return img, torch.tensor(result_noise_indices).squeeze().cpu() | |
| def generate_ours(model: PipelineWrapper, | |
| num_noises: int, | |
| num_noises_to_optimize: int, | |
| prompt: str = "", | |
| negative_prompt: Optional[str] = None, | |
| indices = None, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| device = model.device | |
| dtype = model.dtype | |
| # print(num_noises, num_noises_to_optimize, flush=True) | |
| # model.set_timesteps(model.num_timesteps, device=device) | |
| set_seed(88888888) | |
| if prompt is None: | |
| prompt = "" | |
| prompt_embeds = model.encode_prompt(prompt, negative_prompt) | |
| kwargs = model.get_pre_kwargs(height=model.get_image_size(), | |
| width=model.get_image_size(), | |
| prompt_embeds=prompt_embeds) | |
| set_seed(100000) | |
| xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype) | |
| result_noise_indices = [] | |
| pbar = tqdm(model.timesteps) | |
| for idx, t in enumerate(pbar): | |
| set_seed(idx) | |
| noise = torch.randn(num_noises, *xt.shape[1:], device=device, dtype=dtype) # Codebook | |
| _, epst_uncond, epst_cond = model.get_epst(xt, t, prompt_embeds, 1.0, return_everything=True, **kwargs) | |
| x_0_hat = model.get_x_0_hat(xt, epst_uncond, t) | |
| if t >= 1: | |
| if indices is None: | |
| prev_classif_score = epst_uncond - epst_cond | |
| set_seed(int(time.time_ns() & 0xFFFFFFFF)) | |
| noise_indices = torch.randint(0, num_noises, size=(num_noises_to_optimize,), device=device) | |
| loss = torch.matmul(noise[noise_indices].view(num_noises_to_optimize, -1), | |
| prev_classif_score.view(prev_classif_score.shape[0], -1).transpose(0, 1)) | |
| best_idx = noise_indices[torch.argmax(loss)] | |
| else: | |
| best_idx = indices[idx] | |
| best_noise = noise[best_idx] | |
| result_noise_indices.append(best_idx) | |
| else: | |
| best_noise = torch.zeros_like(noise[0]) | |
| xt = model.finish_step(xt, x_0_hat, epst_uncond, t, best_noise) | |
| try: | |
| img = model.decode_image(xt) | |
| except torch.OutOfMemoryError: | |
| img = model.decode_image(xt.to('cpu')) | |
| return img, torch.stack(result_noise_indices).squeeze().cpu() | |
| def decompress(model: PipelineWrapper, | |
| image_size: Tuple[int, int], | |
| indices: Dict[str, torch.Tensor], | |
| num_noises: int, | |
| prompt: str = "", | |
| negative_prompt: Optional[str] = None, | |
| tedit: int = 0, | |
| new_prompt: str = "", | |
| new_negative_prompt: Optional[str] = None, | |
| guidance_scale: float = 3.0, | |
| num_pursuit_noises: Optional[int] = 1, | |
| num_pursuit_coef_bits: Optional[int] = 3, | |
| t_range: Tuple[int, int] = (999, 0), | |
| robust_randn: bool = False | |
| ) -> torch.Tensor: | |
| noise_indices = indices['noise_indices'] | |
| coeffs_indices = indices['coeff_indices'] | |
| num_pursuit_noises = num_pursuit_noises if num_pursuit_noises is not None else 1 | |
| num_pursuit_coef_bits = num_pursuit_coef_bits if num_pursuit_coef_bits is not None else 1 | |
| device = model.device | |
| dtype = model.dtype | |
| # model.set_timesteps(model.num_timesteps, device=device) | |
| set_seed(88888888) | |
| orig_prompt_embeds = model.encode_prompt(prompt, negative_prompt) | |
| kwargs_orig = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1], | |
| prompt_embeds=orig_prompt_embeds) | |
| if new_prompt != prompt or new_negative_prompt != negative_prompt: | |
| new_prompt_embeds = model.encode_prompt(new_prompt, new_negative_prompt) | |
| kwargs_new = model.get_pre_kwargs(height=image_size[-2], width=image_size[-1], | |
| prompt_embeds=new_prompt_embeds) | |
| else: | |
| new_prompt_embeds = orig_prompt_embeds | |
| kwargs_new = kwargs_orig | |
| set_seed(100000) | |
| xt = torch.randn(1, *model.get_latent_shape(image_size), device=device, dtype=dtype) | |
| pbar = tqdm(model.timesteps) | |
| for idx, t in enumerate(pbar): | |
| set_seed(idx) | |
| dont_optimize_t = not (t_range[0] >= t >= t_range[1]) | |
| # No intermittent support | |
| if robust_randn: | |
| noise = get_robust_randn(num_noises if not dont_optimize_t else 1, xt.shape[1:], device, dtype) | |
| else: | |
| noise = torch.randn(num_noises if not dont_optimize_t else 1, *xt.shape[1:], device=device, dtype=dtype) | |
| curr_embs = orig_prompt_embeds if idx < tedit else new_prompt_embeds | |
| curr_kwargs = kwargs_orig if idx < tedit else kwargs_new | |
| epst = model.get_epst(xt, t, curr_embs, guidance_scale, **curr_kwargs) | |
| x_0_hat = model.get_x_0_hat(xt, epst, t) | |
| curr_t_noise_indices = noise_indices[idx] | |
| best_noise = noise[curr_t_noise_indices[0]] | |
| pursuit_coefs = torch.linspace(0, 1, 2 ** num_pursuit_coef_bits + 1)[1:] | |
| if num_pursuit_noises > 1: | |
| curr_t_coeffs_indices = coeffs_indices[idx] | |
| if curr_t_coeffs_indices[0] == -1: | |
| continue | |
| for pursuit_idx in range(1, num_pursuit_noises): | |
| pursuit_coef = pursuit_coefs[curr_t_coeffs_indices[pursuit_idx]] | |
| best_noise = best_noise * torch.sqrt(pursuit_coef) + noise[ | |
| curr_t_noise_indices[pursuit_idx]] * torch.sqrt(1 - pursuit_coef) | |
| best_noise /= best_noise.std() | |
| best_noise = best_noise.unsqueeze(0) | |
| xt = model.finish_step(xt, x_0_hat, epst, t, best_noise) | |
| img = model.decode_image(xt) | |
| return img | |
| def inf_generate(model: PipelineWrapper, | |
| prompt: str = "", | |
| negative_prompt: Optional[str] = None, | |
| guidance_scale: float = 7.0, | |
| record: int = 0, | |
| save_root: str = "") -> Tuple[torch.Tensor, torch.Tensor]: | |
| device = model.device | |
| dtype = model.dtype | |
| model.set_timesteps(model.num_timesteps, device=device) | |
| prompt_embeds = model.encode_prompt(prompt, negative_prompt) | |
| kwargs = model.get_pre_kwargs(height=model.get_image_size(), | |
| width=model.get_image_size(), | |
| prompt_embeds=prompt_embeds) | |
| xt = torch.randn(1, *model.get_latent_shape(model.get_image_size()), device=device, dtype=dtype) | |
| pbar = tqdm(model.timesteps) | |
| for idx, t in enumerate(pbar): | |
| noise = torch.randn(1, *xt.shape[1:], device=device, dtype=dtype) | |
| epst = model.get_epst(xt, t, prompt_embeds, guidance_scale, **kwargs) | |
| x_0_hat = model.get_x_0_hat(xt, epst, t) | |
| xt = model.finish_step(xt, x_0_hat, epst, t, noise) | |
| if record and not idx % record: | |
| img = model.decode_image(x_0_hat) | |
| plt.imsave(os.path.join(save_root, f"progress/x_0_hat_{str(t.item()).zfill(4)}.png"), | |
| clear_color(img[0].unsqueeze(0), normalize=False)) | |
| try: | |
| img = model.decode_image(xt) | |
| except torch.OutOfMemoryError: | |
| img = model.decode_image(xt.to('cpu')) | |
| return img | |