Spaces:
Running
on
Zero
Running
on
Zero
wangshuai6
commited on
Commit
·
4d3bd2d
1
Parent(s):
de07e5c
init
Browse files
app.py
CHANGED
|
@@ -33,6 +33,7 @@
|
|
| 33 |
# step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
| 34 |
import random
|
| 35 |
import os
|
|
|
|
| 36 |
import spaces
|
| 37 |
import torch
|
| 38 |
import argparse
|
|
@@ -69,8 +70,8 @@ def load_model(weight_dict, denoiser):
|
|
| 69 |
class Pipeline:
|
| 70 |
def __init__(self, vae, denoiser, conditioner, resolution):
|
| 71 |
self.vae = vae
|
| 72 |
-
self.denoiser = denoiser
|
| 73 |
-
self.conditioner = conditioner
|
| 74 |
self.resolution = resolution
|
| 75 |
self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
|
| 76 |
# self.denoiser.compile()
|
|
@@ -78,10 +79,10 @@ class Pipeline:
|
|
| 78 |
def __del__(self):
|
| 79 |
self.tmp_dir.cleanup()
|
| 80 |
|
| 81 |
-
@spaces.GPU
|
| 82 |
@torch.no_grad()
|
| 83 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 84 |
-
def __call__(self, y, seed,
|
| 85 |
diffusion_sampler = AdamLMSampler(
|
| 86 |
order=order,
|
| 87 |
scheduler=LinearScheduler(),
|
|
@@ -91,18 +92,18 @@ class Pipeline:
|
|
| 91 |
timeshift=timeshift
|
| 92 |
)
|
| 93 |
generator = torch.Generator(device="cpu").manual_seed(seed)
|
| 94 |
-
|
| 95 |
-
image_width = image_width // 32 * 32
|
| 96 |
-
self.denoiser.decoder_patch_scaling_h = image_height / 512
|
| 97 |
-
self.denoiser.decoder_patch_scaling_w = image_width / 512
|
| 98 |
-
xT = torch.randn((1, 3, image_height, image_width), device="cpu", dtype=torch.float32,
|
| 99 |
generator=generator).cuda()
|
|
|
|
|
|
|
| 100 |
with torch.no_grad():
|
| 101 |
condition, uncondition = conditioner([y,]*1)
|
|
|
|
| 102 |
|
| 103 |
-
|
| 104 |
# Sample images:
|
| 105 |
samples, trajs = diffusion_sampler(denoiser, xT, condition, uncondition, return_x_trajs=True)
|
|
|
|
| 106 |
|
| 107 |
def decode_images(samples):
|
| 108 |
samples = vae.decode(samples)
|
|
@@ -114,35 +115,35 @@ class Pipeline:
|
|
| 114 |
images.append(image)
|
| 115 |
return images
|
| 116 |
|
| 117 |
-
def decode_trajs(trajs):
|
| 118 |
-
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
|
| 122 |
-
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
|
| 138 |
images = decode_images(samples)
|
| 139 |
-
animations = decode_trajs(trajs)
|
| 140 |
|
| 141 |
-
return images
|
| 142 |
|
| 143 |
if __name__ == "__main__":
|
| 144 |
parser = argparse.ArgumentParser()
|
| 145 |
-
parser.add_argument("--config", type=str, default="configs_t2i/
|
| 146 |
parser.add_argument("--resolution", type=int, default=512)
|
| 147 |
parser.add_argument("--model_id", type=str, default="MCG-NJU/PixNerd-XXL-P16-T2I")
|
| 148 |
parser.add_argument("--ckpt_path", type=str, default="models")
|
|
@@ -167,6 +168,7 @@ if __name__ == "__main__":
|
|
| 167 |
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 168 |
denoiser = load_model(ckpt, denoiser)
|
| 169 |
denoiser = denoiser.cuda()
|
|
|
|
| 170 |
vae = vae.cuda()
|
| 171 |
denoiser.eval()
|
| 172 |
|
|
@@ -179,27 +181,23 @@ if __name__ == "__main__":
|
|
| 179 |
with gr.Column(scale=1):
|
| 180 |
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=25)
|
| 181 |
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
| 182 |
-
image_height = gr.Slider(minimum=128, maximum=1024, step=32, label="image height", value=512)
|
| 183 |
-
image_width = gr.Slider(minimum=128, maximum=1024, step=32, label="image width", value=512)
|
| 184 |
label = gr.Textbox(label="positive prompt", value="a photo of a cat")
|
| 185 |
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
| 186 |
timeshift = gr.Slider(minimum=0.1, maximum=5.0, step=0.1, label="timeshift", value=3.0)
|
| 187 |
order = gr.Slider(minimum=1, maximum=4, step=1, label="order", value=2)
|
| 188 |
with gr.Column(scale=2):
|
| 189 |
btn = gr.Button("Generate")
|
| 190 |
-
output_sample = gr.
|
| 191 |
-
with gr.Column(scale=2):
|
| 192 |
-
|
| 193 |
|
| 194 |
btn.click(fn=pipeline,
|
| 195 |
inputs=[
|
| 196 |
label,
|
| 197 |
seed,
|
| 198 |
-
image_height,
|
| 199 |
-
image_width,
|
| 200 |
num_steps,
|
| 201 |
guidance,
|
| 202 |
timeshift,
|
| 203 |
order
|
| 204 |
-
], outputs=[output_sample
|
| 205 |
demo.launch()
|
|
|
|
| 33 |
# step_fn: src.diffusion.stateful_flow_matching.sampling.ode_step_fn
|
| 34 |
import random
|
| 35 |
import os
|
| 36 |
+
import time
|
| 37 |
import spaces
|
| 38 |
import torch
|
| 39 |
import argparse
|
|
|
|
| 70 |
class Pipeline:
|
| 71 |
def __init__(self, vae, denoiser, conditioner, resolution):
|
| 72 |
self.vae = vae
|
| 73 |
+
self.denoiser = denoiser.cuda()
|
| 74 |
+
self.conditioner = conditioner.cuda()
|
| 75 |
self.resolution = resolution
|
| 76 |
self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
|
| 77 |
# self.denoiser.compile()
|
|
|
|
| 79 |
def __del__(self):
|
| 80 |
self.tmp_dir.cleanup()
|
| 81 |
|
| 82 |
+
# @spaces.GPU
|
| 83 |
@torch.no_grad()
|
| 84 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 85 |
+
def __call__(self, y, seed, num_steps, guidance, timeshift, order):
|
| 86 |
diffusion_sampler = AdamLMSampler(
|
| 87 |
order=order,
|
| 88 |
scheduler=LinearScheduler(),
|
|
|
|
| 92 |
timeshift=timeshift
|
| 93 |
)
|
| 94 |
generator = torch.Generator(device="cpu").manual_seed(seed)
|
| 95 |
+
xT = torch.randn((1, 3, 512, 512), device="cpu", dtype=torch.float32,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 96 |
generator=generator).cuda()
|
| 97 |
+
|
| 98 |
+
start = time.time()
|
| 99 |
with torch.no_grad():
|
| 100 |
condition, uncondition = conditioner([y,]*1)
|
| 101 |
+
print("conditioner:",time.time() - start)
|
| 102 |
|
| 103 |
+
start = time.time()
|
| 104 |
# Sample images:
|
| 105 |
samples, trajs = diffusion_sampler(denoiser, xT, condition, uncondition, return_x_trajs=True)
|
| 106 |
+
print("diffusion:",time.time() - start)
|
| 107 |
|
| 108 |
def decode_images(samples):
|
| 109 |
samples = vae.decode(samples)
|
|
|
|
| 115 |
images.append(image)
|
| 116 |
return images
|
| 117 |
|
| 118 |
+
# def decode_trajs(trajs):
|
| 119 |
+
# cat_trajs = torch.stack(trajs, dim=0).permute(1, 0, 2, 3, 4)
|
| 120 |
+
# animations = []
|
| 121 |
+
# for i in range(cat_trajs.shape[0]):
|
| 122 |
+
# frames = decode_images(
|
| 123 |
+
# cat_trajs[i]
|
| 124 |
+
# )
|
| 125 |
+
# # 生成唯一文件名(结合seed和样本索引,避免冲突)
|
| 126 |
+
# gif_filename = f"{random.randint(0, 100000)}.gif"
|
| 127 |
+
# gif_path = os.path.join(self.tmp_dir.name, gif_filename)
|
| 128 |
+
# frames[0].save(
|
| 129 |
+
# gif_path,
|
| 130 |
+
# format="GIF",
|
| 131 |
+
# append_images=frames[1:],
|
| 132 |
+
# save_all=True,
|
| 133 |
+
# duration=200,
|
| 134 |
+
# loop=0
|
| 135 |
+
# )
|
| 136 |
+
# animations.append(gif_path)
|
| 137 |
+
# return animations
|
| 138 |
|
| 139 |
images = decode_images(samples)
|
| 140 |
+
# animations = decode_trajs(trajs)
|
| 141 |
|
| 142 |
+
return images[0]
|
| 143 |
|
| 144 |
if __name__ == "__main__":
|
| 145 |
parser = argparse.ArgumentParser()
|
| 146 |
+
parser.add_argument("--config", type=str, default="configs_t2i/sft_res512.yaml")
|
| 147 |
parser.add_argument("--resolution", type=int, default=512)
|
| 148 |
parser.add_argument("--model_id", type=str, default="MCG-NJU/PixNerd-XXL-P16-T2I")
|
| 149 |
parser.add_argument("--ckpt_path", type=str, default="models")
|
|
|
|
| 168 |
ckpt = torch.load(ckpt_path, map_location="cpu")
|
| 169 |
denoiser = load_model(ckpt, denoiser)
|
| 170 |
denoiser = denoiser.cuda()
|
| 171 |
+
conditioner = conditioner.cuda()
|
| 172 |
vae = vae.cuda()
|
| 173 |
denoiser.eval()
|
| 174 |
|
|
|
|
| 181 |
with gr.Column(scale=1):
|
| 182 |
num_steps = gr.Slider(minimum=1, maximum=100, step=1, label="num steps", value=25)
|
| 183 |
guidance = gr.Slider(minimum=0.1, maximum=10.0, step=0.1, label="CFG", value=4.0)
|
|
|
|
|
|
|
| 184 |
label = gr.Textbox(label="positive prompt", value="a photo of a cat")
|
| 185 |
seed = gr.Slider(minimum=0, maximum=1000000, step=1, label="seed", value=0)
|
| 186 |
timeshift = gr.Slider(minimum=0.1, maximum=5.0, step=0.1, label="timeshift", value=3.0)
|
| 187 |
order = gr.Slider(minimum=1, maximum=4, step=1, label="order", value=2)
|
| 188 |
with gr.Column(scale=2):
|
| 189 |
btn = gr.Button("Generate")
|
| 190 |
+
output_sample = gr.Image(label="Images")
|
| 191 |
+
# with gr.Column(scale=2):
|
| 192 |
+
# output_trajs = gr.Gallery(label="Trajs of Diffusion", columns=2, rows=2)
|
| 193 |
|
| 194 |
btn.click(fn=pipeline,
|
| 195 |
inputs=[
|
| 196 |
label,
|
| 197 |
seed,
|
|
|
|
|
|
|
| 198 |
num_steps,
|
| 199 |
guidance,
|
| 200 |
timeshift,
|
| 201 |
order
|
| 202 |
+
], outputs=[output_sample])
|
| 203 |
demo.launch()
|