FlowOpt / utils /flux.py
orronai's picture
feat: add application files
8d5a128
from typing import Iterator, List, Tuple
import numpy as np
import torch
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
from PIL import Image
@torch.no_grad()
def calculate_shift(
image_seq_len: int,
base_seq_len: int = 256,
max_seq_len: int = 4096,
base_shift: float = 0.5,
max_shift: float = 1.16,
) -> float:
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
b = base_shift - m * base_seq_len
mu = image_seq_len * m + b
return mu
@torch.no_grad()
def calc_v_flux(
pipe: FluxPipeline, latents: torch.Tensor, prompt_embeds: torch.Tensor,
pooled_prompt_embeds: torch.Tensor, guidance: torch.Tensor,
text_ids: torch.Tensor, latent_image_ids: torch.Tensor, t: torch.Tensor,
) -> torch.Tensor:
"""
Calculate the velocity (v) for FLUX.
Args:
pipe (FluxPipeline): The FLUX pipeline.
latents (torch.Tensor): The latent tensor at the current timestep.
prompt_embeds (torch.Tensor): The prompt embeddings.
pooled_prompt_embeds (torch.Tensor): The pooled prompt embeddings.
guidance (torch.Tensor): The guidance scale tensor.
text_ids (torch.Tensor): The text token IDs.
latent_image_ids (torch.Tensor): The latent image token IDs.
t (torch.Tensor): The current timestep.
Returns:
torch.Tensor: The predicted noise (velocity).
"""
timestep = t.expand(latents.shape[0])
noise_pred = pipe.transformer(
hidden_states=latents,
timestep=timestep / 1000,
guidance=guidance,
encoder_hidden_states=prompt_embeds,
txt_ids=text_ids,
img_ids=latent_image_ids,
pooled_projections=pooled_prompt_embeds,
joint_attention_kwargs=None,
return_dict=False,
)[0]
return noise_pred
@torch.no_grad()
def prep_input(
pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
T_steps: int, x0_src: torch.Tensor, src_prompt: str,
src_guidance_scale: float,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, int, int,
torch.Tensor, torch.Tensor, torch.Tensor,
]:
"""
Prepare the input tensors for the FLUX pipeline.
Args:
pipe (FluxPipeline): The FLUX pipeline.
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
T_steps (int): The total number of timesteps for the diffusion process.
x0_src (torch.Tensor): The source latent tensor.
src_prompt (str): The source text prompt.
src_guidance_scale (float): The guidance scale for classifier-free guidance.
Returns:
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, torch.Tensor, torch.Tensor, torch.Tensor]:
- Prepared source latent tensor.
- Latent image token IDs.
- Timesteps tensor for the diffusion process.
- Original height of the input image.
- Original width of the input image.
- Source prompt embeddings.
- Source pooled prompt embeddings.
- Source text token IDs.
"""
orig_height, orig_width = x0_src.shape[2] * pipe.vae_scale_factor, x0_src.shape[3] * pipe.vae_scale_factor
num_channels_latents = pipe.transformer.config.in_channels // 4
pipe.check_inputs(
prompt=src_prompt,
prompt_2=None,
height=orig_height,
width=orig_width,
callback_on_step_end_tensor_inputs=None,
max_sequence_length=512,
)
x0_src, latent_src_image_ids = pipe.prepare_latents(
batch_size=x0_src.shape[0], num_channels_latents=num_channels_latents,
height=orig_height, width=orig_width,
dtype=x0_src.dtype, device=x0_src.device, generator=None, latents=x0_src,
)
x0_src = pipe._pack_latents(x0_src, x0_src.shape[0], num_channels_latents, x0_src.shape[2], x0_src.shape[3])
sigmas = np.linspace(1.0, 1 / T_steps, T_steps)
image_seq_len = x0_src.shape[1]
mu = calculate_shift(
image_seq_len,
scheduler.config.base_image_seq_len,
scheduler.config.max_image_seq_len,
scheduler.config.base_shift,
scheduler.config.max_shift,
)
timesteps, T_steps = retrieve_timesteps(
scheduler,
T_steps,
x0_src.device,
timesteps=None,
sigmas=sigmas,
mu=mu,
)
pipe._num_timesteps = len(timesteps)
pipe._guidance_scale = src_guidance_scale
(
src_prompt_embeds,
src_pooled_prompt_embeds,
src_text_ids,
) = pipe.encode_prompt(
prompt=src_prompt,
prompt_2=None,
device=x0_src.device,
)
return (
x0_src, latent_src_image_ids, timesteps, orig_height, orig_width,
src_prompt_embeds, src_pooled_prompt_embeds, src_text_ids
)
# https://github.com/DSL-Lab/UniEdit-Flow
@torch.no_grad()
def uniinv(
pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
timesteps: torch.Tensor, n_start: int, x0_src: torch.Tensor,
src_prompt_embeds: torch.Tensor, src_pooled_prompt_embeds: torch.Tensor,
src_guidance: torch.Tensor, src_text_ids: torch.Tensor,
latent_src_image_ids: torch.Tensor,
) -> torch.Tensor:
"""
Perform the UniInv inversion process for FLUX.
Args:
pipe (FluxPipeline): The FLUX pipeline.
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
timesteps (torch.Tensor): The timesteps for the diffusion process.
n_start (int): The number of initial timesteps to skip.
x0_src (torch.Tensor): The source latent tensor.
src_prompt_embeds (torch.Tensor): The source prompt embeddings.
src_pooled_prompt_embeds (torch.Tensor): The source pooled prompt embeddings.
src_guidance (torch.Tensor): The guidance scale tensor.
src_text_ids (torch.Tensor): The source text token IDs.
latent_src_image_ids (torch.Tensor): The latent image token IDs.
Returns:
torch.Tensor: The inverted latent tensor.
"""
x_t = x0_src.clone()
timesteps_inv = timesteps.flip(dims=(0,))[:-n_start] if n_start > 0 else timesteps.flip(dims=(0,))
next_v = None
for _i, t in enumerate(timesteps_inv):
scheduler._init_step_index(t)
t_i = scheduler.sigmas[scheduler.step_index]
t_ip1 = scheduler.sigmas[scheduler.step_index + 1]
dt = t_i - t_ip1
if next_v is None:
v_tar = calc_v_flux(
pipe, latents=x_t, prompt_embeds=src_prompt_embeds,
pooled_prompt_embeds=src_pooled_prompt_embeds, guidance=src_guidance,
text_ids=src_text_ids, latent_image_ids=latent_src_image_ids, t=t_ip1 * 1000,
)
else:
v_tar = next_v
x_t = x_t.to(torch.float32)
x_t_next = x_t + v_tar * dt
x_t_next = x_t_next.to(pipe.dtype)
v_tar_next = calc_v_flux(
pipe, latents=x_t_next, prompt_embeds=src_prompt_embeds,
pooled_prompt_embeds=src_pooled_prompt_embeds, guidance=src_guidance,
text_ids=src_text_ids, latent_image_ids=latent_src_image_ids, t=t,
)
next_v = v_tar_next
x_t = x_t + v_tar_next * dt
x_t = x_t.to(pipe.dtype)
return x_t
@torch.no_grad()
def initialization(
pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
T_steps: int, n_start: int, x0_src: torch.Tensor, src_prompt: str,
src_guidance_scale: float,
) -> Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int,
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
]:
"""
Initialize the inversion process by preparing the latent tensor and prompt embeddings, and performing UniInv.
Args:
pipe (FluxPipeline): The FLUX pipeline.
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
T_steps (int): The total number of timesteps for the diffusion process.
n_start (int): The number of initial timesteps to skip.
x0_src (torch.Tensor): The source latent tensor.
src_prompt (str): The source text prompt.
src_guidance_scale (float): The guidance scale for classifier-free guidance.
Returns:
Tuple[
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int,
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
]:
- The inverted latent tensor.
- The source latent tensor.
- The timesteps for the diffusion process.
- The latent image token IDs.
- The original height of the input image.
- The original width of the input image.
- The source prompt embeddings.
- The source pooled prompt embeddings.
- The source text token IDs.
- The guidance scale tensor.
"""
(
x0_src, latent_src_image_ids, timesteps, orig_height, orig_width,
src_prompt_embeds, src_pooled_prompt_embeds, src_text_ids
) = prep_input(pipe, scheduler, T_steps, x0_src, src_prompt, src_guidance_scale)
# handle guidance
if pipe.transformer.config.guidance_embeds:
src_guidance = torch.tensor([src_guidance_scale], device=pipe.device)
src_guidance = src_guidance.expand(x0_src.shape[0])
else:
src_guidance = None
x_t = uniinv(
pipe, scheduler, timesteps, n_start, x0_src,
src_prompt_embeds, src_pooled_prompt_embeds, src_guidance,
src_text_ids, latent_src_image_ids,
)
return (
x_t, x0_src, timesteps, latent_src_image_ids, orig_height, orig_width,
)
@torch.no_grad()
def flux_denoise(
pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
timesteps: torch.Tensor, n_start: int, x_t: torch.Tensor,
prompt_embeds: torch.Tensor, pooled_prompt_embeds: torch.Tensor,
guidance: torch.Tensor, text_ids: torch.Tensor,
latent_image_ids: torch.Tensor,
) -> torch.Tensor:
"""
Perform the denoising process for FLUX.
Args:
pipe (FluxPipeline): The FLUX pipeline.
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
timesteps (torch.Tensor): The timesteps for the diffusion process.
n_start (int): The number of initial timesteps to skip.
x_t (torch.Tensor): The latent tensor at the starting timestep.
prompt_embeds (torch.Tensor): The prompt embeddings.
pooled_prompt_embeds (torch.Tensor): The pooled prompt embeddings.
guidance (torch.Tensor): The guidance scale tensor.
text_ids (torch.Tensor): The text token IDs.
latent_image_ids (torch.Tensor): The latent image token IDs.
Returns:
torch.Tensor: The denoised latent tensor.
"""
f_xt = x_t.clone()
for _i, t in enumerate(timesteps[n_start:]):
scheduler._init_step_index(t)
t_i = scheduler.sigmas[scheduler.step_index]
t_im1 = scheduler.sigmas[scheduler.step_index + 1]
dt = t_im1 - t_i
v_tar = calc_v_flux(
pipe, latents=f_xt, prompt_embeds=prompt_embeds,
pooled_prompt_embeds=pooled_prompt_embeds, guidance=guidance,
text_ids=text_ids, latent_image_ids=latent_image_ids, t=t,
)
f_xt = f_xt.to(torch.float32)
f_xt = f_xt + v_tar * dt
f_xt = f_xt.to(pipe.dtype)
return f_xt
@torch.no_grad()
def flux_editing(
pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
T_steps: int, n_max: int, x0_src: torch.Tensor, src_prompt: str,
tar_prompt: str, src_guidance_scale: float, tar_guidance_scale: float,
flowopt_iterations: int, eta: float,
) -> Iterator[List[Tuple[Image.Image, str]]]:
"""
Perform the editing process for FLUX using FlowOpt.
Args:
pipe (FluxPipeline): The FLUX pipeline.
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
T_steps (int): The total number of timesteps for the diffusion process.
n_max (int): The maximum number of timesteps to consider.
x0_src (torch.Tensor): The source latent tensor.
src_prompt (str): The source text prompt.
tar_prompt (str): The target text prompt for editing.
src_guidance_scale (float): The guidance scale for the source prompt.
tar_guidance_scale (float): The guidance scale for the target prompt.
flowopt_iterations (int): The number of FlowOpt iterations to perform.
eta (float): The step size for the FlowOpt update.
Yields:
Iterator[List[Tuple[Image.Image, str]]]: A list of tuples containing the generated images and their corresponding iteration labels.
"""
n_start = T_steps - n_max
(
x_t, x0_src, timesteps, latent_src_image_ids, orig_height, orig_width,
) = initialization(
pipe, scheduler, T_steps, n_start, x0_src, src_prompt, src_guidance_scale,
)
pipe._guidance_scale = tar_guidance_scale
(
tar_prompt_embeds,
pooled_tar_prompt_embeds,
tar_text_ids,
) = pipe.encode_prompt(
prompt=tar_prompt,
prompt_2=None,
device=pipe.device,
)
# handle guidance
if pipe.transformer.config.guidance_embeds:
tar_guidance = torch.tensor([tar_guidance_scale], device=pipe.device)
tar_guidance = tar_guidance.expand(x0_src.shape[0])
else:
tar_guidance = None
history = []
j_star = x0_src.clone().to(torch.float32) # y
for flowopt_iter in range(flowopt_iterations + 1):
f_xt = flux_denoise(
pipe, scheduler, timesteps, n_start, x_t,
tar_prompt_embeds, pooled_tar_prompt_embeds, tar_guidance,
tar_text_ids, latent_src_image_ids,
) # Eq. (3)
if flowopt_iter < flowopt_iterations:
x_t = x_t.to(torch.float32)
x_t = x_t - eta * (f_xt - j_star) # Eq. (6) with c = c_tar
x_t = x_t.to(x0_src.dtype)
x0_flowopt = f_xt.clone()
unpacked_x0_flowopt = pipe._unpack_latents(x0_flowopt, orig_height, orig_width, pipe.vae_scale_factor)
x0_flowopt_denorm = (unpacked_x0_flowopt / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
with torch.autocast("cuda"), torch.inference_mode():
x0_flowopt_image = pipe.vae.decode(x0_flowopt_denorm, return_dict=False)[0].clamp(-1, 1)
x0_flowopt_image_pil = pipe.image_processor.postprocess(x0_flowopt_image)[0]
history.append((x0_flowopt_image_pil, f"Iteration {flowopt_iter}"))
yield history