Spaces:
				
			
			
	
			
			
		Running
		
			on 
			
			Zero
	
	
	
			
			
	
	
	
	
		
		
		Running
		
			on 
			
			Zero
	Dynamic Axis
#10
by
						
multimodalart
	
							HF Staff
						- opened
							
					
- app.py +53 -25
- optimization.py +21 -45
    	
        app.py
    CHANGED
    
    | @@ -19,8 +19,11 @@ from optimization import optimize_pipeline_ | |
| 19 |  | 
| 20 | 
             
            MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
         | 
| 21 |  | 
| 22 | 
            -
             | 
| 23 | 
            -
             | 
|  | |
|  | |
|  | |
| 24 | 
             
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 25 |  | 
| 26 | 
             
            FIXED_FPS = 16
         | 
| @@ -50,11 +53,14 @@ for i in range(3): | |
| 50 | 
             
                torch.cuda.synchronize() 
         | 
| 51 | 
             
                torch.cuda.empty_cache()
         | 
| 52 |  | 
|  | |
|  | |
|  | |
| 53 | 
             
            optimize_pipeline_(pipe,
         | 
| 54 | 
            -
                image=Image.new('RGB', ( | 
| 55 | 
             
                prompt='prompt',
         | 
| 56 | 
            -
                height= | 
| 57 | 
            -
                width= | 
| 58 | 
             
                num_frames=MAX_FRAMES_MODEL,
         | 
| 59 | 
             
            )
         | 
| 60 |  | 
| @@ -62,28 +68,51 @@ optimize_pipeline_(pipe, | |
| 62 | 
             
            default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
         | 
| 63 | 
             
            default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
         | 
| 64 |  | 
| 65 | 
            -
             | 
| 66 | 
             
            def resize_image(image: Image.Image) -> Image.Image:
         | 
| 67 | 
            -
                 | 
| 68 | 
            -
             | 
| 69 | 
            -
             | 
| 70 | 
            -
             | 
| 71 | 
            -
                return resize_image_landscape(image)
         | 
| 72 |  | 
|  | |
|  | |
|  | |
| 73 |  | 
| 74 | 
            -
             | 
| 75 | 
            -
                 | 
| 76 | 
            -
                 | 
| 77 | 
            -
                 | 
| 78 | 
            -
             | 
| 79 | 
            -
             | 
| 80 | 
            -
             | 
| 81 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 82 | 
             
                else:
         | 
| 83 | 
            -
                     | 
| 84 | 
            -
             | 
| 85 | 
            -
             | 
| 86 | 
            -
             | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 87 |  | 
| 88 | 
             
            def get_duration(
         | 
| 89 | 
             
                input_image,
         | 
| @@ -147,7 +176,6 @@ def generate_video( | |
| 147 | 
             
                    gr.Error: If input_image is None (no image uploaded).
         | 
| 148 |  | 
| 149 | 
             
                Note:
         | 
| 150 | 
            -
                    - The function automatically resizes the input image to the target dimensions
         | 
| 151 | 
             
                    - Frame count is calculated as duration_seconds * FIXED_FPS (24)
         | 
| 152 | 
             
                    - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
         | 
| 153 | 
             
                    - The function uses GPU acceleration via the @spaces.GPU decorator
         | 
| @@ -185,7 +213,7 @@ with gr.Blocks() as demo: | |
| 185 | 
             
                gr.Markdown("run Wan 2.2 in just 4-8 steps, with [Lightning LoRA](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Wan22-Lightning), fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
         | 
| 186 | 
             
                with gr.Row():
         | 
| 187 | 
             
                    with gr.Column():
         | 
| 188 | 
            -
                        input_image_component = gr.Image(type="pil", label="Input Image | 
| 189 | 
             
                        prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
         | 
| 190 | 
             
                        duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
         | 
| 191 |  | 
|  | |
| 19 |  | 
| 20 | 
             
            MODEL_ID = "Wan-AI/Wan2.2-I2V-A14B-Diffusers"
         | 
| 21 |  | 
| 22 | 
            +
            MAX_DIM = 832
         | 
| 23 | 
            +
            MIN_DIM = 480
         | 
| 24 | 
            +
            SQUARE_DIM = 640
         | 
| 25 | 
            +
            MULTIPLE_OF = 16
         | 
| 26 | 
            +
             | 
| 27 | 
             
            MAX_SEED = np.iinfo(np.int32).max
         | 
| 28 |  | 
| 29 | 
             
            FIXED_FPS = 16
         | 
|  | |
| 53 | 
             
                torch.cuda.synchronize() 
         | 
| 54 | 
             
                torch.cuda.empty_cache()
         | 
| 55 |  | 
| 56 | 
            +
            OPTIMIZE_WIDTH = 832 
         | 
| 57 | 
            +
            OPTIMIZE_HEIGHT = 624
         | 
| 58 | 
            +
             | 
| 59 | 
             
            optimize_pipeline_(pipe,
         | 
| 60 | 
            +
                image=Image.new('RGB', (OPTIMIZE_WIDTH, OPTIMIZE_HEIGHT)),
         | 
| 61 | 
             
                prompt='prompt',
         | 
| 62 | 
            +
                height=OPTIMIZE_HEIGHT,
         | 
| 63 | 
            +
                width=OPTIMIZE_WIDTH,
         | 
| 64 | 
             
                num_frames=MAX_FRAMES_MODEL,
         | 
| 65 | 
             
            )
         | 
| 66 |  | 
|  | |
| 68 | 
             
            default_prompt_i2v = "make this image come alive, cinematic motion, smooth animation"
         | 
| 69 | 
             
            default_negative_prompt = "色调艳丽, 过曝, 静态, 细节模糊不清, 字幕, 风格, 作品, 画作, 画面, 静止, 整体发灰, 最差质量, 低质量, JPEG压缩残留, 丑陋的, 残缺的, 多余的手指, 画得不好的手部, 画得不好的脸部, 畸形的, 毁容的, 形态畸形的肢体, 手指融合, 静止不动的画面, 杂乱的背景, 三条腿, 背景人很多, 倒着走"
         | 
| 70 |  | 
|  | |
| 71 | 
             
            def resize_image(image: Image.Image) -> Image.Image:
         | 
| 72 | 
            +
                """
         | 
| 73 | 
            +
                Resizes an image to fit within the model's constraints, preserving aspect ratio as much as possible.
         | 
| 74 | 
            +
                """
         | 
| 75 | 
            +
                width, height = image.size
         | 
|  | |
| 76 |  | 
| 77 | 
            +
                # Handle square case
         | 
| 78 | 
            +
                if width == height:
         | 
| 79 | 
            +
                    return image.resize((SQUARE_DIM, SQUARE_DIM), Image.LANCZOS)
         | 
| 80 |  | 
| 81 | 
            +
                aspect_ratio = width / height
         | 
| 82 | 
            +
                
         | 
| 83 | 
            +
                MAX_ASPECT_RATIO = MAX_DIM / MIN_DIM 
         | 
| 84 | 
            +
                MIN_ASPECT_RATIO = MIN_DIM / MAX_DIM 
         | 
| 85 | 
            +
             | 
| 86 | 
            +
                image_to_resize = image
         | 
| 87 | 
            +
                
         | 
| 88 | 
            +
                if aspect_ratio > MAX_ASPECT_RATIO:
         | 
| 89 | 
            +
                    # Very wide image -> crop width to fit 832x480 aspect ratio
         | 
| 90 | 
            +
                    target_w, target_h = MAX_DIM, MIN_DIM
         | 
| 91 | 
            +
                    crop_width = int(round(height * MAX_ASPECT_RATIO))
         | 
| 92 | 
            +
                    left = (width - crop_width) // 2
         | 
| 93 | 
            +
                    image_to_resize = image.crop((left, 0, left + crop_width, height))
         | 
| 94 | 
            +
                elif aspect_ratio < MIN_ASPECT_RATIO:
         | 
| 95 | 
            +
                    # Very tall image -> crop height to fit 480x832 aspect ratio
         | 
| 96 | 
            +
                    target_w, target_h = MIN_DIM, MAX_DIM
         | 
| 97 | 
            +
                    crop_height = int(round(width / MIN_ASPECT_RATIO))
         | 
| 98 | 
            +
                    top = (height - crop_height) // 2
         | 
| 99 | 
            +
                    image_to_resize = image.crop((0, top, width, top + crop_height))
         | 
| 100 | 
             
                else:
         | 
| 101 | 
            +
                    if width > height:  # Landscape
         | 
| 102 | 
            +
                        target_w = MAX_DIM
         | 
| 103 | 
            +
                        target_h = int(round(target_w / aspect_ratio))
         | 
| 104 | 
            +
                    else:  # Portrait
         | 
| 105 | 
            +
                        target_h = MAX_DIM
         | 
| 106 | 
            +
                        target_w = int(round(target_h * aspect_ratio))
         | 
| 107 | 
            +
             | 
| 108 | 
            +
                final_w = round(target_w / MULTIPLE_OF) * MULTIPLE_OF
         | 
| 109 | 
            +
                final_h = round(target_h / MULTIPLE_OF) * MULTIPLE_OF
         | 
| 110 | 
            +
             | 
| 111 | 
            +
                final_w = max(MIN_DIM, min(MAX_DIM, final_w))
         | 
| 112 | 
            +
                final_h = max(MIN_DIM, min(MAX_DIM, final_h))
         | 
| 113 | 
            +
                
         | 
| 114 | 
            +
                return image_to_resize.resize((final_w, final_h), Image.LANCZOS)
         | 
| 115 | 
            +
             | 
| 116 |  | 
| 117 | 
             
            def get_duration(
         | 
| 118 | 
             
                input_image,
         | 
|  | |
| 176 | 
             
                    gr.Error: If input_image is None (no image uploaded).
         | 
| 177 |  | 
| 178 | 
             
                Note:
         | 
|  | |
| 179 | 
             
                    - Frame count is calculated as duration_seconds * FIXED_FPS (24)
         | 
| 180 | 
             
                    - Output dimensions are adjusted to be multiples of MOD_VALUE (32)
         | 
| 181 | 
             
                    - The function uses GPU acceleration via the @spaces.GPU decorator
         | 
|  | |
| 213 | 
             
                gr.Markdown("run Wan 2.2 in just 4-8 steps, with [Lightning LoRA](https://huggingface.co/Kijai/WanVideo_comfy/tree/main/Wan22-Lightning), fp8 quantization & AoT compilation - compatible with 🧨 diffusers and ZeroGPU⚡️")
         | 
| 214 | 
             
                with gr.Row():
         | 
| 215 | 
             
                    with gr.Column():
         | 
| 216 | 
            +
                        input_image_component = gr.Image(type="pil", label="Input Image")
         | 
| 217 | 
             
                        prompt_input = gr.Textbox(label="Prompt", value=default_prompt_i2v)
         | 
| 218 | 
             
                        duration_seconds_input = gr.Slider(minimum=MIN_DURATION, maximum=MAX_DURATION, step=0.1, value=3.5, label="Duration (seconds)", info=f"Clamped to model's {MIN_FRAMES_MODEL}-{MAX_FRAMES_MODEL} frames at {FIXED_FPS}fps.")
         | 
| 219 |  | 
    	
        optimization.py
    CHANGED
    
    | @@ -14,18 +14,21 @@ from torchao.quantization import Int8WeightOnlyConfig | |
| 14 |  | 
| 15 | 
             
            from optimization_utils import capture_component_call
         | 
| 16 | 
             
            from optimization_utils import aoti_compile
         | 
| 17 | 
            -
            from optimization_utils import ZeroGPUCompiledModel
         | 
| 18 | 
             
            from optimization_utils import drain_module_parameters
         | 
| 19 |  | 
| 20 |  | 
| 21 | 
             
            P = ParamSpec('P')
         | 
| 22 |  | 
|  | |
| 23 |  | 
| 24 | 
            -
             | 
|  | |
| 25 |  | 
| 26 | 
             
            TRANSFORMER_DYNAMIC_SHAPES = {
         | 
| 27 | 
             
                'hidden_states': {
         | 
| 28 | 
            -
                    2:  | 
|  | |
|  | |
| 29 | 
             
                },
         | 
| 30 | 
             
            }
         | 
| 31 |  | 
| @@ -44,6 +47,7 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw | |
| 44 | 
             
                @spaces.GPU(duration=1500)
         | 
| 45 | 
             
                def compile_transformer():
         | 
| 46 |  | 
|  | |
| 47 | 
             
                    pipeline.load_lora_weights(
         | 
| 48 | 
             
                        "Kijai/WanVideo_comfy", 
         | 
| 49 | 
             
                        weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
         | 
| @@ -70,61 +74,33 @@ def optimize_pipeline_(pipeline: Callable[P, Any], *args: P.args, **kwargs: P.kw | |
| 70 | 
             
                    quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
         | 
| 71 | 
             
                    quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
         | 
| 72 |  | 
| 73 | 
            -
                     | 
| 74 | 
            -
                     | 
| 75 | 
            -
                    if hidden_states.shape[-1] > hidden_states.shape[-2]:
         | 
| 76 | 
            -
                        hidden_states_landscape = hidden_states
         | 
| 77 | 
            -
                        hidden_states_portrait = hidden_states_transposed
         | 
| 78 | 
            -
                    else:
         | 
| 79 | 
            -
                        hidden_states_landscape = hidden_states_transposed
         | 
| 80 | 
            -
                        hidden_states_portrait = hidden_states
         | 
| 81 | 
            -
             | 
| 82 | 
            -
                    exported_landscape_1 = torch.export.export(
         | 
| 83 | 
             
                        mod=pipeline.transformer,
         | 
| 84 | 
             
                        args=call.args,
         | 
| 85 | 
            -
                        kwargs=call.kwargs | 
| 86 | 
             
                        dynamic_shapes=dynamic_shapes,
         | 
| 87 | 
             
                    )
         | 
| 88 |  | 
| 89 | 
            -
                     | 
| 90 | 
             
                        mod=pipeline.transformer_2,
         | 
| 91 | 
             
                        args=call.args,
         | 
| 92 | 
            -
                        kwargs=call.kwargs | 
| 93 | 
             
                        dynamic_shapes=dynamic_shapes,
         | 
| 94 | 
             
                    )
         | 
| 95 |  | 
| 96 | 
            -
                     | 
| 97 | 
            -
                     | 
|  | |
|  | |
| 98 |  | 
| 99 | 
            -
                    compiled_landscape_2 = ZeroGPUCompiledModel(compiled_landscape_1.archive_file, compiled_portrait_2.weights)
         | 
| 100 | 
            -
                    compiled_portrait_1 = ZeroGPUCompiledModel(compiled_portrait_2.archive_file, compiled_landscape_1.weights)
         | 
| 101 | 
            -
             | 
| 102 | 
            -
                    return (
         | 
| 103 | 
            -
                        compiled_landscape_1,
         | 
| 104 | 
            -
                        compiled_landscape_2,
         | 
| 105 | 
            -
                        compiled_portrait_1,
         | 
| 106 | 
            -
                        compiled_portrait_2,
         | 
| 107 | 
            -
                    )
         | 
| 108 |  | 
| 109 | 
             
                quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
         | 
| 110 | 
            -
                 | 
| 111 | 
            -
             | 
| 112 | 
            -
             | 
| 113 | 
            -
             | 
| 114 | 
            -
                    if hidden_states.shape[-1] > hidden_states.shape[-2]:
         | 
| 115 | 
            -
                        return cl1(*args, **kwargs)
         | 
| 116 | 
            -
                    else:
         | 
| 117 | 
            -
                        return cp1(*args, **kwargs)
         | 
| 118 | 
            -
             | 
| 119 | 
            -
                def combined_transformer_2(*args, **kwargs):
         | 
| 120 | 
            -
                    hidden_states: torch.Tensor = kwargs['hidden_states']
         | 
| 121 | 
            -
                    if hidden_states.shape[-1] > hidden_states.shape[-2]:
         | 
| 122 | 
            -
                        return cl2(*args, **kwargs)
         | 
| 123 | 
            -
                    else:
         | 
| 124 | 
            -
                        return cp2(*args, **kwargs)
         | 
| 125 | 
            -
             | 
| 126 | 
            -
                pipeline.transformer.forward = combined_transformer_1
         | 
| 127 | 
             
                drain_module_parameters(pipeline.transformer)
         | 
| 128 |  | 
| 129 | 
            -
                pipeline.transformer_2.forward =  | 
| 130 | 
             
                drain_module_parameters(pipeline.transformer_2)
         | 
|  | |
| 14 |  | 
| 15 | 
             
            from optimization_utils import capture_component_call
         | 
| 16 | 
             
            from optimization_utils import aoti_compile
         | 
|  | |
| 17 | 
             
            from optimization_utils import drain_module_parameters
         | 
| 18 |  | 
| 19 |  | 
| 20 | 
             
            P = ParamSpec('P')
         | 
| 21 |  | 
| 22 | 
            +
            LATENT_FRAMES_DIM = torch.export.Dim('num_latent_frames', min=8, max=81)
         | 
| 23 |  | 
| 24 | 
            +
            LATENT_PATCHED_HEIGHT_DIM = torch.export.Dim('latent_patched_height', min=30, max=52)
         | 
| 25 | 
            +
            LATENT_PATCHED_WIDTH_DIM = torch.export.Dim('latent_patched_width', min=30, max=52)
         | 
| 26 |  | 
| 27 | 
             
            TRANSFORMER_DYNAMIC_SHAPES = {
         | 
| 28 | 
             
                'hidden_states': {
         | 
| 29 | 
            +
                    2: LATENT_FRAMES_DIM,
         | 
| 30 | 
            +
                    3: 2 * LATENT_PATCHED_HEIGHT_DIM,
         | 
| 31 | 
            +
                    4: 2 * LATENT_PATCHED_WIDTH_DIM,
         | 
| 32 | 
             
                },
         | 
| 33 | 
             
            }
         | 
| 34 |  | 
|  | |
| 47 | 
             
                @spaces.GPU(duration=1500)
         | 
| 48 | 
             
                def compile_transformer():
         | 
| 49 |  | 
| 50 | 
            +
                    # This LoRA fusion part remains the same
         | 
| 51 | 
             
                    pipeline.load_lora_weights(
         | 
| 52 | 
             
                        "Kijai/WanVideo_comfy", 
         | 
| 53 | 
             
                        weight_name="Lightx2v/lightx2v_I2V_14B_480p_cfg_step_distill_rank128_bf16.safetensors", 
         | 
|  | |
| 74 | 
             
                    quantize_(pipeline.transformer, Float8DynamicActivationFloat8WeightConfig())
         | 
| 75 | 
             
                    quantize_(pipeline.transformer_2, Float8DynamicActivationFloat8WeightConfig())
         | 
| 76 |  | 
| 77 | 
            +
                    
         | 
| 78 | 
            +
                    exported_1 = torch.export.export(
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 79 | 
             
                        mod=pipeline.transformer,
         | 
| 80 | 
             
                        args=call.args,
         | 
| 81 | 
            +
                        kwargs=call.kwargs,
         | 
| 82 | 
             
                        dynamic_shapes=dynamic_shapes,
         | 
| 83 | 
             
                    )
         | 
| 84 |  | 
| 85 | 
            +
                    exported_2 = torch.export.export(
         | 
| 86 | 
             
                        mod=pipeline.transformer_2,
         | 
| 87 | 
             
                        args=call.args,
         | 
| 88 | 
            +
                        kwargs=call.kwargs,
         | 
| 89 | 
             
                        dynamic_shapes=dynamic_shapes,
         | 
| 90 | 
             
                    )
         | 
| 91 |  | 
| 92 | 
            +
                    compiled_1 = aoti_compile(exported_1, INDUCTOR_CONFIGS)
         | 
| 93 | 
            +
                    compiled_2 = aoti_compile(exported_2, INDUCTOR_CONFIGS)
         | 
| 94 | 
            +
                    
         | 
| 95 | 
            +
                    return compiled_1, compiled_2
         | 
| 96 |  | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 97 |  | 
| 98 | 
             
                quantize_(pipeline.text_encoder, Int8WeightOnlyConfig())
         | 
| 99 | 
            +
                
         | 
| 100 | 
            +
                compiled_transformer_1, compiled_transformer_2 = compile_transformer()
         | 
| 101 | 
            +
             | 
| 102 | 
            +
                pipeline.transformer.forward = compiled_transformer_1
         | 
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
|  | |
| 103 | 
             
                drain_module_parameters(pipeline.transformer)
         | 
| 104 |  | 
| 105 | 
            +
                pipeline.transformer_2.forward = compiled_transformer_2
         | 
| 106 | 
             
                drain_module_parameters(pipeline.transformer_2)
         | 
