Spaces:
Running
on
Zero
Running
on
Zero
wangshuai6
commited on
Commit
·
e321f58
1
Parent(s):
56238f0
init
Browse files
app.py
CHANGED
|
@@ -43,6 +43,8 @@ from src.diffusion.flow_matching.scheduling import LinearScheduler
|
|
| 43 |
from PIL import Image
|
| 44 |
import gradio as gr
|
| 45 |
import tempfile
|
|
|
|
|
|
|
| 46 |
from huggingface_hub import snapshot_download
|
| 47 |
|
| 48 |
|
|
@@ -65,9 +67,9 @@ def load_model(weight_dict, denoiser):
|
|
| 65 |
|
| 66 |
class Pipeline:
|
| 67 |
def __init__(self, vae, denoiser, conditioner, resolution):
|
| 68 |
-
self.vae = vae
|
| 69 |
-
self.denoiser = denoiser
|
| 70 |
-
self.conditioner = conditioner
|
| 71 |
self.conditioner.compile()
|
| 72 |
self.resolution = resolution
|
| 73 |
self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
|
|
@@ -76,6 +78,7 @@ class Pipeline:
|
|
| 76 |
def __del__(self):
|
| 77 |
self.tmp_dir.cleanup()
|
| 78 |
|
|
|
|
| 79 |
@torch.no_grad()
|
| 80 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 81 |
def __call__(self, y, num_images, seed, image_height, image_width, num_steps, guidance, timeshift, order):
|
|
@@ -93,8 +96,7 @@ class Pipeline:
|
|
| 93 |
self.denoiser.decoder_patch_scaling_h = image_height / 512
|
| 94 |
self.denoiser.decoder_patch_scaling_w = image_width / 512
|
| 95 |
xT = torch.randn((num_images, 3, image_height, image_width), device="cpu", dtype=torch.float32,
|
| 96 |
-
generator=generator)
|
| 97 |
-
xT = xT.to("cuda")
|
| 98 |
with torch.no_grad():
|
| 99 |
condition, uncondition = conditioner([y,]*num_images)
|
| 100 |
|
|
|
|
| 43 |
from PIL import Image
|
| 44 |
import gradio as gr
|
| 45 |
import tempfile
|
| 46 |
+
import spaces
|
| 47 |
+
|
| 48 |
from huggingface_hub import snapshot_download
|
| 49 |
|
| 50 |
|
|
|
|
| 67 |
|
| 68 |
class Pipeline:
|
| 69 |
def __init__(self, vae, denoiser, conditioner, resolution):
|
| 70 |
+
self.vae = vae
|
| 71 |
+
self.denoiser = denoiser
|
| 72 |
+
self.conditioner = conditioner
|
| 73 |
self.conditioner.compile()
|
| 74 |
self.resolution = resolution
|
| 75 |
self.tmp_dir = tempfile.TemporaryDirectory(prefix="traj_gifs_")
|
|
|
|
| 78 |
def __del__(self):
|
| 79 |
self.tmp_dir.cleanup()
|
| 80 |
|
| 81 |
+
@spaces.GPU
|
| 82 |
@torch.no_grad()
|
| 83 |
@torch.autocast(device_type="cuda", dtype=torch.bfloat16)
|
| 84 |
def __call__(self, y, num_images, seed, image_height, image_width, num_steps, guidance, timeshift, order):
|
|
|
|
| 96 |
self.denoiser.decoder_patch_scaling_h = image_height / 512
|
| 97 |
self.denoiser.decoder_patch_scaling_w = image_width / 512
|
| 98 |
xT = torch.randn((num_images, 3, image_height, image_width), device="cpu", dtype=torch.float32,
|
| 99 |
+
generator=generator).cuda()
|
|
|
|
| 100 |
with torch.no_grad():
|
| 101 |
condition, uncondition = conditioner([y,]*num_images)
|
| 102 |
|