File size: 11,563 Bytes
8d5a128
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
from typing import Iterator, List, Tuple

import torch
from diffusers import FlowMatchEulerDiscreteScheduler, StableDiffusion3Pipeline
from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import retrieve_timesteps
from PIL import Image

@torch.no_grad()
def calc_v_sd3(
    pipe: StableDiffusion3Pipeline, latent_model_input: torch.Tensor,
    prompt_embeds: torch.Tensor, pooled_prompt_embeds: torch.Tensor,
    guidance_scale: float, t: torch.Tensor,
) -> torch.Tensor:
    """
    Calculate the velocity (v) for Stable Diffusion 3.

    Args:
        pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
        latent_model_input (torch.Tensor): The input latent tensor.
        prompt_embeds (torch.Tensor): The text embeddings for the prompt.
        pooled_prompt_embeds (torch.Tensor): The pooled text embeddings for the prompt.
        guidance_scale (float): The guidance scale for classifier-free guidance.
        t (torch.Tensor): The current timestep.
    Returns:
        torch.Tensor: The predicted noise (velocity).
    """
    timestep = t.expand(latent_model_input.shape[0])

    noise_pred = pipe.transformer(
        hidden_states=latent_model_input,
        timestep=timestep,
        encoder_hidden_states=prompt_embeds,
        pooled_projections=pooled_prompt_embeds,
        joint_attention_kwargs=None,
        return_dict=False,
    )[0]

    # perform guidance source
    if pipe.do_classifier_free_guidance:
        noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
        noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)

    return noise_pred

# https://github.com/DSL-Lab/UniEdit-Flow
@torch.no_grad()
def uniinv(
    pipe: StableDiffusion3Pipeline, timesteps: torch.Tensor, n_start: int,
    x0_src: torch.Tensor, src_prompt_embeds_all: torch.Tensor,
    src_pooled_prompt_embeds_all: torch.Tensor, src_guidance_scale: float,
) -> torch.Tensor:
    """
    Perform the UniInv inversion process for Stable Diffusion 3.

    Args:
        pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
        timesteps (torch.Tensor): The timesteps for the diffusion process.
        n_start (int): The number of initial timesteps to skip.
        x0_src (torch.Tensor): The source latent tensor.
        src_prompt_embeds_all (torch.Tensor): The text embeddings for the source prompt.
        src_pooled_prompt_embeds_all (torch.Tensor): The pooled text embeddings for the source prompt.
        src_guidance_scale (float): The guidance scale for classifier-free guidance.
    Returns:
        torch.Tensor: The inverted latent tensor.
    """
    x_t = x0_src.clone()
    timesteps_inv = torch.cat([torch.tensor([0.0], device=pipe.device), timesteps.flip(dims=(0,))], dim=0)
    if n_start > 0:
        zipped_timesteps_inv = zip(timesteps_inv[:-n_start - 1], timesteps_inv[1:-n_start])
    else:
        zipped_timesteps_inv = zip(timesteps_inv[:-1], timesteps_inv[1:])
    next_v = None
    for _i, (t_cur, t_prev) in enumerate(zipped_timesteps_inv):
        t_i = t_cur / 1000
        t_ip1 = t_prev / 1000
        dt = t_ip1 - t_i

        if next_v is None:
            latent_model_input = torch.cat([x_t, x_t]) if pipe.do_classifier_free_guidance else (x_t)
            v_tar = calc_v_sd3(
                pipe, latent_model_input, src_prompt_embeds_all,
                src_pooled_prompt_embeds_all, src_guidance_scale, t_cur,
            )
        else:
            v_tar = next_v

        x_t = x_t.to(torch.float32)
        x_t_next = x_t + v_tar * dt
        x_t_next = x_t_next.to(pipe.dtype)

        latent_model_input = torch.cat([x_t_next, x_t_next]) if pipe.do_classifier_free_guidance else (x_t_next)
        v_tar_next = calc_v_sd3(
            pipe, latent_model_input, src_prompt_embeds_all,
            src_pooled_prompt_embeds_all, src_guidance_scale, t_prev,
        )
        next_v = v_tar_next
        x_t = x_t + v_tar_next * dt
        x_t = x_t.to(pipe.dtype)

    return x_t

