Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Iterator, List, Tuple | |
| import torch | |
| from diffusers import FlowMatchEulerDiscreteScheduler, StableDiffusion3Pipeline | |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps | |
| from PIL import Image | |
| def calc_v_sd3( | |
| pipe: StableDiffusion3Pipeline, latent_model_input: torch.Tensor, | |
| prompt_embeds: torch.Tensor, pooled_prompt_embeds: torch.Tensor, | |
| guidance_scale: float, t: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Calculate the velocity (v) for Stable Diffusion 3. | |
| Args: | |
| pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline. | |
| latent_model_input (torch.Tensor): The input latent tensor. | |
| prompt_embeds (torch.Tensor): The text embeddings for the prompt. | |
| pooled_prompt_embeds (torch.Tensor): The pooled text embeddings for the prompt. | |
| guidance_scale (float): The guidance scale for classifier-free guidance. | |
| t (torch.Tensor): The current timestep. | |
| Returns: | |
| torch.Tensor: The predicted noise (velocity). | |
| """ | |
| timestep = t.expand(latent_model_input.shape[0]) | |
| noise_pred = pipe.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| pooled_projections=pooled_prompt_embeds, | |
| joint_attention_kwargs=None, | |
| return_dict=False, | |
| )[0] | |
| # perform guidance source | |
| if pipe.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) | |
| noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) | |
| return noise_pred | |
| # https://github.com/DSL-Lab/UniEdit-Flow | |
| def uniinv( | |
| pipe: StableDiffusion3Pipeline, timesteps: torch.Tensor, n_start: int, | |
| x0_src: torch.Tensor, src_prompt_embeds_all: torch.Tensor, | |
| src_pooled_prompt_embeds_all: torch.Tensor, src_guidance_scale: float, | |
| ) -> torch.Tensor: | |
| """ | |
| Perform the UniInv inversion process for Stable Diffusion 3. | |
| Args: | |
| pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline. | |
| timesteps (torch.Tensor): The timesteps for the diffusion process. | |
| n_start (int): The number of initial timesteps to skip. | |
| x0_src (torch.Tensor): The source latent tensor. | |
| src_prompt_embeds_all (torch.Tensor): The text embeddings for the source prompt. | |
| src_pooled_prompt_embeds_all (torch.Tensor): The pooled text embeddings for the source prompt. | |
| src_guidance_scale (float): The guidance scale for classifier-free guidance. | |
| Returns: | |
| torch.Tensor: The inverted latent tensor. | |
| """ | |
| x_t = x0_src.clone() | |
| timesteps_inv = torch.cat([torch.tensor([0.0], device=pipe.device), timesteps.flip(dims=(0,))], dim=0) | |
| if n_start > 0: | |
| zipped_timesteps_inv = zip(timesteps_inv[:-n_start - 1], timesteps_inv[1:-n_start]) | |
| else: | |
| zipped_timesteps_inv = zip(timesteps_inv[:-1], timesteps_inv[1:]) | |
| next_v = None | |
| for _i, (t_cur, t_prev) in enumerate(zipped_timesteps_inv): | |
| t_i = t_cur / 1000 | |
| t_ip1 = t_prev / 1000 | |
| dt = t_ip1 - t_i | |
| if next_v is None: | |
| latent_model_input = torch.cat([x_t, x_t]) if pipe.do_classifier_free_guidance else (x_t) | |
| v_tar = calc_v_sd3( | |
| pipe, latent_model_input, src_prompt_embeds_all, | |
| src_pooled_prompt_embeds_all, src_guidance_scale, t_cur, | |
| ) | |
| else: | |
| v_tar = next_v | |
| x_t = x_t.to(torch.float32) | |
| x_t_next = x_t + v_tar * dt | |
| x_t_next = x_t_next.to(pipe.dtype) | |
| latent_model_input = torch.cat([x_t_next, x_t_next]) if pipe.do_classifier_free_guidance else (x_t_next) | |
| v_tar_next = calc_v_sd3( | |
| pipe, latent_model_input, src_prompt_embeds_all, | |
| src_pooled_prompt_embeds_all, src_guidance_scale, t_prev, | |
| ) | |
| next_v = v_tar_next | |
| x_t = x_t + v_tar_next * dt | |
| x_t = x_t.to(pipe.dtype) | |
| return x_t | |
| def initialization( | |
| pipe: StableDiffusion3Pipeline, scheduler: FlowMatchEulerDiscreteScheduler, | |
| T_steps: int, n_start: int, x0_src: torch.Tensor, | |
| src_prompt: str, negative_prompt: str, src_guidance_scale: float, | |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| """ | |
| Initialize the inversion process by preparing the latent tensor and prompt embeddings, and performing UniInv. | |
| Args: | |
| pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline. | |
| scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process. | |
| T_steps (int): The total number of timesteps for the diffusion process. | |
| n_start (int): The number of initial timesteps to skip. | |
| x0_src (torch.Tensor): The source latent tensor. | |
| src_prompt (str): The source text prompt. | |
| negative_prompt (str): The negative text prompt for classifier-free guidance. | |
| src_guidance_scale (float): The guidance scale for classifier-free guidance. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| - The inverted latent tensor. | |
| - The original source latent tensor. | |
| - The timesteps for the diffusion process. | |
| - The text embeddings for the source prompt. | |
| - The pooled text embeddings for the source prompt. | |
| """ | |
| pipe._guidance_scale = src_guidance_scale | |
| ( | |
| src_prompt_embeds, | |
| src_negative_prompt_embeds, | |
| src_pooled_prompt_embeds, | |
| src_negative_pooled_prompt_embeds, | |
| ) = pipe.encode_prompt( | |
| prompt=src_prompt, | |
| prompt_2=None, | |
| prompt_3=None, | |
| negative_prompt=negative_prompt, | |
| do_classifier_free_guidance=pipe.do_classifier_free_guidance, | |
| device=pipe.device, | |
| ) | |
| src_prompt_embeds_all = torch.cat([src_negative_prompt_embeds, src_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else src_prompt_embeds | |
| src_pooled_prompt_embeds_all = torch.cat([src_negative_pooled_prompt_embeds, src_pooled_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else src_pooled_prompt_embeds | |
| timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, x0_src.device, timesteps=None) | |
| pipe._num_timesteps = len(timesteps) | |
| x_t = uniinv( | |
| pipe, timesteps, n_start, x0_src, src_prompt_embeds_all, | |
| src_pooled_prompt_embeds_all, src_guidance_scale, | |
| ) | |
| return x_t, x0_src, timesteps | |
| def sd3_denoise( | |
| pipe: StableDiffusion3Pipeline, timesteps: torch.Tensor, n_start: int, | |
| x_t: torch.Tensor, prompt_embeds_all: torch.Tensor, | |
| pooled_prompt_embeds_all: torch.Tensor, guidance_scale: float, | |
| ) -> torch.Tensor: | |
| """ | |
| Perform the denoising process for Stable Diffusion 3. | |
| Args: | |
| pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline. | |
| timesteps (torch.Tensor): The timesteps for the diffusion process. | |
| n_start (int): The number of initial timesteps to skip. | |
| x_t (torch.Tensor): The latent tensor at the starting timestep. | |
| prompt_embeds_all (torch.Tensor): The text embeddings for the prompt. | |
| pooled_prompt_embeds_all (torch.Tensor): The pooled text embeddings for the prompt. | |
| guidance_scale (float): The guidance scale for classifier-free guidance. | |
| Returns: | |
| torch.Tensor: The denoised latent tensor. | |
| """ | |
| f_xt = x_t.clone() | |
| for i, t in enumerate(timesteps[n_start:]): | |
| t_i = t / 1000 | |
| if i + 1 < len(timesteps[n_start:]): | |
| t_im1 = (timesteps[n_start + i + 1]) / 1000 | |
| else: | |
| t_im1 = torch.zeros_like(t_i).to(t_i.device) | |
| dt = t_im1 - t_i | |
| latent_model_input = torch.cat([f_xt, f_xt]) if pipe.do_classifier_free_guidance else (f_xt) | |
| v_tar = calc_v_sd3( | |
| pipe, latent_model_input, prompt_embeds_all, | |
| pooled_prompt_embeds_all, guidance_scale, t, | |
| ) | |
| f_xt = f_xt.to(torch.float32) | |
| f_xt = f_xt + v_tar * dt | |
| f_xt = f_xt.to(pipe.dtype) | |
| return f_xt | |
| def sd3_editing( | |
| pipe: StableDiffusion3Pipeline, scheduler: FlowMatchEulerDiscreteScheduler, | |
| T_steps: int, n_max: int, x0_src: torch.Tensor, src_prompt: str, | |
| tar_prompt: str, negative_prompt: str, src_guidance_scale: float, | |
| tar_guidance_scale: float, flowopt_iterations: int, eta: float, | |
| ) -> Iterator[List[Tuple[Image.Image, str]]]: | |
| """ | |
| Perform the editing process for Stable Diffusion 3 using FlowOpt. | |
| Args: | |
| pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline. | |
| scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process. | |
| T_steps (int): The total number of timesteps for the diffusion process. | |
| n_max (int): The maximum number of timesteps to consider. | |
| x0_src (torch.Tensor): The source latent tensor. | |
| src_prompt (str): The source text prompt. | |
| tar_prompt (str): The target text prompt for editing. | |
| negative_prompt (str): The negative text prompt for classifier-free guidance. | |
| src_guidance_scale (float): The guidance scale for the source prompt. | |
| tar_guidance_scale (float): The guidance scale for the target prompt. | |
| flowopt_iterations (int): The number of FlowOpt iterations to perform. | |
| eta (float): The step size for the FlowOpt update. | |
| Yields: | |
| Iterator[List[Tuple[Image.Image, str]]]: A list of tuples containing the generated images and their corresponding iteration labels. | |
| """ | |
| n_start = T_steps - n_max | |
| x_t, x0_src, timesteps = initialization( | |
| pipe, scheduler, T_steps, n_start, x0_src, src_prompt, | |
| negative_prompt, src_guidance_scale, | |
| ) | |
| pipe._guidance_scale = tar_guidance_scale | |
| ( | |
| tar_prompt_embeds, | |
| tar_negative_prompt_embeds, | |
| tar_pooled_prompt_embeds, | |
| tar_negative_pooled_prompt_embeds, | |
| ) = pipe.encode_prompt( | |
| prompt=tar_prompt, | |
| prompt_2=None, | |
| prompt_3=None, | |
| negative_prompt=negative_prompt, | |
| do_classifier_free_guidance=pipe.do_classifier_free_guidance, | |
| device=pipe.device, | |
| ) | |
| tar_prompt_embeds_all = torch.cat([tar_negative_prompt_embeds, tar_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else tar_prompt_embeds | |
| tar_pooled_prompt_embeds_all = torch.cat([tar_negative_pooled_prompt_embeds, tar_pooled_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else tar_pooled_prompt_embeds | |
| history = [] | |
| j_star = x0_src.clone().to(torch.float32) # y | |
| for flowopt_iter in range(flowopt_iterations + 1): | |
| f_xt = sd3_denoise( | |
| pipe, timesteps, n_start, x_t, tar_prompt_embeds_all, | |
| tar_pooled_prompt_embeds_all, tar_guidance_scale, | |
| ) # Eq. (3) | |
| if flowopt_iter < flowopt_iterations: | |
| x_t = x_t.to(torch.float32) | |
| x_t = x_t - eta * (f_xt - j_star) # Eq. (6) with c = c_tar | |
| x_t = x_t.to(x0_src.dtype) | |
| x0_flowopt = f_xt.clone() | |
| x0_flowopt_denorm = (x0_flowopt / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor | |
| with torch.autocast("cuda"), torch.inference_mode(): | |
| x0_flowopt_image = pipe.vae.decode(x0_flowopt_denorm, return_dict=False)[0].clamp(-1, 1) | |
| x0_flowopt_image_pil = pipe.image_processor.postprocess(x0_flowopt_image)[0] | |
| history.append((x0_flowopt_image_pil, f"Iteration {flowopt_iter}")) | |
| yield history | |