Spaces:
Running
on
Zero
Running
on
Zero
feat: add application files
Browse files- README.md +6 -9
- app.py +207 -0
- requirements.txt +6 -0
- utils/flux.py +374 -0
- utils/sd3.py +264 -0
README.md
CHANGED
|
@@ -1,14 +1,11 @@
|
|
| 1 |
-
---
|
| 2 |
title: FlowOpt
|
| 3 |
-
emoji:
|
| 4 |
-
colorFrom:
|
| 5 |
-
colorTo:
|
| 6 |
sdk: gradio
|
| 7 |
-
sdk_version: 5.
|
| 8 |
app_file: app.py
|
| 9 |
pinned: false
|
| 10 |
license: mit
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
|
|
|
|
| 1 |
title: FlowOpt
|
| 2 |
+
emoji: π
|
| 3 |
+
colorFrom: blue
|
| 4 |
+
colorTo: green
|
| 5 |
sdk: gradio
|
| 6 |
+
sdk_version: 5.8.0
|
| 7 |
app_file: app.py
|
| 8 |
pinned: false
|
| 9 |
license: mit
|
| 10 |
+
hf_oauth: true
|
| 11 |
+
short_description: 'FlowOpt Gradio: Fast-Optimization for Training-Free Editing.'
|
|
|
|
|
|
app.py
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import random
|
| 3 |
+
from typing import Tuple
|
| 4 |
+
|
| 5 |
+
import numpy as np
|
| 6 |
+
import spaces
|
| 7 |
+
import torch
|
| 8 |
+
from diffusers import FluxPipeline, StableDiffusion3Pipeline
|
| 9 |
+
from PIL import Image
|
| 10 |
+
|
| 11 |
+
import gradio as gr
|
| 12 |
+
from utils.flux import flux_editing
|
| 13 |
+
from utils.sd3 import sd3_editing
|
| 14 |
+
|
| 15 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 16 |
+
|
| 17 |
+
pipe_sd3 = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16, token=os.getenv('HF_ACCESS_TOK'))
|
| 18 |
+
pipe_flux = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16, token=os.getenv('HF_ACCESS_TOK'))
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def seed_everything(seed: int) -> None:
|
| 22 |
+
"""
|
| 23 |
+
Set the random seed for reproducibility.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
seed (int): The seed value to set.
|
| 27 |
+
"""
|
| 28 |
+
random.seed(seed)
|
| 29 |
+
np.random.seed(seed)
|
| 30 |
+
torch.manual_seed(seed)
|
| 31 |
+
torch.cuda.manual_seed_all(seed)
|
| 32 |
+
|
| 33 |
+
def on_T_steps_change(T_steps: int) -> gr.update:
|
| 34 |
+
"""
|
| 35 |
+
Update the maximum value of the n_max slider based on the T_steps value.
|
| 36 |
+
|
| 37 |
+
Args:
|
| 38 |
+
T_steps (int): The current value of the T_steps slider.
|
| 39 |
+
Returns:
|
| 40 |
+
gr.update: An update object to modify the n_max slider's maximum value.
|
| 41 |
+
"""
|
| 42 |
+
return gr.update(maximum=T_steps)
|
| 43 |
+
|
| 44 |
+
def on_model_change(model_type: str) -> Tuple[int, int, float]:
|
| 45 |
+
if model_type == 'SD3':
|
| 46 |
+
T_steps_value = 15
|
| 47 |
+
n_max_value = 12
|
| 48 |
+
eta_value = 0.01
|
| 49 |
+
elif model_type == 'FLUX':
|
| 50 |
+
T_steps_value = 15
|
| 51 |
+
n_max_value = 13
|
| 52 |
+
eta_value = 0.0025
|
| 53 |
+
else:
|
| 54 |
+
raise NotImplementedError(f"Model type {model_type} not implemented")
|
| 55 |
+
|
| 56 |
+
return T_steps_value, n_max_value, eta_value
|
| 57 |
+
|
| 58 |
+
def get_examples():
|
| 59 |
+
case = [
|
| 60 |
+
["inputs/corgi_walking.png", "FLUX", 15, 13, 0.0025, 7, "A cute brown and white dog walking on a sidewalk near a body of water. The dog is wearing a pink vest, adding a touch of color to the scene.", "A cute brown and white guinea pig walking on a sidewalk near a body of water. The guinea pig is wearing a pink vest, adding a touch of color to the scene.", 1.0, 3.5, [(f"example_outputs/corgi_walking/guinea_pig/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
|
| 61 |
+
["inputs/corgi_walking.png", "SD3", 15, 12, 0.01, 7, "A cute brown and white dog walking on a sidewalk near a body of water. The dog is wearing a pink vest, adding a touch of color to the scene.", "A cute brown and white rabbit walking on a sidewalk near a body of water. The rabbit is wearing a pink vest, adding a touch of color to the scene.", 1.0, 3.5, [(f"example_outputs/corgi_walking/rabbit/sd3_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
|
| 62 |
+
["inputs/puppies.png", "FLUX", 15, 13, 0.0025, 7, "Two adorable golden retriever puppies sitting in a grassy field. They are positioned close to each other, with one dog on the left and the other on the right. Both dogs have their mouths open, possibly panting.", "Two adorable crochet golden retriever puppies sitting in a grassy field. They are positioned close to each other, with one dog on the left and the other on the right. Both dogs have their mouths open, possibly panting or enjoying the outdoor environment.", 1.0, 3.5, [(f"example_outputs/puppies/crochet/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
|
| 63 |
+
["inputs/puppies.png", "SD3", 15, 12, 0.01, 5, "Two adorable golden retriever puppies sitting in a grassy field. They are positioned close to each other, with one dog on the left and the other on the right. Both dogs have their mouths open, possibly panting.", "Two adorable husky puppies sitting in a grassy field. They are positioned close to each other, with one dog on the left and the other on the right. Both dogs have their mouths open, possibly panting or enjoying the outdoor environment.", 1.0, 3.5, [(f"example_outputs/puppies/husky/sd3_iterations={i}.png", f"Iteration {i}") for i in range(6)]],
|
| 64 |
+
["inputs/iguana.png", "FLUX", 15, 13, 0.0025, 7, "A large orange lizard sitting on a rock near the ocean. The lizard is positioned in the center of the scene, with the ocean waves visible in the background. The rock is located close to the water, providing a picturesque setting for the lizard''s resting spot.", "A large crochet lizard sitting on a rock near the ocean. The lizard is positioned in the center of the scene, with the ocean waves visible in the background. The rock is located close to the water, providing a picturesque setting for the lizard''s resting spot.", 1.0, 3.5, [(f"example_outputs/iguana/crochet/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
|
| 65 |
+
["inputs/iguana.png", "FLUX", 15, 13, 0.0025, 7, "A large orange lizard sitting on a rock near the ocean. The lizard is positioned in the center of the scene, with the ocean waves visible in the background. The rock is located close to the water, providing a picturesque setting for the lizard''s resting spot.", "A large lizard made out of lego bricks sitting on a rock near the ocean. The lizard is positioned in the center of the scene, with the ocean waves visible in the background. The rock is located close to the water, providing a picturesque setting for the lizard''s resting spot.", 1.0, 3.5, [(f"example_outputs/iguana/lego_bricks/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
|
| 66 |
+
["inputs/cow_grass2.png", "FLUX", 15, 12, 0.0025, 6, "A large brown and white cow standing in a grassy field. The cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", "A large cow made out of colorful toy bricks standing in a grassy field. The colorful toy brick cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", 1.0, 3.5, [(f"example_outputs/cow_grass2/colorful_toy_bricks/flux_iterations={i}.png", f"Iteration {i}") for i in range(7)]],
|
| 67 |
+
["inputs/cow_grass2.png", "FLUX", 15, 13, 0.0025, 5, "A large brown and white cow standing in a grassy field. The cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", "A large cow made out of flowers standing in a grassy field. The flower cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", 1.0, 3.5, [(f"example_outputs/cow_grass2/flowers/flux_iterations={i}.png", f"Iteration {i}") for i in range(6)]],
|
| 68 |
+
["inputs/cow_grass2.png", "SD3", 15, 12, 0.01, 8, "A large brown and white cow standing in a grassy field. The cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", "A large cow made out of wooden blocks standing in a grassy field. The wooden block cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", 1.0, 3.5, [(f"example_outputs/cow_grass2/wooden_blocks/sd3_iterations={i}.png", f"Iteration {i}") for i in range(9)]],
|
| 69 |
+
["inputs/cat_fridge.png", "SD3", 15, 12, 0.01, 8, "A cat sitting on top of a counter in a store. The cat is positioned towards the right side of the counter, and it appears to be looking at the camera. The store has a variety of items displayed, including several bottles scattered around the counter.", "A cat sitting on top of a counter in a store, with the cat and counter crafted using origami folded paper art techniques. The cat has a delicate and intricate appearance, with paper folds used to create its fur and features. The store has a variety of items displayed, including several bottles scattered around the counter.", 1.0, 3.5, [(f"example_outputs/cat_fridge/origami/sd3_iterations={i}.png", f"Iteration {i}") for i in range(9)]],
|
| 70 |
+
["inputs/cat.png", "FLUX", 15, 13, 0.0025, 7, "A small, fluffy kitten sitting in a grassy field. The kitten is positioned in the center of the scene, surrounded by a field. The kitten appears to be looking at something in the field.", "A small bear cub sitting in a grassy field. The bear cub is positioned in the center of the scene, surrounded by a field. The bear cub appears to be looking at something in the field.", 1.0, 3.5, [(f"example_outputs/cat/bear/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
|
| 71 |
+
["inputs/cat.png", "SD3", 15, 12, 0.01, 6, "A small, fluffy kitten sitting in a grassy field. The kitten is positioned in the center of the scene, surrounded by a field. The kitten appears to be looking at something in the field.", "A small puppy sitting in a grassy field. The puppy is positioned in the center of the scene, surrounded by a field. The puppy appears to be looking at something in the field.", 1.0, 3.5, [(f"example_outputs/cat/puppy/sd3_iterations={i}.png", f"Iteration {i}") for i in range(7)]],
|
| 72 |
+
["inputs/wolf_grass.png", "FLUX", 15, 13, 0.0025, 7, "A wolf standing in a grassy field with yellow flowers. The wolf is positioned towards the center of the scene, and its body is facing the camera. The field is filled with grass, and the yellow flowers are scattered throughout the area.", "A fox standing in a grassy field with yellow flowers. The fox is positioned towards the center of the scene, and its body is facing the camera. The field is filled with grass, and the yellow flowers are scattered throughout the area.", 1.0, 3.5, [(f"example_outputs/wolf_grass/fox/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
|
| 73 |
+
["inputs/wolf_grass.png", "SD3", 15, 12, 0.01, 4, "A wolf standing in a grassy field with yellow flowers. The wolf is positioned towards the center of the scene, and its body is facing the camera. The field is filled with grass, and the yellow flowers are scattered throughout the area.", "A baby deer standing in a grassy field with yellow flowers. The baby deer is positioned towards the center of the scene, and its body is facing the camera. The field is filled with grass, and the yellow flowers are scattered throughout the area.", 1.0, 3.5, [(f"example_outputs/wolf_grass/deer/sd3_iterations={i}.png", f"Iteration {i}") for i in range(5)]],
|
| 74 |
+
]
|
| 75 |
+
return case
|
| 76 |
+
|
| 77 |
+
@spaces.GPU(duration=200)
|
| 78 |
+
def FlowOpt_run(
|
| 79 |
+
image_src_val: str, model_type_val: str, T_steps_val: int,
|
| 80 |
+
n_max_val: int, eta_val: float, flowopt_iterations_val: int,
|
| 81 |
+
src_prompt_val: str, tar_prompt_val: str,
|
| 82 |
+
src_guidance_scale_val: float, tar_guidance_scale_val: float,
|
| 83 |
+
):
|
| 84 |
+
if not len(src_prompt_val):
|
| 85 |
+
raise gr.Error("Source prompt cannot be empty")
|
| 86 |
+
if not len(tar_prompt_val):
|
| 87 |
+
raise gr.Error("Target prompt cannot be empty")
|
| 88 |
+
|
| 89 |
+
if model_type_val == 'FLUX':
|
| 90 |
+
pipe = pipe_flux.to(device)
|
| 91 |
+
elif model_type_val == 'SD3':
|
| 92 |
+
pipe = pipe_sd3.to(device)
|
| 93 |
+
else:
|
| 94 |
+
raise NotImplementedError(f"Model type {model_type_val} not implemented")
|
| 95 |
+
|
| 96 |
+
scheduler = pipe.scheduler
|
| 97 |
+
|
| 98 |
+
# set seed
|
| 99 |
+
seed = 1024
|
| 100 |
+
seed_everything(seed)
|
| 101 |
+
# load image
|
| 102 |
+
image = Image.open(image_src_val)
|
| 103 |
+
# crop image to have both dimensions divisibe by 16 - avoids issues with resizing
|
| 104 |
+
image = image.crop((0, 0, image.width - image.width % 16, image.height - image.height % 16))
|
| 105 |
+
image_src_val = pipe.image_processor.preprocess(image)
|
| 106 |
+
|
| 107 |
+
# cast image to half precision
|
| 108 |
+
image_src_val = image_src_val.to(device).half()
|
| 109 |
+
with torch.autocast("cuda"), torch.inference_mode():
|
| 110 |
+
x0_src_denorm = pipe.vae.encode(image_src_val).latent_dist.mode()
|
| 111 |
+
x0_src = (x0_src_denorm - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
|
| 112 |
+
# send to cuda
|
| 113 |
+
x0_src = x0_src.to(device)
|
| 114 |
+
negative_prompt = "" # (SD3)
|
| 115 |
+
|
| 116 |
+
if model_type_val == 'SD3':
|
| 117 |
+
yield from sd3_editing(
|
| 118 |
+
pipe, scheduler, T_steps_val, n_max_val, x0_src,
|
| 119 |
+
src_prompt_val, tar_prompt_val, negative_prompt,
|
| 120 |
+
src_guidance_scale_val, tar_guidance_scale_val,
|
| 121 |
+
flowopt_iterations_val, eta_val,
|
| 122 |
+
)
|
| 123 |
+
elif model_type_val == 'FLUX':
|
| 124 |
+
yield from flux_editing(
|
| 125 |
+
pipe, scheduler, T_steps_val, n_max_val, x0_src,
|
| 126 |
+
src_prompt_val, tar_prompt_val,
|
| 127 |
+
src_guidance_scale_val, tar_guidance_scale_val,
|
| 128 |
+
flowopt_iterations_val, eta_val,
|
| 129 |
+
)
|
| 130 |
+
else:
|
| 131 |
+
raise NotImplementedError(f"Sampler type {model_type_val} not implemented")
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
intro = """
|
| 135 |
+
<h1 style="font-weight: 1000; text-align: center; margin: 0px;">FlowOpt: Fast Optimization Through Whole Flow Processes for Training-Free Editing</h1>
|
| 136 |
+
<h3 style="margin-bottom: 10px; text-align: center;">
|
| 137 |
+
<a href="">[Paper]</a> |
|
| 138 |
+
<a href="https://orronai.github.io/FlowOpt/">[Project Page]</a> |
|
| 139 |
+
<a href="https://github.com/orronai/FlowOpt">[Code]</a>
|
| 140 |
+
</h3>
|
| 141 |
+
<br> π¨ Edit your image using FlowOpt for Flow models! Upload an image, add a description of it, and specify the edits you want to make.
|
| 142 |
+
<h3>Notes:</h3>
|
| 143 |
+
<ol>
|
| 144 |
+
<li>We use FLUX.1 dev and SD3 for the demo. The models are large and may take a while to load.</li>
|
| 145 |
+
<li>We recommend 1024x1024 images for the best results. If the input images are too large, there may be out-of-memory errors. For other resolutions, we encourage you to find a suitable set of hyperparameters.</li>
|
| 146 |
+
<li>Default hyperparameters for each model used in the paper are provided as examples.</li>
|
| 147 |
+
</ol>
|
| 148 |
+
"""
|
| 149 |
+
|
| 150 |
+
css="""
|
| 151 |
+
#col-container {
|
| 152 |
+
margin: 0 auto;
|
| 153 |
+
max-width: 960px;
|
| 154 |
+
}
|
| 155 |
+
"""
|
| 156 |
+
with gr.Blocks(css=css) as demo:
|
| 157 |
+
with gr.Column(elem_id="col-container"):
|
| 158 |
+
gr.HTML(intro)
|
| 159 |
+
|
| 160 |
+
with gr.Row():
|
| 161 |
+
with gr.Column():
|
| 162 |
+
image_src = gr.Image(type="filepath", label="Source Image", value="inputs/cat.png",)
|
| 163 |
+
src_prompt = gr.Textbox(lines=2, label="Source Prompt", value="A cat sitting in the grass")
|
| 164 |
+
tar_prompt = gr.Textbox(lines=2, label="Target Prompt", value="A puppy sitting in the grass")
|
| 165 |
+
submit_button = gr.Button("Run FlowOpt", variant="primary")
|
| 166 |
+
|
| 167 |
+
with gr.Row():
|
| 168 |
+
model_type = gr.Dropdown(["FLUX", "SD3"], label="Model Type", value="FLUX")
|
| 169 |
+
T_steps = gr.Slider(value=15, minimum=10, maximum=50, step=1, label="Total Steps", info="Total number of discretization steps.")
|
| 170 |
+
n_max = gr.Slider(value=13, minimum=1, maximum=15, step=1, label="n_max", info="Control the strength of the edit.")
|
| 171 |
+
eta = gr.Slider(value=0.0025, minimum=0.0001, maximum=0.05, label="eta", info="Control the optimization step-size.")
|
| 172 |
+
flowopt_iterations = gr.Number(value=10, minimum=1, maximum=15, label="flowopt_iterations", info="Max number of FlowOpt iterations")
|
| 173 |
+
|
| 174 |
+
with gr.Column():
|
| 175 |
+
image_tar = gr.Gallery(
|
| 176 |
+
label="Outputs", show_label=True, format="png",
|
| 177 |
+
columns=[3], rows=[3], height="auto",
|
| 178 |
+
)
|
| 179 |
+
with gr.Accordion(label="Advanced Settings", open=False):
|
| 180 |
+
src_guidance_scale = gr.Slider(value=1.0, minimum=0.0, maximum=15.0, label="src_guidance_scale", info="Source prompt CFG scale.")
|
| 181 |
+
tar_guidance_scale = gr.Slider(value=3.5, minimum=1.0, maximum=15.0, label="tar_guidance_scale", info="Target prompt CFG scale.")
|
| 182 |
+
|
| 183 |
+
submit_button.click(
|
| 184 |
+
fn=FlowOpt_run,
|
| 185 |
+
inputs=[
|
| 186 |
+
image_src, model_type, T_steps, n_max, eta, flowopt_iterations,
|
| 187 |
+
src_prompt, tar_prompt, src_guidance_scale, tar_guidance_scale,
|
| 188 |
+
],
|
| 189 |
+
outputs=[image_tar],
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
gr.Examples(
|
| 193 |
+
label="Examples",
|
| 194 |
+
examples=get_examples(),
|
| 195 |
+
inputs=[
|
| 196 |
+
image_src, model_type, T_steps, n_max, eta,
|
| 197 |
+
flowopt_iterations, src_prompt, tar_prompt,
|
| 198 |
+
src_guidance_scale, tar_guidance_scale, image_tar,
|
| 199 |
+
],
|
| 200 |
+
outputs=[image_tar],
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
model_type.input(fn=on_model_change, inputs=[model_type], outputs=[T_steps, n_max, eta])
|
| 204 |
+
T_steps.change(fn=on_T_steps_change, inputs=[T_steps], outputs=[n_max])
|
| 205 |
+
|
| 206 |
+
demo.queue()
|
| 207 |
+
demo.launch()
|
requirements.txt
ADDED
|
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
diffusers
|
| 3 |
+
transformers
|
| 4 |
+
accelerate
|
| 5 |
+
sentencepiece
|
| 6 |
+
protobuf
|
utils/flux.py
ADDED
|
@@ -0,0 +1,374 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Iterator, List, Tuple
|
| 2 |
+
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline
|
| 6 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
|
| 7 |
+
from PIL import Image
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@torch.no_grad()
|
| 11 |
+
def calculate_shift(
|
| 12 |
+
image_seq_len: int,
|
| 13 |
+
base_seq_len: int = 256,
|
| 14 |
+
max_seq_len: int = 4096,
|
| 15 |
+
base_shift: float = 0.5,
|
| 16 |
+
max_shift: float = 1.16,
|
| 17 |
+
) -> float:
|
| 18 |
+
m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
|
| 19 |
+
b = base_shift - m * base_seq_len
|
| 20 |
+
mu = image_seq_len * m + b
|
| 21 |
+
return mu
|
| 22 |
+
|
| 23 |
+
@torch.no_grad()
|
| 24 |
+
def calc_v_flux(
|
| 25 |
+
pipe: FluxPipeline, latents: torch.Tensor, prompt_embeds: torch.Tensor,
|
| 26 |
+
pooled_prompt_embeds: torch.Tensor, guidance: torch.Tensor,
|
| 27 |
+
text_ids: torch.Tensor, latent_image_ids: torch.Tensor, t: torch.Tensor,
|
| 28 |
+
) -> torch.Tensor:
|
| 29 |
+
"""
|
| 30 |
+
Calculate the velocity (v) for FLUX.
|
| 31 |
+
|
| 32 |
+
Args:
|
| 33 |
+
pipe (FluxPipeline): The FLUX pipeline.
|
| 34 |
+
latents (torch.Tensor): The latent tensor at the current timestep.
|
| 35 |
+
prompt_embeds (torch.Tensor): The prompt embeddings.
|
| 36 |
+
pooled_prompt_embeds (torch.Tensor): The pooled prompt embeddings.
|
| 37 |
+
guidance (torch.Tensor): The guidance scale tensor.
|
| 38 |
+
text_ids (torch.Tensor): The text token IDs.
|
| 39 |
+
latent_image_ids (torch.Tensor): The latent image token IDs.
|
| 40 |
+
t (torch.Tensor): The current timestep.
|
| 41 |
+
Returns:
|
| 42 |
+
torch.Tensor: The predicted noise (velocity).
|
| 43 |
+
"""
|
| 44 |
+
timestep = t.expand(latents.shape[0])
|
| 45 |
+
|
| 46 |
+
noise_pred = pipe.transformer(
|
| 47 |
+
hidden_states=latents,
|
| 48 |
+
timestep=timestep / 1000,
|
| 49 |
+
guidance=guidance,
|
| 50 |
+
encoder_hidden_states=prompt_embeds,
|
| 51 |
+
txt_ids=text_ids,
|
| 52 |
+
img_ids=latent_image_ids,
|
| 53 |
+
pooled_projections=pooled_prompt_embeds,
|
| 54 |
+
joint_attention_kwargs=None,
|
| 55 |
+
return_dict=False,
|
| 56 |
+
)[0]
|
| 57 |
+
|
| 58 |
+
return noise_pred
|
| 59 |
+
|
| 60 |
+
@torch.no_grad()
|
| 61 |
+
def prep_input(
|
| 62 |
+
pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
|
| 63 |
+
T_steps: int, x0_src: torch.Tensor, src_prompt: str,
|
| 64 |
+
src_guidance_scale: float,
|
| 65 |
+
) -> Tuple[
|
| 66 |
+
torch.Tensor, torch.Tensor, torch.Tensor, int, int,
|
| 67 |
+
torch.Tensor, torch.Tensor, torch.Tensor,
|
| 68 |
+
]:
|
| 69 |
+
"""
|
| 70 |
+
Prepare the input tensors for the FLUX pipeline.
|
| 71 |
+
|
| 72 |
+
Args:
|
| 73 |
+
pipe (FluxPipeline): The FLUX pipeline.
|
| 74 |
+
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
|
| 75 |
+
T_steps (int): The total number of timesteps for the diffusion process.
|
| 76 |
+
x0_src (torch.Tensor): The source latent tensor.
|
| 77 |
+
src_prompt (str): The source text prompt.
|
| 78 |
+
src_guidance_scale (float): The guidance scale for classifier-free guidance.
|
| 79 |
+
Returns:
|
| 80 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 81 |
+
- Prepared source latent tensor.
|
| 82 |
+
- Latent image token IDs.
|
| 83 |
+
- Timesteps tensor for the diffusion process.
|
| 84 |
+
- Original height of the input image.
|
| 85 |
+
- Original width of the input image.
|
| 86 |
+
- Source prompt embeddings.
|
| 87 |
+
- Source pooled prompt embeddings.
|
| 88 |
+
- Source text token IDs.
|
| 89 |
+
"""
|
| 90 |
+
orig_height, orig_width = x0_src.shape[2] * pipe.vae_scale_factor, x0_src.shape[3] * pipe.vae_scale_factor
|
| 91 |
+
num_channels_latents = pipe.transformer.config.in_channels // 4
|
| 92 |
+
|
| 93 |
+
pipe.check_inputs(
|
| 94 |
+
prompt=src_prompt,
|
| 95 |
+
prompt_2=None,
|
| 96 |
+
height=orig_height,
|
| 97 |
+
width=orig_width,
|
| 98 |
+
callback_on_step_end_tensor_inputs=None,
|
| 99 |
+
max_sequence_length=512,
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
x0_src, latent_src_image_ids = pipe.prepare_latents(
|
| 103 |
+
batch_size=x0_src.shape[0], num_channels_latents=num_channels_latents,
|
| 104 |
+
height=orig_height, width=orig_width,
|
| 105 |
+
dtype=x0_src.dtype, device=x0_src.device, generator=None, latents=x0_src,
|
| 106 |
+
)
|
| 107 |
+
x0_src = pipe._pack_latents(x0_src, x0_src.shape[0], num_channels_latents, x0_src.shape[2], x0_src.shape[3])
|
| 108 |
+
|
| 109 |
+
sigmas = np.linspace(1.0, 1 / T_steps, T_steps)
|
| 110 |
+
image_seq_len = x0_src.shape[1]
|
| 111 |
+
mu = calculate_shift(
|
| 112 |
+
image_seq_len,
|
| 113 |
+
scheduler.config.base_image_seq_len,
|
| 114 |
+
scheduler.config.max_image_seq_len,
|
| 115 |
+
scheduler.config.base_shift,
|
| 116 |
+
scheduler.config.max_shift,
|
| 117 |
+
)
|
| 118 |
+
timesteps, T_steps = retrieve_timesteps(
|
| 119 |
+
scheduler,
|
| 120 |
+
T_steps,
|
| 121 |
+
x0_src.device,
|
| 122 |
+
timesteps=None,
|
| 123 |
+
sigmas=sigmas,
|
| 124 |
+
mu=mu,
|
| 125 |
+
)
|
| 126 |
+
pipe._num_timesteps = len(timesteps)
|
| 127 |
+
|
| 128 |
+
pipe._guidance_scale = src_guidance_scale
|
| 129 |
+
(
|
| 130 |
+
src_prompt_embeds,
|
| 131 |
+
src_pooled_prompt_embeds,
|
| 132 |
+
src_text_ids,
|
| 133 |
+
) = pipe.encode_prompt(
|
| 134 |
+
prompt=src_prompt,
|
| 135 |
+
prompt_2=None,
|
| 136 |
+
device=x0_src.device,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
return (
|
| 140 |
+
x0_src, latent_src_image_ids, timesteps, orig_height, orig_width,
|
| 141 |
+
src_prompt_embeds, src_pooled_prompt_embeds, src_text_ids
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# https://github.com/DSL-Lab/UniEdit-Flow
|
| 145 |
+
@torch.no_grad()
|
| 146 |
+
def uniinv(
|
| 147 |
+
pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
|
| 148 |
+
timesteps: torch.Tensor, n_start: int, x0_src: torch.Tensor,
|
| 149 |
+
src_prompt_embeds: torch.Tensor, src_pooled_prompt_embeds: torch.Tensor,
|
| 150 |
+
src_guidance: torch.Tensor, src_text_ids: torch.Tensor,
|
| 151 |
+
latent_src_image_ids: torch.Tensor,
|
| 152 |
+
) -> torch.Tensor:
|
| 153 |
+
"""
|
| 154 |
+
Perform the UniInv inversion process for FLUX.
|
| 155 |
+
|
| 156 |
+
Args:
|
| 157 |
+
pipe (FluxPipeline): The FLUX pipeline.
|
| 158 |
+
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
|
| 159 |
+
timesteps (torch.Tensor): The timesteps for the diffusion process.
|
| 160 |
+
n_start (int): The number of initial timesteps to skip.
|
| 161 |
+
x0_src (torch.Tensor): The source latent tensor.
|
| 162 |
+
src_prompt_embeds (torch.Tensor): The source prompt embeddings.
|
| 163 |
+
src_pooled_prompt_embeds (torch.Tensor): The source pooled prompt embeddings.
|
| 164 |
+
src_guidance (torch.Tensor): The guidance scale tensor.
|
| 165 |
+
src_text_ids (torch.Tensor): The source text token IDs.
|
| 166 |
+
latent_src_image_ids (torch.Tensor): The latent image token IDs.
|
| 167 |
+
Returns:
|
| 168 |
+
torch.Tensor: The inverted latent tensor.
|
| 169 |
+
"""
|
| 170 |
+
x_t = x0_src.clone()
|
| 171 |
+
timesteps_inv = timesteps.flip(dims=(0,))[:-n_start] if n_start > 0 else timesteps.flip(dims=(0,))
|
| 172 |
+
next_v = None
|
| 173 |
+
for _i, t in enumerate(timesteps_inv):
|
| 174 |
+
scheduler._init_step_index(t)
|
| 175 |
+
t_i = scheduler.sigmas[scheduler.step_index]
|
| 176 |
+
t_ip1 = scheduler.sigmas[scheduler.step_index + 1]
|
| 177 |
+
dt = t_i - t_ip1
|
| 178 |
+
|
| 179 |
+
if next_v is None:
|
| 180 |
+
v_tar = calc_v_flux(
|
| 181 |
+
pipe, latents=x_t, prompt_embeds=src_prompt_embeds,
|
| 182 |
+
pooled_prompt_embeds=src_pooled_prompt_embeds, guidance=src_guidance,
|
| 183 |
+
text_ids=src_text_ids, latent_image_ids=latent_src_image_ids, t=t_ip1 * 1000,
|
| 184 |
+
)
|
| 185 |
+
else:
|
| 186 |
+
v_tar = next_v
|
| 187 |
+
x_t = x_t.to(torch.float32)
|
| 188 |
+
x_t_next = x_t + v_tar * dt
|
| 189 |
+
x_t_next = x_t_next.to(pipe.dtype)
|
| 190 |
+
|
| 191 |
+
v_tar_next = calc_v_flux(
|
| 192 |
+
pipe, latents=x_t_next, prompt_embeds=src_prompt_embeds,
|
| 193 |
+
pooled_prompt_embeds=src_pooled_prompt_embeds, guidance=src_guidance,
|
| 194 |
+
text_ids=src_text_ids, latent_image_ids=latent_src_image_ids, t=t,
|
| 195 |
+
)
|
| 196 |
+
next_v = v_tar_next
|
| 197 |
+
x_t = x_t + v_tar_next * dt
|
| 198 |
+
x_t = x_t.to(pipe.dtype)
|
| 199 |
+
|
| 200 |
+
return x_t
|
| 201 |
+
|
| 202 |
+
@torch.no_grad()
|
| 203 |
+
def initialization(
|
| 204 |
+
pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
|
| 205 |
+
T_steps: int, n_start: int, x0_src: torch.Tensor, src_prompt: str,
|
| 206 |
+
src_guidance_scale: float,
|
| 207 |
+
) -> Tuple[
|
| 208 |
+
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int,
|
| 209 |
+
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
| 210 |
+
]:
|
| 211 |
+
"""
|
| 212 |
+
Initialize the inversion process by preparing the latent tensor and prompt embeddings, and performing UniInv.
|
| 213 |
+
|
| 214 |
+
Args:
|
| 215 |
+
pipe (FluxPipeline): The FLUX pipeline.
|
| 216 |
+
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
|
| 217 |
+
T_steps (int): The total number of timesteps for the diffusion process.
|
| 218 |
+
n_start (int): The number of initial timesteps to skip.
|
| 219 |
+
x0_src (torch.Tensor): The source latent tensor.
|
| 220 |
+
src_prompt (str): The source text prompt.
|
| 221 |
+
src_guidance_scale (float): The guidance scale for classifier-free guidance.
|
| 222 |
+
Returns:
|
| 223 |
+
Tuple[
|
| 224 |
+
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int,
|
| 225 |
+
torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
|
| 226 |
+
]:
|
| 227 |
+
- The inverted latent tensor.
|
| 228 |
+
- The source latent tensor.
|
| 229 |
+
- The timesteps for the diffusion process.
|
| 230 |
+
- The latent image token IDs.
|
| 231 |
+
- The original height of the input image.
|
| 232 |
+
- The original width of the input image.
|
| 233 |
+
- The source prompt embeddings.
|
| 234 |
+
- The source pooled prompt embeddings.
|
| 235 |
+
- The source text token IDs.
|
| 236 |
+
- The guidance scale tensor.
|
| 237 |
+
"""
|
| 238 |
+
(
|
| 239 |
+
x0_src, latent_src_image_ids, timesteps, orig_height, orig_width,
|
| 240 |
+
src_prompt_embeds, src_pooled_prompt_embeds, src_text_ids
|
| 241 |
+
) = prep_input(pipe, scheduler, T_steps, x0_src, src_prompt, src_guidance_scale)
|
| 242 |
+
|
| 243 |
+
# handle guidance
|
| 244 |
+
if pipe.transformer.config.guidance_embeds:
|
| 245 |
+
src_guidance = torch.tensor([src_guidance_scale], device=pipe.device)
|
| 246 |
+
src_guidance = src_guidance.expand(x0_src.shape[0])
|
| 247 |
+
else:
|
| 248 |
+
src_guidance = None
|
| 249 |
+
|
| 250 |
+
x_t = uniinv(
|
| 251 |
+
pipe, scheduler, timesteps, n_start, x0_src,
|
| 252 |
+
src_prompt_embeds, src_pooled_prompt_embeds, src_guidance,
|
| 253 |
+
src_text_ids, latent_src_image_ids,
|
| 254 |
+
)
|
| 255 |
+
|
| 256 |
+
return (
|
| 257 |
+
x_t, x0_src, timesteps, latent_src_image_ids, orig_height, orig_width,
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
@torch.no_grad()
|
| 261 |
+
def flux_denoise(
|
| 262 |
+
pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
|
| 263 |
+
timesteps: torch.Tensor, n_start: int, x_t: torch.Tensor,
|
| 264 |
+
prompt_embeds: torch.Tensor, pooled_prompt_embeds: torch.Tensor,
|
| 265 |
+
guidance: torch.Tensor, text_ids: torch.Tensor,
|
| 266 |
+
latent_image_ids: torch.Tensor,
|
| 267 |
+
) -> torch.Tensor:
|
| 268 |
+
"""
|
| 269 |
+
Perform the denoising process for FLUX.
|
| 270 |
+
|
| 271 |
+
Args:
|
| 272 |
+
pipe (FluxPipeline): The FLUX pipeline.
|
| 273 |
+
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
|
| 274 |
+
timesteps (torch.Tensor): The timesteps for the diffusion process.
|
| 275 |
+
n_start (int): The number of initial timesteps to skip.
|
| 276 |
+
x_t (torch.Tensor): The latent tensor at the starting timestep.
|
| 277 |
+
prompt_embeds (torch.Tensor): The prompt embeddings.
|
| 278 |
+
pooled_prompt_embeds (torch.Tensor): The pooled prompt embeddings.
|
| 279 |
+
guidance (torch.Tensor): The guidance scale tensor.
|
| 280 |
+
text_ids (torch.Tensor): The text token IDs.
|
| 281 |
+
latent_image_ids (torch.Tensor): The latent image token IDs.
|
| 282 |
+
Returns:
|
| 283 |
+
torch.Tensor: The denoised latent tensor.
|
| 284 |
+
"""
|
| 285 |
+
f_xt = x_t.clone()
|
| 286 |
+
for _i, t in enumerate(timesteps[n_start:]):
|
| 287 |
+
scheduler._init_step_index(t)
|
| 288 |
+
t_i = scheduler.sigmas[scheduler.step_index]
|
| 289 |
+
t_im1 = scheduler.sigmas[scheduler.step_index + 1]
|
| 290 |
+
dt = t_im1 - t_i
|
| 291 |
+
|
| 292 |
+
v_tar = calc_v_flux(
|
| 293 |
+
pipe, latents=f_xt, prompt_embeds=prompt_embeds,
|
| 294 |
+
pooled_prompt_embeds=pooled_prompt_embeds, guidance=guidance,
|
| 295 |
+
text_ids=text_ids, latent_image_ids=latent_image_ids, t=t,
|
| 296 |
+
)
|
| 297 |
+
f_xt = f_xt.to(torch.float32)
|
| 298 |
+
f_xt = f_xt + v_tar * dt
|
| 299 |
+
f_xt = f_xt.to(pipe.dtype)
|
| 300 |
+
|
| 301 |
+
return f_xt
|
| 302 |
+
|
| 303 |
+
@torch.no_grad()
|
| 304 |
+
def flux_editing(
|
| 305 |
+
pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
|
| 306 |
+
T_steps: int, n_max: int, x0_src: torch.Tensor, src_prompt: str,
|
| 307 |
+
tar_prompt: str, src_guidance_scale: float, tar_guidance_scale: float,
|
| 308 |
+
flowopt_iterations: int, eta: float,
|
| 309 |
+
) -> Iterator[List[Tuple[Image.Image, str]]]:
|
| 310 |
+
"""
|
| 311 |
+
Perform the editing process for FLUX using FlowOpt.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
pipe (FluxPipeline): The FLUX pipeline.
|
| 315 |
+
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
|
| 316 |
+
T_steps (int): The total number of timesteps for the diffusion process.
|
| 317 |
+
n_max (int): The maximum number of timesteps to consider.
|
| 318 |
+
x0_src (torch.Tensor): The source latent tensor.
|
| 319 |
+
src_prompt (str): The source text prompt.
|
| 320 |
+
tar_prompt (str): The target text prompt for editing.
|
| 321 |
+
src_guidance_scale (float): The guidance scale for the source prompt.
|
| 322 |
+
tar_guidance_scale (float): The guidance scale for the target prompt.
|
| 323 |
+
flowopt_iterations (int): The number of FlowOpt iterations to perform.
|
| 324 |
+
eta (float): The step size for the FlowOpt update.
|
| 325 |
+
Yields:
|
| 326 |
+
Iterator[List[Tuple[Image.Image, str]]]: A list of tuples containing the generated images and their corresponding iteration labels.
|
| 327 |
+
"""
|
| 328 |
+
n_start = T_steps - n_max
|
| 329 |
+
(
|
| 330 |
+
x_t, x0_src, timesteps, latent_src_image_ids, orig_height, orig_width,
|
| 331 |
+
) = initialization(
|
| 332 |
+
pipe, scheduler, T_steps, n_start, x0_src, src_prompt, src_guidance_scale,
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
pipe._guidance_scale = tar_guidance_scale
|
| 336 |
+
(
|
| 337 |
+
tar_prompt_embeds,
|
| 338 |
+
pooled_tar_prompt_embeds,
|
| 339 |
+
tar_text_ids,
|
| 340 |
+
) = pipe.encode_prompt(
|
| 341 |
+
prompt=tar_prompt,
|
| 342 |
+
prompt_2=None,
|
| 343 |
+
device=pipe.device,
|
| 344 |
+
)
|
| 345 |
+
|
| 346 |
+
# handle guidance
|
| 347 |
+
if pipe.transformer.config.guidance_embeds:
|
| 348 |
+
tar_guidance = torch.tensor([tar_guidance_scale], device=pipe.device)
|
| 349 |
+
tar_guidance = tar_guidance.expand(x0_src.shape[0])
|
| 350 |
+
else:
|
| 351 |
+
tar_guidance = None
|
| 352 |
+
|
| 353 |
+
history = []
|
| 354 |
+
j_star = x0_src.clone().to(torch.float32) # y
|
| 355 |
+
for flowopt_iter in range(flowopt_iterations + 1):
|
| 356 |
+
f_xt = flux_denoise(
|
| 357 |
+
pipe, scheduler, timesteps, n_start, x_t,
|
| 358 |
+
tar_prompt_embeds, pooled_tar_prompt_embeds, tar_guidance,
|
| 359 |
+
tar_text_ids, latent_src_image_ids,
|
| 360 |
+
) # Eq. (3)
|
| 361 |
+
|
| 362 |
+
if flowopt_iter < flowopt_iterations:
|
| 363 |
+
x_t = x_t.to(torch.float32)
|
| 364 |
+
x_t = x_t - eta * (f_xt - j_star) # Eq. (6) with c = c_tar
|
| 365 |
+
x_t = x_t.to(x0_src.dtype)
|
| 366 |
+
|
| 367 |
+
x0_flowopt = f_xt.clone()
|
| 368 |
+
unpacked_x0_flowopt = pipe._unpack_latents(x0_flowopt, orig_height, orig_width, pipe.vae_scale_factor)
|
| 369 |
+
x0_flowopt_denorm = (unpacked_x0_flowopt / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
|
| 370 |
+
with torch.autocast("cuda"), torch.inference_mode():
|
| 371 |
+
x0_flowopt_image = pipe.vae.decode(x0_flowopt_denorm, return_dict=False)[0].clamp(-1, 1)
|
| 372 |
+
x0_flowopt_image_pil = pipe.image_processor.postprocess(x0_flowopt_image)[0]
|
| 373 |
+
history.append((x0_flowopt_image_pil, f"Iteration {flowopt_iter}"))
|
| 374 |
+
yield history
|
utils/sd3.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Iterator, List, Tuple
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
from diffusers import FlowMatchEulerDiscreteScheduler, StableDiffusion3Pipeline
|
| 5 |
+
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
|
| 6 |
+
from PIL import Image
|
| 7 |
+
|
| 8 |
+
@torch.no_grad()
|
| 9 |
+
def calc_v_sd3(
|
| 10 |
+
pipe: StableDiffusion3Pipeline, latent_model_input: torch.Tensor,
|
| 11 |
+
prompt_embeds: torch.Tensor, pooled_prompt_embeds: torch.Tensor,
|
| 12 |
+
guidance_scale: float, t: torch.Tensor,
|
| 13 |
+
) -> torch.Tensor:
|
| 14 |
+
"""
|
| 15 |
+
Calculate the velocity (v) for Stable Diffusion 3.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
|
| 19 |
+
latent_model_input (torch.Tensor): The input latent tensor.
|
| 20 |
+
prompt_embeds (torch.Tensor): The text embeddings for the prompt.
|
| 21 |
+
pooled_prompt_embeds (torch.Tensor): The pooled text embeddings for the prompt.
|
| 22 |
+
guidance_scale (float): The guidance scale for classifier-free guidance.
|
| 23 |
+
t (torch.Tensor): The current timestep.
|
| 24 |
+
Returns:
|
| 25 |
+
torch.Tensor: The predicted noise (velocity).
|
| 26 |
+
"""
|
| 27 |
+
timestep = t.expand(latent_model_input.shape[0])
|
| 28 |
+
|
| 29 |
+
noise_pred = pipe.transformer(
|
| 30 |
+
hidden_states=latent_model_input,
|
| 31 |
+
timestep=timestep,
|
| 32 |
+
encoder_hidden_states=prompt_embeds,
|
| 33 |
+
pooled_projections=pooled_prompt_embeds,
|
| 34 |
+
joint_attention_kwargs=None,
|
| 35 |
+
return_dict=False,
|
| 36 |
+
)[0]
|
| 37 |
+
|
| 38 |
+
# perform guidance source
|
| 39 |
+
if pipe.do_classifier_free_guidance:
|
| 40 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 41 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 42 |
+
|
| 43 |
+
return noise_pred
|
| 44 |
+
|
| 45 |
+
# https://github.com/DSL-Lab/UniEdit-Flow
|
| 46 |
+
@torch.no_grad()
|
| 47 |
+
def uniinv(
|
| 48 |
+
pipe: StableDiffusion3Pipeline, timesteps: torch.Tensor, n_start: int,
|
| 49 |
+
x0_src: torch.Tensor, src_prompt_embeds_all: torch.Tensor,
|
| 50 |
+
src_pooled_prompt_embeds_all: torch.Tensor, src_guidance_scale: float,
|
| 51 |
+
) -> torch.Tensor:
|
| 52 |
+
"""
|
| 53 |
+
Perform the UniInv inversion process for Stable Diffusion 3.
|
| 54 |
+
|
| 55 |
+
Args:
|
| 56 |
+
pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
|
| 57 |
+
timesteps (torch.Tensor): The timesteps for the diffusion process.
|
| 58 |
+
n_start (int): The number of initial timesteps to skip.
|
| 59 |
+
x0_src (torch.Tensor): The source latent tensor.
|
| 60 |
+
src_prompt_embeds_all (torch.Tensor): The text embeddings for the source prompt.
|
| 61 |
+
src_pooled_prompt_embeds_all (torch.Tensor): The pooled text embeddings for the source prompt.
|
| 62 |
+
src_guidance_scale (float): The guidance scale for classifier-free guidance.
|
| 63 |
+
Returns:
|
| 64 |
+
torch.Tensor: The inverted latent tensor.
|
| 65 |
+
"""
|
| 66 |
+
x_t = x0_src.clone()
|
| 67 |
+
timesteps_inv = torch.cat([torch.tensor([0.0], device=pipe.device), timesteps.flip(dims=(0,))], dim=0)
|
| 68 |
+
if n_start > 0:
|
| 69 |
+
zipped_timesteps_inv = zip(timesteps_inv[:-n_start - 1], timesteps_inv[1:-n_start])
|
| 70 |
+
else:
|
| 71 |
+
zipped_timesteps_inv = zip(timesteps_inv[:-1], timesteps_inv[1:])
|
| 72 |
+
next_v = None
|
| 73 |
+
for _i, (t_cur, t_prev) in enumerate(zipped_timesteps_inv):
|
| 74 |
+
t_i = t_cur / 1000
|
| 75 |
+
t_ip1 = t_prev / 1000
|
| 76 |
+
dt = t_ip1 - t_i
|
| 77 |
+
|
| 78 |
+
if next_v is None:
|
| 79 |
+
latent_model_input = torch.cat([x_t, x_t]) if pipe.do_classifier_free_guidance else (x_t)
|
| 80 |
+
v_tar = calc_v_sd3(
|
| 81 |
+
pipe, latent_model_input, src_prompt_embeds_all,
|
| 82 |
+
src_pooled_prompt_embeds_all, src_guidance_scale, t_cur,
|
| 83 |
+
)
|
| 84 |
+
else:
|
| 85 |
+
v_tar = next_v
|
| 86 |
+
|
| 87 |
+
x_t = x_t.to(torch.float32)
|
| 88 |
+
x_t_next = x_t + v_tar * dt
|
| 89 |
+
x_t_next = x_t_next.to(pipe.dtype)
|
| 90 |
+
|
| 91 |
+
latent_model_input = torch.cat([x_t_next, x_t_next]) if pipe.do_classifier_free_guidance else (x_t_next)
|
| 92 |
+
v_tar_next = calc_v_sd3(
|
| 93 |
+
pipe, latent_model_input, src_prompt_embeds_all,
|
| 94 |
+
src_pooled_prompt_embeds_all, src_guidance_scale, t_prev,
|
| 95 |
+
)
|
| 96 |
+
next_v = v_tar_next
|
| 97 |
+
x_t = x_t + v_tar_next * dt
|
| 98 |
+
x_t = x_t.to(pipe.dtype)
|
| 99 |
+
|
| 100 |
+
return x_t
|
| 101 |
+
|
| 102 |
+
@torch.no_grad()
|
| 103 |
+
def initialization(
|
| 104 |
+
pipe: StableDiffusion3Pipeline, scheduler: FlowMatchEulerDiscreteScheduler,
|
| 105 |
+
T_steps: int, n_start: int, x0_src: torch.Tensor,
|
| 106 |
+
src_prompt: str, negative_prompt: str, src_guidance_scale: float,
|
| 107 |
+
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 108 |
+
"""
|
| 109 |
+
Initialize the inversion process by preparing the latent tensor and prompt embeddings, and performing UniInv.
|
| 110 |
+
|
| 111 |
+
Args:
|
| 112 |
+
pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
|
| 113 |
+
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
|
| 114 |
+
T_steps (int): The total number of timesteps for the diffusion process.
|
| 115 |
+
n_start (int): The number of initial timesteps to skip.
|
| 116 |
+
x0_src (torch.Tensor): The source latent tensor.
|
| 117 |
+
src_prompt (str): The source text prompt.
|
| 118 |
+
negative_prompt (str): The negative text prompt for classifier-free guidance.
|
| 119 |
+
src_guidance_scale (float): The guidance scale for classifier-free guidance.
|
| 120 |
+
Returns:
|
| 121 |
+
Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 122 |
+
- The inverted latent tensor.
|
| 123 |
+
- The original source latent tensor.
|
| 124 |
+
- The timesteps for the diffusion process.
|
| 125 |
+
- The text embeddings for the source prompt.
|
| 126 |
+
- The pooled text embeddings for the source prompt.
|
| 127 |
+
"""
|
| 128 |
+
pipe._guidance_scale = src_guidance_scale
|
| 129 |
+
(
|
| 130 |
+
src_prompt_embeds,
|
| 131 |
+
src_negative_prompt_embeds,
|
| 132 |
+
src_pooled_prompt_embeds,
|
| 133 |
+
src_negative_pooled_prompt_embeds,
|
| 134 |
+
) = pipe.encode_prompt(
|
| 135 |
+
prompt=src_prompt,
|
| 136 |
+
prompt_2=None,
|
| 137 |
+
prompt_3=None,
|
| 138 |
+
negative_prompt=negative_prompt,
|
| 139 |
+
do_classifier_free_guidance=pipe.do_classifier_free_guidance,
|
| 140 |
+
device=pipe.device,
|
| 141 |
+
)
|
| 142 |
+
src_prompt_embeds_all = torch.cat([src_negative_prompt_embeds, src_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else src_prompt_embeds
|
| 143 |
+
src_pooled_prompt_embeds_all = torch.cat([src_negative_pooled_prompt_embeds, src_pooled_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else src_pooled_prompt_embeds
|
| 144 |
+
|
| 145 |
+
timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, x0_src.device, timesteps=None)
|
| 146 |
+
pipe._num_timesteps = len(timesteps)
|
| 147 |
+
|
| 148 |
+
x_t = uniinv(
|
| 149 |
+
pipe, timesteps, n_start, x0_src, src_prompt_embeds_all,
|
| 150 |
+
src_pooled_prompt_embeds_all, src_guidance_scale,
|
| 151 |
+
)
|
| 152 |
+
|
| 153 |
+
return x_t, x0_src, timesteps
|
| 154 |
+
|
| 155 |
+
@torch.no_grad()
|
| 156 |
+
def sd3_denoise(
|
| 157 |
+
pipe: StableDiffusion3Pipeline, timesteps: torch.Tensor, n_start: int,
|
| 158 |
+
x_t: torch.Tensor, prompt_embeds_all: torch.Tensor,
|
| 159 |
+
pooled_prompt_embeds_all: torch.Tensor, guidance_scale: float,
|
| 160 |
+
) -> torch.Tensor:
|
| 161 |
+
"""
|
| 162 |
+
Perform the denoising process for Stable Diffusion 3.
|
| 163 |
+
|
| 164 |
+
Args:
|
| 165 |
+
pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
|
| 166 |
+
timesteps (torch.Tensor): The timesteps for the diffusion process.
|
| 167 |
+
n_start (int): The number of initial timesteps to skip.
|
| 168 |
+
x_t (torch.Tensor): The latent tensor at the starting timestep.
|
| 169 |
+
prompt_embeds_all (torch.Tensor): The text embeddings for the prompt.
|
| 170 |
+
pooled_prompt_embeds_all (torch.Tensor): The pooled text embeddings for the prompt.
|
| 171 |
+
guidance_scale (float): The guidance scale for classifier-free guidance.
|
| 172 |
+
Returns:
|
| 173 |
+
torch.Tensor: The denoised latent tensor.
|
| 174 |
+
"""
|
| 175 |
+
f_xt = x_t.clone()
|
| 176 |
+
for i, t in enumerate(timesteps[n_start:]):
|
| 177 |
+
t_i = t / 1000
|
| 178 |
+
if i + 1 < len(timesteps[n_start:]):
|
| 179 |
+
t_im1 = (timesteps[n_start + i + 1]) / 1000
|
| 180 |
+
else:
|
| 181 |
+
t_im1 = torch.zeros_like(t_i).to(t_i.device)
|
| 182 |
+
dt = t_im1 - t_i
|
| 183 |
+
|
| 184 |
+
latent_model_input = torch.cat([f_xt, f_xt]) if pipe.do_classifier_free_guidance else (f_xt)
|
| 185 |
+
v_tar = calc_v_sd3(
|
| 186 |
+
pipe, latent_model_input, prompt_embeds_all,
|
| 187 |
+
pooled_prompt_embeds_all, guidance_scale, t,
|
| 188 |
+
)
|
| 189 |
+
f_xt = f_xt.to(torch.float32)
|
| 190 |
+
f_xt = f_xt + v_tar * dt
|
| 191 |
+
f_xt = f_xt.to(pipe.dtype)
|
| 192 |
+
|
| 193 |
+
return f_xt
|
| 194 |
+
|
| 195 |
+
@torch.no_grad()
|
| 196 |
+
def sd3_editing(
|
| 197 |
+
pipe: StableDiffusion3Pipeline, scheduler: FlowMatchEulerDiscreteScheduler,
|
| 198 |
+
T_steps: int, n_max: int, x0_src: torch.Tensor, src_prompt: str,
|
| 199 |
+
tar_prompt: str, negative_prompt: str, src_guidance_scale: float,
|
| 200 |
+
tar_guidance_scale: float, flowopt_iterations: int, eta: float,
|
| 201 |
+
) -> Iterator[List[Tuple[Image.Image, str]]]:
|
| 202 |
+
"""
|
| 203 |
+
Perform the editing process for Stable Diffusion 3 using FlowOpt.
|
| 204 |
+
|
| 205 |
+
Args:
|
| 206 |
+
pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
|
| 207 |
+
scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
|
| 208 |
+
T_steps (int): The total number of timesteps for the diffusion process.
|
| 209 |
+
n_max (int): The maximum number of timesteps to consider.
|
| 210 |
+
x0_src (torch.Tensor): The source latent tensor.
|
| 211 |
+
src_prompt (str): The source text prompt.
|
| 212 |
+
tar_prompt (str): The target text prompt for editing.
|
| 213 |
+
negative_prompt (str): The negative text prompt for classifier-free guidance.
|
| 214 |
+
src_guidance_scale (float): The guidance scale for the source prompt.
|
| 215 |
+
tar_guidance_scale (float): The guidance scale for the target prompt.
|
| 216 |
+
flowopt_iterations (int): The number of FlowOpt iterations to perform.
|
| 217 |
+
eta (float): The step size for the FlowOpt update.
|
| 218 |
+
Yields:
|
| 219 |
+
Iterator[List[Tuple[Image.Image, str]]]: A list of tuples containing the generated images and their corresponding iteration labels.
|
| 220 |
+
"""
|
| 221 |
+
n_start = T_steps - n_max
|
| 222 |
+
x_t, x0_src, timesteps = initialization(
|
| 223 |
+
pipe, scheduler, T_steps, n_start, x0_src, src_prompt,
|
| 224 |
+
negative_prompt, src_guidance_scale,
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
pipe._guidance_scale = tar_guidance_scale
|
| 228 |
+
(
|
| 229 |
+
tar_prompt_embeds,
|
| 230 |
+
tar_negative_prompt_embeds,
|
| 231 |
+
tar_pooled_prompt_embeds,
|
| 232 |
+
tar_negative_pooled_prompt_embeds,
|
| 233 |
+
) = pipe.encode_prompt(
|
| 234 |
+
prompt=tar_prompt,
|
| 235 |
+
prompt_2=None,
|
| 236 |
+
prompt_3=None,
|
| 237 |
+
negative_prompt=negative_prompt,
|
| 238 |
+
do_classifier_free_guidance=pipe.do_classifier_free_guidance,
|
| 239 |
+
device=pipe.device,
|
| 240 |
+
)
|
| 241 |
+
|
| 242 |
+
tar_prompt_embeds_all = torch.cat([tar_negative_prompt_embeds, tar_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else tar_prompt_embeds
|
| 243 |
+
tar_pooled_prompt_embeds_all = torch.cat([tar_negative_pooled_prompt_embeds, tar_pooled_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else tar_pooled_prompt_embeds
|
| 244 |
+
|
| 245 |
+
history = []
|
| 246 |
+
j_star = x0_src.clone().to(torch.float32) # y
|
| 247 |
+
for flowopt_iter in range(flowopt_iterations + 1):
|
| 248 |
+
f_xt = sd3_denoise(
|
| 249 |
+
pipe, timesteps, n_start, x_t, tar_prompt_embeds_all,
|
| 250 |
+
tar_pooled_prompt_embeds_all, tar_guidance_scale,
|
| 251 |
+
) # Eq. (3)
|
| 252 |
+
|
| 253 |
+
if flowopt_iter < flowopt_iterations:
|
| 254 |
+
x_t = x_t.to(torch.float32)
|
| 255 |
+
x_t = x_t - eta * (f_xt - j_star) # Eq. (6) with c = c_tar
|
| 256 |
+
x_t = x_t.to(x0_src.dtype)
|
| 257 |
+
|
| 258 |
+
x0_flowopt = f_xt.clone()
|
| 259 |
+
x0_flowopt_denorm = (x0_flowopt / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
|
| 260 |
+
with torch.autocast("cuda"), torch.inference_mode():
|
| 261 |
+
x0_flowopt_image = pipe.vae.decode(x0_flowopt_denorm, return_dict=False)[0].clamp(-1, 1)
|
| 262 |
+
x0_flowopt_image_pil = pipe.image_processor.postprocess(x0_flowopt_image)[0]
|
| 263 |
+
history.append((x0_flowopt_image_pil, f"Iteration {flowopt_iter}"))
|
| 264 |
+
yield history
|