@torch.no_grad()
def initialization(
    pipe: StableDiffusion3Pipeline, scheduler: FlowMatchEulerDiscreteScheduler,
    T_steps: int, n_start: int, x0_src: torch.Tensor,
    src_prompt: str, negative_prompt: str, src_guidance_scale: float,
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
    """
    Initialize the inversion process by preparing the latent tensor and prompt embeddings, and performing UniInv.

    Args:
        pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
        scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
        T_steps (int): The total number of timesteps for the diffusion process.
        n_start (int): The number of initial timesteps to skip.
        x0_src (torch.Tensor): The source latent tensor.
        src_prompt (str): The source text prompt.
        negative_prompt (str): The negative text prompt for classifier-free guidance.
        src_guidance_scale (float): The guidance scale for classifier-free guidance.
    Returns:
        Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
            - The inverted latent tensor.
            - The original source latent tensor.
            - The timesteps for the diffusion process.
            - The text embeddings for the source prompt.
            - The pooled text embeddings for the source prompt.
    """
    pipe._guidance_scale = src_guidance_scale
    (
        src_prompt_embeds,
        src_negative_prompt_embeds,
        src_pooled_prompt_embeds,
        src_negative_pooled_prompt_embeds,
    ) = pipe.encode_prompt(
        prompt=src_prompt,
        prompt_2=None,
        prompt_3=None,
        negative_prompt=negative_prompt,
        do_classifier_free_guidance=pipe.do_classifier_free_guidance,
        device=pipe.device,
    )
    src_prompt_embeds_all = torch.cat([src_negative_prompt_embeds, src_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else src_prompt_embeds
    src_pooled_prompt_embeds_all = torch.cat([src_negative_pooled_prompt_embeds, src_pooled_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else src_pooled_prompt_embeds

    timesteps, T_steps = retrieve_timesteps(scheduler, T_steps, x0_src.device, timesteps=None)
    pipe._num_timesteps = len(timesteps)

    x_t = uniinv(
        pipe, timesteps, n_start, x0_src, src_prompt_embeds_all,
        src_pooled_prompt_embeds_all, src_guidance_scale,
    )

    return x_t, x0_src, timesteps

@torch.no_grad()
def sd3_denoise(
    pipe: StableDiffusion3Pipeline, timesteps: torch.Tensor, n_start: int,
    x_t: torch.Tensor, prompt_embeds_all: torch.Tensor,
    pooled_prompt_embeds_all: torch.Tensor, guidance_scale: float,
) -> torch.Tensor:
    """
    Perform the denoising process for Stable Diffusion 3.

    Args:
        pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
        timesteps (torch.Tensor): The timesteps for the diffusion process.
        n_start (int): The number of initial timesteps to skip.
        x_t (torch.Tensor): The latent tensor at the starting timestep.
        prompt_embeds_all (torch.Tensor): The text embeddings for the prompt.
        pooled_prompt_embeds_all (torch.Tensor): The pooled text embeddings for the prompt.
        guidance_scale (float): The guidance scale for classifier-free guidance.
    Returns:
        torch.Tensor: The denoised latent tensor.
    """
    f_xt = x_t.clone()
    for i, t in enumerate(timesteps[n_start:]):
        t_i = t / 1000
        if i + 1 < len(timesteps[n_start:]):
            t_im1 = (timesteps[n_start + i + 1]) / 1000
        else:
            t_im1 = torch.zeros_like(t_i).to(t_i.device)
        dt = t_im1 - t_i

        latent_model_input = torch.cat([f_xt, f_xt]) if pipe.do_classifier_free_guidance else (f_xt)
        v_tar = calc_v_sd3(
            pipe, latent_model_input, prompt_embeds_all,
            pooled_prompt_embeds_all, guidance_scale, t,
        )
        f_xt = f_xt.to(torch.float32)
        f_xt = f_xt + v_tar * dt
        f_xt = f_xt.to(pipe.dtype)

    return f_xt

@torch.no_grad()
def sd3_editing(
    pipe: StableDiffusion3Pipeline, scheduler: FlowMatchEulerDiscreteScheduler,
    T_steps: int, n_max: int, x0_src: torch.Tensor, src_prompt: str,
    tar_prompt: str, negative_prompt: str, src_guidance_scale: float,
    tar_guidance_scale: float, flowopt_iterations: int, eta: float,
) -> Iterator[List[Tuple[Image.Image, str]]]:
    """
    Perform the editing process for Stable Diffusion 3 using FlowOpt.

    Args:
        pipe (StableDiffusion3Pipeline): The Stable Diffusion 3 pipeline.
        scheduler (FlowMatchEulerDiscreteScheduler): The scheduler for the diffusion process.
        T_steps (int): The total number of timesteps for the diffusion process.
        n_max (int): The maximum number of timesteps to consider.
        x0_src (torch.Tensor): The source latent tensor.
        src_prompt (str): The source text prompt.
        tar_prompt (str): The target text prompt for editing.
        negative_prompt (str): The negative text prompt for classifier-free guidance.
        src_guidance_scale (float): The guidance scale for the source prompt.
        tar_guidance_scale (float): The guidance scale for the target prompt.
        flowopt_iterations (int): The number of FlowOpt iterations to perform.
        eta (float): The step size for the FlowOpt update.
    Yields:
        Iterator[List[Tuple[Image.Image, str]]]: A list of tuples containing the generated images and their corresponding iteration labels.
    """
    n_start = T_steps - n_max
    x_t, x0_src, timesteps = initialization(
        pipe, scheduler, T_steps, n_start, x0_src, src_prompt,
        negative_prompt, src_guidance_scale,
    )

    pipe._guidance_scale = tar_guidance_scale
    (
        tar_prompt_embeds,
        tar_negative_prompt_embeds,
        tar_pooled_prompt_embeds,
        tar_negative_pooled_prompt_embeds,
    ) = pipe.encode_prompt(
        prompt=tar_prompt,
        prompt_2=None,
        prompt_3=None,
        negative_prompt=negative_prompt,
        do_classifier_free_guidance=pipe.do_classifier_free_guidance,
        device=pipe.device,
    )

    tar_prompt_embeds_all = torch.cat([tar_negative_prompt_embeds, tar_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else tar_prompt_embeds
    tar_pooled_prompt_embeds_all = torch.cat([tar_negative_pooled_prompt_embeds, tar_pooled_prompt_embeds], dim=0) if pipe.do_classifier_free_guidance else tar_pooled_prompt_embeds

    history = []
    j_star = x0_src.clone().to(torch.float32)  # y
    for flowopt_iter in range(flowopt_iterations + 1):
        f_xt = sd3_denoise(
            pipe, timesteps, n_start, x_t, tar_prompt_embeds_all,
            tar_pooled_prompt_embeds_all, tar_guidance_scale,
        )  # Eq. (3)

        if flowopt_iter < flowopt_iterations:
            x_t = x_t.to(torch.float32)
            x_t = x_t - eta * (f_xt - j_star)  # Eq. (6) with c = c_tar
            x_t = x_t.to(x0_src.dtype)

        x0_flowopt = f_xt.clone()
        x0_flowopt_denorm = (x0_flowopt / pipe.vae.config.scaling_factor) + pipe.vae.config.shift_factor
        with torch.autocast("cuda"), torch.inference_mode():
            x0_flowopt_image = pipe.vae.decode(x0_flowopt_denorm, return_dict=False)[0].clamp(-1, 1)
        x0_flowopt_image_pil = pipe.image_processor.postprocess(x0_flowopt_image)[0]
        history.append((x0_flowopt_image_pil, f"Iteration {flowopt_iter}"))
        yield history