multimodalart HF Staff commited on
Commit
d683c92
·
verified ·
1 Parent(s): d3b1ec0

swap comfyui to diffusers

Browse files
Files changed (1) hide show
  1. app.py +202 -303
app.py CHANGED
@@ -1,179 +1,126 @@
1
  import os
2
- import shutil
3
- import sys
4
- import subprocess
5
- import asyncio
6
- import uuid
7
- import random
8
- import tempfile
9
- from typing import Sequence, Mapping, Any, Union
10
 
 
 
11
  import torch
 
 
 
12
  import gradio as gr
 
 
13
  from PIL import Image
14
- from huggingface_hub import hf_hub_download
15
- import spaces
16
-
17
- # --- 1. Model Download and Setup ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- def hf_hub_download_local(repo_id, filename, local_dir, **kwargs):
20
- """Downloads a file from Hugging Face Hub and symlinks it to a local directory."""
21
- downloaded_path = hf_hub_download(repo_id=repo_id, filename=filename, **kwargs)
22
- os.makedirs(local_dir, exist_ok=True)
23
- base_filename = os.path.basename(filename)
24
- target_path = os.path.join(local_dir, base_filename)
25
-
26
- # Remove existing symlink or file to avoid errors
27
- if os.path.exists(target_path) or os.path.islink(target_path):
28
- os.remove(target_path)
29
-
30
- os.symlink(downloaded_path, target_path)
31
- return target_path
32
-
33
- print("Downloading models from Hugging Face Hub...")
34
- hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/text_encoders/umt5_xxl_fp8_e4m3fn_scaled.safetensors", local_dir="models/text_encoders")
35
- hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
36
- hf_hub_download_local(repo_id="Comfy-Org/Wan_2.2_ComfyUI_Repackaged", filename="split_files/diffusion_models/wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", local_dir="models/unet")
37
- hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/vae/wan_2.1_vae.safetensors", local_dir="models/vae")
38
- hf_hub_download_local(repo_id="Comfy-Org/Wan_2.1_ComfyUI_repackaged", filename="split_files/clip_vision/clip_vision_h.safetensors", local_dir="models/clip_vision")
39
- hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors", local_dir="models/loras")
40
- hf_hub_download_local(repo_id="Kijai/WanVideo_comfy", filename="Wan22-Lightning/Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors", local_dir="models/loras")
41
- print("Downloads complete.")
42
-
43
-
44
- # --- 2. ComfyUI Backend Initialization ---
45
-
46
- def find_path(name: str, path: str = None) -> str:
47
- """Recursively finds a directory with a given name."""
48
- if path is None: path = os.getcwd()
49
- if name in os.listdir(path): return os.path.join(path, name)
50
- parent_directory = os.path.dirname(path)
51
- return find_path(name, parent_directory) if parent_directory != path else None
52
-
53
- def add_comfyui_directory_to_sys_path() -> None:
54
- """Adds the ComfyUI directory to sys.path for imports."""
55
- comfyui_path = find_path("ComfyUI")
56
- if comfyui_path and os.path.isdir(comfyui_path):
57
- sys.path.append(comfyui_path)
58
- print(f"'{comfyui_path}' added to sys.path")
59
-
60
- def add_extra_model_paths() -> None:
61
- """Initializes ComfyUI's folder_paths with custom paths."""
62
- from main import apply_custom_paths
63
- apply_custom_paths()
64
-
65
- def import_custom_nodes() -> None:
66
- """Initializes all ComfyUI custom nodes."""
67
- import nodes
68
- loop = asyncio.new_event_loop()
69
- asyncio.set_event_loop(loop)
70
- loop.run_until_complete(nodes.init_extra_nodes(init_custom_nodes=True))
71
-
72
- print("Setting up ComfyUI paths and nodes...")
73
- add_comfyui_directory_to_sys_path()
74
- add_extra_model_paths()
75
- import_custom_nodes()
76
- print("ComfyUI setup complete.")
77
-
78
-
79
- # --- 3. Global Model & Node Loading and Patching ---
80
-
81
- from nodes import NODE_CLASS_MAPPINGS
82
- import folder_paths
83
- from comfy import model_management
84
-
85
- # Set VRAM mode to HIGH to prevent models from being offloaded from GPU after use.
86
- # model_management.vram_state = model_management.VRAMState.HIGH_VRAM
87
-
88
- MODELS_AND_NODES = {}
89
-
90
- def get_value_at_index(obj: Union[Sequence, Mapping], index: int) -> Any:
91
- """Helper to safely access outputs from ComfyUI nodes, which are often tuples."""
92
- try:
93
- return obj[index]
94
- except (KeyError, TypeError):
95
- # Fallback for custom nodes that might return a dictionary with a 'result' key
96
- if isinstance(obj, Mapping) and "result" in obj:
97
- return obj["result"][index]
98
- raise
99
-
100
- print("Loading models and instantiating nodes into memory. This may take a few minutes...")
101
-
102
- # Instantiate Node Classes that will be used for loading and patching
103
- cliploader = NODE_CLASS_MAPPINGS["CLIPLoader"]()
104
- unetloader = NODE_CLASS_MAPPINGS["UNETLoader"]()
105
- vaeloader = NODE_CLASS_MAPPINGS["VAELoader"]()
106
- clipvisionloader = NODE_CLASS_MAPPINGS["CLIPVisionLoader"]()
107
- loraloadermodelonly = NODE_CLASS_MAPPINGS["LoraLoaderModelOnly"]()
108
- modelsamplingsd3 = NODE_CLASS_MAPPINGS["ModelSamplingSD3"]()
109
- pathchsageattentionkj = NODE_CLASS_MAPPINGS["PathchSageAttentionKJ"]()
110
-
111
- # Load base models into CPU RAM initially
112
- MODELS_AND_NODES["clip"] = cliploader.load_clip(clip_name="umt5_xxl_fp8_e4m3fn_scaled.safetensors", type="wan")
113
- unet_low_noise = unetloader.load_unet(unet_name="wan2.2_i2v_low_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
114
- unet_high_noise = unetloader.load_unet(unet_name="wan2.2_i2v_high_noise_14B_fp8_scaled.safetensors", weight_dtype="default")
115
- MODELS_AND_NODES["vae"] = vaeloader.load_vae(vae_name="wan_2.1_vae.safetensors")
116
- MODELS_AND_NODES["clip_vision"] = clipvisionloader.load_clip(clip_name="clip_vision_h.safetensors")
117
-
118
- # Chain all patching operations together for the final models
119
- print("Applying all patches to models...")
120
-
121
- # --- Low Noise Model Chain ---
122
- model_low_with_lora = loraloadermodelonly.load_lora_model_only(
123
- lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_LOW_fp16.safetensors",
124
- strength_model=0.8, model=get_value_at_index(unet_low_noise, 0))
125
- model_low_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_low_with_lora, 0))
126
- MODELS_AND_NODES["model_low_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_low_patched, 0))
127
-
128
- # --- High Noise Model Chain ---
129
- model_high_with_lora = loraloadermodelonly.load_lora_model_only(
130
- lora_name="Wan2.2-Lightning_I2V-A14B-4steps-lora_HIGH_fp16.safetensors",
131
- strength_model=0.8, model=get_value_at_index(unet_high_noise, 0))
132
- model_high_patched = modelsamplingsd3.patch(shift=8, model=get_value_at_index(model_high_with_lora, 0))
133
- MODELS_AND_NODES["model_high_noise"] = pathchsageattentionkj.patch(sage_attention="auto", model=get_value_at_index(model_high_patched, 0))
134
-
135
- # Instantiate all other node classes ONCE and store them
136
- MODELS_AND_NODES["CLIPTextEncode"] = NODE_CLASS_MAPPINGS["CLIPTextEncode"]()
137
- MODELS_AND_NODES["LoadImage"] = NODE_CLASS_MAPPINGS["LoadImage"]()
138
- MODELS_AND_NODES["CLIPVisionEncode"] = NODE_CLASS_MAPPINGS["CLIPVisionEncode"]()
139
- MODELS_AND_NODES["WanFirstLastFrameToVideo"] = NODE_CLASS_MAPPINGS["WanFirstLastFrameToVideo"]()
140
- MODELS_AND_NODES["KSamplerAdvanced"] = NODE_CLASS_MAPPINGS["KSamplerAdvanced"]()
141
- MODELS_AND_NODES["VAEDecode"] = NODE_CLASS_MAPPINGS["VAEDecode"]()
142
- MODELS_AND_NODES["CreateVideo"] = NODE_CLASS_MAPPINGS["CreateVideo"]()
143
- MODELS_AND_NODES["SaveVideo"] = NODE_CLASS_MAPPINGS["SaveVideo"]()
144
-
145
- # Move all final, fully-patched models to the GPU
146
- print("Moving final models to GPU...")
147
- model_loaders_final = [
148
- MODELS_AND_NODES["clip"],
149
- # MODELS_AND_NODES["vae"],
150
- MODELS_AND_NODES["model_low_noise"],
151
- MODELS_AND_NODES["model_high_noise"],
152
- MODELS_AND_NODES["clip_vision"],
153
- ]
154
- model_management.load_models_gpu([
155
- loader[0].patcher if hasattr(loader[0], 'patcher') else loader[0] for loader in model_loaders_final
156
- ], force_patch_weights=True) # force_patch_weights permanently merges the LoRA
157
-
158
- print("All models loaded, patched, and on GPU. Gradio app is ready.")
159
-
160
-
161
- # --- 4. Application Logic and Gradio Interface ---
162
-
163
- def calculate_video_dimensions(width, height, max_size=832, min_size=480):
164
- """Calculates video dimensions, ensuring they are multiples of 16."""
165
  if width == height:
166
- return min_size, min_size
 
 
167
  aspect_ratio = width / height
168
- if width > height:
169
- video_width = max_size
170
- video_height = int(max_size / aspect_ratio)
171
- else:
172
- video_height = max_size
173
- video_width = int(max_size * aspect_ratio)
174
- video_width = max(16, round(video_width / 16) * 16)
175
- video_height = max(16, round(video_height / 16) * 16)
176
- return video_width, video_height
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  def resize_and_crop_to_match(target_image, reference_image):
179
  """Resizes and center-crops the target image to match the reference image's dimensions."""
@@ -190,125 +137,64 @@ def generate_video(
190
  start_image_pil,
191
  end_image_pil,
192
  prompt,
193
- negative_prompt="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
194
- duration=33,
 
 
 
 
 
195
  progress=gr.Progress(track_tqdm=True)
196
  ):
197
  """
198
- Generates a video by interpolating between a start and end image, guided by a text prompt.
199
- This function relies on globally pre-loaded models and pre-instantiated ComfyUI nodes.
200
  """
201
- FPS = 16
202
-
203
- # --- 1. Retrieve Pre-loaded and Pre-patched Models & Node Instances ---
204
- # These are not re-instantiated; we are just getting references to the global objects.
205
- clip = MODELS_AND_NODES["clip"]
206
- vae = MODELS_AND_NODES["vae"]
207
- model_low_final = MODELS_AND_NODES["model_low_noise"]
208
- model_high_final = MODELS_AND_NODES["model_high_noise"]
209
- clip_vision = MODELS_AND_NODES["clip_vision"]
210
-
211
- cliptextencode = MODELS_AND_NODES["CLIPTextEncode"]
212
- loadimage = MODELS_AND_NODES["LoadImage"]
213
- clipvisionencode = MODELS_AND_NODES["CLIPVisionEncode"]
214
- wanfirstlastframetovideo = MODELS_AND_NODES["WanFirstLastFrameToVideo"]
215
- ksampleradvanced = MODELS_AND_NODES["KSamplerAdvanced"]
216
- vaedecode = MODELS_AND_NODES["VAEDecode"]
217
- createvideo = MODELS_AND_NODES["CreateVideo"]
218
- savevideo = MODELS_AND_NODES["SaveVideo"]
219
-
220
- # --- 2. Image Preprocessing for the Current Run ---
221
- print("Preprocessing images with Pillow...")
222
- processed_start_image = start_image_pil.copy()
223
- processed_end_image = resize_and_crop_to_match(end_image_pil, start_image_pil)
224
- video_width, video_height = calculate_video_dimensions(processed_start_image.width, processed_start_image.height)
225
-
226
- # Save processed images to temporary files for the LoadImage node
227
- temp_dir = "input" # ComfyUI's default input directory
228
- os.makedirs(temp_dir, exist_ok=True)
229
 
230
- with tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as start_file, \
231
- tempfile.NamedTemporaryFile(suffix=".png", delete=False, dir=temp_dir) as end_file:
232
- processed_start_image.save(start_file.name)
233
- processed_end_image.save(end_file.name)
234
- start_image_path = os.path.basename(start_file.name)
235
- end_image_path = os.path.basename(end_file.name)
236
- print(f"Images resized to {video_width}x{video_height} and saved temporarily.")
237
 
238
- # --- 3. Execute the ComfyUI Workflow in Inference Mode ---
239
- with torch.inference_mode():
240
- progress(0.1, desc="Encoding text and images...")
241
-
242
- # Encode prompts and vision models
243
- positive_conditioning = cliptextencode.encode(text=prompt, clip=get_value_at_index(clip, 0))
244
- negative_conditioning = cliptextencode.encode(text=negative_prompt, clip=get_value_at_index(clip, 0))
245
-
246
- start_image_loaded = loadimage.load_image(image=start_image_path)
247
- end_image_loaded = loadimage.load_image(image=end_image_path)
248
-
249
- clip_vision_encoded_start = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(start_image_loaded, 0))
250
- clip_vision_encoded_end = clipvisionencode.encode(crop="none", clip_vision=get_value_at_index(clip_vision, 0), image=get_value_at_index(end_image_loaded, 0))
251
-
252
- progress(0.2, desc="Preparing initial latents...")
253
- initial_latents = wanfirstlastframetovideo.EXECUTE_NORMALIZED(
254
- width=video_width, height=video_height, length=duration, batch_size=1,
255
- positive=get_value_at_index(positive_conditioning, 0),
256
- negative=get_value_at_index(negative_conditioning, 0),
257
- vae=get_value_at_index(vae, 0),
258
- clip_vision_start_image=get_value_at_index(clip_vision_encoded_start, 0),
259
- clip_vision_end_image=get_value_at_index(clip_vision_encoded_end, 0),
260
- start_image=get_value_at_index(start_image_loaded, 0),
261
- end_image=get_value_at_index(end_image_loaded, 0),
262
- )
263
-
264
- ksampler_positive = get_value_at_index(initial_latents, 0)
265
- ksampler_negative = get_value_at_index(initial_latents, 1)
266
- ksampler_latent = get_value_at_index(initial_latents, 2)
267
-
268
- progress(0.5, desc="Denoising (Step 1/2)...")
269
- latent_step1 = ksampleradvanced.sample(
270
- add_noise="enable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
271
- sampler_name="euler", scheduler="simple", start_at_step=0, end_at_step=4,
272
- return_with_leftover_noise="enable", model=get_value_at_index(model_high_final, 0),
273
- positive=ksampler_positive,
274
- negative=ksampler_negative,
275
- latent_image=ksampler_latent,
276
- )
277
-
278
- progress(0.7, desc="Denoising (Step 2/2)...")
279
- latent_step2 = ksampleradvanced.sample(
280
- add_noise="disable", noise_seed=random.randint(1, 2**64), steps=8, cfg=1,
281
- sampler_name="euler", scheduler="simple", start_at_step=4, end_at_step=10000,
282
- return_with_leftover_noise="disable", model=get_value_at_index(model_low_final, 0),
283
- positive=ksampler_positive,
284
- negative=ksampler_negative,
285
- latent_image=get_value_at_index(latent_step1, 0),
286
- )
287
-
288
- progress(0.8, desc="Decoding VAE...")
289
- decoded_images = vaedecode.decode(samples=get_value_at_index(latent_step2, 0), vae=get_value_at_index(vae, 0))
290
-
291
- progress(0.9, desc="Creating and saving video...")
292
- video_data = createvideo.create_video(fps=FPS, images=get_value_at_index(decoded_images, 0))
293
-
294
- # Save the video to ComfyUI's default output directory
295
- save_result = savevideo.save_video(
296
- filename_prefix="GradioVideo", format="mp4", codec="h264",
297
- video=get_value_at_index(video_data, 0),
298
- )
299
-
300
- progress(1.0, desc="Done!")
301
-
302
- # --- 4. Cleanup and Return ---
303
- try:
304
- os.remove(start_file.name)
305
- os.remove(end_file.name)
306
- except Exception as e:
307
- print(f"Error cleaning up temporary files: {e}")
308
-
309
- # Gradio video component expects a filepath relative to the root of the app
310
- return f"output/{save_result['ui']['images'][0]['filename']}"
311
 
 
312
 
