orronai commited on
Commit
8d5a128
Β·
1 Parent(s): 492742b

feat: add application files

Browse files
Files changed (5) hide show
  1. README.md +6 -9
  2. app.py +207 -0
  3. requirements.txt +6 -0
  4. utils/flux.py +374 -0
  5. utils/sd3.py +264 -0
README.md CHANGED
@@ -1,14 +1,11 @@
1
- ---
2
  title: FlowOpt
3
- emoji: πŸ“ˆ
4
- colorFrom: purple
5
- colorTo: pink
6
  sdk: gradio
7
- sdk_version: 5.48.0
8
  app_file: app.py
9
  pinned: false
10
  license: mit
11
- short_description: 'FlowOpt Gradio: Fast Optimization for Training-Free Editing'
12
- ---
13
-
14
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
1
  title: FlowOpt
2
+ emoji: πŸ“š
3
+ colorFrom: blue
4
+ colorTo: green
5
  sdk: gradio
6
+ sdk_version: 5.8.0
7
  app_file: app.py
8
  pinned: false
9
  license: mit
10
+ hf_oauth: true
11
+ short_description: 'FlowOpt Gradio: Fast-Optimization for Training-Free Editing.'
 
 
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ from typing import Tuple
4
+
5
+ import numpy as np
6
+ import spaces
7
+ import torch
8
+ from diffusers import FluxPipeline, StableDiffusion3Pipeline
9
+ from PIL import Image
10
+
11
+ import gradio as gr
12
+ from utils.flux import flux_editing
13
+ from utils.sd3 import sd3_editing
14
+
15
+ device = "cuda" if torch.cuda.is_available() else "cpu"
16
+
17
+ pipe_sd3 = StableDiffusion3Pipeline.from_pretrained("stabilityai/stable-diffusion-3-medium-diffusers", torch_dtype=torch.float16, token=os.getenv('HF_ACCESS_TOK'))
18
+ pipe_flux = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-dev", torch_dtype=torch.float16, token=os.getenv('HF_ACCESS_TOK'))
19
+
20
+
21
+ def seed_everything(seed: int) -> None:
22
+ """
23
+ Set the random seed for reproducibility.
24
+
25
+ Args:
26
+ seed (int): The seed value to set.
27
+ """
28
+ random.seed(seed)
29
+ np.random.seed(seed)
30
+ torch.manual_seed(seed)
31
+ torch.cuda.manual_seed_all(seed)
32
+
33
+ def on_T_steps_change(T_steps: int) -> gr.update:
34
+ """
35
+ Update the maximum value of the n_max slider based on the T_steps value.
36
+
37
+ Args:
38
+ T_steps (int): The current value of the T_steps slider.
39
+ Returns:
40
+ gr.update: An update object to modify the n_max slider's maximum value.
41
+ """
42
+ return gr.update(maximum=T_steps)
43
+
44
+ def on_model_change(model_type: str) -> Tuple[int, int, float]:
45
+ if model_type == 'SD3':
46
+ T_steps_value = 15
47
+ n_max_value = 12
48
+ eta_value = 0.01
49
+ elif model_type == 'FLUX':
50
+ T_steps_value = 15
51
+ n_max_value = 13
52
+ eta_value = 0.0025
53
+ else:
54
+ raise NotImplementedError(f"Model type {model_type} not implemented")
55
+
56
+ return T_steps_value, n_max_value, eta_value
57
+
58
+ def get_examples():
59
+ case = [
60
+ ["inputs/corgi_walking.png", "FLUX", 15, 13, 0.0025, 7, "A cute brown and white dog walking on a sidewalk near a body of water. The dog is wearing a pink vest, adding a touch of color to the scene.", "A cute brown and white guinea pig walking on a sidewalk near a body of water. The guinea pig is wearing a pink vest, adding a touch of color to the scene.", 1.0, 3.5, [(f"example_outputs/corgi_walking/guinea_pig/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
61
+ ["inputs/corgi_walking.png", "SD3", 15, 12, 0.01, 7, "A cute brown and white dog walking on a sidewalk near a body of water. The dog is wearing a pink vest, adding a touch of color to the scene.", "A cute brown and white rabbit walking on a sidewalk near a body of water. The rabbit is wearing a pink vest, adding a touch of color to the scene.", 1.0, 3.5, [(f"example_outputs/corgi_walking/rabbit/sd3_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
62
+ ["inputs/puppies.png", "FLUX", 15, 13, 0.0025, 7, "Two adorable golden retriever puppies sitting in a grassy field. They are positioned close to each other, with one dog on the left and the other on the right. Both dogs have their mouths open, possibly panting.", "Two adorable crochet golden retriever puppies sitting in a grassy field. They are positioned close to each other, with one dog on the left and the other on the right. Both dogs have their mouths open, possibly panting or enjoying the outdoor environment.", 1.0, 3.5, [(f"example_outputs/puppies/crochet/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
63
+ ["inputs/puppies.png", "SD3", 15, 12, 0.01, 5, "Two adorable golden retriever puppies sitting in a grassy field. They are positioned close to each other, with one dog on the left and the other on the right. Both dogs have their mouths open, possibly panting.", "Two adorable husky puppies sitting in a grassy field. They are positioned close to each other, with one dog on the left and the other on the right. Both dogs have their mouths open, possibly panting or enjoying the outdoor environment.", 1.0, 3.5, [(f"example_outputs/puppies/husky/sd3_iterations={i}.png", f"Iteration {i}") for i in range(6)]],
64
+ ["inputs/iguana.png", "FLUX", 15, 13, 0.0025, 7, "A large orange lizard sitting on a rock near the ocean. The lizard is positioned in the center of the scene, with the ocean waves visible in the background. The rock is located close to the water, providing a picturesque setting for the lizard''s resting spot.", "A large crochet lizard sitting on a rock near the ocean. The lizard is positioned in the center of the scene, with the ocean waves visible in the background. The rock is located close to the water, providing a picturesque setting for the lizard''s resting spot.", 1.0, 3.5, [(f"example_outputs/iguana/crochet/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
65
+ ["inputs/iguana.png", "FLUX", 15, 13, 0.0025, 7, "A large orange lizard sitting on a rock near the ocean. The lizard is positioned in the center of the scene, with the ocean waves visible in the background. The rock is located close to the water, providing a picturesque setting for the lizard''s resting spot.", "A large lizard made out of lego bricks sitting on a rock near the ocean. The lizard is positioned in the center of the scene, with the ocean waves visible in the background. The rock is located close to the water, providing a picturesque setting for the lizard''s resting spot.", 1.0, 3.5, [(f"example_outputs/iguana/lego_bricks/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
66
+ ["inputs/cow_grass2.png", "FLUX", 15, 12, 0.0025, 6, "A large brown and white cow standing in a grassy field. The cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", "A large cow made out of colorful toy bricks standing in a grassy field. The colorful toy brick cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", 1.0, 3.5, [(f"example_outputs/cow_grass2/colorful_toy_bricks/flux_iterations={i}.png", f"Iteration {i}") for i in range(7)]],
67
+ ["inputs/cow_grass2.png", "FLUX", 15, 13, 0.0025, 5, "A large brown and white cow standing in a grassy field. The cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", "A large cow made out of flowers standing in a grassy field. The flower cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", 1.0, 3.5, [(f"example_outputs/cow_grass2/flowers/flux_iterations={i}.png", f"Iteration {i}") for i in range(6)]],
68
+ ["inputs/cow_grass2.png", "SD3", 15, 12, 0.01, 8, "A large brown and white cow standing in a grassy field. The cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", "A large cow made out of wooden blocks standing in a grassy field. The wooden block cow is positioned towards the center of the scene. The field is lush and green, providing a perfect environment for the cow to graze.", 1.0, 3.5, [(f"example_outputs/cow_grass2/wooden_blocks/sd3_iterations={i}.png", f"Iteration {i}") for i in range(9)]],
69
+ ["inputs/cat_fridge.png", "SD3", 15, 12, 0.01, 8, "A cat sitting on top of a counter in a store. The cat is positioned towards the right side of the counter, and it appears to be looking at the camera. The store has a variety of items displayed, including several bottles scattered around the counter.", "A cat sitting on top of a counter in a store, with the cat and counter crafted using origami folded paper art techniques. The cat has a delicate and intricate appearance, with paper folds used to create its fur and features. The store has a variety of items displayed, including several bottles scattered around the counter.", 1.0, 3.5, [(f"example_outputs/cat_fridge/origami/sd3_iterations={i}.png", f"Iteration {i}") for i in range(9)]],
70
+ ["inputs/cat.png", "FLUX", 15, 13, 0.0025, 7, "A small, fluffy kitten sitting in a grassy field. The kitten is positioned in the center of the scene, surrounded by a field. The kitten appears to be looking at something in the field.", "A small bear cub sitting in a grassy field. The bear cub is positioned in the center of the scene, surrounded by a field. The bear cub appears to be looking at something in the field.", 1.0, 3.5, [(f"example_outputs/cat/bear/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
71
+ ["inputs/cat.png", "SD3", 15, 12, 0.01, 6, "A small, fluffy kitten sitting in a grassy field. The kitten is positioned in the center of the scene, surrounded by a field. The kitten appears to be looking at something in the field.", "A small puppy sitting in a grassy field. The puppy is positioned in the center of the scene, surrounded by a field. The puppy appears to be looking at something in the field.", 1.0, 3.5, [(f"example_outputs/cat/puppy/sd3_iterations={i}.png", f"Iteration {i}") for i in range(7)]],
72
+ ["inputs/wolf_grass.png", "FLUX", 15, 13, 0.0025, 7, "A wolf standing in a grassy field with yellow flowers. The wolf is positioned towards the center of the scene, and its body is facing the camera. The field is filled with grass, and the yellow flowers are scattered throughout the area.", "A fox standing in a grassy field with yellow flowers. The fox is positioned towards the center of the scene, and its body is facing the camera. The field is filled with grass, and the yellow flowers are scattered throughout the area.", 1.0, 3.5, [(f"example_outputs/wolf_grass/fox/flux_iterations={i}.png", f"Iteration {i}") for i in range(8)]],
73
+ ["inputs/wolf_grass.png", "SD3", 15, 12, 0.01, 4, "A wolf standing in a grassy field with yellow flowers. The wolf is positioned towards the center of the scene, and its body is facing the camera. The field is filled with grass, and the yellow flowers are scattered throughout the area.", "A baby deer standing in a grassy field with yellow flowers. The baby deer is positioned towards the center of the scene, and its body is facing the camera. The field is filled with grass, and the yellow flowers are scattered throughout the area.", 1.0, 3.5, [(f"example_outputs/wolf_grass/deer/sd3_iterations={i}.png", f"Iteration {i}") for i in range(5)]],
74
+ ]
75
+ return case
76
+
77
+ @spaces.GPU(duration=200)
78
+ def FlowOpt_run(
79
+ image_src_val: str, model_type_val: str, T_steps_val: int,
80
+ n_max_val: int, eta_val: float, flowopt_iterations_val: int,
81
+ src_prompt_val: str, tar_prompt_val: str,
82
+ src_guidance_scale_val: float, tar_guidance_scale_val: float,
83
+ ):
84
+ if not len(src_prompt_val):
85
+ raise gr.Error("Source prompt cannot be empty")
86
+ if not len(tar_prompt_val):
87
+ raise gr.Error("Target prompt cannot be empty")
88
+
89
+ if model_type_val == 'FLUX':
90
+ pipe = pipe_flux.to(device)
91
+ elif model_type_val == 'SD3':
92
+ pipe = pipe_sd3.to(device)
93
+ else:
94
+ raise NotImplementedError(f"Model type {model_type_val} not implemented")
95
+
96
+ scheduler = pipe.scheduler
97
+
98
+ # set seed
99
+ seed = 1024
100
+ seed_everything(seed)
101
+ # load image
102
+ image = Image.open(image_src_val)
103
+ # crop image to have both dimensions divisibe by 16 - avoids issues with resizing
104
+ image = image.crop((0, 0, image.width - image.width % 16, image.height - image.height % 16))
105
+ image_src_val = pipe.image_processor.preprocess(image)
106
+
107
+ # cast image to half precision
108
+ image_src_val = image_src_val.to(device).half()
109
+ with torch.autocast("cuda"), torch.inference_mode():
110
+ x0_src_denorm = pipe.vae.encode(image_src_val).latent_dist.mode()
111
+ x0_src = (x0_src_denorm - pipe.vae.config.shift_factor) * pipe.vae.config.scaling_factor
112
+ # send to cuda
113
+ x0_src = x0_src.to(device)
114
+ negative_prompt = "" # (SD3)
115
+
116
+ if model_type_val == 'SD3':
117
+ yield from sd3_editing(
118
+ pipe, scheduler, T_steps_val, n_max_val, x0_src,
119
+ src_prompt_val, tar_prompt_val, negative_prompt,
120
+ src_guidance_scale_val, tar_guidance_scale_val,
121
+ flowopt_iterations_val, eta_val,
122
+ )
123
+ elif model_type_val == 'FLUX':
124
+ yield from flux_editing(
125
+ pipe, scheduler, T_steps_val, n_max_val, x0_src,
126
+ src_prompt_val, tar_prompt_val,
127
+ src_guidance_scale_val, tar_guidance_scale_val,
128
+ flowopt_iterations_val, eta_val,
129
+ )
130
+ else:
131
+ raise NotImplementedError(f"Sampler type {model_type_val} not implemented")
132
+
133
+
134
+ intro = """
135
+ <h1 style="font-weight: 1000; text-align: center; margin: 0px;">FlowOpt: Fast Optimization Through Whole Flow Processes for Training-Free Editing</h1>
136
+ <h3 style="margin-bottom: 10px; text-align: center;">
137
+ <a href="">[Paper]</a>&nbsp;|&nbsp;
138
+ <a href="https://orronai.github.io/FlowOpt/">[Project Page]</a>&nbsp;|&nbsp;
139
+ <a href="https://github.com/orronai/FlowOpt">[Code]</a>
140
+ </h3>
141
+ <br> 🎨 Edit your image using FlowOpt for Flow models! Upload an image, add a description of it, and specify the edits you want to make.
142
+ <h3>Notes:</h3>
143
+ <ol>
144
+ <li>We use FLUX.1 dev and SD3 for the demo. The models are large and may take a while to load.</li>
145
+ <li>We recommend 1024x1024 images for the best results. If the input images are too large, there may be out-of-memory errors. For other resolutions, we encourage you to find a suitable set of hyperparameters.</li>
146
+ <li>Default hyperparameters for each model used in the paper are provided as examples.</li>
147
+ </ol>
148
+ """
149
+
150
+ css="""
151
+ #col-container {
152
+ margin: 0 auto;
153
+ max-width: 960px;
154
+ }
155
+ """
156
+ with gr.Blocks(css=css) as demo:
157
+ with gr.Column(elem_id="col-container"):
158
+ gr.HTML(intro)
159
+
160
+ with gr.Row():
161
+ with gr.Column():
162
+ image_src = gr.Image(type="filepath", label="Source Image", value="inputs/cat.png",)
163
+ src_prompt = gr.Textbox(lines=2, label="Source Prompt", value="A cat sitting in the grass")
164
+ tar_prompt = gr.Textbox(lines=2, label="Target Prompt", value="A puppy sitting in the grass")
165
+ submit_button = gr.Button("Run FlowOpt", variant="primary")
166
+
167
+ with gr.Row():
168
+ model_type = gr.Dropdown(["FLUX", "SD3"], label="Model Type", value="FLUX")
169
+ T_steps = gr.Slider(value=15, minimum=10, maximum=50, step=1, label="Total Steps", info="Total number of discretization steps.")
170
+ n_max = gr.Slider(value=13, minimum=1, maximum=15, step=1, label="n_max", info="Control the strength of the edit.")
171
+ eta = gr.Slider(value=0.0025, minimum=0.0001, maximum=0.05, label="eta", info="Control the optimization step-size.")
172
+ flowopt_iterations = gr.Number(value=10, minimum=1, maximum=15, label="flowopt_iterations", info="Max number of FlowOpt iterations")
173
+
174
+ with gr.Column():
175
+ image_tar = gr.Gallery(
176
+ label="Outputs", show_label=True, format="png",
177
+ columns=[3], rows=[3], height="auto",
178
+ )
179
+ with gr.Accordion(label="Advanced Settings", open=False):
180
+ src_guidance_scale = gr.Slider(value=1.0, minimum=0.0, maximum=15.0, label="src_guidance_scale", info="Source prompt CFG scale.")
181
+ tar_guidance_scale = gr.Slider(value=3.5, minimum=1.0, maximum=15.0, label="tar_guidance_scale", info="Target prompt CFG scale.")
182
+
183
+ submit_button.click(
184
+ fn=FlowOpt_run,
185
+ inputs=[
186
+ image_src, model_type, T_steps, n_max, eta, flowopt_iterations,
187
+ src_prompt, tar_prompt, src_guidance_scale, tar_guidance_scale,
188
+ ],
189
+ outputs=[image_tar],
190
+ )
191
+
192
+ gr.Examples(
193
+ label="Examples",
194
+ examples=get_examples(),
195
+ inputs=[
196
+ image_src, model_type, T_steps, n_max, eta,
197
+ flowopt_iterations, src_prompt, tar_prompt,
198
+ src_guidance_scale, tar_guidance_scale, image_tar,
199
+ ],
200
+ outputs=[image_tar],
201
+ )
202
+
203
+ model_type.input(fn=on_model_change, inputs=[model_type], outputs=[T_steps, n_max, eta])
204
+ T_steps.change(fn=on_T_steps_change, inputs=[T_steps], outputs=[n_max])
205
+
206
+ demo.queue()
207
+ demo.launch()
requirements.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ torch
2
+ diffusers
3
+ transformers
4
+ accelerate
5
+ sentencepiece
6
+ protobuf
utils/flux.py ADDED
@@ -0,0 +1,374 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator, List, Tuple
2
+
3
+ import numpy as np
4
+ import torch
5
+ from diffusers import FlowMatchEulerDiscreteScheduler, FluxPipeline
6
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
7
+ from PIL import Image
8
+
9
+
10
+ @torch.no_grad()
11
+ def calculate_shift(
12
+ image_seq_len: int,
13
+ base_seq_len: int = 256,
14
+ max_seq_len: int = 4096,
15
+ base_shift: float = 0.5,
16
+ max_shift: float = 1.16,
17
+ ) -> float:
18
+ m = (max_shift - base_shift) / (max_seq_len - base_seq_len)
19
+ b = base_shift - m * base_seq_len
20
+ mu = image_seq_len * m + b
21
+ return mu
22
+
23
+ @torch.no_grad()
24
+ def calc_v_flux(
25
+ pipe: FluxPipeline, latents: torch.Tensor, prompt_embeds: torch.Tensor,
26
+ pooled_prompt_embeds: torch.Tensor, guidance: torch.Tensor,
27
+ text_ids: torch.Tensor, latent_image_ids: torch.Tensor, t: torch.Tensor,
28
+ ) -> torch.Tensor:
29
+ """
30
+ Calculate the velocity (v) for FLUX.
31
+
32
+ Args:
33
+ pipe (FluxPipeline): The FLUX pipeline.
34
+ latents (torch.Tensor): The latent tensor at the current timestep.
35
+ prompt_embeds (torch.Tensor): The prompt embeddings.
36
+ pooled_prompt_embeds (torch.Tensor): The pooled prompt embeddings.
37
+ guidance (torch.Tensor): The guidance scale tensor.
38
+ text_ids (torch.Tensor): The text token IDs.
39
+ latent_image_ids (torch.Tensor): The latent image token IDs.
40
+ t (torch.Tensor): The current timestep.
41
+ Returns:
42
+ torch.Tensor: The predicted noise (velocity).
43
+ """
44
+ timestep = t.expand(latents.shape[0])
45
+
46
+ noise_pred = pipe.transformer(
47
+ hidden_states=latents,
48
+ timestep=timestep / 1000,
49
+ guidance=guidance,
50
+ encoder_hidden_states=prompt_embeds,
51
+ txt_ids=text_ids,
52
+ img_ids=latent_image_ids,
53
+ pooled_projections=pooled_prompt_embeds,
54
+ joint_attention_kwargs=None,
55
+ return_dict=False,
56
+ )[0]
57
+
58
+ return noise_pred
59
+
60
+ @torch.no_grad()
61
+ def prep_input(
62
+ pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
63
+ T_steps: int, x0_src: torch.Tensor, src_prompt: str,
64
+ src_guidance_scale: float,
65
+ ) -> Tuple[
66
+ torch.Tensor, torch.Tensor, torch.Tensor, int, int,
67
+ torch.Tensor, torch.Tensor, torch.Tensor,
68
+ ]:
69
+ """
70
+ Prepare the input tensors for the FLUX pipeline.
71
+
72
+ Args:
73
+ pipe (FluxPipeline): The FLUX pipeline.
74
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
75
+ T_steps (int): The total number of timesteps for the diffusion process.
76
+ x0_src (torch.Tensor): The source latent tensor.
77
+ src_prompt (str): The source text prompt.
78
+ src_guidance_scale (float): The guidance scale for classifier-free guidance.
79
+ Returns:
80
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, int, int, torch.Tensor, torch.Tensor, torch.Tensor]:
81
+ - Prepared source latent tensor.
82
+ - Latent image token IDs.
83
+ - Timesteps tensor for the diffusion process.
84
+ - Original height of the input image.
85
+ - Original width of the input image.
86
+ - Source prompt embeddings.
87
+ - Source pooled prompt embeddings.
88
+ - Source text token IDs.
89
+ """
90
+ orig_height, orig_width = x0_src.shape[2] * pipe.vae_scale_factor, x0_src.shape[3] * pipe.vae_scale_factor
91
+ num_channels_latents = pipe.transformer.config.in_channels // 4
92
+
93
+ pipe.check_inputs(
94
+ prompt=src_prompt,
95
+ prompt_2=None,
96
+ height=orig_height,
97
+ width=orig_width,
98
+ callback_on_step_end_tensor_inputs=None,
99
+ max_sequence_length=512,
100
+ )
101
+
102
+ x0_src, latent_src_image_ids = pipe.prepare_latents(
103
+ batch_size=x0_src.shape[0], num_channels_latents=num_channels_latents,
104
+ height=orig_height, width=orig_width,
105
+ dtype=x0_src.dtype, device=x0_src.device, generator=None, latents=x0_src,
106
+ )
107
+ x0_src = pipe._pack_latents(x0_src, x0_src.shape[0], num_channels_latents, x0_src.shape[2], x0_src.shape[3])
108
+
109
+ sigmas = np.linspace(1.0, 1 / T_steps, T_steps)
110
+ image_seq_len = x0_src.shape[1]
111
+ mu = calculate_shift(
112
+ image_seq_len,
113
+ scheduler.config.base_image_seq_len,
114
+ scheduler.config.max_image_seq_len,
115
+ scheduler.config.base_shift,
116
+ scheduler.config.max_shift,
117
+ )
118
+ timesteps, T_steps = retrieve_timesteps(
119
+ scheduler,
120
+ T_steps,
121
+ x0_src.device,
122
+ timesteps=None,
123
+ sigmas=sigmas,
124
+ mu=mu,
125
+ )
126
+ pipe._num_timesteps = len(timesteps)
127
+
128
+ pipe._guidance_scale = src_guidance_scale
129
+ (
130
+ src_prompt_embeds,
131
+ src_pooled_prompt_embeds,
132
+ src_text_ids,
133
+ ) = pipe.encode_prompt(
134
+ prompt=src_prompt,
135
+ prompt_2=None,
136
+ device=x0_src.device,
137
+ )
138
+
139
+ return (
140
+ x0_src, latent_src_image_ids, timesteps, orig_height, orig_width,
141
+ src_prompt_embeds, src_pooled_prompt_embeds, src_text_ids
142
+ )
143
+
144
+ # https://github.com/DSL-Lab/UniEdit-Flow
145
+ @torch.no_grad()
146
+ def uniinv(
147
+ pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
148
+ timesteps: torch.Tensor, n_start: int, x0_src: torch.Tensor,
149
+ src_prompt_embeds: torch.Tensor, src_pooled_prompt_embeds: torch.Tensor,
150
+ src_guidance: torch.Tensor, src_text_ids: torch.Tensor,
151
+ latent_src_image_ids: torch.Tensor,
152
+ ) -> torch.Tensor:
153
+ """
154
+ Perform the UniInv inversion process for FLUX.
155
+
156
+ Args:
157
+ pipe (FluxPipeline): The FLUX pipeline.
158
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
159
+ timesteps (torch.Tensor): The timesteps for the diffusion process.
160
+ n_start (int): The number of initial timesteps to skip.
161
+ x0_src (torch.Tensor): The source latent tensor.
162
+ src_prompt_embeds (torch.Tensor): The source prompt embeddings.
163
+ src_pooled_prompt_embeds (torch.Tensor): The source pooled prompt embeddings.
164
+ src_guidance (torch.Tensor): The guidance scale tensor.
165
+ src_text_ids (torch.Tensor): The source text token IDs.
166
+ latent_src_image_ids (torch.Tensor): The latent image token IDs.
167
+ Returns:
168
+ torch.Tensor: The inverted latent tensor.
169
+ """
170
+ x_t = x0_src.clone()
171
+ timesteps_inv = timesteps.flip(dims=(0,))[:-n_start] if n_start > 0 else timesteps.flip(dims=(0,))
172
+ next_v = None
173
+ for _i, t in enumerate(timesteps_inv):
174
+ scheduler._init_step_index(t)
175
+ t_i = scheduler.sigmas[scheduler.step_index]
176
+ t_ip1 = scheduler.sigmas[scheduler.step_index + 1]
177
+ dt = t_i - t_ip1
178
+
179
+ if next_v is None:
180
+ v_tar = calc_v_flux(
181
+ pipe, latents=x_t, prompt_embeds=src_prompt_embeds,
182
+ pooled_prompt_embeds=src_pooled_prompt_embeds, guidance=src_guidance,
183
+ text_ids=src_text_ids, latent_image_ids=latent_src_image_ids, t=t_ip1 * 1000,
184
+ )
185
+ else:
186
+ v_tar = next_v
187
+ x_t = x_t.to(torch.float32)
188
+ x_t_next = x_t + v_tar * dt
189
+ x_t_next = x_t_next.to(pipe.dtype)
190
+
191
+ v_tar_next = calc_v_flux(
192
+ pipe, latents=x_t_next, prompt_embeds=src_prompt_embeds,
193
+ pooled_prompt_embeds=src_pooled_prompt_embeds, guidance=src_guidance,
194
+ text_ids=src_text_ids, latent_image_ids=latent_src_image_ids, t=t,
195
+ )
196
+ next_v = v_tar_next
197
+ x_t = x_t + v_tar_next * dt
198
+ x_t = x_t.to(pipe.dtype)
199
+
200
+ return x_t
201
+
202
+ @torch.no_grad()
203
+ def initialization(
204
+ pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
205
+ T_steps: int, n_start: int, x0_src: torch.Tensor, src_prompt: str,
206
+ src_guidance_scale: float,
207
+ ) -> Tuple[
208
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int,
209
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
210
+ ]:
211
+ """
212
+ Initialize the inversion process by preparing the latent tensor and prompt embeddings, and performing UniInv.
213
+
214
+ Args:
215
+ pipe (FluxPipeline): The FLUX pipeline.
216
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
217
+ T_steps (int): The total number of timesteps for the diffusion process.
218
+ n_start (int): The number of initial timesteps to skip.
219
+ x0_src (torch.Tensor): The source latent tensor.
220
+ src_prompt (str): The source text prompt.
221
+ src_guidance_scale (float): The guidance scale for classifier-free guidance.
222
+ Returns:
223
+ Tuple[
224
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, int, int,
225
+ torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor,
226
+ ]:
227
+ - The inverted latent tensor.
228
+ - The source latent tensor.
229
+ - The timesteps for the diffusion process.
230
+ - The latent image token IDs.
231
+ - The original height of the input image.
232
+ - The original width of the input image.
233
+ - The source prompt embeddings.
234
+ - The source pooled prompt embeddings.
235
+ - The source text token IDs.
236
+ - The guidance scale tensor.
237
+ """
238
+ (
239
+ x0_src, latent_src_image_ids, timesteps, orig_height, orig_width,
240
+ src_prompt_embeds, src_pooled_prompt_embeds, src_text_ids
241
+ ) = prep_input(pipe, scheduler, T_steps, x0_src, src_prompt, src_guidance_scale)
242
+
243
+ # handle guidance
244
+ if pipe.transformer.config.guidance_embeds:
245
+ src_guidance = torch.tensor([src_guidance_scale], device=pipe.device)
246
+ src_guidance = src_guidance.expand(x0_src.shape[0])
247
+ else:
248
+ src_guidance = None
249
+
250
+ x_t = uniinv(
251
+ pipe, scheduler, timesteps, n_start, x0_src,
252
+ src_prompt_embeds, src_pooled_prompt_embeds, src_guidance,
253
+ src_text_ids, latent_src_image_ids,
254
+ )
255
+
256
+ return (
257
+ x_t, x0_src, timesteps, latent_src_image_ids, orig_height, orig_width,
258
+ )
259
+
260
+ @torch.no_grad()
261
+ def flux_denoise(
262
+ pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
263
+ timesteps: torch.Tensor, n_start: int, x_t: torch.Tensor,
264
+ prompt_embeds: torch.Tensor, pooled_prompt_embeds: torch.Tensor,
265
+ guidance: torch.Tensor, text_ids: torch.Tensor,
266
+ latent_image_ids: torch.Tensor,
267
+ ) -> torch.Tensor:
268
+ """
269
+ Perform the denoising process for FLUX.
270
+
271
+ Args:
272
+ pipe (FluxPipeline): The FLUX pipeline.
273
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
274
+ timesteps (torch.Tensor): The timesteps for the diffusion process.
275
+ n_start (int): The number of initial timesteps to skip.
276
+ x_t (torch.Tensor): The latent tensor at the starting timestep.
277
+ prompt_embeds (torch.Tensor): The prompt embeddings.
278
+ pooled_prompt_embeds (torch.Tensor): The pooled prompt embeddings.
279
+ guidance (torch.Tensor): The guidance scale tensor.
280
+ text_ids (torch.Tensor): The text token IDs.
281
+ latent_image_ids (torch.Tensor): The latent image token IDs.
282
+ Returns:
283
+ torch.Tensor: The denoised latent tensor.
284
+ """
285
+ f_xt = x_t.clone()
286
+ for _i, t in enumerate(timesteps[n_start:]):
287
+ scheduler._init_step_index(t)
288
+ t_i = scheduler.sigmas[scheduler.step_index]
289
+ t_im1 = scheduler.sigmas[scheduler.step_index + 1]
290
+ dt = t_im1 - t_i
291
+
292
+ v_tar = calc_v_flux(
293
+ pipe, latents=f_xt, prompt_embeds=prompt_embeds,
294
+ pooled_prompt_embeds=pooled_prompt_embeds, guidance=guidance,
295
+ text_ids=text_ids, latent_image_ids=latent_image_ids, t=t,
296
+ )
297
+ f_xt = f_xt.to(torch.float32)
298
+ f_xt = f_xt + v_tar * dt
299
+ f_xt = f_xt.to(pipe.dtype)
300
+
301
+ return f_xt
302
+
303
+ @torch.no_grad()
304
+ def flux_editing(
305
+ pipe: FluxPipeline, scheduler: FlowMatchEulerDiscreteScheduler,
306
+ T_steps: int, n_max: int, x0_src: torch.Tensor, src_prompt: str,
307
+ tar_prompt: str, src_guidance_scale: float, tar_guidance_scale: float,
308
+ flowopt_iterations: int, eta: float,
309
+ ) -> Iterator[List[Tuple[Image.Image, str]]]:
310
+ """
311
+ Perform the editing process for FLUX using FlowOpt.
312
+
313
+ Args:
314
+ pipe (FluxPipeline): The FLUX pipeline.
315
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
316
+ T_steps (int): The total number of timesteps for the diffusion process.
317
+ n_max (int): The maximum number of timesteps to consider.
318
+ x0_src (torch.Tensor): The source latent tensor.
319
+ src_prompt (str): The source text prompt.
320
+ tar_prompt (str): The target text prompt for editing.
321
+ src_guidance_scale (float): The guidance scale for the source prompt.
322
+ tar_guidance_scale (float): The guidance scale for the target prompt.
323
+ flowopt_iterations (int): The number of FlowOpt iterations to perform.
324
+ eta (float): The step size for the FlowOpt update.
325
+ Yields:
326
+ Iterator[List[Tuple[Image.Image, str]]]: A list of tuples containing the generated images and their corresponding iteration labels.
327
+ """
328
+ n_start = T_steps - n_max
329
+ (
330
+ x_t, x0_src, timesteps, latent_src_image_ids, orig_height, orig_width,
331
+ ) = initialization(
332
+ pipe, scheduler, T_steps, n_start, x0_src, src_prompt, src_guidance_scale,
333
+ )
334
+
335
+ pipe._guidance_scale = tar_guidance_scale
336
+ (
337
+ tar_prompt_embeds,
338
+ pooled_tar_prompt_embeds,
339
+ tar_text_ids,
340
+ ) = pipe.encode_prompt(
341
+ prompt=tar_prompt,
342
+ prompt_2=None,
343
+ device=pipe.device,
344
+ )
345
+
346
+ # handle guidance
347
+ if pipe.transformer.config.guidance_embeds:
348
+ tar_guidance = torch.tensor([tar_guidance_scale], device=pipe.device)
349
+ tar_guidance = tar_guidance.expand(x0_src.shape[0])
350
+ else:
351
+ tar_guidance = None
352
+
353
+ history = []
354
+ j_star = x0_src.clone().to(torch.float32) # y
355
+ for flowopt_iter in range(flowopt_iterations + 1):
356
+ f_xt = flux_denoise(
357
+ pipe, scheduler, timesteps, n_start, x_t,
358
+ tar_prompt_embeds, pooled_tar_prompt_embeds, tar_guidance,
359
+ tar_text_ids, latent_src_image_ids,
360
+ ) # Eq. (3)
361
+
362
+ if flowopt_iter < flowopt_iterations:
363
+ x_t = x_t.to(torch.float32)
364
+ x_t = x_t - eta * (f_xt - j_star) # Eq. (6) with c = c_tar
365
+ x_t = x_t.to(x0_src.dtype)
366
+
367
+ x0_flowopt = f_xt.clone()
368
+ unpacked_x0_flowopt = pipe._unpack_latents(x0_flowopt, orig_height, orig_width, pipe.vae_scale_factor)
369
+ x0_flowopt_denorm = (unpacked_x0_flowopt / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
370
+ with torch.autocast("cuda"), torch.inference_mode():
371
+ x0_flowopt_image = pipe.vae.decode(x0_flowopt_denorm, return_dict=False)[0].clamp(-1, 1)
372
+ x0_flowopt_image_pil = pipe.image_processor.postprocess(x0_flowopt_image)[0]
373
+ history.append((x0_flowopt_image_pil, f"Iteration {flowopt_iter}"))
374
+ yield history
utils/sd3.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Iterator, List, Tuple
2
+
3
+ import torch
4
+ from diffusers import FlowMatchEulerDiscreteScheduler, StableDiffusion3Pipeline
5
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
6
+ from PIL import Image
7
+
8
+ @torch.no_grad()
9
+ def calc_v_sd3(
10
+ pipe: StableDiffusion3Pipeline, latent_model_input: torch.Tensor,
11
+ prompt_embeds: torch.Tensor, pooled_prompt_embeds: torch.Tensor,
12
+ guidance_scale: float, t: torch.Tensor,
13
+ ) -> torch.Tensor:
14
+ """
15
+ Calculate the velocity (v) for Stable Diffusion 3.
16
+
17
+ Args:
18
+ pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
19
+ latent_model_input (torch.Tensor): The input latent tensor.
20
+ prompt_embeds (torch.Tensor): The text embeddings for the prompt.
21
+ pooled_prompt_embeds (torch.Tensor): The pooled text embeddings for the prompt.
22
+ guidance_scale (float): The guidance scale for classifier-free guidance.
23
+ t (torch.Tensor): The current timestep.
24
+ Returns:
25
+ torch.Tensor: The predicted noise (velocity).
26
+ """
27
+ timestep = t.expand(latent_model_input.shape[0])
28
+
29
+ noise_pred = pipe.transformer(
30
+ hidden_states=latent_model_input,
31
+ timestep=timestep,
32
+ encoder_hidden_states=prompt_embeds,
33
+ pooled_projections=pooled_prompt_embeds,
34
+ joint_attention_kwargs=None,
35
+ return_dict=False,
36
+ )[0]
37
+
38
+ # perform guidance source
39
+ if pipe.do_classifier_free_guidance:
40
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
41
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
42
+
43
+ return noise_pred
44
+
45
+ # https://github.com/DSL-Lab/UniEdit-Flow
46
+ @torch.no_grad()
47
+ def uniinv(
48
+ pipe: StableDiffusion3Pipeline, timesteps: torch.Tensor, n_start: int,
49
+ x0_src: torch.Tensor, src_prompt_embeds_all: torch.Tensor,
50
+ src_pooled_prompt_embeds_all: torch.Tensor, src_guidance_scale: float,
51
+ ) -> torch.Tensor:
52
+ """
53
+ Perform the UniInv inversion process for Stable Diffusion 3.
54
+
55
+ Args:
56
+ pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
57
+ timesteps (torch.Tensor): The timesteps for the diffusion process.
58
+ n_start (int): The number of initial timesteps to skip.
59
+ x0_src (torch.Tensor): The source latent tensor.
60
+ src_prompt_embeds_all (torch.Tensor): The text embeddings for the source prompt.
61
+ src_pooled_prompt_embeds_all (torch.Tensor): The pooled text embeddings for the source prompt.
62
+ src_guidance_scale (float): The guidance scale for classifier-free guidance.
63
+ Returns:
64
+ torch.Tensor: The inverted latent tensor.
65
+ """
66
+ x_t = x0_src.clone()
67
+ timesteps_inv = torch.cat([torch.tensor([0.0], device=pipe.device), timesteps.flip(dims=(0,))], dim=0)
68
+ if n_start > 0:
69
+ zipped_timesteps_inv = zip(timesteps_inv[:-n_start - 1], timesteps_inv[1:-n_start])
70
+ else:
71
+ zipped_timesteps_inv = zip(timesteps_inv[:-1], timesteps_inv[1:])
72
+ next_v = None
73
+ for _i, (t_cur, t_prev) in enumerate(zipped_timesteps_inv):
74
+ t_i = t_cur / 1000
75
+ t_ip1 = t_prev / 1000
76
+ dt = t_ip1 - t_i
77
+
78
+ if next_v is None:
79
+ latent_model_input = torch.cat([x_t, x_t]) if pipe.do_classifier_free_guidance else (x_t)
80
+ v_tar = calc_v_sd3(
81
+ pipe, latent_model_input, src_prompt_embeds_all,
82
+ src_pooled_prompt_embeds_all, src_guidance_scale, t_cur,
83
+ )
84
+ else:
85
+ v_tar = next_v
86
+
87
+ x_t = x_t.to(torch.float32)
88
+ x_t_next = x_t + v_tar * dt
89
+ x_t_next = x_t_next.to(pipe.dtype)
90
+
91
+ latent_model_input = torch.cat([x_t_next, x_t_next]) if pipe.do_classifier_free_guidance else (x_t_next)
92
+ v_tar_next = calc_v_sd3(
93
+ pipe, latent_model_input, src_prompt_embeds_all,
94
+ src_pooled_prompt_embeds_all, src_guidance_scale, t_prev,
95
+ )
96
+ next_v = v_tar_next
97
+ x_t = x_t + v_tar_next * dt
98
+ x_t = x_t.to(pipe.dtype)
99
+
100
+ return x_t
101
+
102
+ @torch.no_grad()
103
+ def initialization(
104
+ pipe: StableDiffusion3Pipeline, scheduler: FlowMatchEulerDiscreteScheduler,
105
+ T_steps: int, n_start: int, x0_src: torch.Tensor,
106
+ src_prompt: str, negative_prompt: str, src_guidance_scale: float,
107
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
108
+ """
109
+ Initialize the inversion process by preparing the latent tensor and prompt embeddings, and performing UniInv.
110
+
111
+ Args:
112
+ pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
113
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
114
+ T_steps (int): The total number of timesteps for the diffusion process.
115
+ n_start (int): The number of initial timesteps to skip.
116
+ x0_src (torch.Tensor): The source latent tensor.
117
+ src_prompt (str): The source text prompt.
118
+ negative_prompt (str): The negative text prompt for classifier-free guidance.
119
+ src_guidance_scale (float): The guidance scale for classifier-free guidance.
120
+ Returns:
121
+ Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
122
+ - The inverted latent tensor.
123
+ - The original source latent tensor.
124
+ - The timesteps for the diffusion process.
125
+ - The text embeddings for the source prompt.
126
+ - The pooled text embeddings for the source prompt.
127
+ """
128
+ pipe._guidance_scale = src_guidance_scale
129
+ (
130
+ src_prompt_embeds,
131
+ src_negative_prompt_embeds,
132
+ src_pooled_prompt_embeds,
133
+ src_negative_pooled_prompt_embeds,
134
+ ) = pipe.encode_prompt(
135
+ prompt=src_prompt,
136
+ prompt_2=None,
137
+ prompt_3=None,
138
+ negative_prompt=negative_prompt,
139
+ do_classifier_free_guidance=pipe.do_classifier_free_guidance,
140
+ device=pipe.device,
141
+ )
142
+ src_prompt_embeds_all = torch.cat([src_negative_prompt_embeds, src_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else src_prompt_embeds
143
+ src_pooled_prompt_embeds_all = torch.cat([src_negative_pooled_prompt_embeds, src_pooled_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else src_pooled_prompt_embeds
144
+
145
+ timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, x0_src.device, timesteps=None)
146
+ pipe._num_timesteps = len(timesteps)
147
+
148
+ x_t = uniinv(
149
+ pipe, timesteps, n_start, x0_src, src_prompt_embeds_all,
150
+ src_pooled_prompt_embeds_all, src_guidance_scale,
151
+ )
152
+
153
+ return x_t, x0_src, timesteps
154
+
155
+ @torch.no_grad()
156
+ def sd3_denoise(
157
+ pipe: StableDiffusion3Pipeline, timesteps: torch.Tensor, n_start: int,
158
+ x_t: torch.Tensor, prompt_embeds_all: torch.Tensor,
159
+ pooled_prompt_embeds_all: torch.Tensor, guidance_scale: float,
160
+ ) -> torch.Tensor:
161
+ """
162
+ Perform the denoising process for Stable Diffusion 3.
163
+
164
+ Args:
165
+ pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
166
+ timesteps (torch.Tensor): The timesteps for the diffusion process.
167
+ n_start (int): The number of initial timesteps to skip.
168
+ x_t (torch.Tensor): The latent tensor at the starting timestep.
169
+ prompt_embeds_all (torch.Tensor): The text embeddings for the prompt.
170
+ pooled_prompt_embeds_all (torch.Tensor): The pooled text embeddings for the prompt.
171
+ guidance_scale (float): The guidance scale for classifier-free guidance.
172
+ Returns:
173
+ torch.Tensor: The denoised latent tensor.
174
+ """
175
+ f_xt = x_t.clone()
176
+ for i, t in enumerate(timesteps[n_start:]):
177
+ t_i = t / 1000
178
+ if i + 1 < len(timesteps[n_start:]):
179
+ t_im1 = (timesteps[n_start + i + 1]) / 1000
180
+ else:
181
+ t_im1 = torch.zeros_like(t_i).to(t_i.device)
182
+ dt = t_im1 - t_i
183
+
184
+ latent_model_input = torch.cat([f_xt, f_xt]) if pipe.do_classifier_free_guidance else (f_xt)
185
+ v_tar = calc_v_sd3(
186
+ pipe, latent_model_input, prompt_embeds_all,
187
+ pooled_prompt_embeds_all, guidance_scale, t,
188
+ )
189
+ f_xt = f_xt.to(torch.float32)
190
+ f_xt = f_xt + v_tar * dt
191
+ f_xt = f_xt.to(pipe.dtype)
192
+
193
+ return f_xt
194
+
195
+ @torch.no_grad()
196
+ def sd3_editing(
197
+ pipe: StableDiffusion3Pipeline, scheduler: FlowMatchEulerDiscreteScheduler,
198
+ T_steps: int, n_max: int, x0_src: torch.Tensor, src_prompt: str,
199
+ tar_prompt: str, negative_prompt: str, src_guidance_scale: float,
200
+ tar_guidance_scale: float, flowopt_iterations: int, eta: float,
201
+ ) -> Iterator[List[Tuple[Image.Image, str]]]:
202
+ """
203
+ Perform the editing process for Stable Diffusion 3 using FlowOpt.
204
+
205
+ Args:
206
+ pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
207
+ scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
208
+ T_steps (int): The total number of timesteps for the diffusion process.
209
+ n_max (int): The maximum number of timesteps to consider.
210
+ x0_src (torch.Tensor): The source latent tensor.
211
+ src_prompt (str): The source text prompt.
212
+ tar_prompt (str): The target text prompt for editing.
213
+ negative_prompt (str): The negative text prompt for classifier-free guidance.
214
+ src_guidance_scale (float): The guidance scale for the source prompt.
215
+ tar_guidance_scale (float): The guidance scale for the target prompt.
216
+ flowopt_iterations (int): The number of FlowOpt iterations to perform.
217
+ eta (float): The step size for the FlowOpt update.
218
+ Yields:
219
+ Iterator[List[Tuple[Image.Image, str]]]: A list of tuples containing the generated images and their corresponding iteration labels.
220
+ """
221
+ n_start = T_steps - n_max
222
+ x_t, x0_src, timesteps = initialization(
223
+ pipe, scheduler, T_steps, n_start, x0_src, src_prompt,
224
+ negative_prompt, src_guidance_scale,
225
+ )
226
+
227
+ pipe._guidance_scale = tar_guidance_scale
228
+ (
229
+ tar_prompt_embeds,
230
+ tar_negative_prompt_embeds,
231
+ tar_pooled_prompt_embeds,
232
+ tar_negative_pooled_prompt_embeds,
233
+ ) = pipe.encode_prompt(
234
+ prompt=tar_prompt,
235
+ prompt_2=None,
236
+ prompt_3=None,
237
+ negative_prompt=negative_prompt,
238
+ do_classifier_free_guidance=pipe.do_classifier_free_guidance,
239
+ device=pipe.device,
240
+ )
241
+
242
+ tar_prompt_embeds_all = torch.cat([tar_negative_prompt_embeds, tar_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else tar_prompt_embeds
243
+ tar_pooled_prompt_embeds_all = torch.cat([tar_negative_pooled_prompt_embeds, tar_pooled_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else tar_pooled_prompt_embeds
244
+
245
+ history = []
246
+ j_star = x0_src.clone().to(torch.float32) # y
247
+ for flowopt_iter in range(flowopt_iterations + 1):
248
+ f_xt = sd3_denoise(
249
+ pipe, timesteps, n_start, x_t, tar_prompt_embeds_all,
250
+ tar_pooled_prompt_embeds_all, tar_guidance_scale,
251
+ ) # Eq. (3)
252
+
253
+ if flowopt_iter < flowopt_iterations:
254
+ x_t = x_t.to(torch.float32)
255
+ x_t = x_t - eta * (f_xt - j_star) # Eq. (6) with c = c_tar
256
+ x_t = x_t.to(x0_src.dtype)
257
+
258
+ x0_flowopt = f_xt.clone()
259
+ x0_flowopt_denorm = (x0_flowopt / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
260
+ with torch.autocast("cuda"), torch.inference_mode():
261
+ x0_flowopt_image = pipe.vae.decode(x0_flowopt_denorm, return_dict=False)[0].clamp(-1, 1)
262
+ x0_flowopt_image_pil = pipe.image_processor.postprocess(x0_flowopt_image)[0]
263
+ history.append((x0_flowopt_image_pil, f"Iteration {flowopt_iter}"))
264
+ yield history