Spaces:
Running
on
Zero
Running
on
Zero
| import re | |
| from copy import deepcopy | |
| from dataclasses import asdict, dataclass | |
| from enum import Enum | |
| from typing import List, Optional, Union | |
| import numpy as np | |
| import torch | |
| from numpy import exp, pi, sqrt | |
| from torchvision.transforms.functional import resize | |
| from tqdm.auto import tqdm | |
| from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer | |
| from diffusers.models import AutoencoderKL, UNet2DConditionModel | |
| from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin | |
| from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker | |
| from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler | |
| def preprocess_image(image): | |
| from PIL import Image | |
| """Preprocess an input image | |
| Same as | |
| https://github.com/huggingface/diffusers/blob/1138d63b519e37f0ce04e027b9f4a3261d27c628/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L44 | |
| """ | |
| w, h = image.size | |
| w, h = (x - x % 32 for x in (w, h)) # resize to integer multiple of 32 | |
| image = image.resize((w, h), resample=Image.LANCZOS) | |
| image = np.array(image).astype(np.float32) / 255.0 | |
| image = image[None].transpose(0, 3, 1, 2) | |
| image = torch.from_numpy(image) | |
| return 2.0 * image - 1.0 | |
| class CanvasRegion: | |
| """Class defining a rectangular region in the canvas""" | |
| row_init: int # Region starting row in pixel space (included) | |
| row_end: int # Region end row in pixel space (not included) | |
| col_init: int # Region starting column in pixel space (included) | |
| col_end: int # Region end column in pixel space (not included) | |
| region_seed: int = None # Seed for random operations in this region | |
| noise_eps: float = 0.0 # Deviation of a zero-mean gaussian noise to be applied over the latents in this region. Useful for slightly "rerolling" latents | |
| def __post_init__(self): | |
| # Initialize arguments if not specified | |
| if self.region_seed is None: | |
| self.region_seed = np.random.randint(9999999999) | |
| # Check coordinates are non-negative | |
| for coord in [self.row_init, self.row_end, self.col_init, self.col_end]: | |
| if coord < 0: | |
| raise ValueError( | |
| f"A CanvasRegion must be defined with non-negative indices, found ({self.row_init}, {self.row_end}, {self.col_init}, {self.col_end})" | |
| ) | |
| # Check coordinates are divisible by 8, else we end up with nasty rounding error when mapping to latent space | |
| for coord in [self.row_init, self.row_end, self.col_init, self.col_end]: | |
| if coord // 8 != coord / 8: | |
| raise ValueError( | |
| f"A CanvasRegion must be defined with locations divisible by 8, found ({self.row_init}-{self.row_end}, {self.col_init}-{self.col_end})" | |
| ) | |
| # Check noise eps is non-negative | |
| if self.noise_eps < 0: | |
| raise ValueError(f"A CanvasRegion must be defined noises eps non-negative, found {self.noise_eps}") | |
| # Compute coordinates for this region in latent space | |
| self.latent_row_init = self.row_init // 8 | |
| self.latent_row_end = self.row_end // 8 | |
| self.latent_col_init = self.col_init // 8 | |
| self.latent_col_end = self.col_end // 8 | |
| def width(self): | |
| return self.col_end - self.col_init | |
| def height(self): | |
| return self.row_end - self.row_init | |
| def get_region_generator(self, device="cpu"): | |
| """Creates a torch.Generator based on the random seed of this region""" | |
| # Initialize region generator | |
| return torch.Generator(device).manual_seed(self.region_seed) | |
| def __dict__(self): | |
| return asdict(self) | |
| class MaskModes(Enum): | |
| """Modes in which the influence of diffuser is masked""" | |
| CONSTANT = "constant" | |
| GAUSSIAN = "gaussian" | |
| QUARTIC = "quartic" # See https://en.wikipedia.org/wiki/Kernel_(statistics) | |
| class DiffusionRegion(CanvasRegion): | |
| """Abstract class defining a region where some class of diffusion process is acting""" | |
| pass | |
| class Text2ImageRegion(DiffusionRegion): | |
| """Class defining a region where a text guided diffusion process is acting""" | |
| prompt: str = "" # Text prompt guiding the diffuser in this region | |
| guidance_scale: float = 7.5 # Guidance scale of the diffuser in this region. If None, randomize | |
| mask_type: MaskModes = MaskModes.GAUSSIAN.value # Kind of weight mask applied to this region | |
| mask_weight: float = 1.0 # Global weights multiplier of the mask | |
| tokenized_prompt = None # Tokenized prompt | |
| encoded_prompt = None # Encoded prompt | |
| def __post_init__(self): | |
| super().__post_init__() | |
| # Mask weight cannot be negative | |
| if self.mask_weight < 0: | |
| raise ValueError( | |
| f"A Text2ImageRegion must be defined with non-negative mask weight, found {self.mask_weight}" | |
| ) | |
| # Mask type must be an actual known mask | |
| if self.mask_type not in [e.value for e in MaskModes]: | |
| raise ValueError( | |
| f"A Text2ImageRegion was defined with mask {self.mask_type}, which is not an accepted mask ({[e.value for e in MaskModes]})" | |
| ) | |
| # Randomize arguments if given as None | |
| if self.guidance_scale is None: | |
| self.guidance_scale = np.random.randint(5, 30) | |
| # Clean prompt | |
| self.prompt = re.sub(" +", " ", self.prompt).replace("\n", " ") | |
| def tokenize_prompt(self, tokenizer): | |
| """Tokenizes the prompt for this diffusion region using a given tokenizer""" | |
| self.tokenized_prompt = tokenizer( | |
| self.prompt, | |
| padding="max_length", | |
| max_length=tokenizer.model_max_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| ) | |
| def encode_prompt(self, text_encoder, device): | |
| """Encodes the previously tokenized prompt for this diffusion region using a given encoder""" | |
| assert self.tokenized_prompt is not None, ValueError( | |
| "Prompt in diffusion region must be tokenized before encoding" | |
| ) | |
| self.encoded_prompt = text_encoder(self.tokenized_prompt.input_ids.to(device))[0] | |
| class Image2ImageRegion(DiffusionRegion): | |
| """Class defining a region where an image guided diffusion process is acting""" | |
| reference_image: torch.Tensor = None | |
| strength: float = 0.8 # Strength of the image | |
| def __post_init__(self): | |
| super().__post_init__() | |
| if self.reference_image is None: | |
| raise ValueError("Must provide a reference image when creating an Image2ImageRegion") | |
| if self.strength < 0 or self.strength > 1: | |
| raise ValueError(f"The value of strength should in [0.0, 1.0] but is {self.strength}") | |
| # Rescale image to region shape | |
| self.reference_image = resize(self.reference_image, size=[self.height, self.width]) | |
| def encode_reference_image(self, encoder, device, generator, cpu_vae=False): | |
| """Encodes the reference image for this Image2Image region into the latent space""" | |
| # Place encoder in CPU or not following the parameter cpu_vae | |
| if cpu_vae: | |
| # Note here we use mean instead of sample, to avoid moving also generator to CPU, which is troublesome | |
| self.reference_latents = encoder.cpu().encode(self.reference_image).latent_dist.mean.to(device) | |
| else: | |
| self.reference_latents = encoder.encode(self.reference_image.to(device)).latent_dist.sample( | |
| generator=generator | |
| ) | |
| self.reference_latents = 0.18215 * self.reference_latents | |
| def __dict__(self): | |
| # This class requires special casting to dict because of the reference_image tensor. Otherwise it cannot be casted to JSON | |
| # Get all basic fields from parent class | |
| super_fields = {key: getattr(self, key) for key in DiffusionRegion.__dataclass_fields__.keys()} | |
| # Pack other fields | |
| return {**super_fields, "reference_image": self.reference_image.cpu().tolist(), "strength": self.strength} | |
| class RerollModes(Enum): | |
| """Modes in which the reroll regions operate""" | |
| RESET = "reset" # Completely reset the random noise in the region | |
| EPSILON = "epsilon" # Alter slightly the latents in the region | |
| class RerollRegion(CanvasRegion): | |
| """Class defining a rectangular canvas region in which initial latent noise will be rerolled""" | |
| reroll_mode: RerollModes = RerollModes.RESET.value | |
| class MaskWeightsBuilder: | |
| """Auxiliary class to compute a tensor of weights for a given diffusion region""" | |
| latent_space_dim: int # Size of the U-net latent space | |
| nbatch: int = 1 # Batch size in the U-net | |
| def compute_mask_weights(self, region: DiffusionRegion) -> torch.tensor: | |
| """Computes a tensor of weights for a given diffusion region""" | |
| MASK_BUILDERS = { | |
| MaskModes.CONSTANT.value: self._constant_weights, | |
| MaskModes.GAUSSIAN.value: self._gaussian_weights, | |
| MaskModes.QUARTIC.value: self._quartic_weights, | |
| } | |
| return MASK_BUILDERS[region.mask_type](region) | |
| def _constant_weights(self, region: DiffusionRegion) -> torch.tensor: | |
| """Computes a tensor of constant for a given diffusion region""" | |
| latent_width = region.latent_col_end - region.latent_col_init | |
| latent_height = region.latent_row_end - region.latent_row_init | |
| return torch.ones(self.nbatch, self.latent_space_dim, latent_height, latent_width) * region.mask_weight | |
| def _gaussian_weights(self, region: DiffusionRegion) -> torch.tensor: | |
| """Generates a gaussian mask of weights for tile contributions""" | |
| latent_width = region.latent_col_end - region.latent_col_init | |
| latent_height = region.latent_row_end - region.latent_row_init | |
| var = 0.01 | |
| midpoint = (latent_width - 1) / 2 # -1 because index goes from 0 to latent_width - 1 | |
| x_probs = [ | |
| exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var) | |
| for x in range(latent_width) | |
| ] | |
| midpoint = (latent_height - 1) / 2 | |
| y_probs = [ | |
| exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var) | |
| for y in range(latent_height) | |
| ] | |
| weights = np.outer(y_probs, x_probs) * region.mask_weight | |
| return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1)) | |
| def _quartic_weights(self, region: DiffusionRegion) -> torch.tensor: | |
| """Generates a quartic mask of weights for tile contributions | |
| The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits. | |
| """ | |
| quartic_constant = 15.0 / 16.0 | |
| support = (np.array(range(region.latent_col_init, region.latent_col_end)) - region.latent_col_init) / ( | |
| region.latent_col_end - region.latent_col_init - 1 | |
| ) * 1.99 - (1.99 / 2.0) | |
| x_probs = quartic_constant * np.square(1 - np.square(support)) | |
| support = (np.array(range(region.latent_row_init, region.latent_row_end)) - region.latent_row_init) / ( | |
| region.latent_row_end - region.latent_row_init - 1 | |
| ) * 1.99 - (1.99 / 2.0) | |
| y_probs = quartic_constant * np.square(1 - np.square(support)) | |
| weights = np.outer(y_probs, x_probs) * region.mask_weight | |
| return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1)) | |
| class StableDiffusionCanvasPipeline(DiffusionPipeline, StableDiffusionMixin): | |
| """Stable Diffusion pipeline that mixes several diffusers in the same canvas""" | |
| def __init__( | |
| self, | |
| vae: AutoencoderKL, | |
| text_encoder: CLIPTextModel, | |
| tokenizer: CLIPTokenizer, | |
| unet: UNet2DConditionModel, | |
| scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler], | |
| safety_checker: StableDiffusionSafetyChecker, | |
| feature_extractor: CLIPImageProcessor, | |
| ): | |
| super().__init__() | |
| self.register_modules( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| unet=unet, | |
| scheduler=scheduler, | |
| safety_checker=safety_checker, | |
| feature_extractor=feature_extractor, | |
| ) | |
| def decode_latents(self, latents, cpu_vae=False): | |
| """Decodes a given array of latents into pixel space""" | |
| # scale and decode the image latents with vae | |
| if cpu_vae: | |
| lat = deepcopy(latents).cpu() | |
| vae = deepcopy(self.vae).cpu() | |
| else: | |
| lat = latents | |
| vae = self.vae | |
| lat = 1 / 0.18215 * lat | |
| image = vae.decode(lat).sample | |
| image = (image / 2 + 0.5).clamp(0, 1) | |
| image = image.cpu().permute(0, 2, 3, 1).numpy() | |
| return self.numpy_to_pil(image) | |
| def get_latest_timestep_img2img(self, num_inference_steps, strength): | |
| """Finds the latest timesteps where an img2img strength does not impose latents anymore""" | |
| # get the original timestep using init_timestep | |
| offset = self.scheduler.config.get("steps_offset", 0) | |
| init_timestep = int(num_inference_steps * (1 - strength)) + offset | |
| init_timestep = min(init_timestep, num_inference_steps) | |
| t_start = min(max(num_inference_steps - init_timestep + offset, 0), num_inference_steps - 1) | |
| latest_timestep = self.scheduler.timesteps[t_start] | |
| return latest_timestep | |
| def __call__( | |
| self, | |
| canvas_height: int, | |
| canvas_width: int, | |
| regions: List[DiffusionRegion], | |
| num_inference_steps: Optional[int] = 50, | |
| seed: Optional[int] = 12345, | |
| reroll_regions: Optional[List[RerollRegion]] = None, | |
| cpu_vae: Optional[bool] = False, | |
| decode_steps: Optional[bool] = False, | |
| ): | |
| if reroll_regions is None: | |
| reroll_regions = [] | |
| batch_size = 1 | |
| if decode_steps: | |
| steps_images = [] | |
| # Prepare scheduler | |
| self.scheduler.set_timesteps(num_inference_steps, device=self.device) | |
| # Split diffusion regions by their kind | |
| text2image_regions = [region for region in regions if isinstance(region, Text2ImageRegion)] | |
| image2image_regions = [region for region in regions if isinstance(region, Image2ImageRegion)] | |
| # Prepare text embeddings | |
| for region in text2image_regions: | |
| region.tokenize_prompt(self.tokenizer) | |
| region.encode_prompt(self.text_encoder, self.device) | |
| # Create original noisy latents using the timesteps | |
| latents_shape = (batch_size, self.unet.config.in_channels, canvas_height // 8, canvas_width // 8) | |
| generator = torch.Generator(self.device).manual_seed(seed) | |
| init_noise = torch.randn(latents_shape, generator=generator, device=self.device) | |
| # Reset latents in seed reroll regions, if requested | |
| for region in reroll_regions: | |
| if region.reroll_mode == RerollModes.RESET.value: | |
| region_shape = ( | |
| latents_shape[0], | |
| latents_shape[1], | |
| region.latent_row_end - region.latent_row_init, | |
| region.latent_col_end - region.latent_col_init, | |
| ) | |
| init_noise[ | |
| :, | |
| :, | |
| region.latent_row_init : region.latent_row_end, | |
| region.latent_col_init : region.latent_col_end, | |
| ] = torch.randn(region_shape, generator=region.get_region_generator(self.device), device=self.device) | |
| # Apply epsilon noise to regions: first diffusion regions, then reroll regions | |
| all_eps_rerolls = regions + [r for r in reroll_regions if r.reroll_mode == RerollModes.EPSILON.value] | |
| for region in all_eps_rerolls: | |
| if region.noise_eps > 0: | |
| region_noise = init_noise[ | |
| :, | |
| :, | |
| region.latent_row_init : region.latent_row_end, | |
| region.latent_col_init : region.latent_col_end, | |
| ] | |
| eps_noise = ( | |
| torch.randn( | |
| region_noise.shape, generator=region.get_region_generator(self.device), device=self.device | |
| ) | |
| * region.noise_eps | |
| ) | |
| init_noise[ | |
| :, | |
| :, | |
| region.latent_row_init : region.latent_row_end, | |
| region.latent_col_init : region.latent_col_end, | |
| ] += eps_noise | |
| # scale the initial noise by the standard deviation required by the scheduler | |
| latents = init_noise * self.scheduler.init_noise_sigma | |
| # Get unconditional embeddings for classifier free guidance in text2image regions | |
| for region in text2image_regions: | |
| max_length = region.tokenized_prompt.input_ids.shape[-1] | |
| uncond_input = self.tokenizer( | |
| [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" | |
| ) | |
| uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] | |
| # For classifier free guidance, we need to do two forward passes. | |
| # Here we concatenate the unconditional and text embeddings into a single batch | |
| # to avoid doing two forward passes | |
| region.encoded_prompt = torch.cat([uncond_embeddings, region.encoded_prompt]) | |
| # Prepare image latents | |
| for region in image2image_regions: | |
| region.encode_reference_image(self.vae, device=self.device, generator=generator) | |
| # Prepare mask of weights for each region | |
| mask_builder = MaskWeightsBuilder(latent_space_dim=self.unet.config.in_channels, nbatch=batch_size) | |
| mask_weights = [mask_builder.compute_mask_weights(region).to(self.device) for region in text2image_regions] | |
| # Diffusion timesteps | |
| for i, t in tqdm(enumerate(self.scheduler.timesteps)): | |
| # Diffuse each region | |
| noise_preds_regions = [] | |
| # text2image regions | |
| for region in text2image_regions: | |
| region_latents = latents[ | |
| :, | |
| :, | |
| region.latent_row_init : region.latent_row_end, | |
| region.latent_col_init : region.latent_col_end, | |
| ] | |
| # expand the latents if we are doing classifier free guidance | |
| latent_model_input = torch.cat([region_latents] * 2) | |
| # scale model input following scheduler rules | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| # predict the noise residual | |
| noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=region.encoded_prompt)["sample"] | |
| # perform guidance | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred_region = noise_pred_uncond + region.guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| noise_preds_regions.append(noise_pred_region) | |
| # Merge noise predictions for all tiles | |
| noise_pred = torch.zeros(latents.shape, device=self.device) | |
| contributors = torch.zeros(latents.shape, device=self.device) | |
| # Add each tile contribution to overall latents | |
| for region, noise_pred_region, mask_weights_region in zip( | |
| text2image_regions, noise_preds_regions, mask_weights | |
| ): | |
| noise_pred[ | |
| :, | |
| :, | |
| region.latent_row_init : region.latent_row_end, | |
| region.latent_col_init : region.latent_col_end, | |
| ] += noise_pred_region * mask_weights_region | |
| contributors[ | |
| :, | |
| :, | |
| region.latent_row_init : region.latent_row_end, | |
| region.latent_col_init : region.latent_col_end, | |
| ] += mask_weights_region | |
| # Average overlapping areas with more than 1 contributor | |
| noise_pred /= contributors | |
| noise_pred = torch.nan_to_num( | |
| noise_pred | |
| ) # Replace NaNs by zeros: NaN can appear if a position is not covered by any DiffusionRegion | |
| # compute the previous noisy sample x_t -> x_t-1 | |
| latents = self.scheduler.step(noise_pred, t, latents).prev_sample | |
| # Image2Image regions: override latents generated by the scheduler | |
| for region in image2image_regions: | |
| influence_step = self.get_latest_timestep_img2img(num_inference_steps, region.strength) | |
| # Only override in the timesteps before the last influence step of the image (given by its strength) | |
| if t > influence_step: | |
| timestep = t.repeat(batch_size) | |
| region_init_noise = init_noise[ | |
| :, | |
| :, | |
| region.latent_row_init : region.latent_row_end, | |
| region.latent_col_init : region.latent_col_end, | |
| ] | |
| region_latents = self.scheduler.add_noise(region.reference_latents, region_init_noise, timestep) | |
| latents[ | |
| :, | |
| :, | |
| region.latent_row_init : region.latent_row_end, | |
| region.latent_col_init : region.latent_col_end, | |
| ] = region_latents | |
| if decode_steps: | |
| steps_images.append(self.decode_latents(latents, cpu_vae)) | |
| # scale and decode the image latents with vae | |
| image = self.decode_latents(latents, cpu_vae) | |
| output = {"images": image} | |
| if decode_steps: | |
| output = {**output, "steps_images": steps_images} | |
| return output | |