313
  css = '''
314
  .fillable{max-width: 1100px !important}
@@ -316,46 +202,59 @@ css = '''
316
  '''
317
  with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
318
  gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
319
- gr.Markdown("Running the [Wan 2.2 First/Last Frame ComfyUI workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) and the [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) 8-step LoRA on ZeroGPU")
320
-
321
  with gr.Row():
322
  with gr.Column():
323
  with gr.Group():
324
  with gr.Row():
325
- start_image = gr.Image(type="pil", label="Start Frame")
326
- end_image = gr.Image(type="pil", label="End Frame")
327
-
328
  prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images")
329
-
330
- with gr.Accordion("Advanced Settings", open=False, visible=False):
331
- duration = gr.Radio(
332
- [("Short (2s)", 33), ("Mid (4s)", 66)],
333
- value=33,
334
- label="Video Duration",
335
- visible=False
336
- )
337
- negative_prompt = gr.Textbox(
338
- label="Negative Prompt",
339
- value="色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,",
340
- visible=False
341
- )
342
-
343
  generate_button = gr.Button("Generate Video", variant="primary")
344
-
345
  with gr.Column():
346
  output_video = gr.Video(label="Generated Video", autoplay=True)
347
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
348
  generate_button.click(
349
  fn=generate_video,
350
- inputs=[start_image, end_image, prompt, negative_prompt, duration],
351
- outputs=output_video
352
  )
353
 
354
  gr.Examples(
355
  examples=[
356
  ["poli_tower.png", "tower_takes_off.png", "the man turns around"],
357
  ["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"],
358
- ["capyabara_zoomed.png", "capybara.webp", "a dramatic dolly zoom"],
359
  ],
360
  inputs=[start_image, end_image, prompt],
361
  outputs=output_video,
 
1
  import os
2
+ # PyTorch 2.8 (temporary hack)
3
+ os.system('pip install --upgrade --pre --extra-index-url https://download.pytorch.org/whl/nightly/cu126 "torch<2.9" spaces')
 
 
 
 
 
 
4
 
5
+ # --- 1. Model Download and Setup (Diffusers Backend) ---
6
+ import spaces
7
  import torch
8
+ from diffusers.pipelines.wan.pipeline_wan_i2v import WanImageToVideoPipeline
9
+ from diffusers.models.transformers.transformer_wan import WanTransformer3DModel
10
+ from diffusers.utils.export_utils import export_to_video
11
  import gradio as gr
12
+ import tempfile
13
+ import numpy as np
14
  from PIL import Image
15
+ import random
16
+ import gc
17
+
18
+ # Import the optimization function from the separate file
19
+ from optimization import optimize_pipeline_
20
+
21
+ # --- Constants and Model Loading ---
22
+ MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
23
+
24
+ # --- NEW: Flexible Dimension Constants ---
25
+ MAX_DIMENSION = 832
26
+ MIN_DIMENSION = 480
27
+ DIMENSION_MULTIPLE = 16
28
+ SQUARE_SIZE = 480
29
+
30
+ MAX_SEED = np.iinfo(np.int32).max
31
+
32
+ FIXED_FPS = 16
33
+ MIN_FRAMES_MODEL = 8
34
+ MAX_FRAMES_MODEL = 81
35
+
36
+ MIN_DURATION = round(MIN_FRAMES_MODEL/FIXED_FPS, 1)
37
+ MAX_DURATION = round(MAX_FRAMES_MODEL/FIXED_FPS, 1)
38
+
39
+ default_negative_prompt = "色调艳丽,过曝,静态,细节模糊不清,字幕,风格,作品,画作,画面,静止,整体发灰,最差质量,低质量,JPEG压缩残留,丑陋的,残缺的,多余的手指,画得不好的手部,画得不好的脸部,畸形的,毁容的,形态畸形的肢体,手指融合,静止不动的画面,杂乱的背景,三条腿,背景人很多,倒着走,过曝,"
40
+
41
+ print("Loading models into memory. This may take a few minutes...")
42
+
43
+ pipe = WanImageToVideoPipeline.from_pretrained(
44
+ MODEL_ID,
45
+ transformer=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
46
+ subfolder='transformer',
47
+ torch_dtype=torch.bfloat16,
48
+ device_map='cuda',
49
+ ),
50
+ transformer_2=WanTransformer3DModel.from_pretrained('cbensimon/Wan2.2-I2V-A14B-bf16-Diffusers',
51
+ subfolder='transformer_2',
52
+ torch_dtype=torch.bfloat16,
53
+ device_map='cuda',
54
+ ),
55
+ torch_dtype=torch.bfloat16,
56
+ ).to('cuda')
57
+
58
+ print("Optimizing pipeline...")
59
+ for i in range(3):
60
+ gc.collect()
61
+ torch.cuda.synchronize()
62
+ torch.cuda.empty_cache()
63
+
64
+ # Calling the imported optimization function with a placeholder image for compilation tracing
65
+ optimize_pipeline_(pipe,
66
+ image=Image.new('RGB', (MAX_DIMENSION, MIN_DIMENSION)), # Use representative dims
67
+ prompt='prompt',
68
+ height=MIN_DIMENSION,
69
+ width=MAX_DIMENSION,
70
+ num_frames=MAX_FRAMES_MODEL,
71
+ )
72
+ print("All models loaded and optimized. Gradio app is ready.")
73
+
74
+
75
+ # --- 2. Image Processing and Application Logic ---
76
+
77
+ def process_image_for_video(image: Image.Image) -> Image.Image:
78
+ """
79
+ Resizes an image based on the following rules for video generation:
80
+ 1. The longest side will be scaled down to MAX_DIMENSION if it's larger.
81
+ 2. The shortest side will be scaled up to MIN_DIMENSION if it's smaller.
82
+ 3. The final dimensions will be rounded to the nearest multiple of DIMENSION_MULTIPLE.
83
+ 4. Square images are resized to a fixed SQUARE_SIZE.
84
+ The aspect ratio is preserved as closely as possible.
85
+ """
86
+ width, height = image.size
87
 
88
+ # Rule 4: Handle square images
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  if width == height:
90
+ return image.resize((SQUARE_SIZE, SQUARE_SIZE), Image.Resampling.LANCZOS)
91
+
92
+ # Determine target dimensions while preserving aspect ratio
93
  aspect_ratio = width / height
94
+ new_width, new_height = width, height
95
+
96
+ # Rule 1: Scale down if too large
97
+ if new_width > MAX_DIMENSION or new_height > MAX_DIMENSION:
98
+ if aspect_ratio > 1: # Landscape
99
+ scale = MAX_DIMENSION / new_width
100
+ else: # Portrait
101
+ scale = MAX_DIMENSION / new_height
102
+ new_width *= scale
103
+ new_height *= scale
104
+
105
+ # Rule 2: Scale up if too small
106
+ if new_width < MIN_DIMENSION or new_height < MIN_DIMENSION:
107
+ if aspect_ratio > 1: # Landscape
108
+ scale = MIN_DIMENSION / new_height
109
+ else: # Portrait
110
+ scale = MIN_DIMENSION / new_width
111
+ new_width *= scale
112
+ new_height *= scale
113
+
114
+ # Rule 3: Round to the nearest multiple of DIMENSION_MULTIPLE
115
+ final_width = int(round(new_width / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
116
+ final_height = int(round(new_height / DIMENSION_MULTIPLE) * DIMENSION_MULTIPLE)
117
+
118
+ # Ensure final dimensions are at least the minimum
119
+ final_width = max(final_width, MIN_DIMENSION if aspect_ratio < 1 else SQUARE_SIZE)
120
+ final_height = max(final_height, MIN_DIMENSION if aspect_ratio > 1 else SQUARE_SIZE)
121
+
122
+
123
+ return image.resize((final_width, final_height), Image.Resampling.LANCZOS)
124
 
125
  def resize_and_crop_to_match(target_image, reference_image):
126
  """Resizes and center-crops the target image to match the reference image's dimensions."""
 
137
  start_image_pil,
138
  end_image_pil,
139
  prompt,
140
+ negative_prompt,
141
+ duration_seconds,
142
+ steps,
143
+ guidance_scale,
144
+ guidance_scale_2,
145
+ seed,
146
+ randomize_seed,
147
  progress=gr.Progress(track_tqdm=True)
148
  ):
149
  """
150
+ Generates a video by interpolating between a start and end image, guided by a text prompt,
151
+ using the diffusers Wan2.2 pipeline.
152
  """
153
+ if start_image_pil is None or end_image_pil is None:
154
+ raise gr.Error("Please upload both a start and an end image.")
155
+
156
+ progress(0.1, desc="Preprocessing images...")
157
+
158
+ # Step 1: Process the start image to get our target dimensions based on the new rules.
159
+ processed_start_image = process_image_for_video(start_image_pil)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
 
161
+ # Step 2: Make the end image match the *exact* dimensions of the processed start image.
162
+ processed_end_image = resize_and_crop_to_match(end_image_pil, processed_start_image)
 
 
 
 
 
163
 
164
+ target_height, target_width = processed_start_image.height, processed_start_image.width
165
+
166
+ # Handle seed and frame count
167
+ current_seed = random.randint(0, MAX_SEED) if randomize_seed else int(seed)
168
+ num_frames = np.clip(int(round(duration_seconds * FIXED_FPS)), MIN_FRAMES_MODEL, MAX_FRAMES_MODEL)
169
+
170
+ progress(0.2, desc=f"Generating {num_frames} frames at {target_width}x{target_height} (seed: {current_seed})...")
171
+
172
+ output_frames_list = pipe(
173
+ image=processed_start_image,
174
+ last_image=processed_end_image,
175
+ prompt=prompt,
176
+ negative_prompt=negative_prompt,
177
+ height=target_height,
178
+ width=target_width,
179
+ num_frames=num_frames,
180
+ guidance_scale=float(guidance_scale),
181
+ guidance_scale_2=float(guidance_scale_2),
182
+ num_inference_steps=int(steps),
183
+ generator=torch.Generator(device="cuda").manual_seed(current_seed),
184
+ ).frames[0]
185
+
186
+ progress(0.9, desc="Encoding and saving video...")
187
+
188
+ with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmpfile:
189
+ video_path = tmpfile.name
190
+
191
+ export_to_video(output_frames_list, video_path, fps=FIXED_FPS)
192
+
193
+ progress(1.0, desc="Done!")
194
+ return video_path, current_seed
195
+
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
196
 
197
+ # --- 3. Gradio User Interface --- (No changes needed here)
198
 
199
  css = '''
