Spaces:
Running
on
Zero
Running
on
Zero
| from typing import Iterator, List, Tuple | |
| import numpy as np | |
| import torch | |
| from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline | |
| from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps | |
| from PIL import Image | |
| def calculate_shift( | |
| image_seq_len: int, | |
| base_seq_len: int = 256, | |
| max_seq_len: int = 4096, | |
| base_shift: float = 0.5, | |
| max_shift: float = 1.16, | |
| ) -> float: | |
| m = (max_shift - base_shift) / (max_seq_len - base_seq_len) | |
| b = base_shift - m * base_seq_len | |
| mu = image_seq_len * m + b | |
| return mu | |
| def calc_v_flux( | |
| pipe: FluxPipeline, latents: torch.Tensor, prompt_embeds: torch.Tensor, | |
| pooled_prompt_embeds: torch.Tensor, guidance: torch.Tensor, | |
| text_ids: torch.Tensor, latent_image_ids: torch.Tensor, t: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Calculate the velocity (v) for FLUX. | |
| Args: | |
| pipe (FluxPipeline): The FLUX pipeline. | |
| latents (torch.Tensor): The latent tensor at the current timestep. | |
| prompt_embeds (torch.Tensor): The prompt embeddings. | |
| pooled_prompt_embeds (torch.Tensor): The pooled prompt embeddings. | |
| guidance (torch.Tensor): The guidance scale tensor. | |
| text_ids (torch.Tensor): The text token IDs. | |
| latent_image_ids (torch.Tensor): The latent image token IDs. | |
| t (torch.Tensor): The current timestep. | |
| Returns: | |
| torch.Tensor: The predicted noise (velocity). | |
| """ | |
| timestep = t.expand(latents.shape[0]) | |
| noise_pred = pipe.transformer( | |
| hidden_states=latents, | |
| timestep=timestep / 1000, | |
| guidance=guidance, | |
| encoder_hidden_states=prompt_embeds, | |
| txt_ids=text_ids, | |
| img_ids=latent_image_ids, | |
| pooled_projections=pooled_prompt_embeds, | |
| joint_attention_kwargs=None, | |
| return_dict=False, | |
| )[0] | |
| return noise_pred | |
| def prep_input( | |
| pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler, | |
| T_steps: int, x0_src: torch.Tensor, src_prompt: str, | |
| src_guidance_scale: float, | |
| ) -> Tuple[ | |
| torch.Tensor, torch.Tensor, torch.Tensor, int, int, | |
| torch.Tensor, torch.Tensor, torch.Tensor, | |
| ]: | |
| """ | |
| Prepare the input tensors for the FLUX pipeline. | |
| Args: | |
| pipe (FluxPipeline): The FLUX pipeline. | |
| scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process. | |
| T_steps (int): The total number of timesteps for the diffusion process. | |
| x0_src (torch.Tensor): The source latent tensor. | |
| src_prompt (str): The source text prompt. | |
| src_guidance_scale (float): The guidance scale for classifier-free guidance. | |
| Returns: | |
| Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, torch.Tensor, torch.Tensor, torch.Tensor]: | |
| - Prepared source latent tensor. | |
| - Latent image token IDs. | |
| - Timesteps tensor for the diffusion process. | |
| - Original height of the input image. | |
| - Original width of the input image. | |
| - Source prompt embeddings. | |
| - Source pooled prompt embeddings. | |
| - Source text token IDs. | |
| """ | |
| orig_height, orig_width = x0_src.shape[2] * pipe.vae_scale_factor, x0_src.shape[3] * pipe.vae_scale_factor | |
| num_channels_latents = pipe.transformer.config.in_channels // 4 | |
| pipe.check_inputs( | |
| prompt=src_prompt, | |
| prompt_2=None, | |
| height=orig_height, | |
| width=orig_width, | |
| callback_on_step_end_tensor_inputs=None, | |
| max_sequence_length=512, | |
| ) | |
| x0_src, latent_src_image_ids = pipe.prepare_latents( | |
| batch_size=x0_src.shape[0], num_channels_latents=num_channels_latents, | |
| height=orig_height, width=orig_width, | |
| dtype=x0_src.dtype, device=x0_src.device, generator=None, latents=x0_src, | |
| ) | |
| x0_src = pipe._pack_latents(x0_src, x0_src.shape[0], num_channels_latents, x0_src.shape[2], x0_src.shape[3]) | |
| sigmas = np.linspace(1.0, 1 / T_steps, T_steps) | |
| image_seq_len = x0_src.shape[1] | |
| mu = calculate_shift( | |
| image_seq_len, | |
| scheduler.config.base_image_seq_len, | |
| scheduler.config.max_image_seq_len, | |
| scheduler.config.base_shift, | |
| scheduler.config.max_shift, | |
| ) | |
| timesteps, T_steps = retrieve_timesteps( | |
| scheduler, | |
| T_steps, | |
| x0_src.device, | |
| timesteps=None, | |
| sigmas=sigmas, | |
| mu=mu, | |
| ) | |
| pipe._num_timesteps = len(timesteps) | |
| pipe._guidance_scale = src_guidance_scale | |
| ( | |
| src_prompt_embeds, | |
| src_pooled_prompt_embeds, | |
| src_text_ids, | |
| ) = pipe.encode_prompt( | |
| prompt=src_prompt, | |
| prompt_2=None, | |
| device=x0_src.device, | |
| ) | |
| return ( | |
| x0_src, latent_src_image_ids, timesteps, orig_height, orig_width, | |
| src_prompt_embeds, src_pooled_prompt_embeds, src_text_ids | |
| ) | |
| # https://github.com/DSL-Lab/UniEdit-Flow | |
| def uniinv( | |
| pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler, | |
| timesteps: torch.Tensor, n_start: int, x0_src: torch.Tensor, | |
| src_prompt_embeds: torch.Tensor, src_pooled_prompt_embeds: torch.Tensor, | |
| src_guidance: torch.Tensor, src_text_ids: torch.Tensor, | |
| latent_src_image_ids: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Perform the UniInv inversion process for FLUX. | |
| Args: | |
| pipe (FluxPipeline): The FLUX pipeline. | |
| scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process. | |
| timesteps (torch.Tensor): The timesteps for the diffusion process. | |
| n_start (int): The number of initial timesteps to skip. | |
| x0_src (torch.Tensor): The source latent tensor. | |
| src_prompt_embeds (torch.Tensor): The source prompt embeddings. | |
| src_pooled_prompt_embeds (torch.Tensor): The source pooled prompt embeddings. | |
| src_guidance (torch.Tensor): The guidance scale tensor. | |
| src_text_ids (torch.Tensor): The source text token IDs. | |
| latent_src_image_ids (torch.Tensor): The latent image token IDs. | |
| Returns: | |
| torch.Tensor: The inverted latent tensor. | |
| """ | |
| x_t = x0_src.clone() | |
| timesteps_inv = timesteps.flip(dims=(0,))[:-n_start] if n_start > 0 else timesteps.flip(dims=(0,)) | |
| next_v = None | |
| for _i, t in enumerate(timesteps_inv): | |
| scheduler._init_step_index(t) | |
| t_i = scheduler.sigmas[scheduler.step_index] | |
| t_ip1 = scheduler.sigmas[scheduler.step_index + 1] | |
| dt = t_i - t_ip1 | |
| if next_v is None: | |
| v_tar = calc_v_flux( | |
| pipe, latents=x_t, prompt_embeds=src_prompt_embeds, | |
| pooled_prompt_embeds=src_pooled_prompt_embeds, guidance=src_guidance, | |
| text_ids=src_text_ids, latent_image_ids=latent_src_image_ids, t=t_ip1 * 1000, | |
| ) | |
| else: | |
| v_tar = next_v | |
| x_t = x_t.to(torch.float32) | |
| x_t_next = x_t + v_tar * dt | |
| x_t_next = x_t_next.to(pipe.dtype) | |
| v_tar_next = calc_v_flux( | |
| pipe, latents=x_t_next, prompt_embeds=src_prompt_embeds, | |
| pooled_prompt_embeds=src_pooled_prompt_embeds, guidance=src_guidance, | |
| text_ids=src_text_ids, latent_image_ids=latent_src_image_ids, t=t, | |
| ) | |
| next_v = v_tar_next | |
| x_t = x_t + v_tar_next * dt | |
| x_t = x_t.to(pipe.dtype) | |
| return x_t | |
| def initialization( | |
| pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler, | |
| T_steps: int, n_start: int, x0_src: torch.Tensor, src_prompt: str, | |
| src_guidance_scale: float, | |
| ) -> Tuple[ | |
| torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int, | |
| torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, | |
| ]: | |
| """ | |
| Initialize the inversion process by preparing the latent tensor and prompt embeddings, and performing UniInv. | |
| Args: | |
| pipe (FluxPipeline): The FLUX pipeline. | |
| scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process. | |
| T_steps (int): The total number of timesteps for the diffusion process. | |
| n_start (int): The number of initial timesteps to skip. | |
| x0_src (torch.Tensor): The source latent tensor. | |
| src_prompt (str): The source text prompt. | |
| src_guidance_scale (float): The guidance scale for classifier-free guidance. | |
| Returns: | |
| Tuple[ | |
| torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int, | |
| torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, | |
| ]: | |
| - The inverted latent tensor. | |
| - The source latent tensor. | |
| - The timesteps for the diffusion process. | |
| - The latent image token IDs. | |
| - The original height of the input image. | |
| - The original width of the input image. | |
| - The source prompt embeddings. | |
| - The source pooled prompt embeddings. | |
| - The source text token IDs. | |
| - The guidance scale tensor. | |
| """ | |
| ( | |
| x0_src, latent_src_image_ids, timesteps, orig_height, orig_width, | |
| src_prompt_embeds, src_pooled_prompt_embeds, src_text_ids | |
| ) = prep_input(pipe, scheduler, T_steps, x0_src, src_prompt, src_guidance_scale) | |
| # handle guidance | |
| if pipe.transformer.config.guidance_embeds: | |
| src_guidance = torch.tensor([src_guidance_scale], device=pipe.device) | |
| src_guidance = src_guidance.expand(x0_src.shape[0]) | |
| else: | |
| src_guidance = None | |
| x_t = uniinv( | |
| pipe, scheduler, timesteps, n_start, x0_src, | |
| src_prompt_embeds, src_pooled_prompt_embeds, src_guidance, | |
| src_text_ids, latent_src_image_ids, | |
| ) | |
| return ( | |
| x_t, x0_src, timesteps, latent_src_image_ids, orig_height, orig_width, | |
| ) | |
| def flux_denoise( | |
| pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler, | |
| timesteps: torch.Tensor, n_start: int, x_t: torch.Tensor, | |
| prompt_embeds: torch.Tensor, pooled_prompt_embeds: torch.Tensor, | |
| guidance: torch.Tensor, text_ids: torch.Tensor, | |
| latent_image_ids: torch.Tensor, | |
| ) -> torch.Tensor: | |
| """ | |
| Perform the denoising process for FLUX. | |
| Args: | |
| pipe (FluxPipeline): The FLUX pipeline. | |
| scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process. | |
| timesteps (torch.Tensor): The timesteps for the diffusion process. | |
| n_start (int): The number of initial timesteps to skip. | |
| x_t (torch.Tensor): The latent tensor at the starting timestep. | |
| prompt_embeds (torch.Tensor): The prompt embeddings. | |
| pooled_prompt_embeds (torch.Tensor): The pooled prompt embeddings. | |
| guidance (torch.Tensor): The guidance scale tensor. | |
| text_ids (torch.Tensor): The text token IDs. | |
| latent_image_ids (torch.Tensor): The latent image token IDs. | |
| Returns: | |
| torch.Tensor: The denoised latent tensor. | |
| """ | |
| f_xt = x_t.clone() | |
| for _i, t in enumerate(timesteps[n_start:]): | |
| scheduler._init_step_index(t) | |
| t_i = scheduler.sigmas[scheduler.step_index] | |
| t_im1 = scheduler.sigmas[scheduler.step_index + 1] | |
| dt = t_im1 - t_i | |
| v_tar = calc_v_flux( | |
| pipe, latents=f_xt, prompt_embeds=prompt_embeds, | |
| pooled_prompt_embeds=pooled_prompt_embeds, guidance=guidance, | |
| text_ids=text_ids, latent_image_ids=latent_image_ids, t=t, | |
| ) | |
| f_xt = f_xt.to(torch.float32) | |
| f_xt = f_xt + v_tar * dt | |
| f_xt = f_xt.to(pipe.dtype) | |
| return f_xt | |
| def flux_editing( | |
| pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler, | |
| T_steps: int, n_max: int, x0_src: torch.Tensor, src_prompt: str, | |
| tar_prompt: str, src_guidance_scale: float, tar_guidance_scale: float, | |
| flowopt_iterations: int, eta: float, | |
| ) -> Iterator[List[Tuple[Image.Image, str]]]: | |
| """ | |
| Perform the editing process for FLUX using FlowOpt. | |
| Args: | |
| pipe (FluxPipeline): The FLUX pipeline. | |
| scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process. | |
| T_steps (int): The total number of timesteps for the diffusion process. | |
| n_max (int): The maximum number of timesteps to consider. | |
| x0_src (torch.Tensor): The source latent tensor. | |
| src_prompt (str): The source text prompt. | |
| tar_prompt (str): The target text prompt for editing. | |
| src_guidance_scale (float): The guidance scale for the source prompt. | |
| tar_guidance_scale (float): The guidance scale for the target prompt. | |
| flowopt_iterations (int): The number of FlowOpt iterations to perform. | |
| eta (float): The step size for the FlowOpt update. | |
| Yields: | |
| Iterator[List[Tuple[Image.Image, str]]]: A list of tuples containing the generated images and their corresponding iteration labels. | |
| """ | |
| n_start = T_steps - n_max | |
| ( | |
| x_t, x0_src, timesteps, latent_src_image_ids, orig_height, orig_width, | |
| ) = initialization( | |
| pipe, scheduler, T_steps, n_start, x0_src, src_prompt, src_guidance_scale, | |
| ) | |
| pipe._guidance_scale = tar_guidance_scale | |
| ( | |
| tar_prompt_embeds, | |
| pooled_tar_prompt_embeds, | |
| tar_text_ids, | |
| ) = pipe.encode_prompt( | |
| prompt=tar_prompt, | |
| prompt_2=None, | |
| device=pipe.device, | |
| ) | |
| # handle guidance | |
| if pipe.transformer.config.guidance_embeds: | |
| tar_guidance = torch.tensor([tar_guidance_scale], device=pipe.device) | |
| tar_guidance = tar_guidance.expand(x0_src.shape[0]) | |
| else: | |
| tar_guidance = None | |
| history = [] | |
| j_star = x0_src.clone().to(torch.float32) # y | |
| for flowopt_iter in range(flowopt_iterations + 1): | |
| f_xt = flux_denoise( | |
| pipe, scheduler, timesteps, n_start, x_t, | |
| tar_prompt_embeds, pooled_tar_prompt_embeds, tar_guidance, | |
| tar_text_ids, latent_src_image_ids, | |
| ) # Eq. (3) | |
| if flowopt_iter < flowopt_iterations: | |
| x_t = x_t.to(torch.float32) | |
| x_t = x_t - eta * (f_xt - j_star) # Eq. (6) with c = c_tar | |
| x_t = x_t.to(x0_src.dtype) | |
| x0_flowopt = f_xt.clone() | |
| unpacked_x0_flowopt = pipe._unpack_latents(x0_flowopt, orig_height, orig_width, pipe.vae_scale_factor) | |
| x0_flowopt_denorm = (unpacked_x0_flowopt / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor | |
| with torch.autocast("cuda"), torch.inference_mode(): | |
| x0_flowopt_image = pipe.vae.decode(x0_flowopt_denorm, return_dict=False)[0].clamp(-1, 1) | |
| x0_flowopt_image_pil = pipe.image_processor.postprocess(x0_flowopt_image)[0] | |
| history.append((x0_flowopt_image_pil, f"Iteration {flowopt_iter}")) | |
| yield history | |