Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
|
@@ -149,7 +149,7 @@ class AnimateController:
|
|
| 149 |
for key in f.keys():
|
| 150 |
self.lora_model_state_dict[key] = f.get_tensor(key)
|
| 151 |
return gr.Dropdown.update()
|
| 152 |
-
|
| 153 |
def animate(
|
| 154 |
self,
|
| 155 |
lora_alpha_slider,
|
|
@@ -174,8 +174,8 @@ class AnimateController:
|
|
| 174 |
**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
| 175 |
).to("cuda")
|
| 176 |
|
| 177 |
-
|
| 178 |
-
|
| 179 |
|
| 180 |
pipeline.to("cuda")
|
| 181 |
|
|
@@ -185,15 +185,19 @@ class AnimateController:
|
|
| 185 |
torch.seed()
|
| 186 |
seed = torch.initial_seed()
|
| 187 |
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
|
| 191 |
-
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
|
| 195 |
-
|
| 196 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 197 |
|
| 198 |
save_sample_path = os.path.join(
|
| 199 |
self.savedir_sample, f"{sample_idx}.mp4")
|
|
|
|
| 149 |
for key in f.keys():
|
| 150 |
self.lora_model_state_dict[key] = f.get_tensor(key)
|
| 151 |
return gr.Dropdown.update()
|
| 152 |
+
@torch.no_grad()
|
| 153 |
def animate(
|
| 154 |
self,
|
| 155 |
lora_alpha_slider,
|
|
|
|
| 174 |
**OmegaConf.to_container(self.inference_config.noise_scheduler_kwargs))
|
| 175 |
).to("cuda")
|
| 176 |
|
| 177 |
+
original_state_dict = {k: v.cpu().clone() for k, v in model.state_dict().items()}
|
| 178 |
+
pipeline.unet = convert_lcm_lora(pipeline.unet, self.lcm_lora_path, spatial_lora_slider)
|
| 179 |
|
| 180 |
pipeline.to("cuda")
|
| 181 |
|
|
|
|
| 185 |
torch.seed()
|
| 186 |
seed = torch.initial_seed()
|
| 187 |
|
| 188 |
+
with torch.autocast("cuda"):
|
| 189 |
+
sample = pipeline(
|
| 190 |
+
prompt_textbox,
|
| 191 |
+
negative_prompt=negative_prompt_textbox,
|
| 192 |
+
num_inference_steps=sample_step_slider,
|
| 193 |
+
guidance_scale=cfg_scale_slider,
|
| 194 |
+
width=width_slider,
|
| 195 |
+
height=height_slider,
|
| 196 |
+
video_length=length_slider,
|
| 197 |
+
).videos
|
| 198 |
+
|
| 199 |
+
pipeline.unet.load(original_state_dict)
|
| 200 |
+
del original_state_dict
|
| 201 |
|
| 202 |
save_sample_path = os.path.join(
|
| 203 |
self.savedir_sample, f"{sample_idx}.mp4")
|