200
  .fillable{max-width: 1100px !important}
 
202
  '''
203
  with gr.Blocks(theme=gr.themes.Citrus(), css=css) as app:
204
  gr.Markdown("# Wan 2.2 First/Last Frame Video Fast")
205
+ gr.Markdown("Running the [Wan 2.2 First/Last Frame workflow](https://www.reddit.com/r/StableDiffusion/comments/1me4306/psa_wan_22_does_first_frame_last_frame_out_of_the/) via 🧨 Diffusers and the [lightx2v/Wan2.2-Lightning](https://huggingface.co/lightx2v/Wan2.2-Lightning) principles on ZeroGPU")
206
+
207
  with gr.Row():
208
  with gr.Column():
209
  with gr.Group():
210
  with gr.Row():
211
+ start_image = gr.Image(type="pil", label="Start Frame", sources=["upload", "clipboard"])
212
+ end_image = gr.Image(type="pil", label="End Frame", sources=["upload", "clipboard"])
213
+
214
  prompt = gr.Textbox(label="Prompt", info="Describe the transition between the two images")
215
+
216
+ with gr.Accordion("Advanced Settings", open=False):
217
+ duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=2.1, label="Video Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
218
+ negative_prompt_input = gr.Textbox(label="Negative Prompt", value=default_negative_prompt, lines=3)
219
+ steps_slider = gr.Slider(minimum=1, maximum=30, step=1, value=8, label="Inference Steps")
220
+ guidance_scale_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - high noise")
221
+ guidance_scale_2_input = gr.Slider(minimum=0.0, maximum=10.0, step=0.5, value=1.0, label="Guidance Scale - low noise")
222
+ with gr.Row():
223
+ seed_input = gr.Slider(label="Seed", minimum=0, maximum=MAX_SEED, step=1, value=42)
224
+ randomize_seed_checkbox = gr.Checkbox(label="Randomize seed", value=True)
225
+
 
 
 
226
  generate_button = gr.Button("Generate Video", variant="primary")
227
+
228
  with gr.Column():
229
  output_video = gr.Video(label="Generated Video", autoplay=True)
230
 
231
+ # Define the inputs list for the click event
232
+ ui_inputs = [
233
+ start_image,
234
+ end_image,
235
+ prompt,
236
+ negative_prompt_input,
237
+ duration_seconds_input,
238
+ steps_slider,
239
+ guidance_scale_input,
240
+ guidance_scale_2_input,
241
+ seed_input,
242
+ randomize_seed_checkbox
243
+ ]
244
+ # The seed_input is both an input and an output to reflect the randomly generated seed
245
+ ui_outputs = [output_video, seed_input]
246
+
247
  generate_button.click(
248
  fn=generate_video,
249
+ inputs=ui_inputs,
250
+ outputs=ui_outputs
251
  )
252
 
253
  gr.Examples(
254
  examples=[
255
  ["poli_tower.png", "tower_takes_off.png", "the man turns around"],
256
  ["ugly_sonic.jpeg", "squatting_sonic.png", "the character dodges the missiles"],
257
+ ["capyabara_zoomed.png", "capyabara.webp", "a dramatic dolly zoom"],
258
  ],
259
  inputs=[start_image, end_image, prompt],
260
  outputs=output_video,