from typing import Iterator, List, Tuple import torch from diffusers import FlowMatchEulerDiscreteScheduler, StableDiffusion3Pipeline from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps from PIL import Image @torch.no_grad() def calc_v_sd3( pipe: StableDiffusion3Pipeline, latent_model_input: torch.Tensor, prompt_embeds: torch.Tensor, pooled_prompt_embeds: torch.Tensor, guidance_scale: float, t: torch.Tensor, ) -> torch.Tensor: """ Calculate the velocity (v) for Stable Diffusion 3. Args: pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline. latent_model_input (torch.Tensor): The input latent tensor. prompt_embeds (torch.Tensor): The text embeddings for the prompt. pooled_prompt_embeds (torch.Tensor): The pooled text embeddings for the prompt. guidance_scale (float): The guidance scale for classifier-free guidance. t (torch.Tensor): The current timestep. Returns: torch.Tensor: The predicted noise (velocity). """ timestep = t.expand(latent_model_input.shape[0]) noise_pred = pipe.transformer( hidden_states=latent_model_input, timestep=timestep, encoder_hidden_states=prompt_embeds, pooled_projections=pooled_prompt_embeds, joint_attention_kwargs=None, return_dict=False, )[0] # perform guidance source if pipe.do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) return noise_pred # https://github.com/DSL-Lab/UniEdit-Flow @torch.no_grad() def uniinv( pipe: StableDiffusion3Pipeline, timesteps: torch.Tensor, n_start: int, x0_src: torch.Tensor, src_prompt_embeds_all: torch.Tensor, src_pooled_prompt_embeds_all: torch.Tensor, src_guidance_scale: float, ) -> torch.Tensor: """ Perform the UniInv inversion process for Stable Diffusion 3. Args: pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline. timesteps (torch.Tensor): The timesteps for the diffusion process. n_start (int): The number of initial timesteps to skip. x0_src (torch.Tensor): The source latent tensor. src_prompt_embeds_all (torch.Tensor): The text embeddings for the source prompt. src_pooled_prompt_embeds_all (torch.Tensor): The pooled text embeddings for the source prompt. src_guidance_scale (float): The guidance scale for classifier-free guidance. Returns: torch.Tensor: The inverted latent tensor. """ x_t = x0_src.clone() timesteps_inv = torch.cat([torch.tensor([0.0], device=pipe.device), timesteps.flip(dims=(0,))], dim=0) if n_start > 0: zipped_timesteps_inv = zip(timesteps_inv[:-n_start - 1], timesteps_inv[1:-n_start]) else: zipped_timesteps_inv = zip(timesteps_inv[:-1], timesteps_inv[1:]) next_v = None for _i, (t_cur, t_prev) in enumerate(zipped_timesteps_inv): t_i = t_cur / 1000 t_ip1 = t_prev / 1000 dt = t_ip1 - t_i if next_v is None: latent_model_input = torch.cat([x_t, x_t]) if pipe.do_classifier_free_guidance else (x_t) v_tar = calc_v_sd3( pipe, latent_model_input, src_prompt_embeds_all, src_pooled_prompt_embeds_all, src_guidance_scale, t_cur, ) else: v_tar = next_v x_t = x_t.to(torch.float32) x_t_next = x_t + v_tar * dt x_t_next = x_t_next.to(pipe.dtype) latent_model_input = torch.cat([x_t_next, x_t_next]) if pipe.do_classifier_free_guidance else (x_t_next) v_tar_next = calc_v_sd3( pipe, latent_model_input, src_prompt_embeds_all, src_pooled_prompt_embeds_all, src_guidance_scale, t_prev, ) next_v = v_tar_next x_t = x_t + v_tar_next * dt x_t = x_t.to(pipe.dtype) return x_t @torch.no_grad() def initialization( pipe: StableDiffusion3Pipeline, scheduler: FlowMatchEulerDiscreteScheduler, T_steps: int, n_start: int, x0_src: torch.Tensor, src_prompt: str, negative_prompt: str, src_guidance_scale: float, ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: """ Initialize the inversion process by preparing the latent tensor and prompt embeddings, and performing UniInv. Args: pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline. scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process. T_steps (int): The total number of timesteps for the diffusion process. n_start (int): The number of initial timesteps to skip. x0_src (torch.Tensor): The source latent tensor. src_prompt (str): The source text prompt. negative_prompt (str): The negative text prompt for classifier-free guidance. src_guidance_scale (float): The guidance scale for classifier-free guidance. Returns: Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: - The inverted latent tensor. - The original source latent tensor. - The timesteps for the diffusion process. - The text embeddings for the source prompt. - The pooled text embeddings for the source prompt. """ pipe._guidance_scale = src_guidance_scale ( src_prompt_embeds, src_negative_prompt_embeds, src_pooled_prompt_embeds, src_negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt=src_prompt, prompt_2=None, prompt_3=None, negative_prompt=negative_prompt, do_classifier_free_guidance=pipe.do_classifier_free_guidance, device=pipe.device, ) src_prompt_embeds_all = torch.cat([src_negative_prompt_embeds, src_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else src_prompt_embeds src_pooled_prompt_embeds_all = torch.cat([src_negative_pooled_prompt_embeds, src_pooled_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else src_pooled_prompt_embeds timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, x0_src.device, timesteps=None) pipe._num_timesteps = len(timesteps) x_t = uniinv( pipe, timesteps, n_start, x0_src, src_prompt_embeds_all, src_pooled_prompt_embeds_all, src_guidance_scale, ) return x_t, x0_src, timesteps @torch.no_grad() def sd3_denoise( pipe: StableDiffusion3Pipeline, timesteps: torch.Tensor, n_start: int, x_t: torch.Tensor, prompt_embeds_all: torch.Tensor, pooled_prompt_embeds_all: torch.Tensor, guidance_scale: float, ) -> torch.Tensor: """ Perform the denoising process for Stable Diffusion 3. Args: pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline. timesteps (torch.Tensor): The timesteps for the diffusion process. n_start (int): The number of initial timesteps to skip. x_t (torch.Tensor): The latent tensor at the starting timestep. prompt_embeds_all (torch.Tensor): The text embeddings for the prompt. pooled_prompt_embeds_all (torch.Tensor): The pooled text embeddings for the prompt. guidance_scale (float): The guidance scale for classifier-free guidance. Returns: torch.Tensor: The denoised latent tensor. """ f_xt = x_t.clone() for i, t in enumerate(timesteps[n_start:]): t_i = t / 1000 if i + 1 < len(timesteps[n_start:]): t_im1 = (timesteps[n_start + i + 1]) / 1000 else: t_im1 = torch.zeros_like(t_i).to(t_i.device) dt = t_im1 - t_i latent_model_input = torch.cat([f_xt, f_xt]) if pipe.do_classifier_free_guidance else (f_xt) v_tar = calc_v_sd3( pipe, latent_model_input, prompt_embeds_all, pooled_prompt_embeds_all, guidance_scale, t, ) f_xt = f_xt.to(torch.float32) f_xt = f_xt + v_tar * dt f_xt = f_xt.to(pipe.dtype) return f_xt @torch.no_grad() def sd3_editing( pipe: StableDiffusion3Pipeline, scheduler: FlowMatchEulerDiscreteScheduler, T_steps: int, n_max: int, x0_src: torch.Tensor, src_prompt: str, tar_prompt: str, negative_prompt: str, src_guidance_scale: float, tar_guidance_scale: float, flowopt_iterations: int, eta: float, ) -> Iterator[List[Tuple[Image.Image, str]]]: """ Perform the editing process for Stable Diffusion 3 using FlowOpt. Args: pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline. scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process. T_steps (int): The total number of timesteps for the diffusion process. n_max (int): The maximum number of timesteps to consider. x0_src (torch.Tensor): The source latent tensor. src_prompt (str): The source text prompt. tar_prompt (str): The target text prompt for editing. negative_prompt (str): The negative text prompt for classifier-free guidance. src_guidance_scale (float): The guidance scale for the source prompt. tar_guidance_scale (float): The guidance scale for the target prompt. flowopt_iterations (int): The number of FlowOpt iterations to perform. eta (float): The step size for the FlowOpt update. Yields: Iterator[List[Tuple[Image.Image, str]]]: A list of tuples containing the generated images and their corresponding iteration labels. """ n_start = T_steps - n_max x_t, x0_src, timesteps = initialization( pipe, scheduler, T_steps, n_start, x0_src, src_prompt, negative_prompt, src_guidance_scale, ) pipe._guidance_scale = tar_guidance_scale ( tar_prompt_embeds, tar_negative_prompt_embeds, tar_pooled_prompt_embeds, tar_negative_pooled_prompt_embeds, ) = pipe.encode_prompt( prompt=tar_prompt, prompt_2=None, prompt_3=None, negative_prompt=negative_prompt, do_classifier_free_guidance=pipe.do_classifier_free_guidance, device=pipe.device, ) tar_prompt_embeds_all = torch.cat([tar_negative_prompt_embeds, tar_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else tar_prompt_embeds tar_pooled_prompt_embeds_all = torch.cat([tar_negative_pooled_prompt_embeds, tar_pooled_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else tar_pooled_prompt_embeds history = [] j_star = x0_src.clone().to(torch.float32) # y for flowopt_iter in range(flowopt_iterations + 1): f_xt = sd3_denoise( pipe, timesteps, n_start, x_t, tar_prompt_embeds_all, tar_pooled_prompt_embeds_all, tar_guidance_scale, ) # Eq. (3) if flowopt_iter < flowopt_iterations: x_t = x_t.to(torch.float32) x_t = x_t - eta * (f_xt - j_star) # Eq. (6) with c = c_tar x_t = x_t.to(x0_src.dtype) x0_flowopt = f_xt.clone() x0_flowopt_denorm = (x0_flowopt / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor with torch.autocast("cuda"), torch.inference_mode(): x0_flowopt_image = pipe.vae.decode(x0_flowopt_denorm, return_dict=False)[0].clamp(-1, 1) x0_flowopt_image_pil = pipe.image_processor.postprocess(x0_flowopt_image)[0] history.append((x0_flowopt_image_pil, f"Iteration {flowopt_iter}")) yield history