dn6 HF Staff commited on
Commit
5178ef1
·
verified ·
1 Parent(s): 3907010

Upload folder using huggingface_hub

Browse files
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Matrix Game 2.0 Modular Pipeline
2
+
3
+ ## Set Up
4
+
5
+ ```shell
6
+ uv venv -p 3.10
7
+ uv pip install -r requirements.txt
8
+ uv pip install git+https://github.com/huggingface/diffusers.git
9
+ ```
10
+
11
+ ## How to Use
12
+
13
+ ```python
14
+ import torch
15
+ from diffusers import ModularPipelineBlocks
16
+ from diffusers.utils import export_to_video, load_image
17
+ from diffusers.modular_pipelines import WanModularPipeline
18
+
19
+ class MatrixGameWanModularPipeline(WanModularPipeline):
20
+ """
21
+ A ModularPipeline for MatrixGameWan.
22
+
23
+ <Tip warning={true}>
24
+
25
+ This is an experimental feature and is likely to change in the future.
26
+
27
+ </Tip>
28
+ """
29
+
30
+ @property
31
+ def default_sample_height(self):
32
+ return 44
33
+
34
+ @property
35
+ def default_sample_width(self):
36
+ return 80
37
+
38
+ # Download custom blocks for the pipeline
39
+ blocks = ModularPipelineBlocks.from_pretrained(
40
+ "diffusers-internal-dev/matrix-game-2-modular",
41
+ trust_remote_code=True,
42
+ )
43
+
44
+ # Initialize the pipeline runtime using the config in the repo
45
+ pipe = MatrixGameWanModularPipeline(blocks, "diffusers-internal-dev/matrix-game-2-modular")
46
+
47
+ # Load the model components of the pipeline
48
+ pipe.load_components(
49
+ trust_remote_code=True,
50
+ device_map="cuda",
51
+ torch_dtype={"default": torch.bfloat16, "vae": torch.float32}
52
+ )
53
+
54
+ image = load_image("https://github.com/SkyworkAI/Matrix-Game/blob/main/Matrix-Game-2/demo_images/universal/0016.png?raw=true")
55
+ output = pipe(image=image, num_frames=141)
56
+ export_to_video(output.values['videos'][0], "matrix-game.mp4")
57
+ ```
58
+
59
+ ## Providing Actions as Inputs
60
+
61
+ Each action is represented as a string. The available actions are:
62
+
63
+ Motion Actions: ["forward", "left", "right"]
64
+ Camera Actions: ["camera_l", "camera_r", "camera_u", "camera_d"]
65
+ Compound Actions: Combinations of motion and camera actions with an `_` separating actions, e.g. "forward_left", "forward_left_camera_l"
66
+
67
+ ```py
68
+ image = load_image("https://github.com/SkyworkAI/Matrix-Game/blob/main/Matrix-Game-2/demo_images/universal/0016.png?raw=true")
69
+ output = pipe(image=image, actions=["forward", "camera_l"], num_frames=141)
70
+ export_to_video(output.values['videos'][0], "matrix-game.mp4")
71
+ ```
__init__.py ADDED
File without changes
before_denoise.py ADDED
@@ -0,0 +1,604 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from typing import List, Optional, Union, Dict
17
+
18
+ import torch
19
+
20
+ from diffusers import AutoencoderKLWan
21
+ from diffusers.configuration_utils import FrozenDict
22
+ from diffusers.schedulers import UniPCMultistepScheduler
23
+ from diffusers.utils import logging
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.video_processor import VideoProcessor
26
+ from diffusers.modular_pipelines import ModularPipeline, ModularPipelineBlocks, PipelineState
27
+ from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+ # Constants
32
+ FRAME_MULTIPLE = 4
33
+ DEFAULT_SAMPLES_PER_ACTION = 4
34
+ DEFAULT_FRAMES_PER_ACTION = 12
35
+
36
+ DEFAULT_MOUSE_DIM = 2
37
+ DEFAULT_KEYBOARD_DIM = 4
38
+
39
+ # Camera movement configuration
40
+ CAMERA_MOVEMENT_VALUE = 0.1
41
+ CAMERA_VALUE_MAP = {
42
+ "camera_up": [CAMERA_MOVEMENT_VALUE, 0],
43
+ "camera_down": [-CAMERA_MOVEMENT_VALUE, 0],
44
+ "camera_l": [0, -CAMERA_MOVEMENT_VALUE],
45
+ "camera_r": [0, CAMERA_MOVEMENT_VALUE],
46
+ "camera_ur": [CAMERA_MOVEMENT_VALUE, CAMERA_MOVEMENT_VALUE],
47
+ "camera_ul": [CAMERA_MOVEMENT_VALUE, -CAMERA_MOVEMENT_VALUE],
48
+ "camera_dr": [-CAMERA_MOVEMENT_VALUE, CAMERA_MOVEMENT_VALUE],
49
+ "camera_dl": [-CAMERA_MOVEMENT_VALUE, -CAMERA_MOVEMENT_VALUE],
50
+ }
51
+
52
+ # Define available actions
53
+ MOVEMENT_ACTIONS = ["forward", "left", "right"]
54
+ COMPOUND_MOVEMENTS = ["forward_left", "forward_right"]
55
+ CAMERA_ACTIONS = list(CAMERA_VALUE_MAP.keys())
56
+
57
+ # Keyboard action indices
58
+ KEYBOARD_ACTION_INDICES = {"forward": 0, "back": 1, "left": 2, "right": 3}
59
+
60
+
61
+ def sync_actions_to_frames(
62
+ actions: List[str],
63
+ num_frames: int,
64
+ min_duration: int = 12
65
+ ) -> List[Dict[str, Union[str, int]]]:
66
+ """
67
+ Synchronize a list of actions to fit exactly within the given number of frames
68
+ using equal distribution strategy.
69
+
70
+ Args:
71
+ actions: List of action names to perform
72
+ num_frames: Total frames to fill
73
+ min_duration: Minimum frames per action (should be multiple of frame_multiple)
74
+ frame_multiple: Actions must be multiples of this value
75
+
76
+ Returns:
77
+ List of action dictionaries with 'type', 'start_frame', and 'duration'
78
+ """
79
+
80
+ if not actions:
81
+ raise ValueError("No actions provided")
82
+
83
+ max_possible_actions = num_frames // DEFAULT_FRAMES_PER_ACTION
84
+ if len(actions) > max_possible_actions:
85
+ actions = actions[:max_possible_actions]
86
+
87
+ num_actions = len(actions)
88
+
89
+ frames_per_action = num_frames // num_actions
90
+ frames_per_action = (frames_per_action // FRAME_MULTIPLE) * FRAME_MULTIPLE
91
+ frames_per_action = max(DEFAULT_FRAMES_PER_ACTION, frames_per_action)
92
+
93
+ remaining_frames = num_frames - (frames_per_action * num_actions)
94
+ output = []
95
+ current_frame = 0
96
+
97
+ for i, action in enumerate(actions):
98
+ duration = frames_per_action if i != num_actions - 1 else num_frames - current_frame
99
+
100
+ output.append({
101
+ "action_type": action,
102
+ "start_frame": current_frame,
103
+ "duration": duration
104
+ })
105
+
106
+ current_frame += duration
107
+
108
+ return output
109
+
110
+
111
+ def actions_to_condition_tensors(actions, num_frames):
112
+ keyboard_conditions = torch.zeros((num_frames, DEFAULT_KEYBOARD_DIM))
113
+ mouse_conditions = torch.zeros((num_frames, DEFAULT_MOUSE_DIM))
114
+
115
+ for action in actions:
116
+ action_type = action['action_type']
117
+ start_frame = action['start_frame']
118
+ end_frame = start_frame + action['duration']
119
+
120
+ action_components = action_type.split("_")
121
+ for component in action_components:
122
+ if component in KEYBOARD_ACTION_INDICES:
123
+ action_idx = KEYBOARD_ACTION_INDICES[component]
124
+ keyboard_conditions[start_frame:end_frame, action_idx] = 1.0
125
+
126
+ if not "camera" in action_type:
127
+ continue
128
+
129
+ mouse_x = mouse_y = 0.0
130
+ for idx, component in enumerate(action_components):
131
+ if not action_components[idx] == "camera":
132
+ continue
133
+
134
+ camera_action = f"camera_{action_components[idx+1]}"
135
+ if camera_action not in CAMERA_VALUE_MAP:
136
+ continue
137
+
138
+ camera_values = CAMERA_VALUE_MAP[camera_action]
139
+ mouse_x += camera_values[0]
140
+ mouse_y += camera_values[1]
141
+
142
+ mouse_conditions[start_frame:end_frame, 0] = mouse_x
143
+ mouse_conditions[start_frame:end_frame, 1] = mouse_y
144
+
145
+ return keyboard_conditions, mouse_conditions
146
+
147
+
148
+ def _build_test_actions(
149
+ movement_actions: List[str],
150
+ compound_movements: List[str],
151
+ camera_actions: List[str],
152
+ ) -> List[str]:
153
+ """Build comprehensive list of test action combinations.
154
+
155
+ Args:
156
+ movement_actions: List of basic movement actions
157
+ compound_movements: List of compound movement actions
158
+ camera_actions: List of camera control actions
159
+
160
+ Returns:
161
+ List of all action combinations to test
162
+ """
163
+ # Create base test actions with repetition for variety
164
+ test_actions = compound_movements * 5 + camera_actions * 5 + movement_actions * 5
165
+
166
+ # Add combined movement + camera actions
167
+ for movement in movement_actions + compound_movements:
168
+ for camera in camera_actions:
169
+ combined_action = f"{movement}_{camera}"
170
+ test_actions.append(combined_action)
171
+
172
+ return test_actions
173
+
174
+
175
+ def generate_random_condition_tensors(num_frames: int) -> Dict[str, torch.Tensor]:
176
+ """Generate benchmark action sequences for testing.
177
+
178
+ Args:
179
+ num_frames: Total number of frames to generate
180
+ num_samples_per_action: Number of samples per action type
181
+
182
+ Returns:
183
+ Dictionary containing keyboard and mouse conditions for benchmark actions
184
+ """
185
+ # Build test action combinations
186
+ actions = _build_test_actions(
187
+ MOVEMENT_ACTIONS, COMPOUND_MOVEMENTS, CAMERA_ACTIONS
188
+ )
189
+ actions = sync_actions_to_frames(actions, num_frames)
190
+ return actions_to_condition_tensors(actions, num_frames)
191
+
192
+
193
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
194
+ def retrieve_timesteps(
195
+ scheduler,
196
+ num_inference_steps: Optional[int] = None,
197
+ device: Optional[Union[str, torch.device]] = None,
198
+ timesteps: Optional[List[int]] = None,
199
+ sigmas: Optional[List[float]] = None,
200
+ **kwargs,
201
+ ):
202
+ r"""
203
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
204
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
205
+
206
+ Args:
207
+ scheduler (`SchedulerMixin`):
208
+ The scheduler to get timesteps from.
209
+ num_inference_steps (`int`):
210
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
211
+ must be `None`.
212
+ device (`str` or `torch.device`, *optional*):
213
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
214
+ timesteps (`List[int]`, *optional*):
215
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
216
+ `num_inference_steps` and `sigmas` must be `None`.
217
+ sigmas (`List[float]`, *optional*):
218
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
219
+ `num_inference_steps` and `timesteps` must be `None`.
220
+
221
+ Returns:
222
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
223
+ second element is the number of inference steps.
224
+ """
225
+ if timesteps is not None and sigmas is not None:
226
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
227
+ if timesteps is not None:
228
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
229
+ if not accepts_timesteps:
230
+ raise ValueError(
231
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
232
+ f" timestep schedules. Please check whether you are using the correct scheduler."
233
+ )
234
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
235
+ timesteps = scheduler.timesteps
236
+ num_inference_steps = len(timesteps)
237
+ elif sigmas is not None:
238
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
239
+ if not accept_sigmas:
240
+ raise ValueError(
241
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
242
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
243
+ )
244
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
245
+ timesteps = scheduler.timesteps
246
+ num_inference_steps = len(timesteps)
247
+ else:
248
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
249
+ timesteps = scheduler.timesteps
250
+ return timesteps, num_inference_steps
251
+
252
+
253
+ def retrieve_latents(
254
+ encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample"
255
+ ):
256
+ if hasattr(encoder_output, "latent_dist") and sample_mode == "sample":
257
+ return encoder_output.latent_dist.sample(generator)
258
+ elif hasattr(encoder_output, "latent_dist") and sample_mode == "argmax":
259
+ return encoder_output.latent_dist.mode()
260
+ elif hasattr(encoder_output, "latents"):
261
+ return encoder_output.latents
262
+ else:
263
+ raise AttributeError("Could not access latents of provided encoder_output")
264
+
265
+
266
+ class MatrixGameWanActionInputStep(ModularPipelineBlocks):
267
+ model_name = "MatrixGameWan"
268
+
269
+ @property
270
+ def description(self) -> str:
271
+ return "Action Input step"
272
+
273
+ @property
274
+ def expected_components(self) -> List[ComponentSpec]:
275
+ return []
276
+
277
+ @property
278
+ def inputs(self) -> List[InputParam]:
279
+ return [InputParam("num_frames", type_hint=int, required=True), InputParam("actions", type_hint=List[str])]
280
+
281
+ @property
282
+ def intermediate_outputs(self) -> List[OutputParam]:
283
+ return [
284
+ OutputParam(
285
+ "keyboard_conditions",
286
+ type_hint=torch.Tensor,
287
+ description="image embeddings used to guide the image generation",
288
+ ),
289
+ OutputParam(
290
+ "mouse_conditions",
291
+ type_hint=torch.Tensor,
292
+ description="image embeddings used to guide the image generation",
293
+ )
294
+ ]
295
+
296
+ @torch.no_grad()
297
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
298
+ # Get inputs and intermediates
299
+ block_state = self.get_block_state(state)
300
+ block_state.device = components._execution_device
301
+ actions = block_state.actions
302
+
303
+ if actions is not None:
304
+ actions = sync_actions_to_frames(actions, block_state.num_frames)
305
+ keyboard_conditions, mouse_conditions = actions_to_condition_tensors(actions, block_state.num_frames)
306
+ else:
307
+ keyboard_conditions, mouse_conditions = generate_random_condition_tensors(block_state.num_frames)
308
+
309
+ block_state.keyboard_conditions = keyboard_conditions.to(block_state.device)
310
+ block_state.mouse_conditions = mouse_conditions.to(block_state.device)
311
+
312
+ # Add outputs
313
+ self.set_block_state(state, block_state)
314
+ return components, state
315
+
316
+
317
+ class MatrixGameWanSetTimestepsStep(ModularPipelineBlocks):
318
+ model_name = "MatrixGameWan"
319
+
320
+ @property
321
+ def expected_components(self) -> List[ComponentSpec]:
322
+ return [
323
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
324
+ ]
325
+
326
+ @property
327
+ def description(self) -> str:
328
+ return "Step that sets the scheduler's timesteps for inference"
329
+
330
+ @property
331
+ def inputs(self) -> List[InputParam]:
332
+ return [
333
+ InputParam("num_inference_steps", default=4),
334
+ InputParam("timesteps"),
335
+ InputParam("sigmas"),
336
+ ]
337
+
338
+ @property
339
+ def intermediate_outputs(self) -> List[OutputParam]:
340
+ return [
341
+ OutputParam("timesteps", type_hint=torch.Tensor, description="The timesteps to use for inference"),
342
+ OutputParam(
343
+ "num_inference_steps",
344
+ type_hint=int,
345
+ description="The number of denoising steps to perform at inference time",
346
+ ),
347
+ ]
348
+
349
+ @torch.no_grad()
350
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
351
+ block_state = self.get_block_state(state)
352
+ block_state.device = components._execution_device
353
+
354
+ block_state.timesteps, block_state.num_inference_steps = retrieve_timesteps(
355
+ components.scheduler,
356
+ block_state.num_inference_steps,
357
+ block_state.device,
358
+ block_state.timesteps,
359
+ block_state.sigmas,
360
+ )
361
+
362
+ self.set_block_state(state, block_state)
363
+ return components, state
364
+
365
+
366
+ class MatrixGameWanPrepareLatentsStep(ModularPipelineBlocks):
367
+ model_name = "MatrixGameWan"
368
+
369
+ @property
370
+ def expected_components(self) -> List[ComponentSpec]:
371
+ return [ComponentSpec("vae", AutoencoderKLWan),]
372
+
373
+ @property
374
+ def description(self) -> str:
375
+ return "Prepare latents step that prepares the latents for the text-to-video generation process"
376
+
377
+ @property
378
+ def inputs(self) -> List[InputParam]:
379
+ return [
380
+ InputParam("height", type_hint=int),
381
+ InputParam("width", type_hint=int),
382
+ InputParam("num_frames", type_hint=int),
383
+ InputParam("latents", type_hint=Optional[torch.Tensor]),
384
+ InputParam("num_videos_per_prompt", type_hint=int, default=1),
385
+ InputParam("generator"),
386
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
387
+ ]
388
+
389
+ @property
390
+ def intermediate_outputs(self) -> List[OutputParam]:
391
+ return [
392
+ OutputParam(
393
+ "latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
394
+ )
395
+ ]
396
+
397
+ @staticmethod
398
+ def check_inputs(components, block_state):
399
+ if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
400
+ block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
401
+ ):
402
+ raise ValueError(
403
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
404
+ )
405
+ if block_state.num_frames is not None and (
406
+ block_state.num_frames < 1 or (block_state.num_frames - 1) % components.vae_scale_factor_temporal != 0
407
+ ):
408
+ raise ValueError(
409
+ f"`num_frames` has to be greater than 0, and (num_frames - 1) must be divisible by {components.vae_scale_factor_temporal}, but got {block_state.num_frames}."
410
+ )
411
+
412
+ @staticmethod
413
+ def prepare_latents(
414
+ components,
415
+ batch_size: int,
416
+ num_channels_latents: int = 16,
417
+ height: int = 352,
418
+ width: int = 640,
419
+ num_frames: int = 81,
420
+ dtype: Optional[torch.dtype] = None,
421
+ device: Optional[torch.device] = None,
422
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
423
+ latents: Optional[torch.Tensor] = None,
424
+ ) -> torch.Tensor:
425
+ if latents is not None:
426
+ return latents.to(device=device, dtype=dtype)
427
+
428
+ num_latent_frames = (num_frames - 1) // components.vae_scale_factor_temporal + 1
429
+ shape = (
430
+ batch_size,
431
+ num_channels_latents,
432
+ num_latent_frames,
433
+ int(height) // components.vae_scale_factor_spatial,
434
+ int(width) // components.vae_scale_factor_spatial,
435
+ )
436
+ if isinstance(generator, list) and len(generator) != batch_size:
437
+ raise ValueError(
438
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
439
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
440
+ )
441
+
442
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
443
+ return latents
444
+
445
+ @torch.no_grad()
446
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
447
+ block_state = self.get_block_state(state)
448
+
449
+ block_state.height = block_state.height or components.default_height
450
+ block_state.width = block_state.width or components.default_width
451
+ block_state.num_frames = block_state.num_frames or components.default_num_frames
452
+ block_state.device = components._execution_device
453
+ block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
454
+ block_state.num_channels_latents = components.num_channels_latents
455
+
456
+ self.check_inputs(components, block_state)
457
+
458
+ block_state.latents = self.prepare_latents(
459
+ components,
460
+ 1,
461
+ block_state.num_channels_latents,
462
+ block_state.height,
463
+ block_state.width,
464
+ block_state.num_frames,
465
+ block_state.dtype,
466
+ block_state.device,
467
+ block_state.generator,
468
+ block_state.latents,
469
+ )
470
+
471
+ self.set_block_state(state, block_state)
472
+
473
+ return components, state
474
+
475
+
476
+ class MatrixGameWanPrepareImageMaskLatentsStep(ModularPipelineBlocks):
477
+ model_name = "MatrixGameWan"
478
+
479
+ @property
480
+ def expected_components(self) -> List[ComponentSpec]:
481
+ return [
482
+ ComponentSpec("vae", AutoencoderKLWan),
483
+ ComponentSpec("video_processor", VideoProcessor, config=FrozenDict({"vae_scale_factor": 8}))
484
+ ]
485
+
486
+ @property
487
+ def description(self) -> str:
488
+ return "Prepare latents step that prepares the latents for the text-to-video generation process"
489
+
490
+ @property
491
+ def inputs(self) -> List[InputParam]:
492
+ return [
493
+ InputParam("image"),
494
+ InputParam("height", type_hint=int),
495
+ InputParam("width", type_hint=int),
496
+ InputParam("num_frames", type_hint=int),
497
+ InputParam("image_mask_latents", type_hint=Optional[torch.Tensor]),
498
+ InputParam("num_videos_per_prompt", type_hint=int, default=1),
499
+ InputParam("generator"),
500
+ InputParam("dtype", type_hint=torch.dtype, description="The dtype of the model inputs"),
501
+ ]
502
+
503
+ @property
504
+ def intermediate_outputs(self) -> List[OutputParam]:
505
+ return [
506
+ OutputParam(
507
+ "image_mask_latents", type_hint=torch.Tensor, description="The initial latents to use for the denoising process"
508
+ )
509
+ ]
510
+
511
+ @staticmethod
512
+ def check_inputs(components, block_state):
513
+ if (block_state.height is not None and block_state.height % components.vae_scale_factor_spatial != 0) or (
514
+ block_state.width is not None and block_state.width % components.vae_scale_factor_spatial != 0
515
+ ):
516
+ raise ValueError(
517
+ f"`height` and `width` have to be divisible by {components.vae_scale_factor_spatial} but are {block_state.height} and {block_state.width}."
518
+ )
519
+
520
+ @staticmethod
521
+ @torch.no_grad()
522
+ def prepare_latents(
523
+ components,
524
+ image,
525
+ batch_size: int,
526
+ num_channels_latents: int = 16,
527
+ height: int = 352,
528
+ width: int = 640,
529
+ num_frames: int = 81,
530
+ dtype: Optional[torch.dtype] = None,
531
+ device: Optional[torch.device] = None,
532
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
533
+ latents: Optional[torch.Tensor] = None,
534
+ ) -> torch.Tensor:
535
+ if latents is not None:
536
+ return latents.to(device=device, dtype=dtype)
537
+
538
+ image = components.video_processor.preprocess(image, height, width).to(device, torch.float32)
539
+ image = image.unsqueeze(2) # [batch_size, channels, 1, height, width]
540
+
541
+ video_condition = torch.cat(
542
+ [image, image.new_zeros(image.shape[0], image.shape[1], num_frames - 1, height, width)], dim=2
543
+ )
544
+ video_condition = video_condition.to(device=device, dtype=components.vae.dtype)
545
+ latent_condition = retrieve_latents(components.vae.encode(video_condition), sample_mode="argmax")
546
+ latent_condition = latent_condition.repeat(batch_size, 1, 1, 1, 1)
547
+
548
+ latents_mean = (
549
+ torch.tensor(components.vae.config.latents_mean)
550
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
551
+ .to(device, dtype)
552
+ )
553
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(1, components.vae.config.z_dim, 1, 1, 1).to(
554
+ device, dtype
555
+ )
556
+ latent_condition = latent_condition.to(dtype)
557
+ latent_condition = (latent_condition - latents_mean) * latents_std
558
+
559
+ latent_height = height // components.vae_scale_factor_spatial
560
+ latent_width = width // components.vae_scale_factor_spatial
561
+
562
+ mask_lat_size = torch.ones(batch_size, 1, num_frames, latent_height, latent_width)
563
+ mask_lat_size[:, :, list(range(1, num_frames))] = 0
564
+
565
+ first_frame_mask = mask_lat_size[:, :, 0:1]
566
+ first_frame_mask = torch.repeat_interleave(first_frame_mask, dim=2, repeats=components.vae_scale_factor_temporal)
567
+
568
+ mask_lat_size = torch.concat([first_frame_mask, mask_lat_size[:, :, 1:, :]], dim=2)
569
+ mask_lat_size = mask_lat_size.view(batch_size, -1, components.vae_scale_factor_temporal, latent_height, latent_width)
570
+ mask_lat_size = mask_lat_size.transpose(1, 2).to(latent_condition.device)
571
+
572
+ image_mask_latents = torch.concat([mask_lat_size, latent_condition], dim=1)
573
+ return image_mask_latents
574
+
575
+ @torch.no_grad()
576
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
577
+ block_state = self.get_block_state(state)
578
+
579
+ block_state.height = block_state.height or components.default_height
580
+ block_state.width = block_state.width or components.default_width
581
+ block_state.num_frames = block_state.num_frames or components.default_num_frames
582
+ block_state.device = components._execution_device
583
+ block_state.dtype = torch.float32 # Wan latents should be torch.float32 for best quality
584
+ block_state.num_channels_latents = components.num_channels_latents
585
+
586
+ self.check_inputs(components, block_state)
587
+ block_state.image_mask_latents = self.prepare_latents(
588
+ components,
589
+ block_state.image,
590
+ 1,
591
+ block_state.num_channels_latents,
592
+ block_state.height,
593
+ block_state.width,
594
+ block_state.num_frames,
595
+ block_state.dtype,
596
+ block_state.device,
597
+ block_state.generator,
598
+ block_state.image_mask_latents,
599
+ )
600
+
601
+ self.set_block_state(state, block_state)
602
+
603
+ return components, state
604
+
block.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from diffusers.modular_pipelines import SequentialPipelineBlocks
2
+ from .modular_blocks import ACTION2VIDEO_BLOCKS
3
+
4
+ class MatrixGameWanBlocks(SequentialPipelineBlocks):
5
+ block_classes = list(ACTION2VIDEO_BLOCKS.copy().values())
6
+ block_names = list(ACTION2VIDEO_BLOCKS.copy().keys())
decoders.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List, Tuple, Union
16
+
17
+ import numpy as np
18
+ import PIL
19
+ import torch
20
+
21
+ from diffusers.configuration_utils import FrozenDict
22
+ from diffusers.models import AutoencoderKLWan
23
+ from diffusers.utils import logging
24
+ from diffusers.video_processor import VideoProcessor
25
+ from diffusers.modular_pipelines import ModularPipelineBlocks, PipelineState
26
+ from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
27
+
28
+
29
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
30
+
31
+
32
+ class MatrixGameWanDecodeStep(ModularPipelineBlocks):
33
+ model_name = "MatrixGameWan"
34
+
35
+ @property
36
+ def expected_components(self) -> List[ComponentSpec]:
37
+ return [
38
+ ComponentSpec("vae", AutoencoderKLWan, repo="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae"),
39
+ ComponentSpec(
40
+ "video_processor",
41
+ VideoProcessor,
42
+ config=FrozenDict({"vae_scale_factor": 8}),
43
+ default_creation_method="from_config",
44
+ ),
45
+ ]
46
+
47
+ @property
48
+ def description(self) -> str:
49
+ return "Step that decodes the denoised latents into images"
50
+
51
+ @property
52
+ def inputs(self) -> List[Tuple[str, Any]]:
53
+ return [
54
+ InputParam("output_type", default="pil"),
55
+ InputParam(
56
+ "latents",
57
+ required=True,
58
+ type_hint=torch.Tensor,
59
+ description="The denoised latents from the denoising step",
60
+ )
61
+ ]
62
+
63
+ @property
64
+ def intermediate_outputs(self) -> List[str]:
65
+ return [
66
+ OutputParam(
67
+ "videos",
68
+ type_hint=Union[List[List[PIL.Image.Image]], List[torch.Tensor], List[np.ndarray]],
69
+ description="The generated videos, can be a PIL.Image.Image, torch.Tensor or a numpy array",
70
+ )
71
+ ]
72
+
73
+ @torch.no_grad()
74
+ def __call__(self, components, state: PipelineState) -> PipelineState:
75
+ block_state = self.get_block_state(state)
76
+ vae_dtype = components.vae.dtype
77
+
78
+ if not block_state.output_type == "latent":
79
+ latents = block_state.latents
80
+ latents_mean = (
81
+ torch.tensor(components.vae.config.latents_mean)
82
+ .view(1, components.vae.config.z_dim, 1, 1, 1)
83
+ .to(latents.device, latents.dtype)
84
+ )
85
+ latents_std = 1.0 / torch.tensor(components.vae.config.latents_std).view(
86
+ 1, components.vae.config.z_dim, 1, 1, 1
87
+ ).to(latents.device, latents.dtype)
88
+ latents = latents / latents_std + latents_mean
89
+ latents = latents.to(vae_dtype)
90
+ block_state.videos = components.vae.decode(latents, return_dict=False)[0]
91
+ else:
92
+ block_state.videos = block_state.latents
93
+
94
+ block_state.videos = components.video_processor.postprocess_video(
95
+ block_state.videos, output_type=block_state.output_type
96
+ )
97
+
98
+ self.set_block_state(state, block_state)
99
+
100
+ return components, state
denoise.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import Any, List, Tuple
16
+
17
+ import torch
18
+
19
+ from diffusers.configuration_utils import FrozenDict
20
+ from diffusers.guiders import ClassifierFreeGuidance
21
+ from diffusers.models import AutoModel, WanTransformer3DModel
22
+ from diffusers.schedulers import UniPCMultistepScheduler
23
+ from diffusers.utils import logging
24
+ from diffusers.utils.torch_utils import randn_tensor
25
+ from diffusers.modular_pipelines import (
26
+ BlockState,
27
+ LoopSequentialPipelineBlocks,
28
+ ModularPipelineBlocks,
29
+ PipelineState,
30
+ ModularPipeline
31
+ )
32
+ from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, InputParam, OutputParam
33
+
34
+
35
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
36
+
37
+
38
+ class MatrixGameWanLoopDenoiser(ModularPipelineBlocks):
39
+ model_name = "MatrixGameWan"
40
+ frame_seq_length = 880
41
+
42
+ @property
43
+ def expected_components(self) -> List[ComponentSpec]:
44
+ return [
45
+ ComponentSpec(
46
+ "guider",
47
+ ClassifierFreeGuidance,
48
+ config=FrozenDict({"guidance_scale": 5.0}),
49
+ default_creation_method="from_config",
50
+ ),
51
+ ComponentSpec("transformer", AutoModel),
52
+ ]
53
+
54
+ @property
55
+ def description(self) -> str:
56
+ return (
57
+ "Step within the denoising loop that denoise the latents with guidance. "
58
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
59
+ "object (e.g. `MatrixGameWanDenoiseLoopWrapper`)"
60
+ )
61
+
62
+ @property
63
+ def inputs(self) -> List[Tuple[str, Any]]:
64
+ return [
65
+ InputParam("attention_kwargs"),
66
+ InputParam(
67
+ "latents",
68
+ required=True,
69
+ type_hint=torch.Tensor,
70
+ description="The initial latents to use for the denoising process. Can be generated in prepare_latent step.",
71
+ ),
72
+ InputParam(
73
+ "image_mask_latents",
74
+ required=True,
75
+ type_hint=torch.Tensor,
76
+ ),
77
+ InputParam(
78
+ "image_embeds",
79
+ required=True,
80
+ type_hint=torch.Tensor,
81
+ ),
82
+ InputParam(
83
+ "keyboard_conditions",
84
+ required=True,
85
+ type_hint=torch.Tensor,
86
+ ),
87
+ InputParam(
88
+ "mouse_conditions",
89
+ required=True,
90
+ type_hint=torch.Tensor,
91
+ ),
92
+ InputParam(
93
+ "num_inference_steps",
94
+ required=True,
95
+ type_hint=int,
96
+ default=4,
97
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
98
+ ),
99
+ InputParam(
100
+ kwargs_type="guider_input_fields",
101
+ description=(
102
+ "All conditional model inputs that need to be prepared with guider. "
103
+ "It should contain prompt_embeds/negative_prompt_embeds. "
104
+ "Please add `kwargs_type=guider_input_fields` to their parameter spec (`OutputParam`) when they are created and added to the pipeline state"
105
+ ),
106
+ ),
107
+ ]
108
+
109
+ @torch.no_grad()
110
+ def __call__(
111
+ self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor
112
+ ) -> PipelineState:
113
+ cond_concat = block_state.image_mask_latents
114
+ keyboard_conditions = block_state.keyboard_conditions
115
+ mouse_conditions = block_state.mouse_conditions
116
+ visual_context = block_state.image_embeds
117
+
118
+ transformer_dtype = components.transformer.dtype
119
+ components.guider.set_state(step=i, num_inference_steps=block_state.num_inference_steps, timestep=t)
120
+
121
+ # Prepare mini‐batches according to guidance method and `guider_input_fields`
122
+ # Each guider_state_batch will have .prompt_embeds, .time_ids, text_embeds, image_embeds.
123
+ # e.g. for CFG, we prepare two batches: one for uncond, one for cond
124
+ # for first batch, guider_state_batch.prompt_embeds correspond to block_state.prompt_embeds
125
+ # for second batch, guider_state_batch.prompt_embeds correspond to block_state.negative_prompt_embeds
126
+ guider_state = components.guider.prepare_inputs(block_state, {})
127
+
128
+ # run the denoiser for each guidance batch
129
+ for guider_state_batch in guider_state:
130
+ components.guider.prepare_models(components.transformer)
131
+ cond_kwargs = guider_state_batch.as_dict()
132
+
133
+ # Predict the noise residual
134
+ # store the noise_pred in guider_state_batch so that we can apply guidance across all batches
135
+ guider_state_batch.noise_pred = components.transformer(
136
+ x=block_state.latents.to(transformer_dtype),
137
+ t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block),
138
+ visual_context=visual_context.to(transformer_dtype),
139
+ cond_concat=cond_concat.to(transformer_dtype),
140
+ keyboard_cond=keyboard_conditions,
141
+ mouse_cond=mouse_conditions,
142
+ kv_cache=block_state.kv_cache,
143
+ kv_cache_mouse=block_state.kv_cache_mouse,
144
+ kv_cache_keyboard=block_state.kv_cache_keyboard,
145
+ crossattn_cache=block_state.kv_cache_cross_attn,
146
+ current_start=block_state.current_frame_idx * self.frame_seq_length,
147
+ num_frames_per_block=block_state.num_frames_per_block,
148
+ )[0]
149
+ components.guider.cleanup_models(components.transformer)
150
+
151
+ # Perform guidance
152
+ block_state.noise_pred = components.guider(guider_state)[0]
153
+
154
+ return components, block_state
155
+
156
+
157
+ class MatrixGameWanLoopAfterDenoiser(ModularPipelineBlocks):
158
+ model_name = "MatrixGameWan"
159
+
160
+ @property
161
+ def expected_components(self) -> List[ComponentSpec]:
162
+ return [
163
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
164
+ ]
165
+
166
+ @property
167
+ def description(self) -> str:
168
+ return (
169
+ "step within the denoising loop that update the latents. "
170
+ "This block should be used to compose the `sub_blocks` attribute of a `LoopSequentialPipelineBlocks` "
171
+ "object (e.g. `MatrixGameWanDenoiseLoopWrapper`)"
172
+ )
173
+
174
+ @property
175
+ def inputs(self) -> List[Tuple[str, Any]]:
176
+ return []
177
+
178
+ @property
179
+ def intermediate_inputs(self) -> List[str]:
180
+ return [
181
+ InputParam("generator"),
182
+ ]
183
+
184
+ @property
185
+ def intermediate_outputs(self) -> List[OutputParam]:
186
+ return [OutputParam("latents", type_hint=torch.Tensor, description="The denoised latents")]
187
+
188
+ @torch.no_grad()
189
+ def __call__(self, components: ModularPipeline, block_state: BlockState, i: int, t: torch.Tensor):
190
+ # Perform scheduler step using the predicted output
191
+ latents_dtype = block_state.latents.dtype
192
+
193
+ step_index = components.scheduler.index_for_timestep(t)
194
+ sigma_t = components.scheduler.sigmas[step_index]
195
+
196
+ latents = block_state.latents.double() - sigma_t.double() * block_state.noise_pred.double()
197
+ block_state.latents = latents
198
+
199
+ if block_state.latents.dtype != latents_dtype:
200
+ block_state.latents = block_state.latents.to(latents_dtype)
201
+
202
+ return components, block_state
203
+
204
+
205
+ class MatrixGameWanDenoiseLoopWrapper(LoopSequentialPipelineBlocks):
206
+ model_name = "MatrixGameWan"
207
+ frame_seq_length = 880
208
+ local_attn_size = 6
209
+ num_transformer_blocks = 30
210
+
211
+ def _initialize_kv_cache(self, batch_size, dtype, device):
212
+ """
213
+ Initialize a Per-GPU KV cache for the Wan model.
214
+ """
215
+ cache = []
216
+ if self.local_attn_size != -1:
217
+ # Use the local attention size to compute the KV cache size
218
+ kv_cache_size = self.local_attn_size * self.frame_seq_length
219
+ else:
220
+ # Use the default KV cache size
221
+ kv_cache_size = 15 * 1 * self.frame_seq_length # 32760
222
+
223
+ for _ in range(self.num_transformer_blocks):
224
+ cache.append({
225
+ "k": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device),
226
+ "v": torch.zeros((batch_size, kv_cache_size, 12, 128), dtype=dtype, device=device),
227
+ "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
228
+ "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
229
+ })
230
+
231
+ return cache # always store the clean cache
232
+
233
+ def _initialize_kv_cache_mouse_and_keyboard(self, batch_size, dtype, device):
234
+ """
235
+ Initialize a Per-GPU KV cache for the Wan model.
236
+ """
237
+ kv_cache_mouse = []
238
+ kv_cache_keyboard = []
239
+ if self.local_attn_size != -1:
240
+ kv_cache_size = self.local_attn_size
241
+ else:
242
+ kv_cache_size = 15 * 1
243
+ for _ in range(self.num_transformer_blocks):
244
+ kv_cache_keyboard.append({
245
+ "k": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
246
+ "v": torch.zeros([batch_size, kv_cache_size, 16, 64], dtype=dtype, device=device),
247
+ "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
248
+ "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
249
+ })
250
+ kv_cache_mouse.append({
251
+ "k": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
252
+ "v": torch.zeros([batch_size * self.frame_seq_length, kv_cache_size, 16, 64], dtype=dtype, device=device),
253
+ "global_end_index": torch.tensor([0], dtype=torch.long, device=device),
254
+ "local_end_index": torch.tensor([0], dtype=torch.long, device=device)
255
+ })
256
+ return kv_cache_mouse, kv_cache_keyboard # always store the clean cache
257
+
258
+ def _initialize_crossattn_cache(self, batch_size, dtype, device):
259
+ """
260
+ Initialize a Per-GPU cross-attention cache for the Wan model.
261
+ """
262
+ crossattn_cache = []
263
+
264
+ for _ in range(self.num_transformer_blocks):
265
+ crossattn_cache.append({
266
+ "k": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
267
+ "v": torch.zeros([batch_size, 257, 12, 128], dtype=dtype, device=device),
268
+ "is_init": False
269
+ })
270
+
271
+ return crossattn_cache
272
+
273
+ @property
274
+ def description(self) -> str:
275
+ return (
276
+ "Pipeline block that iteratively denoise the latents over `timesteps`. "
277
+ "The specific steps with each iteration can be customized with `sub_blocks` attributes"
278
+ )
279
+
280
+ @property
281
+ def loop_expected_components(self) -> List[ComponentSpec]:
282
+ return [
283
+ ComponentSpec(
284
+ "guider",
285
+ ClassifierFreeGuidance,
286
+ config=FrozenDict({"guidance_scale": 5.0}),
287
+ default_creation_method="from_config",
288
+ ),
289
+ ComponentSpec("scheduler", UniPCMultistepScheduler),
290
+ ComponentSpec("transformer", AutoModel),
291
+ ]
292
+
293
+ @property
294
+ def loop_inputs(self) -> List[InputParam]:
295
+ return [
296
+ InputParam(
297
+ "timesteps",
298
+ required=True,
299
+ type_hint=torch.Tensor,
300
+ description="The timesteps to use for the denoising process. Can be generated in set_timesteps step.",
301
+ ),
302
+ InputParam(
303
+ "num_inference_steps",
304
+ required=True,
305
+ type_hint=int,
306
+ description="The number of inference steps to use for the denoising process. Can be generated in set_timesteps step.",
307
+ ),
308
+ InputParam(
309
+ "num_frames_per_block",
310
+ required=True,
311
+ type_hint=int,
312
+ default=3,
313
+ ),
314
+ ]
315
+
316
+ @torch.no_grad()
317
+ def __call__(
318
+ self, components: ModularPipeline, state: PipelineState
319
+ ) -> PipelineState:
320
+ block_state = self.get_block_state(state)
321
+ transformer_dtype = components.transformer.dtype
322
+
323
+ num_frames_per_block = block_state.num_frames_per_block
324
+ latents = block_state.latents.to(transformer_dtype)
325
+ image_mask_latents = block_state.image_mask_latents.to(transformer_dtype)
326
+ mouse_conditions = block_state.mouse_conditions.unsqueeze(0).to(transformer_dtype)
327
+ keyboard_conditions = block_state.keyboard_conditions.unsqueeze(0).to(transformer_dtype)
328
+ visual_context = block_state.image_embeds
329
+
330
+ batch_size, num_channels, num_frames, height, width = latents.shape
331
+ output = torch.zeros(
332
+ (batch_size, num_channels, num_frames, height, width),
333
+ device=latents.device,
334
+ dtype=latents.dtype,
335
+ )
336
+
337
+ current_frame_idx = 0
338
+ num_blocks = num_frames // num_frames_per_block
339
+
340
+ kv_cache = self._initialize_kv_cache(batch_size, latents.dtype, latents.device)
341
+ kv_cache_mouse, kv_cache_keyboard = self._initialize_kv_cache_mouse_and_keyboard(batch_size, latents.dtype, latents.device)
342
+ kv_cache_cross_attn = self._initialize_crossattn_cache(batch_size, latents.dtype, latents.device)
343
+
344
+ block_state.kv_cache = kv_cache
345
+ block_state.kv_cache_mouse = kv_cache_mouse
346
+ block_state.kv_cache_keyboard = kv_cache_keyboard
347
+ block_state.kv_cache_cross_attn = kv_cache_cross_attn
348
+
349
+ for _ in range(num_blocks):
350
+ block_state.current_frame_idx = current_frame_idx
351
+ block_state.image_mask_latents = image_mask_latents[
352
+ :, :, current_frame_idx : current_frame_idx + num_frames_per_block
353
+ ]
354
+ cond_idx = 1 + 4 * (current_frame_idx + num_frames_per_block - 1)
355
+ block_state.mouse_conditions = mouse_conditions[:, :cond_idx]
356
+ block_state.keyboard_conditions = keyboard_conditions[:, :cond_idx]
357
+
358
+ block_state.latents = latents[
359
+ :, :, current_frame_idx : current_frame_idx + num_frames_per_block
360
+ ]
361
+ for i, t in enumerate(block_state.timesteps):
362
+ components, block_state = self.loop_step(
363
+ components, block_state, i=i, t=t
364
+ )
365
+
366
+ if i < (block_state.num_inference_steps - 1):
367
+ t1 = components.scheduler.timesteps[i+1]
368
+ block_state.latents = components.scheduler.add_noise(
369
+ block_state.latents,
370
+ randn_tensor(
371
+ block_state.latents.shape,
372
+ device=block_state.latents.device,
373
+ dtype=block_state.latents.dtype
374
+ ),
375
+ t1.expand(block_state.latents.shape[0])
376
+ )
377
+
378
+ output[
379
+ :, :, current_frame_idx : current_frame_idx + num_frames_per_block
380
+ ] = block_state.latents
381
+
382
+ components.transformer(
383
+ x=block_state.latents,
384
+ t=t.expand(block_state.latents.shape[0], block_state.num_frames_per_block) * 0.0,
385
+ visual_context=visual_context,
386
+ cond_concat=block_state.image_mask_latents,
387
+ keyboard_cond=block_state.keyboard_conditions,
388
+ mouse_cond=block_state.mouse_conditions,
389
+ kv_cache=block_state.kv_cache,
390
+ kv_cache_mouse=block_state.kv_cache_mouse,
391
+ kv_cache_keyboard=block_state.kv_cache_keyboard,
392
+ crossattn_cache=block_state.kv_cache_cross_attn,
393
+ current_start=block_state.current_frame_idx * self.frame_seq_length,
394
+ num_frames_per_block=block_state.num_frames_per_block,
395
+ )[0]
396
+ current_frame_idx += num_frames_per_block
397
+
398
+ block_state.latents = output
399
+ self.set_block_state(state, block_state)
400
+
401
+ return components, state
402
+
403
+
404
+ class MatrixGameWanDenoiseStep(MatrixGameWanDenoiseLoopWrapper):
405
+ block_classes = [
406
+ MatrixGameWanLoopDenoiser,
407
+ MatrixGameWanLoopAfterDenoiser,
408
+ ]
409
+ block_names = ["denoiser", "after_denoiser"]
410
+
411
+ @property
412
+ def description(self) -> str:
413
+ return (
414
+ "Denoise step that iteratively denoise the latents. \n"
415
+ "Its loop logic is defined in `MatrixGameWanDenoiseLoopWrapper.__call__` method \n"
416
+ "At each iteration, it runs blocks defined in `sub_blocks` sequencially:\n"
417
+ " - `MatrixGameWanLoopDenoiser`\n"
418
+ " - `MatrixGameWanLoopAfterDenoiser`\n"
419
+ "This block supports both text2vid tasks."
420
+ )
encoders.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ from typing import Dict, List
15
+
16
+ import random
17
+ import torch
18
+ from torchvision.transforms import v2
19
+
20
+ from diffusers.utils import logging
21
+ from diffusers import ModularPipeline, ModularPipelineBlocks
22
+ from diffusers.modular_pipelines import PipelineState
23
+ from diffusers.modular_pipelines.modular_pipeline_utils import ComponentSpec, ConfigSpec, InputParam, OutputParam
24
+ from transformers import CLIPVisionModelWithProjection, CLIPImageProcessor
25
+
26
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
27
+
28
+
29
+ class MatrixGameWanImageEncoderStep(ModularPipelineBlocks):
30
+ model_name = "MatrixGameWan"
31
+
32
+ @property
33
+ def description(self) -> str:
34
+ return "Image Encoder step that generate image_embeddings to guide the video generation"
35
+
36
+ @property
37
+ def expected_components(self) -> List[ComponentSpec]:
38
+ return [
39
+ ComponentSpec(
40
+ "image_encoder",
41
+ CLIPVisionModelWithProjection,
42
+ repo="laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
43
+ ),
44
+ ComponentSpec(
45
+ "image_processor",
46
+ CLIPImageProcessor,
47
+ repo="Wan-AI/Wan2.1-I2V-14B-720P-Diffusers",
48
+ subfolder="image_processor"
49
+ ),
50
+ ]
51
+
52
+ @property
53
+ def expected_configs(self) -> List[ConfigSpec]:
54
+ return []
55
+
56
+ @property
57
+ def inputs(self) -> List[InputParam]:
58
+ return [
59
+ InputParam("image"),
60
+ ]
61
+
62
+ @property
63
+ def intermediate_outputs(self) -> List[OutputParam]:
64
+ return [
65
+ OutputParam(
66
+ "image_embeds",
67
+ type_hint=torch.Tensor,
68
+ description="image embeddings used to guide the image generation",
69
+ )
70
+ ]
71
+
72
+ def encode_image(self, components, image):
73
+ device = components._execution_device
74
+ image = components.image_processor(images=image, return_tensors="pt").to(device)
75
+ image_embeds = components.image_encoder(**image, output_hidden_states=True)
76
+ return image_embeds.hidden_states[-2]
77
+
78
+ @torch.no_grad()
79
+ def __call__(self, components: ModularPipeline, state: PipelineState) -> PipelineState:
80
+ # Get inputs and intermediates
81
+ block_state = self.get_block_state(state)
82
+ block_state.device = components._execution_device
83
+ #image_tensor = preprocess(block_state.image)
84
+ #image_tensor = image_tensor.to(block_state.device)
85
+ block_state.image_embeds = self.encode_image(components, block_state.image)
86
+
87
+ # Add outputs
88
+ self.set_block_state(state, block_state)
89
+ return components, state
modular_blocks.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from diffusers.utils import logging
16
+ from diffusers.modular_pipelines import SequentialPipelineBlocks
17
+ from diffusers.modular_pipelines.modular_pipeline_utils import InsertableDict
18
+
19
+ from .before_denoise import (
20
+ MatrixGameWanActionInputStep,
21
+ MatrixGameWanPrepareImageMaskLatentsStep,
22
+ MatrixGameWanPrepareLatentsStep,
23
+ MatrixGameWanSetTimestepsStep,
24
+ )
25
+ from .decoders import MatrixGameWanDecodeStep
26
+ from .encoders import MatrixGameWanImageEncoderStep
27
+ from .denoise import MatrixGameWanDenoiseStep
28
+
29
+
30
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
31
+
32
+
33
+ class MatrixGameWanBeforeDenoiseStep(SequentialPipelineBlocks):
34
+ block_classes = [
35
+ MatrixGameWanActionInputStep,
36
+ MatrixGameWanSetTimestepsStep,
37
+ MatrixGameWanPrepareLatentsStep,
38
+ MatrixGameWanPrepareImageMaskLatentsStep,
39
+ ]
40
+ block_names = ["action_input", "set_timesteps", "prepare_latents", "prepare_mask_latents"]
41
+
42
+ @property
43
+ def description(self):
44
+ return (
45
+ "Before denoise step that prepare the inputs for the denoise step.\n"
46
+ + "This is a sequential pipeline blocks:\n"
47
+ + " - `MatrixGameWanInputStep` is used to adjust the batch size of the model inputs\n"
48
+ + " - `MatrixGameWanSetTimestepsStep` is used to set the timesteps\n"
49
+ + " - `MatrixGameWanPrepareLatentsStep` is used to prepare the latents\n"
50
+ )
51
+
52
+ ACTION2VIDEO_BLOCKS = InsertableDict(
53
+ [
54
+ ("action_input", MatrixGameWanActionInputStep),
55
+ ("image_encoder", MatrixGameWanImageEncoderStep),
56
+ ("set_timesteps", MatrixGameWanSetTimestepsStep),
57
+ ("prepare_latents", MatrixGameWanPrepareLatentsStep),
58
+ ("prepare_masked_latents", MatrixGameWanPrepareImageMaskLatentsStep),
59
+ ("denoise", MatrixGameWanDenoiseStep),
60
+ ("decode", MatrixGameWanDecodeStep),
61
+ ]
62
+ )
63
+
64
+ ALL_BLOCKS = {
65
+ "action2video": ACTION2VIDEO_BLOCKS,
66
+ }
modular_config.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MatrixGameWanBlocks",
3
+ "_diffusers_version": "0.36.0.dev0",
4
+ "auto_map": {
5
+ "ModularPipelineBlocks": "block.MatrixGameWanBlocks"
6
+ }
7
+ }
modular_model_index.json ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_blocks_class_name": "SequentialPipelineBlocks",
3
+ "_class_name": "MatrixGameWanModularPipeline",
4
+ "_diffusers_version": "0.36.0.dev0",
5
+ "image_encoder": [
6
+ null,
7
+ null,
8
+ {
9
+ "repo": "laion/CLIP-ViT-H-14-laion2B-s32B-b79K",
10
+ "revision": null,
11
+ "type_hint": [
12
+ "transformers",
13
+ "CLIPVisionModelWithProjection"
14
+ ],
15
+ "variant": null
16
+ }
17
+ ],
18
+ "scheduler": [
19
+ null,
20
+ null,
21
+ {
22
+ "repo": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
23
+ "revision": null,
24
+ "subfolder": "scheduler",
25
+ "type_hint": [
26
+ "diffusers",
27
+ "UniPCMultistepScheduler"
28
+ ],
29
+ "variant": null
30
+ }
31
+ ],
32
+ "transformer": [
33
+ null,
34
+ null,
35
+ {
36
+ "repo": "diffusers-internal-dev/matrix-game-2-modular",
37
+ "revision": null,
38
+ "subfolder": "transformer",
39
+ "type_hint": [
40
+ "diffusers",
41
+ "AutoModel"
42
+ ],
43
+ "variant": null
44
+ }
45
+ ],
46
+ "vae": [
47
+ "diffusers",
48
+ "AutoencoderKLWan",
49
+ {
50
+ "repo": "Wan-AI/Wan2.1-T2V-1.3B-Diffusers",
51
+ "revision": null,
52
+ "subfolder": "vae",
53
+ "type_hint": [
54
+ "diffusers",
55
+ "AutoencoderKLWan"
56
+ ],
57
+ "variant": null
58
+ }
59
+ ]
60
+ }
modular_pipeline.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+
16
+ from diffusers.loaders import WanLoraLoaderMixin
17
+ from diffusers.utils import logging
18
+ from diffusers.modular_pipelines import ModularPipeline
19
+
20
+
21
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
+
23
+
24
+ class MatrixGameWanModularPipeline(ModularPipeline, WanLoraLoaderMixin):
25
+ """
26
+ A ModularPipeline for MatrixGameWan.
27
+
28
+ <Tip warning={true}>
29
+
30
+ This is an experimental feature and is likely to change in the future.
31
+
32
+ </Tip>
33
+ """
34
+
35
+ @property
36
+ def default_height(self):
37
+ return self.default_sample_height * self.vae_scale_factor_spatial
38
+
39
+ @property
40
+ def default_width(self):
41
+ return self.default_sample_width * self.vae_scale_factor_spatial
42
+
43
+ @property
44
+ def default_num_frames(self):
45
+ return (self.default_sample_num_frames - 1) * self.vae_scale_factor_temporal + 1
46
+
47
+ @property
48
+ def default_sample_height(self):
49
+ return 44
50
+
51
+ @property
52
+ def default_sample_width(self):
53
+ return 80
54
+
55
+ @property
56
+ def default_sample_num_frames(self):
57
+ return 21
58
+
59
+ @property
60
+ def vae_scale_factor_spatial(self):
61
+ vae_scale_factor = 8
62
+ if hasattr(self, "vae") and self.vae is not None:
63
+ vae_scale_factor = 2 ** len(self.vae.temperal_downsample)
64
+ return vae_scale_factor
65
+
66
+ @property
67
+ def vae_scale_factor_temporal(self):
68
+ vae_scale_factor = 4
69
+ if hasattr(self, "vae") and self.vae is not None:
70
+ vae_scale_factor = 2 ** sum(self.vae.temperal_downsample)
71
+ return vae_scale_factor
72
+
73
+ @property
74
+ def num_channels_transformer(self):
75
+ num_channels_transformer = 16
76
+ if hasattr(self, "transformer") and self.transformer is not None:
77
+ num_channels_transformer = self.transformer.config.in_channels
78
+ return num_channels_transformer
79
+
80
+ @property
81
+ def num_channels_latents(self):
82
+ num_channels_latents = 16
83
+ if hasattr(self, "vae") and self.vae is not None:
84
+ num_channels_latents = self.vae.config.z_dim
85
+ return num_channels_latents
requirements.txt ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==1.10.1
2
+ einops==0.8.1
3
+ flash-attn @ https://github.com/Dao-AILab/flash-attention/releases/download/v2.7.4.post1/flash_attn-2.7.4.post1+cu12torch2.6cxx11abiTRUE-cp310-cp310-linux_x86_64.whl
4
+ hf-transfer==0.1.9
5
+ hf-xet==1.1.8
6
+ huggingface-hub==0.34.4
7
+ imageio==2.37.0
8
+ imageio-ffmpeg==0.6.0
9
+ safetensors==0.6.2
10
+ sentencepiece==0.2.1
11
+ torch==2.7.0
12
+ torchao==0.12.0
13
+ torchvision==0.22.0
14
+ transformers==4.55.4
test_pipeline.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
4
+
5
+ import torch
6
+ from modular_pipeline import MatrixGameWanModularPipeline
7
+ from modular_blocks import ACTION2VIDEO_BLOCKS
8
+ from diffusers.modular_pipelines import SequentialPipelineBlocks
9
+ from diffusers import AutoModel
10
+ from diffusers.utils import load_image, export_to_video
11
+
12
+ blocks = SequentialPipelineBlocks.from_blocks_dict(ACTION2VIDEO_BLOCKS.copy())
13
+ pipe = MatrixGameWanModularPipeline(blocks)
14
+ pipe.load_components(torch_dtype=torch.bfloat16)
15
+ pipe.load_components(["vae"], repo="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="vae", torch_dtype=torch.float32)
16
+ pipe.load_components(["scheduler"], repo="Wan-AI/Wan2.1-T2V-1.3B-Diffusers", subfolder="scheduler")
17
+
18
+ transformer = AutoModel.from_pretrained("./transformer", trust_remote_code=True, torch_dtype=torch.bfloat16)
19
+ pipe.transformer = transformer
20
+ pipe.to("cuda")
21
+
22
+ image = load_image("/home/dhruv/matrix-game-workspace/Matrix-Game/Matrix-Game-2/demo_images/universal/0000.png")
23
+ output = pipe(image=image, num_frames=141)
24
+ export_to_video(output.values['videos'][0], "modular-matrix-game.mp4")
transformer/__init__.py ADDED
File without changes
transformer/action_module.py ADDED
@@ -0,0 +1,1148 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Any, List, Tuple, Optional, Union, Dict
2
+ from einops import rearrange
3
+ from flash_attn import flash_attn_func
4
+ import torch
5
+ import torch.nn as nn
6
+ import math
7
+ from torch.nn.attention.flex_attention import flex_attention
8
+
9
+ try:
10
+ import flash_attn
11
+
12
+ except:
13
+ from flash_attn import flash_attn_func
14
+
15
+ FLASH_ATTN_3_AVAILABLE = False
16
+
17
+
18
+ DISABLE_COMPILE = False # get os env
19
+ flex_attention = torch.compile(
20
+ flex_attention, dynamic=False, mode="max-autotune-no-cudagraphs"
21
+ )
22
+
23
+ import torch
24
+ from typing import Union, Tuple, List
25
+
26
+
27
+ def _to_tuple(x, dim=2):
28
+ if isinstance(x, int):
29
+ return (x,) * dim
30
+ elif len(x) == dim:
31
+ return x
32
+ else:
33
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
34
+
35
+
36
+ def get_meshgrid_nd(start, *args, dim=2):
37
+ """
38
+ Get n-D meshgrid with start, stop and num.
39
+
40
+ Args:
41
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
42
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
43
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
44
+ n-tuples.
45
+ *args: See above.
46
+ dim (int): Dimension of the meshgrid. Defaults to 2.
47
+
48
+ Returns:
49
+ grid (np.ndarray): [dim, ...]
50
+ """
51
+ if len(args) == 0:
52
+ # start is grid_size
53
+ num = _to_tuple(start, dim=dim)
54
+ start = (0,) * dim
55
+ stop = num
56
+ elif len(args) == 1:
57
+ # start is start, args[0] is stop, step is 1
58
+ start = _to_tuple(start, dim=dim)
59
+ stop = _to_tuple(args[0], dim=dim)
60
+ num = [stop[i] - start[i] for i in range(dim)]
61
+ elif len(args) == 2:
62
+ # start is start, args[0] is stop, args[1] is num
63
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
64
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
65
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
66
+ else:
67
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
68
+
69
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
70
+ axis_grid = []
71
+ for i in range(dim):
72
+ a, b, n = start[i], stop[i], num[i]
73
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32, device=torch.cuda.current_device())[:n]
74
+ axis_grid.append(g)
75
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
76
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
77
+
78
+ return grid
79
+
80
+
81
+ #################################################################################
82
+ # Rotary Positional Embedding Functions #
83
+ #################################################################################
84
+ # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
85
+
86
+
87
+ def reshape_for_broadcast(
88
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
89
+ x: torch.Tensor,
90
+ head_first=False,
91
+ ):
92
+ """
93
+ Reshape frequency tensor for broadcasting it with another tensor.
94
+
95
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
96
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
97
+
98
+ Notes:
99
+ When using FlashMHAModified, head_first should be False.
100
+ When using Attention, head_first should be True.
101
+
102
+ Args:
103
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
104
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
105
+ head_first (bool): head dimension first (except batch dim) or not.
106
+
107
+ Returns:
108
+ torch.Tensor: Reshaped frequency tensor.
109
+
110
+ Raises:
111
+ AssertionError: If the frequency tensor doesn't match the expected shape.
112
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
113
+ """
114
+ ndim = x.ndim
115
+ assert 0 <= 1 < ndim
116
+
117
+ if isinstance(freqs_cis, tuple):
118
+ # freqs_cis: (cos, sin) in real space
119
+ if head_first:
120
+ assert freqs_cis[0].shape == (
121
+ x.shape[-2],
122
+ x.shape[-1],
123
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
124
+ shape = [
125
+ d if i == ndim - 2 or i == ndim - 1 else 1
126
+ for i, d in enumerate(x.shape)
127
+ ]
128
+ else:
129
+ # assert freqs_cis[0].shape == (
130
+ # x.shape[1],
131
+ # x.shape[-1],
132
+ # ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
133
+ # shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
134
+ shape = [1, freqs_cis[0].shape[0], 1, freqs_cis[0].shape[1]]
135
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
136
+ else:
137
+ # freqs_cis: values in complex space
138
+ if head_first:
139
+ assert freqs_cis.shape == (
140
+ x.shape[-2],
141
+ x.shape[-1],
142
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
143
+ shape = [
144
+ d if i == ndim - 2 or i == ndim - 1 else 1
145
+ for i, d in enumerate(x.shape)
146
+ ]
147
+ else:
148
+ assert freqs_cis.shape == (
149
+ x.shape[1],
150
+ x.shape[-1],
151
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
152
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
153
+ return freqs_cis.view(*shape)
154
+
155
+
156
+ def rotate_half(x):
157
+ x_real, x_imag = (
158
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
159
+ ) # [B, S, H, D//2]
160
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
161
+
162
+
163
+ def apply_rotary_emb(
164
+ xq: torch.Tensor,
165
+ xk: torch.Tensor,
166
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
167
+ head_first: bool = False,
168
+ start_offset: int = 0,
169
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
170
+ """
171
+ Apply rotary embeddings to input tensors using the given frequency tensor.
172
+
173
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
174
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
175
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
176
+ returned as real tensors.
177
+
178
+ Args:
179
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
180
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
181
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
182
+ head_first (bool): head dimension first (except batch dim) or not.
183
+
184
+ Returns:
185
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
186
+
187
+ """
188
+ # print(freqs_cis[0].shape, xq.shape, xk.shape)
189
+ xk_out = None
190
+ assert isinstance(freqs_cis, tuple)
191
+ if isinstance(freqs_cis, tuple):
192
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
193
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
194
+ # real * cos - imag * sin
195
+ # imag * cos + real * sin
196
+ xq_out = (xq.float() * cos[:, start_offset:start_offset + xq.shape[1], :, :] + rotate_half(xq.float()) * sin[:, start_offset:start_offset + xq.shape[1], :, :]).type_as(xq)
197
+ xk_out = (xk.float() * cos[:, start_offset:start_offset + xk.shape[1], :, :] + rotate_half(xk.float()) * sin[:, start_offset:start_offset + xk.shape[1], :, :]).type_as(xk)
198
+ else:
199
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
200
+ xq_ = torch.view_as_complex(
201
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
202
+ ) # [B, S, H, D//2]
203
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
204
+ xq.device
205
+ ) # [S, D//2] --> [1, S, 1, D//2]
206
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
207
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
208
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
209
+ xk_ = torch.view_as_complex(
210
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
211
+ ) # [B, S, H, D//2]
212
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
213
+
214
+ return xq_out, xk_out
215
+
216
+
217
+ def get_nd_rotary_pos_embed(
218
+ rope_dim_list,
219
+ start,
220
+ *args,
221
+ theta=10000.0,
222
+ use_real=False,
223
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
224
+ interpolation_factor: Union[float, List[float]] = 1.0,
225
+ ):
226
+ """
227
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
228
+
229
+ Args:
230
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
231
+ sum(rope_dim_list) should equal to head_dim of attention layer.
232
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
233
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
234
+ *args: See above.
235
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
236
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
237
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
238
+ part and an imaginary part separately.
239
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
240
+
241
+ Returns:
242
+ pos_embed (torch.Tensor): [HW, D/2]
243
+ """
244
+
245
+ grid = get_meshgrid_nd(
246
+ start, *args, dim=len(rope_dim_list)
247
+ ) # [3, W, H, D] / [2, W, H]
248
+
249
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
250
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
251
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
252
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
253
+ assert len(theta_rescale_factor) == len(
254
+ rope_dim_list
255
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
256
+
257
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
258
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
259
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
260
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
261
+ assert len(interpolation_factor) == len(
262
+ rope_dim_list
263
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
264
+
265
+ # use 1/ndim of dimensions to encode grid_axis
266
+ embs = []
267
+ for i in range(len(rope_dim_list)):
268
+ emb = get_1d_rotary_pos_embed(
269
+ rope_dim_list[i],
270
+ grid[i].reshape(-1),
271
+ theta,
272
+ use_real=use_real,
273
+ theta_rescale_factor=theta_rescale_factor[i],
274
+ interpolation_factor=interpolation_factor[i],
275
+ ) # 2 x [WHD, rope_dim_list[i]]
276
+ embs.append(emb)
277
+
278
+ if use_real:
279
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
280
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
281
+ return cos, sin
282
+ else:
283
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
284
+ return emb
285
+
286
+
287
+ def get_1d_rotary_pos_embed(
288
+ dim: int,
289
+ pos: Union[torch.FloatTensor, int],
290
+ theta: float = 10000.0,
291
+ use_real: bool = False,
292
+ theta_rescale_factor: float = 1.0,
293
+ interpolation_factor: float = 1.0,
294
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
295
+ """
296
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
297
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
298
+
299
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
300
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
301
+ The returned tensor contains complex values in complex64 data type.
302
+
303
+ Args:
304
+ dim (int): Dimension of the frequency tensor.
305
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
306
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
307
+ use_real (bool, optional): If True, return real part and imaginary part separately.
308
+ Otherwise, return complex numbers.
309
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
310
+
311
+ Returns:
312
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
313
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
314
+ """
315
+ if isinstance(pos, int):
316
+ pos = torch.arange(pos, device=torch.cuda.current_device()).float()
317
+
318
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
319
+ # has some connection to NTK literature
320
+ if theta_rescale_factor != 1.0:
321
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
322
+
323
+ freqs = 1.0 / (
324
+ theta ** (torch.arange(0, dim, 2, device=torch.cuda.current_device())[: (dim // 2)].float() / dim)
325
+ ) # [D/2]
326
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
327
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
328
+ if use_real:
329
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
330
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
331
+ return freqs_cos, freqs_sin
332
+ else:
333
+ freqs_cis = torch.polar(
334
+ torch.ones_like(freqs), freqs
335
+ ) # complex64 # [S, D/2]
336
+ return freqs_cis
337
+
338
+
339
+ class MatrixGameWanRMSNorm(nn.Module):
340
+ def __init__(self, dim, eps=1e-5):
341
+ super().__init__()
342
+ self.dim = dim
343
+ self.eps = eps
344
+ self.weight = nn.Parameter(torch.ones(dim))
345
+
346
+ def forward(self, x):
347
+ r"""
348
+ Args:
349
+ x(Tensor): Shape [B, L, C]
350
+ """
351
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
352
+
353
+
354
+ class ActionModule(nn.Module):
355
+ """
356
+ action module from https://arxiv.org/pdf/2501.08325
357
+ 鼠标控制信号的输入是一个 L*D 的向量
358
+ 键盘同样
359
+ """
360
+
361
+ def __init__(
362
+ self,
363
+ mouse_dim_in: int = 2,
364
+ keyboard_dim_in: int = 6,
365
+ hidden_size: int = 128,
366
+ img_hidden_size: int = 1536,
367
+ keyboard_hidden_dim: int = 1024,
368
+ mouse_hidden_dim: int = 1024,
369
+ vae_time_compression_ratio: int = 4,
370
+ windows_size: int = 3,
371
+ heads_num: int = 16,
372
+ patch_size: list = [1, 2, 2],
373
+ qk_norm: bool = True,
374
+ qkv_bias: bool = False,
375
+ rope_dim_list: list = [8, 28, 28],
376
+ rope_theta=256,
377
+ mouse_qk_dim_list=[8, 28, 28],
378
+ enable_mouse=True,
379
+ enable_keyboard=True,
380
+ local_attn_size=6,
381
+ blocks=[],
382
+ ):
383
+ device = None
384
+
385
+ super().__init__()
386
+ self.local_attn_size = local_attn_size
387
+ self.enable_mouse = enable_mouse
388
+ self.enable_keyboard = enable_keyboard
389
+
390
+ self.rope_dim_list = rope_dim_list
391
+ self.rope_theta = rope_theta
392
+ if self.enable_keyboard:
393
+ self.keyboard_embed = nn.Sequential(
394
+ nn.Linear(keyboard_dim_in, hidden_size, bias=True),
395
+ nn.SiLU(),
396
+ nn.Linear(hidden_size, hidden_size, bias=True),
397
+ )
398
+
399
+ self.mouse_qk_dim_list = mouse_qk_dim_list
400
+ self.heads_num = heads_num
401
+ if self.enable_mouse:
402
+ c = mouse_hidden_dim
403
+ self.mouse_mlp = torch.nn.Sequential(
404
+ torch.nn.Linear(
405
+ mouse_dim_in * vae_time_compression_ratio * windows_size
406
+ + img_hidden_size,
407
+ c,
408
+ bias=True,
409
+ ),
410
+ torch.nn.GELU(approximate="tanh"),
411
+ torch.nn.Linear(c, c),
412
+ torch.nn.LayerNorm(c),
413
+ )
414
+
415
+ head_dim = c // heads_num
416
+ self.t_qkv = nn.Linear(c, c * 3, bias=qkv_bias)
417
+ self.img_attn_q_norm = (
418
+ MatrixGameWanRMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity()
419
+ )
420
+ self.img_attn_k_norm = (
421
+ MatrixGameWanRMSNorm(head_dim, eps=1e-6) if qk_norm else nn.Identity()
422
+ )
423
+ self.proj_mouse = nn.Linear(c, img_hidden_size, bias=qkv_bias)
424
+
425
+ if self.enable_keyboard:
426
+ head_dim_key = keyboard_hidden_dim // heads_num
427
+ self.key_attn_q_norm = (
428
+ MatrixGameWanRMSNorm(head_dim_key, eps=1e-6) if qk_norm else nn.Identity()
429
+ )
430
+ self.key_attn_k_norm = (
431
+ MatrixGameWanRMSNorm(head_dim_key, eps=1e-6) if qk_norm else nn.Identity()
432
+ )
433
+
434
+ self.mouse_attn_q = nn.Linear(
435
+ img_hidden_size, keyboard_hidden_dim, bias=qkv_bias
436
+ )
437
+ self.keyboard_attn_kv = nn.Linear(
438
+ hidden_size * windows_size * vae_time_compression_ratio,
439
+ keyboard_hidden_dim * 2,
440
+ bias=qkv_bias,
441
+ )
442
+ self.proj_keyboard = nn.Linear(
443
+ keyboard_hidden_dim, img_hidden_size, bias=qkv_bias
444
+ )
445
+
446
+ self.vae_time_compression_ratio = vae_time_compression_ratio
447
+ self.windows_size = windows_size
448
+ self.patch_size = patch_size
449
+ self.freqs_cos, self.freqs_sin = self.get_rotary_pos_embed(
450
+ 7500,
451
+ self.patch_size[1],
452
+ self.patch_size[2],
453
+ 64,
454
+ self.mouse_qk_dim_list,
455
+ start_offset=0,
456
+ )
457
+
458
+ def patchify(self, x, patch_size):
459
+ """
460
+ x : (N C T H W)
461
+ """
462
+ pt, ph, pw = self.patch_size
463
+ t, h, w = x.shape[2] // pt, x.shape[3] // ph, x.shape[4] // pw
464
+ c = x.shape[1]
465
+ x = x.reshape(shape=(x.shape[0], c, t, pt, h, ph, w, pw))
466
+ x = torch.einsum("nctohpwq->nthwcopq", x)
467
+ x = x.reshape(shape=(x.shape[0], t * h * w, c * pt * ph * pw))
468
+ return x
469
+
470
+ def unpatchify(self, x, t, h, w, patch_size):
471
+ """
472
+ x: (N, T, patch_size**2 * C)
473
+ imgs: (N, H, W, C)
474
+ """
475
+ c = x.shape[2] // patch_size # self.unpatchify_channels
476
+ pt, ph, pw = self.patch_size
477
+ assert t * h * w == x.shape[1]
478
+
479
+ x = x.reshape(shape=(x.shape[0], t, h, w, c, pt, ph, pw))
480
+ x = torch.einsum("nthwcopq->nctohpwq", x)
481
+ imgs = x.reshape(shape=(x.shape[0], c, t * pt, h * ph, w * pw))
482
+
483
+ return imgs
484
+
485
+ def get_rotary_pos_embed(
486
+ self, video_length, height, width, head_dim, rope_dim_list=None, start_offset=0
487
+ ):
488
+ target_ndim = 3
489
+ ndim = 5 - 2
490
+
491
+ latents_size = [video_length + start_offset, height, width]
492
+
493
+ if isinstance(self.patch_size, int):
494
+ assert all(s % self.patch_size == 0 for s in latents_size), (
495
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), "
496
+ f"but got {latents_size}."
497
+ )
498
+ rope_sizes = [s // self.patch_size for s in latents_size]
499
+ elif isinstance(self.patch_size, list):
500
+ assert all(
501
+ s % self.patch_size[idx] == 0 for idx, s in enumerate(latents_size)
502
+ ), (
503
+ f"Latent size(last {ndim} dimensions) should be divisible by patch size({self.patch_size}), "
504
+ f"but got {latents_size}."
505
+ )
506
+ rope_sizes = [
507
+ s // self.patch_size[idx] for idx, s in enumerate(latents_size)
508
+ ]
509
+
510
+ if len(rope_sizes) != target_ndim:
511
+ rope_sizes = [1] * (target_ndim - len(rope_sizes)) + rope_sizes # time axis
512
+
513
+ if rope_dim_list is None:
514
+ rope_dim_list = [head_dim // target_ndim for _ in range(target_ndim)]
515
+ assert (
516
+ sum(rope_dim_list) == head_dim
517
+ ), "sum(rope_dim_list) should equal to head_dim of attention layer"
518
+ freqs_cos, freqs_sin = get_nd_rotary_pos_embed(
519
+ rope_dim_list,
520
+ rope_sizes,
521
+ theta=self.rope_theta,
522
+ use_real=True,
523
+ theta_rescale_factor=1,
524
+ )
525
+ return freqs_cos[
526
+ -video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] :
527
+ ], freqs_sin[
528
+ -video_length * rope_sizes[1] * rope_sizes[2] // self.patch_size[0] :
529
+ ]
530
+
531
+ def forward(
532
+ self,
533
+ x,
534
+ tt,
535
+ th,
536
+ tw,
537
+ mouse_condition=None,
538
+ keyboard_condition=None,
539
+ block_mask_mouse=None,
540
+ block_mask_keyboard=None,
541
+ is_causal=False,
542
+ kv_cache_mouse=None,
543
+ kv_cache_keyboard=None,
544
+ start_frame=0,
545
+ use_rope_keyboard=True,
546
+ num_frame_per_block=3,
547
+ ):
548
+ """
549
+ hidden_states: B, tt*th*tw, C
550
+ mouse_condition: B, N_frames, C1
551
+ keyboard_condition: B, N_frames, C2
552
+ """
553
+ assert use_rope_keyboard == True
554
+
555
+ B, N_frames, C = keyboard_condition.shape
556
+
557
+ assert tt * th * tw == x.shape[1]
558
+ assert (
559
+ (N_frames - 1) + self.vae_time_compression_ratio
560
+ ) % self.vae_time_compression_ratio == 0
561
+ N_feats = int((N_frames - 1) / self.vae_time_compression_ratio) + 1
562
+
563
+ # Defined freqs_cis early so it's available for both mouse and keyboard
564
+ freqs_cis = (self.freqs_cos, self.freqs_sin)
565
+
566
+ assert (
567
+ N_feats == tt and ((is_causal and kv_cache_mouse == None) or not is_causal)
568
+ ) or (
569
+ (N_frames - 1) // self.vae_time_compression_ratio + 1 == start_frame + num_frame_per_block and is_causal
570
+ )
571
+
572
+ if self.enable_mouse and mouse_condition is not None:
573
+ hidden_states = rearrange(
574
+ x, "B (T S) C -> (B S) T C", T=tt, S=th * tw
575
+ ) # 65*272*480 -> 17*(272//16)*(480//16) -> 8670
576
+ B, N_frames, C = mouse_condition.shape
577
+ else:
578
+ hidden_states = x
579
+ # padding
580
+
581
+ pad_t = self.vae_time_compression_ratio * self.windows_size
582
+ if self.enable_mouse and mouse_condition is not None:
583
+ pad = mouse_condition[:, 0:1, :].expand(-1, pad_t, -1)
584
+ mouse_condition = torch.cat([pad, mouse_condition], dim=1)
585
+ if is_causal and kv_cache_mouse is not None:
586
+ mouse_condition = mouse_condition[
587
+ :,
588
+ self.vae_time_compression_ratio
589
+ * (N_feats - num_frame_per_block - self.windows_size)
590
+ + pad_t :,
591
+ :,
592
+ ]
593
+ group_mouse = [
594
+ mouse_condition[
595
+ :,
596
+ self.vae_time_compression_ratio * (i - self.windows_size)
597
+ + pad_t : i * self.vae_time_compression_ratio + pad_t,
598
+ :,
599
+ ]
600
+ for i in range(num_frame_per_block)
601
+ ]
602
+ else:
603
+ group_mouse = [
604
+ mouse_condition[
605
+ :,
606
+ self.vae_time_compression_ratio * (i - self.windows_size)
607
+ + pad_t : i * self.vae_time_compression_ratio + pad_t,
608
+ :,
609
+ ]
610
+ for i in range(N_feats)
611
+ ]
612
+
613
+ group_mouse = torch.stack(group_mouse, dim=1)
614
+
615
+ S = th * tw
616
+ group_mouse = group_mouse.unsqueeze(-1).expand(
617
+ B, num_frame_per_block, pad_t, C, S
618
+ )
619
+ group_mouse = group_mouse.permute(0, 4, 1, 2, 3).reshape(
620
+ B * S, num_frame_per_block, pad_t * C
621
+ )
622
+
623
+ group_mouse = torch.cat([hidden_states, group_mouse], dim=-1)
624
+ group_mouse = self.mouse_mlp(group_mouse)
625
+
626
+ # qkv
627
+ mouse_qkv = self.t_qkv(group_mouse)
628
+ q, k, v = rearrange(
629
+ mouse_qkv, "B L (K H D) -> K B L H D", K=3, H=self.heads_num
630
+ ) # BHW F H C
631
+ q = self.img_attn_q_norm(q).to(v)
632
+ k = self.img_attn_k_norm(k).to(v)
633
+ # rope embd
634
+
635
+ # freqs_cis = (self.freqs_cos, self.freqs_sin)
636
+
637
+ q, k = apply_rotary_emb(
638
+ q, k, freqs_cis, start_offset=start_frame, head_first=False
639
+ )
640
+ ## TODO: adding cache here
641
+ if is_causal:
642
+ if kv_cache_mouse is None:
643
+ assert (
644
+ q.shape[0] == k.shape[0] and q.shape[0] % 880 == 0
645
+ ) # == 880, f"{q.shape[0]},{k.shape[0]}"
646
+ padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1]
647
+ padded_q = torch.cat(
648
+ [
649
+ q,
650
+ torch.zeros(
651
+ [q.shape[0], padded_length, q.shape[2], q.shape[3]],
652
+ device=q.device,
653
+ dtype=v.dtype,
654
+ ),
655
+ ],
656
+ dim=1,
657
+ )
658
+ padded_k = torch.cat(
659
+ [
660
+ k,
661
+ torch.zeros(
662
+ [k.shape[0], padded_length, k.shape[2], k.shape[3]],
663
+ device=k.device,
664
+ dtype=v.dtype,
665
+ ),
666
+ ],
667
+ dim=1,
668
+ )
669
+ padded_v = torch.cat(
670
+ [
671
+ v,
672
+ torch.zeros(
673
+ [v.shape[0], padded_length, v.shape[2], v.shape[3]],
674
+ device=v.device,
675
+ dtype=v.dtype,
676
+ ),
677
+ ],
678
+ dim=1,
679
+ )
680
+ attn = flex_attention(
681
+ query=padded_q.transpose(2, 1), # after: B, HW, F, C
682
+ key=padded_k.transpose(2, 1),
683
+ value=padded_v.transpose(2, 1),
684
+ block_mask=block_mask_mouse,
685
+ )[:, :, :-padded_length].transpose(2, 1)
686
+ else:
687
+ current_start = start_frame
688
+ current_end = current_start + q.shape[1]
689
+
690
+ assert q.shape[1] == num_frame_per_block
691
+ sink_size = 0
692
+ max_attention_size = self.local_attn_size
693
+ sink_tokens = sink_size * 1
694
+ kv_cache_size = kv_cache_mouse["k"].shape[1]
695
+ num_new_tokens = q.shape[1]
696
+
697
+ if (current_end > kv_cache_mouse["global_end_index"].item()) and (
698
+ num_new_tokens + kv_cache_mouse["local_end_index"].item()
699
+ > kv_cache_size
700
+ ):
701
+ num_evicted_tokens = (
702
+ num_new_tokens
703
+ + kv_cache_mouse["local_end_index"].item()
704
+ - kv_cache_size
705
+ )
706
+ num_rolled_tokens = (
707
+ kv_cache_mouse["local_end_index"].item()
708
+ - num_evicted_tokens
709
+ - sink_tokens
710
+ )
711
+ kv_cache_mouse["k"][
712
+ :, sink_tokens : sink_tokens + num_rolled_tokens
713
+ ] = kv_cache_mouse["k"][
714
+ :,
715
+ sink_tokens + num_evicted_tokens : sink_tokens
716
+ + num_evicted_tokens
717
+ + num_rolled_tokens,
718
+ ].clone()
719
+ kv_cache_mouse["v"][
720
+ :, sink_tokens : sink_tokens + num_rolled_tokens
721
+ ] = kv_cache_mouse["v"][
722
+ :,
723
+ sink_tokens + num_evicted_tokens : sink_tokens
724
+ + num_evicted_tokens
725
+ + num_rolled_tokens,
726
+ ].clone()
727
+ # Insert the new keys/values at the end
728
+ local_end_index = (
729
+ kv_cache_mouse["local_end_index"].item()
730
+ + current_end
731
+ - kv_cache_mouse["global_end_index"].item()
732
+ - num_evicted_tokens
733
+ )
734
+ local_start_index = local_end_index - num_new_tokens
735
+ else:
736
+ local_end_index = (
737
+ kv_cache_mouse["local_end_index"].item()
738
+ + current_end
739
+ - kv_cache_mouse["global_end_index"].item()
740
+ )
741
+ local_start_index = local_end_index - num_new_tokens
742
+
743
+ kv_cache_mouse["k"][:, local_start_index:local_end_index] = k
744
+ kv_cache_mouse["v"][:, local_start_index:local_end_index] = v
745
+
746
+ if FLASH_ATTN_3_AVAILABLE:
747
+ attn, attn_prob = flash_attn.flash_attn_func(
748
+ q,
749
+ kv_cache_mouse["k"][
750
+ :,
751
+ max(
752
+ 0, local_end_index - max_attention_size
753
+ ) : local_end_index,
754
+ ],
755
+ kv_cache_mouse["v"][
756
+ :,
757
+ max(
758
+ 0, local_end_index - max_attention_size
759
+ ) : local_end_index,
760
+ ],
761
+ )
762
+ else:
763
+ attn = flash_attn_func(
764
+ q,
765
+ kv_cache_mouse["k"][
766
+ :,
767
+ max(
768
+ 0, local_end_index - max_attention_size
769
+ ) : local_end_index,
770
+ ],
771
+ kv_cache_mouse["v"][
772
+ :,
773
+ max(
774
+ 0, local_end_index - max_attention_size
775
+ ) : local_end_index,
776
+ ],
777
+ )
778
+ kv_cache_mouse["global_end_index"].fill_(current_end)
779
+ kv_cache_mouse["local_end_index"].fill_(local_end_index)
780
+ else:
781
+ attn = flash_attn_func(
782
+ q, # 880, f, 16, 64
783
+ k, # 880, f, 16, 64
784
+ v, # 880, f, 16, 64
785
+ )
786
+ # Compute cu_squlens and max_seqlen for flash attention
787
+ # qk norm
788
+ attn = rearrange(attn, "(b S) T h d -> b (T S) (h d)", b=B)
789
+
790
+ hidden_states = rearrange(x, "(B S) T C -> B (T S) C", B=B)
791
+ attn = self.proj_mouse(attn)
792
+
793
+ hidden_states = hidden_states + attn
794
+
795
+ if self.enable_keyboard and keyboard_condition is not None:
796
+ pad = keyboard_condition[:, 0:1, :].expand(-1, pad_t, -1)
797
+ keyboard_condition = torch.cat([pad, keyboard_condition], dim=1)
798
+ if is_causal and kv_cache_keyboard is not None:
799
+ keyboard_condition = keyboard_condition[
800
+ :,
801
+ self.vae_time_compression_ratio
802
+ * (N_feats - num_frame_per_block - self.windows_size)
803
+ + pad_t :,
804
+ :,
805
+ ] # keyboard_condition[:, self.vae_time_compression_ratio*(start_frame - self.windows_size) + pad_t:start_frame * self.vae_time_compression_ratio + pad_t,:]
806
+ keyboard_condition = self.keyboard_embed(keyboard_condition)
807
+ group_keyboard = [
808
+ keyboard_condition[
809
+ :,
810
+ self.vae_time_compression_ratio * (i - self.windows_size)
811
+ + pad_t : i * self.vae_time_compression_ratio + pad_t,
812
+ :,
813
+ ]
814
+ for i in range(num_frame_per_block)
815
+ ]
816
+ else:
817
+ keyboard_condition = self.keyboard_embed(keyboard_condition)
818
+ group_keyboard = [
819
+ keyboard_condition[
820
+ :,
821
+ self.vae_time_compression_ratio * (i - self.windows_size)
822
+ + pad_t : i * self.vae_time_compression_ratio + pad_t,
823
+ :,
824
+ ]
825
+ for i in range(N_feats)
826
+ ]
827
+ group_keyboard = torch.stack(group_keyboard, dim=1) # B F RW C
828
+ group_keyboard = group_keyboard.reshape(
829
+ shape=(group_keyboard.shape[0], group_keyboard.shape[1], -1)
830
+ )
831
+ # apply cross attn
832
+ mouse_q = self.mouse_attn_q(hidden_states)
833
+ keyboard_kv = self.keyboard_attn_kv(group_keyboard)
834
+
835
+ B, L, HD = mouse_q.shape
836
+ D = HD // self.heads_num
837
+ q = mouse_q.view(B, L, self.heads_num, D)
838
+
839
+ B, L, KHD = keyboard_kv.shape
840
+ k, v = keyboard_kv.view(B, L, 2, self.heads_num, D).permute(2, 0, 1, 3, 4)
841
+
842
+ # Compute cu_squlens and max_seqlen for flash attention
843
+ # qk norm
844
+
845
+ q = self.key_attn_q_norm(q).to(v)
846
+ k = self.key_attn_k_norm(k).to(v)
847
+ S = th * tw
848
+ assert S == 880
849
+ # position embed
850
+ if use_rope_keyboard:
851
+ B, TS, H, D = q.shape
852
+ T_ = TS // S
853
+ q = q.view(B, T_, S, H, D).transpose(1, 2).reshape(B * S, T_, H, D)
854
+ q, k = apply_rotary_emb(
855
+ q, k, freqs_cis, start_offset=start_frame, head_first=False
856
+ )
857
+
858
+ k1, k2, k3, k4 = k.shape
859
+ k = k.expand(S, k2, k3, k4)
860
+ v = v.expand(S, k2, k3, k4)
861
+
862
+ if is_causal:
863
+ if kv_cache_keyboard is None:
864
+ assert q.shape[0] == k.shape[0] and q.shape[0] % 880 == 0
865
+
866
+ padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1]
867
+ padded_q = torch.cat(
868
+ [
869
+ q,
870
+ torch.zeros(
871
+ [q.shape[0], padded_length, q.shape[2], q.shape[3]],
872
+ device=q.device,
873
+ dtype=v.dtype,
874
+ ),
875
+ ],
876
+ dim=1,
877
+ )
878
+ padded_k = torch.cat(
879
+ [
880
+ k,
881
+ torch.zeros(
882
+ [k.shape[0], padded_length, k.shape[2], k.shape[3]],
883
+ device=k.device,
884
+ dtype=v.dtype,
885
+ ),
886
+ ],
887
+ dim=1,
888
+ )
889
+ padded_v = torch.cat(
890
+ [
891
+ v,
892
+ torch.zeros(
893
+ [v.shape[0], padded_length, v.shape[2], v.shape[3]],
894
+ device=v.device,
895
+ dtype=v.dtype,
896
+ ),
897
+ ],
898
+ dim=1,
899
+ )
900
+ attn = flex_attention(
901
+ query=padded_q.transpose(2, 1), # after: B, HW, F, C
902
+ key=padded_k.transpose(2, 1),
903
+ value=padded_v.transpose(2, 1),
904
+ block_mask=block_mask_keyboard,
905
+ )[:, :, :-padded_length].transpose(2, 1)
906
+ else:
907
+ current_start = start_frame
908
+ current_end = current_start + k.shape[1]
909
+ assert k.shape[1] == num_frame_per_block
910
+ sink_size = 0
911
+ max_attention_size = self.local_attn_size
912
+ sink_tokens = sink_size * 1
913
+ kv_cache_size = kv_cache_keyboard["k"].shape[1]
914
+ num_new_tokens = k.shape[1]
915
+
916
+ if (
917
+ current_end > kv_cache_keyboard["global_end_index"].item()
918
+ ) and (
919
+ num_new_tokens + kv_cache_keyboard["local_end_index"].item()
920
+ > kv_cache_size
921
+ ):
922
+ num_evicted_tokens = (
923
+ num_new_tokens
924
+ + kv_cache_keyboard["local_end_index"].item()
925
+ - kv_cache_size
926
+ )
927
+ num_rolled_tokens = (
928
+ kv_cache_keyboard["local_end_index"].item()
929
+ - num_evicted_tokens
930
+ - sink_tokens
931
+ )
932
+ kv_cache_keyboard["k"][
933
+ :, sink_tokens : sink_tokens + num_rolled_tokens
934
+ ] = kv_cache_keyboard["k"][
935
+ :,
936
+ sink_tokens + num_evicted_tokens : sink_tokens
937
+ + num_evicted_tokens
938
+ + num_rolled_tokens,
939
+ ].clone()
940
+ kv_cache_keyboard["v"][
941
+ :, sink_tokens : sink_tokens + num_rolled_tokens
942
+ ] = kv_cache_keyboard["v"][
943
+ :,
944
+ sink_tokens + num_evicted_tokens : sink_tokens
945
+ + num_evicted_tokens
946
+ + num_rolled_tokens,
947
+ ].clone()
948
+ # Insert the new keys/values at the end
949
+ local_end_index = (
950
+ kv_cache_keyboard["local_end_index"].item()
951
+ + current_end
952
+ - kv_cache_keyboard["global_end_index"].item()
953
+ - num_evicted_tokens
954
+ )
955
+ local_start_index = local_end_index - num_new_tokens
956
+ else:
957
+ local_end_index = (
958
+ kv_cache_keyboard["local_end_index"].item()
959
+ + current_end
960
+ - kv_cache_keyboard["global_end_index"].item()
961
+ )
962
+ local_start_index = local_end_index - num_new_tokens
963
+ assert (
964
+ k.shape[0] == 880
965
+ ) # BS == 1 or the cache should not be saved/ load method should be modified
966
+ kv_cache_keyboard["k"][:, local_start_index:local_end_index] = (
967
+ k[:1]
968
+ )
969
+ kv_cache_keyboard["v"][:, local_start_index:local_end_index] = (
970
+ v[:1]
971
+ )
972
+
973
+ if FLASH_ATTN_3_AVAILABLE:
974
+ attn, attn_prob = flash_attn.flash_attn_func(
975
+ q,
976
+ kv_cache_keyboard["k"][
977
+ :,
978
+ max(
979
+ 0, local_end_index - max_attention_size
980
+ ) : local_end_index,
981
+ ].repeat(S, 1, 1, 1),
982
+ kv_cache_keyboard["v"][
983
+ :,
984
+ max(
985
+ 0, local_end_index - max_attention_size
986
+ ) : local_end_index,
987
+ ].repeat(S, 1, 1, 1),
988
+ )
989
+ else:
990
+ attn = flash_attn_func(
991
+ q,
992
+ kv_cache_keyboard["k"][
993
+ :,
994
+ max(
995
+ 0, local_end_index - max_attention_size
996
+ ) : local_end_index,
997
+ ].repeat(S, 1, 1, 1),
998
+ kv_cache_keyboard["v"][
999
+ :,
1000
+ max(
1001
+ 0, local_end_index - max_attention_size
1002
+ ) : local_end_index,
1003
+ ].repeat(S, 1, 1, 1),
1004
+ )
1005
+
1006
+ kv_cache_keyboard["global_end_index"].fill_(current_end)
1007
+ kv_cache_keyboard["local_end_index"].fill_(local_end_index)
1008
+ else:
1009
+ attn = flash_attn_func(
1010
+ q, # 1, f*880, 16, 64
1011
+ k, # 1, f, 16, 64
1012
+ v, # 1, f, 16, 64
1013
+ causal=False,
1014
+ )
1015
+ attn = rearrange(attn, "(B S) T H D -> B (T S) (H D)", S=S)
1016
+ else:
1017
+ if is_causal:
1018
+ if kv_cache_keyboard is None:
1019
+ padded_length = math.ceil(q.shape[1] / 32) * 32 - q.shape[1]
1020
+ padded_q = torch.cat(
1021
+ [
1022
+ q,
1023
+ torch.zeros(
1024
+ [q.shape[0], padded_length, q.shape[2], q.shape[3]],
1025
+ device=q.device,
1026
+ dtype=v.dtype,
1027
+ ),
1028
+ ],
1029
+ dim=1,
1030
+ )
1031
+ padded_k = torch.cat(
1032
+ [
1033
+ k,
1034
+ torch.zeros(
1035
+ [k.shape[0], padded_length, k.shape[2], k.shape[3]],
1036
+ device=k.device,
1037
+ dtype=v.dtype,
1038
+ ),
1039
+ ],
1040
+ dim=1,
1041
+ )
1042
+ padded_v = torch.cat(
1043
+ [
1044
+ v,
1045
+ torch.zeros(
1046
+ [v.shape[0], padded_length, v.shape[2], v.shape[3]],
1047
+ device=v.device,
1048
+ dtype=v.dtype,
1049
+ ),
1050
+ ],
1051
+ dim=1,
1052
+ )
1053
+ attn = flex_attention(
1054
+ query=padded_q.transpose(2, 1), # after: B, HW, F, C
1055
+ key=padded_k.transpose(2, 1),
1056
+ value=padded_v.transpose(2, 1),
1057
+ block_mask=block_mask_keyboard,
1058
+ )[:, :, :-padded_length].transpose(2, 1)
1059
+ else:
1060
+ current_start = start_frame
1061
+ current_end = current_start + k.shape[1]
1062
+ assert k.shape[1] == num_frame_per_block
1063
+ sink_size = 0
1064
+ local_attn_size = self.local_attn_size
1065
+ max_attention_size = self.local_attn_size
1066
+ sink_tokens = sink_size * 1
1067
+ kv_cache_size = kv_cache_keyboard["k"].shape[1]
1068
+ num_new_tokens = k.shape[1]
1069
+
1070
+ if (
1071
+ current_end > kv_cache_keyboard["global_end_index"].item()
1072
+ ) and (
1073
+ num_new_tokens + kv_cache_keyboard["local_end_index"].item()
1074
+ > kv_cache_size
1075
+ ):
1076
+ num_evicted_tokens = (
1077
+ num_new_tokens
1078
+ + kv_cache_keyboard["local_end_index"].item()
1079
+ - kv_cache_size
1080
+ )
1081
+ num_rolled_tokens = (
1082
+ kv_cache_keyboard["local_end_index"].item()
1083
+ - num_evicted_tokens
1084
+ - sink_tokens
1085
+ )
1086
+ kv_cache_keyboard["k"][
1087
+ :, sink_tokens : sink_tokens + num_rolled_tokens
1088
+ ] = kv_cache_keyboard["k"][
1089
+ :,
1090
+ sink_tokens + num_evicted_tokens : sink_tokens
1091
+ + num_evicted_tokens
1092
+ + num_rolled_tokens,
1093
+ ].clone()
1094
+ kv_cache_keyboard["v"][
1095
+ :, sink_tokens : sink_tokens + num_rolled_tokens
1096
+ ] = kv_cache_keyboard["v"][
1097
+ :,
1098
+ sink_tokens + num_evicted_tokens : sink_tokens
1099
+ + num_evicted_tokens
1100
+ + num_rolled_tokens,
1101
+ ].clone()
1102
+ # Insert the new keys/values at the end
1103
+ local_end_index = (
1104
+ kv_cache_keyboard["local_end_index"].item()
1105
+ + current_end
1106
+ - kv_cache_keyboard["global_end_index"].item()
1107
+ - num_evicted_tokens
1108
+ )
1109
+ local_start_index = local_end_index - num_new_tokens
1110
+
1111
+ else:
1112
+ local_end_index = (
1113
+ kv_cache_keyboard["local_end_index"].item()
1114
+ + current_end
1115
+ - kv_cache_keyboard["global_end_index"].item()
1116
+ )
1117
+ local_start_index = local_end_index - num_new_tokens
1118
+ kv_cache_keyboard["k"][:, local_start_index:local_end_index] = k
1119
+ kv_cache_keyboard["v"][:, local_start_index:local_end_index] = v
1120
+ attn = flash_attn_func(
1121
+ q,
1122
+ kv_cache_keyboard["k"][
1123
+ :,
1124
+ max(
1125
+ 0, local_end_index - max_attention_size
1126
+ ) : local_end_index,
1127
+ ],
1128
+ kv_cache_keyboard["v"][
1129
+ :,
1130
+ max(
1131
+ 0, local_end_index - max_attention_size
1132
+ ) : local_end_index,
1133
+ ],
1134
+ # causal=is_causal
1135
+ )
1136
+ kv_cache_keyboard["global_end_index"].fill_(current_end)
1137
+ kv_cache_keyboard["local_end_index"].fill_(local_end_index)
1138
+ else:
1139
+ attn = flash_attn_func(
1140
+ q, # 1, f*880, 16, 64
1141
+ k, # 1, f, 16, 64
1142
+ v, # 1, f, 16, 64
1143
+ # causal=is_causal,
1144
+ )
1145
+ attn = rearrange(attn, "B L H D -> B L (H D)")
1146
+ attn = self.proj_keyboard(attn)
1147
+ hidden_states = hidden_states + attn
1148
+ return hidden_states
transformer/attention.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba MatrixGameWan Team Authors. All rights reserved.
2
+ import torch
3
+
4
+ try:
5
+ import flash_attn
6
+
7
+ def is_hopper_gpu():
8
+ if not torch.cuda.is_available():
9
+ return False
10
+ device_name = torch.cuda.get_device_name(0).lower()
11
+ return (
12
+ "h100" in device_name
13
+ or "hopper" in device_name
14
+ or "l20y" in device_name
15
+ or "h800" in device_name
16
+ )
17
+
18
+ FLASH_ATTN_3_AVAILABLE = is_hopper_gpu()
19
+ except ModuleNotFoundError:
20
+ FLASH_ATTN_3_AVAILABLE = False
21
+
22
+ try:
23
+ import flash_attn
24
+
25
+ FLASH_ATTN_2_AVAILABLE = True
26
+ except ModuleNotFoundError:
27
+ FLASH_ATTN_2_AVAILABLE = False
28
+
29
+
30
+ import warnings
31
+
32
+ __all__ = [
33
+ "flash_attention",
34
+ "attention",
35
+ ]
36
+
37
+
38
+ def flash_attention(
39
+ q,
40
+ k,
41
+ v,
42
+ q_lens=None,
43
+ k_lens=None,
44
+ dropout_p=0.0,
45
+ softmax_scale=None,
46
+ q_scale=None,
47
+ causal=False,
48
+ window_size=(-1, -1),
49
+ deterministic=False,
50
+ dtype=torch.bfloat16,
51
+ version=None,
52
+ ):
53
+ """
54
+ q: [B, Lq, Nq, C1].
55
+ k: [B, Lk, Nk, C1].
56
+ v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
57
+ q_lens: [B].
58
+ k_lens: [B].
59
+ dropout_p: float. Dropout probability.
60
+ softmax_scale: float. The scaling of QK^T before applying softmax.
61
+ causal: bool. Whether to apply causal attention mask.
62
+ window_size: (left right). If not (-1, -1), apply sliding window local attention.
63
+ deterministic: bool. If True, slightly slower and uses more memory.
64
+ dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
65
+ """
66
+ half_dtypes = (torch.float16, torch.bfloat16)
67
+ assert dtype in half_dtypes
68
+ assert q.device.type == "cuda" and q.size(-1) <= 256
69
+
70
+ # params
71
+ b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
72
+
73
+ def half(x):
74
+ return x if x.dtype in half_dtypes else x.to(dtype)
75
+
76
+ # preprocess query
77
+ if q_lens is None:
78
+ q = half(q.flatten(0, 1))
79
+ q_lens = torch.tensor([lq] * b, dtype=torch.int32).to(
80
+ device=q.device, non_blocking=True
81
+ )
82
+ else:
83
+ q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
84
+
85
+ # preprocess key, value
86
+ if k_lens is None:
87
+ k = half(k.flatten(0, 1))
88
+ v = half(v.flatten(0, 1))
89
+ k_lens = torch.tensor([lk] * b, dtype=torch.int32).to(
90
+ device=k.device, non_blocking=True
91
+ )
92
+ else:
93
+ k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
94
+ v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
95
+
96
+ q = q.to(v.dtype)
97
+ k = k.to(v.dtype)
98
+
99
+ if q_scale is not None:
100
+ q = q * q_scale
101
+
102
+ if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
103
+ warnings.warn(
104
+ "Flash attention 3 is not available, use flash attention 2 instead."
105
+ )
106
+
107
+ # apply attention
108
+ if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
109
+ # Note: dropout_p, window_size are not supported in FA3 now.
110
+ x = flash_attn.flash_attn_varlen_func(
111
+ q=q,
112
+ k=k,
113
+ v=v,
114
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
115
+ .cumsum(0, dtype=torch.int32)
116
+ .to(q.device, non_blocking=True),
117
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
118
+ .cumsum(0, dtype=torch.int32)
119
+ .to(q.device, non_blocking=True),
120
+ max_seqlen_q=lq,
121
+ max_seqlen_k=lk,
122
+ softmax_scale=softmax_scale,
123
+ causal=causal,
124
+ deterministic=deterministic,
125
+ )[0].unflatten(0, (b, lq))
126
+ else:
127
+ assert FLASH_ATTN_2_AVAILABLE
128
+ x = flash_attn.flash_attn_varlen_func(
129
+ q=q,
130
+ k=k,
131
+ v=v,
132
+ cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens])
133
+ .cumsum(0, dtype=torch.int32)
134
+ .to(q.device, non_blocking=True),
135
+ cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens])
136
+ .cumsum(0, dtype=torch.int32)
137
+ .to(q.device, non_blocking=True),
138
+ max_seqlen_q=lq,
139
+ max_seqlen_k=lk,
140
+ dropout_p=dropout_p,
141
+ softmax_scale=softmax_scale,
142
+ causal=causal,
143
+ window_size=window_size,
144
+ deterministic=deterministic,
145
+ ).unflatten(0, (b, lq))
146
+
147
+ # output
148
+ return x.type(out_dtype)
149
+
150
+
151
+ def attention(
152
+ q,
153
+ k,
154
+ v,
155
+ q_lens=None,
156
+ k_lens=None,
157
+ dropout_p=0.0,
158
+ softmax_scale=None,
159
+ q_scale=None,
160
+ causal=False,
161
+ window_size=(-1, -1),
162
+ deterministic=False,
163
+ dtype=torch.bfloat16,
164
+ fa_version=None,
165
+ ):
166
+ if FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE:
167
+ return flash_attention(
168
+ q=q,
169
+ k=k,
170
+ v=v,
171
+ q_lens=q_lens,
172
+ k_lens=k_lens,
173
+ dropout_p=dropout_p,
174
+ softmax_scale=softmax_scale,
175
+ q_scale=q_scale,
176
+ causal=causal,
177
+ window_size=window_size,
178
+ deterministic=deterministic,
179
+ dtype=dtype,
180
+ version=fa_version,
181
+ )
182
+ else:
183
+ if q_lens is not None or k_lens is not None:
184
+ warnings.warn(
185
+ "Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance."
186
+ )
187
+ attn_mask = None
188
+
189
+ q = q.transpose(1, 2).to(dtype)
190
+ k = k.transpose(1, 2).to(dtype)
191
+ v = v.transpose(1, 2).to(dtype)
192
+
193
+ out = torch.nn.functional.scaled_dot_product_attention(
194
+ q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p
195
+ )
196
+
197
+ out = out.transpose(1, 2).contiguous()
198
+ return out
transformer/causal_model.py ADDED
@@ -0,0 +1,949 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .attention import attention
2
+ from .model import (
3
+ MatrixGameWanRMSNorm,
4
+ rope_apply,
5
+ MatrixGameWanLayerNorm,
6
+ MatrixGameWan_CROSSATTENTION_CLASSES,
7
+ rope_params,
8
+ MLPProj,
9
+ sinusoidal_embedding_1d,
10
+ )
11
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
12
+ from torch.nn.attention.flex_attention import create_block_mask, flex_attention
13
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
14
+ from torch.nn.attention.flex_attention import BlockMask
15
+ from diffusers.models.modeling_utils import ModelMixin
16
+ import torch.nn as nn
17
+ import torch
18
+ import math
19
+ import torch.distributed as dist
20
+ from .action_module import ActionModule
21
+
22
+
23
+ def causal_rope_apply(x, grid_sizes, freqs, start_frame=0):
24
+ n, c = x.size(2), x.size(3) // 2
25
+
26
+ # split freqs
27
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
28
+
29
+ # loop over samples
30
+ output = []
31
+ f, h, w = grid_sizes.tolist()
32
+
33
+ for i in range(len(x)):
34
+ seq_len = f * h * w
35
+
36
+ # precompute multipliers
37
+ x_i = torch.view_as_complex(
38
+ x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
39
+ )
40
+ freqs_i = torch.cat(
41
+ [
42
+ freqs[0][start_frame : start_frame + f]
43
+ .view(f, 1, 1, -1)
44
+ .expand(f, h, w, -1),
45
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
46
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
47
+ ],
48
+ dim=-1,
49
+ ).reshape(seq_len, 1, -1)
50
+
51
+ # apply rotary embedding
52
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
53
+ x_i = torch.cat([x_i, x[i, seq_len:]])
54
+
55
+ # append to collection
56
+ output.append(x_i)
57
+ return torch.stack(output).type_as(x)
58
+
59
+
60
+ class MatrixGameWanCausalSelfAttention(nn.Module):
61
+ def __init__(
62
+ self, dim, num_heads, local_attn_size=-1, sink_size=0, qk_norm=True, eps=1e-6
63
+ ):
64
+ assert dim % num_heads == 0
65
+ super().__init__()
66
+ self.dim = dim
67
+ self.num_heads = num_heads
68
+ self.head_dim = dim // num_heads
69
+ self.local_attn_size = local_attn_size
70
+ self.sink_size = sink_size
71
+ self.qk_norm = qk_norm
72
+ self.eps = eps
73
+ self.max_attention_size = (
74
+ 15 * 1 * 880 if local_attn_size == -1 else local_attn_size * 880
75
+ )
76
+ # layers
77
+ self.q = nn.Linear(dim, dim)
78
+ self.k = nn.Linear(dim, dim)
79
+ self.v = nn.Linear(dim, dim)
80
+ self.o = nn.Linear(dim, dim)
81
+ self.norm_q = MatrixGameWanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
82
+ self.norm_k = MatrixGameWanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
83
+
84
+ def forward(
85
+ self,
86
+ x,
87
+ seq_lens,
88
+ grid_sizes,
89
+ freqs,
90
+ block_mask,
91
+ kv_cache=None,
92
+ current_start=0,
93
+ cache_start=None,
94
+ ):
95
+ r"""
96
+ Args:
97
+ x(Tensor): Shape [B, L, C] # num_heads, C / num_heads]
98
+ seq_lens(Tensor): Shape [B]
99
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
100
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
101
+ block_mask (BlockMask)
102
+ """
103
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
104
+ if cache_start is None:
105
+ cache_start = current_start
106
+
107
+ # query, key, value function
108
+ def qkv_fn(x):
109
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
110
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
111
+ v = self.v(x).view(b, s, n, d)
112
+ return q, k, v
113
+
114
+ q, k, v = qkv_fn(x) # B, F, HW, C
115
+
116
+ if kv_cache is None:
117
+ roped_query = rope_apply(q, grid_sizes, freqs).type_as(v)
118
+ roped_key = rope_apply(k, grid_sizes, freqs).type_as(v)
119
+
120
+ padded_length = math.ceil(q.shape[1] / 128) * 128 - q.shape[1]
121
+ padded_roped_query = torch.cat(
122
+ [
123
+ roped_query,
124
+ torch.zeros(
125
+ [q.shape[0], padded_length, q.shape[2], q.shape[3]],
126
+ device=q.device,
127
+ dtype=v.dtype,
128
+ ),
129
+ ],
130
+ dim=1,
131
+ )
132
+
133
+ padded_roped_key = torch.cat(
134
+ [
135
+ roped_key,
136
+ torch.zeros(
137
+ [k.shape[0], padded_length, k.shape[2], k.shape[3]],
138
+ device=k.device,
139
+ dtype=v.dtype,
140
+ ),
141
+ ],
142
+ dim=1,
143
+ )
144
+
145
+ padded_v = torch.cat(
146
+ [
147
+ v,
148
+ torch.zeros(
149
+ [v.shape[0], padded_length, v.shape[2], v.shape[3]],
150
+ device=v.device,
151
+ dtype=v.dtype,
152
+ ),
153
+ ],
154
+ dim=1,
155
+ )
156
+
157
+ x = flex_attention(
158
+ query=padded_roped_query.transpose(2, 1), # after: B, HW, F, C
159
+ key=padded_roped_key.transpose(2, 1),
160
+ value=padded_v.transpose(2, 1),
161
+ block_mask=block_mask,
162
+ )[:, :, :-padded_length].transpose(2, 1)
163
+ else:
164
+ assert grid_sizes.ndim == 1
165
+ frame_seqlen = math.prod(grid_sizes[1:]).item()
166
+ current_start_frame = current_start // frame_seqlen
167
+ roped_query = causal_rope_apply(
168
+ q, grid_sizes, freqs, start_frame=current_start_frame
169
+ ).type_as(v)
170
+ roped_key = causal_rope_apply(
171
+ k, grid_sizes, freqs, start_frame=current_start_frame
172
+ ).type_as(v)
173
+
174
+ current_end = current_start + roped_query.shape[1]
175
+ sink_tokens = self.sink_size * frame_seqlen
176
+
177
+ kv_cache_size = kv_cache["k"].shape[1]
178
+ num_new_tokens = roped_query.shape[1]
179
+
180
+ if (current_end > kv_cache["global_end_index"].item()) and (
181
+ num_new_tokens + kv_cache["local_end_index"].item() > kv_cache_size
182
+ ):
183
+ num_evicted_tokens = (
184
+ num_new_tokens + kv_cache["local_end_index"].item() - kv_cache_size
185
+ )
186
+ num_rolled_tokens = (
187
+ kv_cache["local_end_index"].item()
188
+ - num_evicted_tokens
189
+ - sink_tokens
190
+ )
191
+ kv_cache["k"][:, sink_tokens : sink_tokens + num_rolled_tokens] = (
192
+ kv_cache["k"][
193
+ :,
194
+ sink_tokens + num_evicted_tokens : sink_tokens
195
+ + num_evicted_tokens
196
+ + num_rolled_tokens,
197
+ ].clone()
198
+ )
199
+ kv_cache["v"][:, sink_tokens : sink_tokens + num_rolled_tokens] = (
200
+ kv_cache["v"][
201
+ :,
202
+ sink_tokens + num_evicted_tokens : sink_tokens
203
+ + num_evicted_tokens
204
+ + num_rolled_tokens,
205
+ ].clone()
206
+ )
207
+ # Insert the new keys/values at the end
208
+ local_end_index = (
209
+ kv_cache["local_end_index"].item()
210
+ + current_end
211
+ - kv_cache["global_end_index"].item()
212
+ - num_evicted_tokens
213
+ )
214
+ local_start_index = local_end_index - num_new_tokens
215
+ kv_cache["k"][:, local_start_index:local_end_index] = roped_key
216
+ kv_cache["v"][:, local_start_index:local_end_index] = v
217
+ else:
218
+ # Assign new keys/values directly up to current_end
219
+ local_end_index = (
220
+ kv_cache["local_end_index"].item()
221
+ + current_end
222
+ - kv_cache["global_end_index"].item()
223
+ )
224
+ local_start_index = local_end_index - num_new_tokens
225
+
226
+ kv_cache["k"][:, local_start_index:local_end_index] = roped_key
227
+ kv_cache["v"][:, local_start_index:local_end_index] = v
228
+ x = attention(
229
+ roped_query,
230
+ kv_cache["k"][
231
+ :,
232
+ max(0, local_end_index - self.max_attention_size) : local_end_index,
233
+ ],
234
+ kv_cache["v"][
235
+ :,
236
+ max(0, local_end_index - self.max_attention_size) : local_end_index,
237
+ ],
238
+ )
239
+ kv_cache["global_end_index"].fill_(current_end)
240
+ kv_cache["local_end_index"].fill_(local_end_index)
241
+
242
+ # output
243
+ x = x.flatten(2)
244
+ x = self.o(x)
245
+ return x
246
+
247
+
248
+ class MatrixGameWanCausalAttentionBlock(nn.Module):
249
+ def __init__(
250
+ self,
251
+ cross_attn_type,
252
+ dim,
253
+ ffn_dim,
254
+ num_heads,
255
+ local_attn_size=-1,
256
+ sink_size=0,
257
+ qk_norm=True,
258
+ cross_attn_norm=False,
259
+ action_config={},
260
+ block_idx=0,
261
+ eps=1e-6,
262
+ ):
263
+ super().__init__()
264
+ self.dim = dim
265
+ self.ffn_dim = ffn_dim
266
+ self.num_heads = num_heads
267
+ self.local_attn_size = local_attn_size
268
+ self.qk_norm = qk_norm
269
+ self.cross_attn_norm = cross_attn_norm
270
+ self.eps = eps
271
+ if len(action_config) != 0 and block_idx in action_config["blocks"]:
272
+ self.action_model = ActionModule(
273
+ **action_config, local_attn_size=self.local_attn_size
274
+ )
275
+ else:
276
+ self.action_model = None
277
+ # layers
278
+ self.norm1 = MatrixGameWanLayerNorm(dim, eps)
279
+ self.self_attn = MatrixGameWanCausalSelfAttention(
280
+ dim, num_heads, local_attn_size, sink_size, qk_norm, eps
281
+ )
282
+ self.norm3 = (
283
+ MatrixGameWanLayerNorm(dim, eps, elementwise_affine=True)
284
+ if cross_attn_norm
285
+ else nn.Identity()
286
+ )
287
+ self.cross_attn = MatrixGameWan_CROSSATTENTION_CLASSES[cross_attn_type](
288
+ dim, num_heads, (-1, -1), qk_norm, eps
289
+ )
290
+ self.norm2 = MatrixGameWanLayerNorm(dim, eps)
291
+ self.ffn = nn.Sequential(
292
+ nn.Linear(dim, ffn_dim),
293
+ nn.GELU(approximate="tanh"),
294
+ nn.Linear(ffn_dim, dim),
295
+ )
296
+
297
+ # modulation
298
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
299
+
300
+ def forward(
301
+ self,
302
+ x,
303
+ e,
304
+ seq_lens,
305
+ grid_sizes,
306
+ freqs,
307
+ context,
308
+ block_mask,
309
+ block_mask_mouse,
310
+ block_mask_keyboard,
311
+ num_frame_per_block=3,
312
+ use_rope_keyboard=False,
313
+ mouse_cond=None,
314
+ keyboard_cond=None,
315
+ kv_cache=None,
316
+ kv_cache_mouse=None,
317
+ kv_cache_keyboard=None,
318
+ crossattn_cache=None,
319
+ current_start=0,
320
+ cache_start=None,
321
+ context_lens=None,
322
+ ):
323
+ r"""
324
+ Args:
325
+ x(Tensor): Shape [B, L, C]
326
+ e(Tensor): Shape [B, F, 6, C]
327
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
328
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
329
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
330
+ """
331
+ assert e.ndim == 4
332
+ num_frames, frame_seqlen = e.shape[1], x.shape[1] // e.shape[1]
333
+
334
+ e = (self.modulation.unsqueeze(1) + e).chunk(6, dim=2)
335
+
336
+ y = self.self_attn(
337
+ (
338
+ self.norm1(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen))
339
+ * (1 + e[1])
340
+ + e[0]
341
+ ).flatten(1, 2),
342
+ seq_lens,
343
+ grid_sizes,
344
+ freqs,
345
+ block_mask,
346
+ kv_cache,
347
+ current_start,
348
+ cache_start,
349
+ )
350
+
351
+ x = x + (y.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * e[2]).flatten(
352
+ 1, 2
353
+ )
354
+
355
+ # cross-attention & ffn function
356
+ def cross_attn_ffn(
357
+ x,
358
+ context,
359
+ e,
360
+ mouse_cond,
361
+ keyboard_cond,
362
+ block_mask_mouse,
363
+ block_mask_keyboard,
364
+ kv_cache_mouse=None,
365
+ kv_cache_keyboard=None,
366
+ crossattn_cache=None,
367
+ start_frame=0,
368
+ use_rope_keyboard=False,
369
+ num_frame_per_block=3,
370
+ ):
371
+ x = x + self.cross_attn(
372
+ self.norm3(x.to(context.dtype)),
373
+ context,
374
+ crossattn_cache=crossattn_cache,
375
+ )
376
+ if self.action_model is not None:
377
+ assert mouse_cond is not None or keyboard_cond is not None
378
+ x = self.action_model(
379
+ x.to(context.dtype),
380
+ grid_sizes[0],
381
+ grid_sizes[1],
382
+ grid_sizes[2],
383
+ mouse_cond,
384
+ keyboard_cond,
385
+ block_mask_mouse,
386
+ block_mask_keyboard,
387
+ is_causal=True,
388
+ kv_cache_mouse=kv_cache_mouse,
389
+ kv_cache_keyboard=kv_cache_keyboard,
390
+ start_frame=start_frame,
391
+ use_rope_keyboard=use_rope_keyboard,
392
+ num_frame_per_block=num_frame_per_block,
393
+ )
394
+
395
+ y = self.ffn(
396
+ (
397
+ self.norm2(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen))
398
+ * (1 + e[4])
399
+ + e[3]
400
+ ).flatten(1, 2)
401
+ )
402
+
403
+ x = x + (
404
+ y.unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * e[5]
405
+ ).flatten(1, 2)
406
+ return x
407
+
408
+ assert grid_sizes.ndim == 1
409
+ x = cross_attn_ffn(
410
+ x,
411
+ context,
412
+ e,
413
+ mouse_cond,
414
+ keyboard_cond,
415
+ block_mask_mouse,
416
+ block_mask_keyboard,
417
+ kv_cache_mouse,
418
+ kv_cache_keyboard,
419
+ crossattn_cache,
420
+ start_frame=current_start // math.prod(grid_sizes[1:]).item(),
421
+ use_rope_keyboard=use_rope_keyboard,
422
+ num_frame_per_block=num_frame_per_block,
423
+ )
424
+ return x
425
+
426
+
427
+ class CausalHead(nn.Module):
428
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
429
+ super().__init__()
430
+ self.dim = dim
431
+ self.out_dim = out_dim
432
+ self.patch_size = patch_size
433
+ self.eps = eps
434
+
435
+ # layers
436
+ out_dim = math.prod(patch_size) * out_dim
437
+ self.norm = MatrixGameWanLayerNorm(dim, eps)
438
+ self.head = nn.Linear(dim, out_dim)
439
+
440
+ # modulation
441
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
442
+
443
+ def forward(self, x, e):
444
+ r"""
445
+ Args:
446
+ x(Tensor): Shape [B, L1, C]
447
+ e(Tensor): Shape [B, F, 1, C]
448
+ """
449
+
450
+ num_frames, frame_seqlen = e.shape[1], x.shape[1] // e.shape[1]
451
+ e = (self.modulation.unsqueeze(1) + e).chunk(2, dim=2)
452
+ x = self.head(
453
+ self.norm(x).unflatten(dim=1, sizes=(num_frames, frame_seqlen)) * (1 + e[1])
454
+ + e[0]
455
+ )
456
+ return x
457
+
458
+
459
+ class MatrixGameWanCausalModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
460
+ r"""
461
+ MatrixGameWan diffusion backbone supporting both text-to-video and image-to-video.
462
+ """
463
+
464
+ ignore_for_config = ["patch_size", "cross_attn_norm", "qk_norm", "text_dim"]
465
+ _no_split_modules = ["MatrixGameWanAttentionBlock"]
466
+ _supports_gradient_checkpointing = True
467
+
468
+ @register_to_config
469
+ def __init__(
470
+ self,
471
+ model_type="t2v",
472
+ patch_size=(1, 2, 2),
473
+ text_len=512,
474
+ in_dim=36,
475
+ dim=1536,
476
+ ffn_dim=8960,
477
+ freq_dim=256,
478
+ text_dim=4096,
479
+ out_dim=16,
480
+ num_heads=12,
481
+ num_layers=30,
482
+ local_attn_size=-1,
483
+ sink_size=0,
484
+ qk_norm=True,
485
+ cross_attn_norm=True,
486
+ action_config={},
487
+ eps=1e-6,
488
+ ):
489
+ r"""
490
+ Initialize the diffusion model backbone.
491
+
492
+ Args:
493
+ model_type (`str`, *optional*, defaults to 't2v'):
494
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
495
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
496
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
497
+ text_len (`int`, *optional*, defaults to 512):
498
+ Fixed length for text embeddings
499
+ in_dim (`int`, *optional*, defaults to 16):
500
+ Input video channels (C_in)
501
+ dim (`int`, *optional*, defaults to 2048):
502
+ Hidden dimension of the transformer
503
+ ffn_dim (`int`, *optional*, defaults to 8192):
504
+ Intermediate dimension in feed-forward network
505
+ freq_dim (`int`, *optional*, defaults to 256):
506
+ Dimension for sinusoidal time embeddings
507
+ text_dim (`int`, *optional*, defaults to 4096):
508
+ Input dimension for text embeddings
509
+ out_dim (`int`, *optional*, defaults to 16):
510
+ Output video channels (C_out)
511
+ num_heads (`int`, *optional*, defaults to 16):
512
+ Number of attention heads
513
+ num_layers (`int`, *optional*, defaults to 32):
514
+ Number of transformer blocks
515
+ local_attn_size (`int`, *optional*, defaults to -1):
516
+ Window size for temporal local attention (-1 indicates global attention)
517
+ sink_size (`int`, *optional*, defaults to 0):
518
+ Size of the attention sink, we keep the first `sink_size` frames unchanged when rolling the KV cache
519
+ qk_norm (`bool`, *optional*, defaults to True):
520
+ Enable query/key normalization
521
+ cross_attn_norm (`bool`, *optional*, defaults to False):
522
+ Enable cross-attention normalization
523
+ eps (`float`, *optional*, defaults to 1e-6):
524
+ Epsilon value for normalization layers
525
+ """
526
+
527
+ super().__init__()
528
+
529
+ assert model_type in ["i2v"]
530
+ self.model_type = model_type
531
+ self.use_action_module = len(action_config) > 0
532
+ self.patch_size = patch_size
533
+ self.text_len = text_len
534
+ self.in_dim = in_dim
535
+ self.dim = dim
536
+ self.ffn_dim = ffn_dim
537
+ self.freq_dim = freq_dim
538
+ self.text_dim = text_dim
539
+ self.out_dim = out_dim
540
+ self.num_heads = num_heads
541
+ self.num_layers = num_layers
542
+ self.local_attn_size = local_attn_size
543
+ self.qk_norm = qk_norm
544
+ self.cross_attn_norm = cross_attn_norm
545
+ self.eps = eps
546
+
547
+ # embeddings
548
+ self.patch_embedding = nn.Conv3d(
549
+ in_dim, dim, kernel_size=patch_size, stride=patch_size
550
+ )
551
+
552
+ self.time_embedding = nn.Sequential(
553
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
554
+ )
555
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
556
+
557
+ # blocks
558
+ cross_attn_type = "i2v_cross_attn"
559
+ self.blocks = nn.ModuleList(
560
+ [
561
+ MatrixGameWanCausalAttentionBlock(
562
+ cross_attn_type,
563
+ dim,
564
+ ffn_dim,
565
+ num_heads,
566
+ local_attn_size,
567
+ sink_size,
568
+ qk_norm,
569
+ cross_attn_norm,
570
+ action_config=action_config,
571
+ eps=eps,
572
+ block_idx=idx,
573
+ )
574
+ for idx in range(num_layers)
575
+ ]
576
+ )
577
+
578
+ # head
579
+ self.head = CausalHead(dim, out_dim, patch_size, eps)
580
+
581
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
582
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
583
+ d = dim // num_heads
584
+ self.freqs = torch.cat(
585
+ [
586
+ rope_params(1024, d - 4 * (d // 6)),
587
+ rope_params(1024, 2 * (d // 6)),
588
+ rope_params(1024, 2 * (d // 6)),
589
+ ],
590
+ dim=1,
591
+ )
592
+
593
+ if model_type == "i2v":
594
+ self.img_emb = MLPProj(1280, dim)
595
+
596
+ self.gradient_checkpointing = False
597
+
598
+ self.block_mask = None
599
+ self.block_mask_keyboard = None
600
+ self.block_mask_mouse = None
601
+ self.use_rope_keyboard = True
602
+
603
+ def _set_gradient_checkpointing(self, module, value=False):
604
+ self.gradient_checkpointing = value
605
+
606
+ @staticmethod
607
+ def _prepare_blockwise_causal_attn_mask(
608
+ device: torch.device | str,
609
+ num_frames: int = 9,
610
+ frame_seqlen: int = 880,
611
+ num_frame_per_block=1,
612
+ local_attn_size=-1,
613
+ ) -> BlockMask:
614
+ """
615
+ we will divide the token sequence into the following format
616
+ [1 latent frame] [1 latent frame] ... [1 latent frame]
617
+ We use flexattention to construct the attention mask
618
+ """
619
+ total_length = num_frames * frame_seqlen
620
+
621
+ # we do right padding to get to a multiple of 128
622
+ padded_length = math.ceil(total_length / 128) * 128 - total_length
623
+
624
+ ends = torch.zeros(
625
+ total_length + padded_length, device=device, dtype=torch.long
626
+ )
627
+
628
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
629
+ frame_indices = torch.arange(
630
+ start=0,
631
+ end=total_length,
632
+ step=frame_seqlen * num_frame_per_block,
633
+ device=device,
634
+ )
635
+
636
+ for tmp in frame_indices:
637
+ ends[tmp : tmp + frame_seqlen * num_frame_per_block] = (
638
+ tmp + frame_seqlen * num_frame_per_block
639
+ )
640
+
641
+ def attention_mask(b, h, q_idx, kv_idx):
642
+ if local_attn_size == -1:
643
+ return (kv_idx < ends[q_idx]) | (q_idx == kv_idx)
644
+ else:
645
+ return (
646
+ (kv_idx < ends[q_idx])
647
+ & (kv_idx >= (ends[q_idx] - local_attn_size * frame_seqlen))
648
+ ) | (q_idx == kv_idx)
649
+ # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
650
+
651
+ block_mask = create_block_mask(
652
+ attention_mask,
653
+ B=None,
654
+ H=None,
655
+ Q_LEN=total_length + padded_length,
656
+ KV_LEN=total_length + padded_length,
657
+ _compile=False,
658
+ device=device,
659
+ )
660
+
661
+ import torch.distributed as dist
662
+
663
+ if not dist.is_initialized() or dist.get_rank() == 0:
664
+ print(
665
+ f" cache a block wise causal mask with block size of {num_frame_per_block} frames"
666
+ )
667
+
668
+ return block_mask
669
+
670
+ @staticmethod
671
+ def _prepare_blockwise_causal_attn_mask_keyboard(
672
+ device: torch.device | str,
673
+ num_frames: int = 9,
674
+ frame_seqlen: int = 880,
675
+ num_frame_per_block=1,
676
+ local_attn_size=-1,
677
+ ) -> BlockMask:
678
+ """
679
+ we will divide the token sequence into the following format
680
+ [1 latent frame] [1 latent frame] ... [1 latent frame]
681
+ We use flexattention to construct the attention mask
682
+ """
683
+ total_length2 = num_frames * frame_seqlen
684
+
685
+ # we do right padding to get to a multiple of 128
686
+ padded_length2 = math.ceil(total_length2 / 32) * 32 - total_length2
687
+ padded_length_kv2 = math.ceil(num_frames / 32) * 32 - num_frames
688
+ ends2 = torch.zeros(
689
+ total_length2 + padded_length2, device=device, dtype=torch.long
690
+ )
691
+
692
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
693
+ frame_indices2 = torch.arange(
694
+ start=0,
695
+ end=total_length2,
696
+ step=frame_seqlen * num_frame_per_block,
697
+ device=device,
698
+ )
699
+ cnt = num_frame_per_block
700
+ for tmp in frame_indices2:
701
+ ends2[tmp : tmp + frame_seqlen * num_frame_per_block] = cnt
702
+ cnt += num_frame_per_block
703
+
704
+ def attention_mask2(b, h, q_idx, kv_idx):
705
+ if local_attn_size == -1:
706
+ return (kv_idx < ends2[q_idx]) | (q_idx == kv_idx)
707
+ else:
708
+ return (
709
+ (kv_idx < ends2[q_idx])
710
+ & (kv_idx >= (ends2[q_idx] - local_attn_size))
711
+ ) | (q_idx == kv_idx)
712
+ # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
713
+
714
+ block_mask2 = create_block_mask(
715
+ attention_mask2,
716
+ B=None,
717
+ H=None,
718
+ Q_LEN=total_length2 + padded_length2,
719
+ KV_LEN=num_frames + padded_length_kv2,
720
+ _compile=False,
721
+ device=device,
722
+ )
723
+
724
+ import torch.distributed as dist
725
+
726
+ if not dist.is_initialized() or dist.get_rank() == 0:
727
+ print(
728
+ f" cache a block wise causal mask with block size of {num_frame_per_block} frames"
729
+ )
730
+
731
+ return block_mask2
732
+
733
+ @staticmethod
734
+ def _prepare_blockwise_causal_attn_mask_action(
735
+ device: torch.device | str,
736
+ num_frames: int = 9,
737
+ frame_seqlen: int = 1,
738
+ num_frame_per_block=1,
739
+ local_attn_size=-1,
740
+ ) -> BlockMask:
741
+ """
742
+ we will divide the token sequence into the following format
743
+ [1 latent frame] [1 latent frame] ... [1 latent frame]
744
+ We use flexattention to construct the attention mask
745
+ """
746
+ total_length2 = num_frames * frame_seqlen
747
+
748
+ # we do right padding to get to a multiple of 128
749
+ padded_length2 = math.ceil(total_length2 / 32) * 32 - total_length2
750
+ padded_length_kv2 = math.ceil(num_frames / 32) * 32 - num_frames
751
+ ends2 = torch.zeros(
752
+ total_length2 + padded_length2, device=device, dtype=torch.long
753
+ )
754
+
755
+ # Block-wise causal mask will attend to all elements that are before the end of the current chunk
756
+ frame_indices2 = torch.arange(
757
+ start=0,
758
+ end=total_length2,
759
+ step=frame_seqlen * num_frame_per_block,
760
+ device=device,
761
+ )
762
+ cnt = num_frame_per_block
763
+ for tmp in frame_indices2:
764
+ ends2[tmp : tmp + frame_seqlen * num_frame_per_block] = cnt
765
+ cnt += num_frame_per_block
766
+
767
+ def attention_mask2(b, h, q_idx, kv_idx):
768
+ if local_attn_size == -1:
769
+ return (kv_idx < ends2[q_idx]) | (q_idx == kv_idx)
770
+ else:
771
+ return (
772
+ (kv_idx < ends2[q_idx])
773
+ & (kv_idx >= (ends2[q_idx] - local_attn_size))
774
+ ) | (q_idx == kv_idx)
775
+ # return ((kv_idx < total_length) & (q_idx < total_length)) | (q_idx == kv_idx) # bidirectional mask
776
+
777
+ block_mask2 = create_block_mask(
778
+ attention_mask2,
779
+ B=None,
780
+ H=None,
781
+ Q_LEN=total_length2 + padded_length2,
782
+ KV_LEN=num_frames + padded_length_kv2,
783
+ _compile=False,
784
+ device=device,
785
+ )
786
+
787
+ import torch.distributed as dist
788
+
789
+ if not dist.is_initialized() or dist.get_rank() == 0:
790
+ print(
791
+ f" cache a block wise causal mask with block size of {num_frame_per_block} frames"
792
+ )
793
+
794
+ return block_mask2
795
+
796
+ def _forward_inference(
797
+ self,
798
+ x,
799
+ t,
800
+ visual_context,
801
+ cond_concat,
802
+ mouse_cond=None,
803
+ keyboard_cond=None,
804
+ kv_cache: dict = None,
805
+ kv_cache_mouse=None,
806
+ kv_cache_keyboard=None,
807
+ crossattn_cache: dict = None,
808
+ current_start: int = 0,
809
+ cache_start: int = 0,
810
+ num_frames_per_block=3,
811
+ ):
812
+ r"""
813
+ Run the diffusion model with kv caching.
814
+ See Algorithm 2 of CausVid paper https://arxiv.org/abs/2412.07772 for details.
815
+ This function will be run for num_frame times.
816
+ Process the latent frames one by one (1560 tokens each)
817
+
818
+ Args:
819
+ x (List[Tensor]):
820
+ List of input video tensors, each with shape [C_in, F, H, W]
821
+ t (Tensor):
822
+ Diffusion timesteps tensor of shape [B]
823
+ context (List[Tensor]):
824
+ List of text embeddings each with shape [L, C]
825
+ seq_len (`int`):
826
+ Maximum sequence length for positional encoding
827
+ clip_fea (Tensor, *optional*):
828
+ CLIP image features for image-to-video mode
829
+ y (List[Tensor], *optional*):
830
+ Conditional video inputs for image-to-video mode, same shape as x
831
+
832
+ Returns:
833
+ List[Tensor]:
834
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
835
+ """
836
+
837
+ if mouse_cond is not None or keyboard_cond is not None:
838
+ assert self.use_action_module == True
839
+ # params
840
+ device = self.patch_embedding.weight.device
841
+ if self.freqs.device != device:
842
+ self.freqs = self.freqs.to(device)
843
+
844
+ x = torch.cat([x, cond_concat], dim=1) # B C' F H W
845
+
846
+ # embeddings
847
+ x = self.patch_embedding(x)
848
+ grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
849
+
850
+ x = x.flatten(2).transpose(1, 2) # B FHW C'
851
+ seq_lens = torch.tensor([u.size(0) for u in x], dtype=torch.long)
852
+ assert seq_lens[0] <= 15 * 1 * 880
853
+
854
+ e = self.time_embedding(
855
+ sinusoidal_embedding_1d(self.freq_dim, t.flatten()).type_as(x)
856
+ )
857
+ e0 = (
858
+ self.time_projection(e)
859
+ .unflatten(1, (6, self.dim))
860
+ .unflatten(dim=0, sizes=t.shape)
861
+ )
862
+ # context
863
+ context_lens = None
864
+ context = self.img_emb(visual_context)
865
+ # arguments
866
+ kwargs = dict(
867
+ e=e0,
868
+ seq_lens=seq_lens,
869
+ grid_sizes=grid_sizes,
870
+ freqs=self.freqs,
871
+ context=context,
872
+ mouse_cond=mouse_cond,
873
+ context_lens=context_lens,
874
+ keyboard_cond=keyboard_cond,
875
+ block_mask=self.block_mask,
876
+ block_mask_mouse=self.block_mask_mouse,
877
+ block_mask_keyboard=self.block_mask_keyboard,
878
+ use_rope_keyboard=self.use_rope_keyboard,
879
+ num_frame_per_block=num_frames_per_block,
880
+ )
881
+
882
+ def create_custom_forward(module):
883
+ def custom_forward(*inputs, **kwargs):
884
+ return module(*inputs, **kwargs)
885
+
886
+ return custom_forward
887
+
888
+ for block_index, block in enumerate(self.blocks):
889
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
890
+ kwargs.update(
891
+ {
892
+ "kv_cache": kv_cache[block_index],
893
+ "kv_cache_mouse": kv_cache_mouse[block_index],
894
+ "kv_cache_keyboard": kv_cache_keyboard[block_index],
895
+ "current_start": current_start,
896
+ "cache_start": cache_start,
897
+ }
898
+ )
899
+ x = torch.utils.checkpoint.checkpoint(
900
+ create_custom_forward(block),
901
+ x,
902
+ **kwargs,
903
+ use_reentrant=False,
904
+ )
905
+ else:
906
+ kwargs.update(
907
+ {
908
+ "kv_cache": kv_cache[block_index],
909
+ "kv_cache_mouse": kv_cache_mouse[block_index],
910
+ "kv_cache_keyboard": kv_cache_keyboard[block_index],
911
+ "crossattn_cache": crossattn_cache[block_index],
912
+ "current_start": current_start,
913
+ "cache_start": cache_start,
914
+ }
915
+ )
916
+ x = block(x, **kwargs)
917
+
918
+ # head
919
+ x = self.head(x, e.unflatten(dim=0, sizes=t.shape).unsqueeze(2))
920
+ # unpatchify
921
+ x = self.unpatchify(x, grid_sizes)
922
+ return x
923
+
924
+ def forward(self, *args, **kwargs):
925
+ return self._forward_inference(*args, **kwargs)
926
+
927
+ def unpatchify(self, x, grid_sizes):
928
+ r"""
929
+ Reconstruct video tensors from patch embeddings.
930
+
931
+ Args:
932
+ x (List[Tensor]):
933
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
934
+ grid_sizes (Tensor):
935
+ Original spatial-temporal grid dimensions before patching,
936
+ shape [3] (3 dimensions correspond to F_patches, H_patches, W_patches)
937
+
938
+ Returns:
939
+ List[Tensor]:
940
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
941
+ """
942
+
943
+ c = self.out_dim
944
+ bs = x.shape[0]
945
+ x = x.view(bs, *grid_sizes, *self.patch_size, c)
946
+ x = torch.einsum("bfhwpqrc->bcfphqwr", x)
947
+ x = x.reshape(bs, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
948
+ return x
949
+
transformer/config.json ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_class_name": "MatrixGameWanCausalModel",
3
+ "_diffusers_version": "0.35.1",
4
+ "auto_map": {
5
+ "AutoModel": "causal_model.MatrixGameWanCausalModel"
6
+ },
7
+ "action_config": {
8
+ "blocks": [
9
+ 0,
10
+ 1,
11
+ 2,
12
+ 3,
13
+ 4,
14
+ 5,
15
+ 6,
16
+ 7,
17
+ 8,
18
+ 9,
19
+ 10,
20
+ 11,
21
+ 12,
22
+ 13,
23
+ 14
24
+ ],
25
+ "enable_keyboard": true,
26
+ "enable_mouse": true,
27
+ "heads_num": 16,
28
+ "hidden_size": 128,
29
+ "img_hidden_size": 1536,
30
+ "keyboard_dim_in": 4,
31
+ "keyboard_hidden_dim": 1024,
32
+ "mouse_dim_in": 2,
33
+ "mouse_hidden_dim": 1024,
34
+ "mouse_qk_dim_list": [
35
+ 8,
36
+ 28,
37
+ 28
38
+ ],
39
+ "patch_size": [
40
+ 1,
41
+ 2,
42
+ 2
43
+ ],
44
+ "qk_norm": true,
45
+ "qkv_bias": false,
46
+ "rope_dim_list": [
47
+ 8,
48
+ 28,
49
+ 28
50
+ ],
51
+ "rope_theta": 256,
52
+ "vae_time_compression_ratio": 4,
53
+ "windows_size": 3
54
+ },
55
+ "dim": 1536,
56
+ "eps": 1e-06,
57
+ "ffn_dim": 8960,
58
+ "freq_dim": 256,
59
+ "in_dim": 36,
60
+ "inject_sample_info": false,
61
+ "local_attn_size": 6,
62
+ "model_type": "i2v",
63
+ "num_heads": 12,
64
+ "num_layers": 30,
65
+ "out_dim": 16,
66
+ "sink_size": 0,
67
+ "text_len": 512
68
+ }
transformer/diffusion_pytorch_model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d4d7930e670e1475abdeed25fa0f1c34c47c8be51d9e7b8f637e2a4b720548a4
3
+ size 3238601944
transformer/model.py ADDED
@@ -0,0 +1,781 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024-2025 The Alibaba MatrixGameWan Team Authors. All rights reserved.
2
+ import math
3
+ import numpy as np
4
+ import torch
5
+ import torch.amp as amp
6
+ import torch.nn as nn
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.loaders import FromOriginalModelMixin, PeftAdapterMixin
9
+ from diffusers.models.modeling_utils import ModelMixin
10
+ from einops import repeat, rearrange
11
+ from .action_module import ActionModule
12
+ from .attention import flash_attention
13
+
14
+ DISABLE_COMPILE = False # get os env
15
+ __all__ = ["MatrixGameWanModel"]
16
+
17
+
18
+ def sinusoidal_embedding_1d(dim, position):
19
+ # preprocess
20
+ assert dim % 2 == 0
21
+ half = dim // 2
22
+ position = position.type(torch.float64)
23
+
24
+ # calculation
25
+ sinusoid = torch.outer(
26
+ position, torch.pow(10000, -torch.arange(half).to(position).div(half))
27
+ )
28
+ x = torch.cat([torch.cos(sinusoid), torch.sin(sinusoid)], dim=1)
29
+ return x
30
+
31
+
32
+ # @amp.autocast(enabled=False)
33
+ def rope_params(max_seq_len, dim, theta=10000):
34
+ assert dim % 2 == 0
35
+ freqs = torch.outer(
36
+ torch.arange(max_seq_len),
37
+ 1.0 / torch.pow(theta, torch.arange(0, dim, 2).to(torch.float64).div(dim)),
38
+ )
39
+ freqs = torch.polar(torch.ones_like(freqs), freqs)
40
+ return freqs
41
+
42
+
43
+ # @amp.autocast(enabled=False)
44
+ def rope_apply(x, grid_sizes, freqs):
45
+ n, c = x.size(2), x.size(3) // 2
46
+
47
+ # split freqs
48
+ freqs = freqs.split([c - 2 * (c // 3), c // 3, c // 3], dim=1)
49
+
50
+ # loop over samples
51
+ output = []
52
+ # print(grid_sizes.shape, len(grid_sizes.tolist()), grid_sizes.tolist()[0])
53
+ f, h, w = grid_sizes.tolist()
54
+ for i in range(len(x)):
55
+ seq_len = f * h * w
56
+
57
+ # precompute multipliers
58
+ x_i = torch.view_as_complex(
59
+ x[i, :seq_len].to(torch.float64).reshape(seq_len, n, -1, 2)
60
+ )
61
+ freqs_i = torch.cat(
62
+ [
63
+ freqs[0][:f].view(f, 1, 1, -1).expand(f, h, w, -1),
64
+ freqs[1][:h].view(1, h, 1, -1).expand(f, h, w, -1),
65
+ freqs[2][:w].view(1, 1, w, -1).expand(f, h, w, -1),
66
+ ],
67
+ dim=-1,
68
+ ).reshape(seq_len, 1, -1)
69
+
70
+ # apply rotary embedding
71
+ x_i = torch.view_as_real(x_i * freqs_i).flatten(2)
72
+ x_i = torch.cat([x_i, x[i, seq_len:]])
73
+
74
+ # append to collection
75
+ output.append(x_i)
76
+ return torch.stack(output).type_as(x)
77
+
78
+
79
+ class MatrixGameWanRMSNorm(nn.Module):
80
+ def __init__(self, dim, eps=1e-5):
81
+ super().__init__()
82
+ self.dim = dim
83
+ self.eps = eps
84
+ self.weight = nn.Parameter(torch.ones(dim))
85
+
86
+ def forward(self, x):
87
+ r"""
88
+ Args:
89
+ x(Tensor): Shape [B, L, C]
90
+ """
91
+ return self._norm(x.float()).type_as(x) * self.weight
92
+
93
+ def _norm(self, x):
94
+ return x * torch.rsqrt(x.pow(2).mean(dim=-1, keepdim=True) + self.eps)
95
+
96
+
97
+ class MatrixGameWanLayerNorm(nn.LayerNorm):
98
+ def __init__(self, dim, eps=1e-6, elementwise_affine=False):
99
+ super().__init__(dim, elementwise_affine=elementwise_affine, eps=eps)
100
+
101
+ def forward(self, x):
102
+ r"""
103
+ Args:
104
+ x(Tensor): Shape [B, L, C]
105
+ """
106
+ return super().forward(x).type_as(x)
107
+
108
+
109
+ class MatrixGameWanSelfAttention(nn.Module):
110
+ def __init__(self, dim, num_heads, window_size=(-1, -1), qk_norm=True, eps=1e-6):
111
+ assert dim % num_heads == 0
112
+ super().__init__()
113
+ self.dim = dim
114
+ self.num_heads = num_heads
115
+ self.head_dim = dim // num_heads
116
+ self.window_size = window_size
117
+ self.qk_norm = qk_norm
118
+ self.eps = eps
119
+
120
+ # layers
121
+ self.q = nn.Linear(dim, dim)
122
+ self.k = nn.Linear(dim, dim)
123
+ self.v = nn.Linear(dim, dim)
124
+ self.o = nn.Linear(dim, dim)
125
+ self.norm_q = MatrixGameWanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
126
+ self.norm_k = MatrixGameWanRMSNorm(dim, eps=eps) if qk_norm else nn.Identity()
127
+
128
+ def forward(self, x, seq_lens, grid_sizes, freqs):
129
+ r"""
130
+ Args:
131
+ x(Tensor): Shape [B, L, num_heads, C / num_heads]
132
+ seq_lens(Tensor): Shape [B]
133
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
134
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
135
+ """
136
+ b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
137
+
138
+ # query, key, value function
139
+ def qkv_fn(x):
140
+ q = self.norm_q(self.q(x)).view(b, s, n, d)
141
+ k = self.norm_k(self.k(x)).view(b, s, n, d)
142
+ v = self.v(x).view(b, s, n, d)
143
+ return q, k, v
144
+
145
+ q, k, v = qkv_fn(x)
146
+ # print(k.shape, seq_lens)
147
+ x = flash_attention(
148
+ q=rope_apply(q, grid_sizes, freqs),
149
+ k=rope_apply(k, grid_sizes, freqs),
150
+ v=v,
151
+ k_lens=seq_lens,
152
+ window_size=self.window_size,
153
+ )
154
+
155
+ # output
156
+ x = x.flatten(2)
157
+ x = self.o(x)
158
+ return x
159
+
160
+
161
+ # class MatrixGameWanT2VCrossAttention(MatrixGameWanSelfAttention):
162
+
163
+ # def forward(self, x, context, context_lens, crossattn_cache=None):
164
+ # r"""
165
+ # Args:
166
+ # x(Tensor): Shape [B, L1, C]
167
+ # context(Tensor): Shape [B, L2, C]
168
+ # context_lens(Tensor): Shape [B]
169
+ # crossattn_cache (List[dict], *optional*): Contains the cached key and value tensors for context embedding.
170
+ # """
171
+ # b, n, d = x.size(0), self.num_heads, self.head_dim
172
+
173
+ # # compute query, key, value
174
+ # q = self.norm_q(self.q(x)).view(b, -1, n, d)
175
+
176
+ # if crossattn_cache is not None:
177
+ # if not crossattn_cache["is_init"]:
178
+ # crossattn_cache["is_init"] = True
179
+ # k = self.norm_k(self.k(context)).view(b, -1, n, d)
180
+ # v = self.v(context).view(b, -1, n, d)
181
+ # crossattn_cache["k"] = k
182
+ # crossattn_cache["v"] = v
183
+ # else:
184
+ # k = crossattn_cache["k"]
185
+ # v = crossattn_cache["v"]
186
+ # else:
187
+ # k = self.norm_k(self.k(context)).view(b, -1, n, d)
188
+ # v = self.v(context).view(b, -1, n, d)
189
+
190
+ # # compute attention
191
+ # x = flash_attention(q, k, v, k_lens=context_lens)
192
+
193
+ # # output
194
+ # x = x.flatten(2)
195
+ # x = self.o(x)
196
+ # return x
197
+
198
+
199
+ # class MatrixGameWanGanCrossAttention(MatrixGameWanSelfAttention):
200
+
201
+ # def forward(self, x, context, crossattn_cache=None):
202
+ # r"""
203
+ # Args:
204
+ # x(Tensor): Shape [B, L1, C]
205
+ # context(Tensor): Shape [B, L2, C]
206
+ # context_lens(Tensor): Shape [B]
207
+ # crossattn_cache (List[dict], *optional*): Contains the cached key and value tensors for context embedding.
208
+ # """
209
+ # b, n, d = x.size(0), self.num_heads, self.head_dim
210
+
211
+ # # compute query, key, value
212
+ # qq = self.norm_q(self.q(context)).view(b, 1, -1, d)
213
+
214
+ # kk = self.norm_k(self.k(x)).view(b, -1, n, d)
215
+ # vv = self.v(x).view(b, -1, n, d)
216
+
217
+ # # compute attention
218
+ # x = flash_attention(qq, kk, vv)
219
+
220
+ # # output
221
+ # x = x.flatten(2)
222
+ # x = self.o(x)
223
+ # return x
224
+
225
+
226
+ class MatrixGameWanI2VCrossAttention(MatrixGameWanSelfAttention):
227
+ def forward(self, x, context, crossattn_cache=None):
228
+ r"""
229
+ Args:
230
+ x(Tensor): Shape [B, L1, C]
231
+ context(Tensor): Shape [B, L2, C]
232
+ context_lens(Tensor): Shape [B]
233
+ """
234
+ b, n, d = x.size(0), self.num_heads, self.head_dim
235
+
236
+ # compute query, key, value
237
+ q = self.norm_q(self.q(x)).view(b, -1, n, d)
238
+ if crossattn_cache is not None:
239
+ if not crossattn_cache["is_init"]:
240
+ crossattn_cache["is_init"] = True
241
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
242
+ v = self.v(context).view(b, -1, n, d)
243
+ crossattn_cache["k"] = k
244
+ crossattn_cache["v"] = v
245
+ else:
246
+ k = crossattn_cache["k"]
247
+ v = crossattn_cache["v"]
248
+ else:
249
+ k = self.norm_k(self.k(context)).view(b, -1, n, d)
250
+ v = self.v(context).view(b, -1, n, d)
251
+ # compute attention
252
+ x = flash_attention(q, k, v, k_lens=None)
253
+
254
+ # output
255
+ x = x.flatten(2)
256
+ x = self.o(x)
257
+ return x
258
+
259
+
260
+ MatrixGameWan_CROSSATTENTION_CLASSES = {
261
+ "i2v_cross_attn": MatrixGameWanI2VCrossAttention,
262
+ }
263
+
264
+
265
+ def mul_add(x, y, z):
266
+ return x.float() + y.float() * z.float()
267
+
268
+
269
+ def mul_add_add(x, y, z):
270
+ return x.float() * (1 + y) + z
271
+
272
+
273
+ class MatrixGameWanAttentionBlock(nn.Module):
274
+ def __init__(
275
+ self,
276
+ cross_attn_type,
277
+ dim,
278
+ ffn_dim,
279
+ num_heads,
280
+ window_size=(-1, -1),
281
+ qk_norm=True,
282
+ cross_attn_norm=False,
283
+ action_config={},
284
+ eps=1e-6,
285
+ ):
286
+ super().__init__()
287
+ self.dim = dim
288
+ self.ffn_dim = ffn_dim
289
+ self.num_heads = num_heads
290
+ self.window_size = window_size
291
+ self.qk_norm = qk_norm
292
+ self.cross_attn_norm = cross_attn_norm
293
+ self.eps = eps
294
+ if len(action_config) != 0:
295
+ self.action_model = ActionModule(**action_config)
296
+ else:
297
+ self.action_model = None
298
+ # layers
299
+ self.norm1 = MatrixGameWanLayerNorm(dim, eps)
300
+ self.self_attn = MatrixGameWanSelfAttention(dim, num_heads, window_size, qk_norm, eps)
301
+ self.norm3 = (
302
+ MatrixGameWanLayerNorm(dim, eps, elementwise_affine=True)
303
+ if cross_attn_norm
304
+ else nn.Identity()
305
+ )
306
+ self.cross_attn = MatrixGameWan_CROSSATTENTION_CLASSES[cross_attn_type](
307
+ dim, num_heads, (-1, -1), qk_norm, eps
308
+ )
309
+ self.norm2 = MatrixGameWanLayerNorm(dim, eps)
310
+ self.ffn = nn.Sequential(
311
+ nn.Linear(dim, ffn_dim),
312
+ nn.GELU(approximate="tanh"),
313
+ nn.Linear(ffn_dim, dim),
314
+ )
315
+
316
+ # modulation
317
+ self.modulation = nn.Parameter(torch.randn(1, 6, dim) / dim**0.5)
318
+
319
+ def forward(
320
+ self,
321
+ x,
322
+ e,
323
+ seq_lens,
324
+ grid_sizes,
325
+ freqs,
326
+ context,
327
+ mouse_cond=None,
328
+ keyboard_cond=None,
329
+ # context_lens,
330
+ ):
331
+ r"""
332
+ Args:
333
+ x(Tensor): Shape [B, L, C]
334
+ e(Tensor): Shape [B, 6, C]
335
+ seq_lens(Tensor): Shape [B], length of each sequence in batch
336
+ grid_sizes(Tensor): Shape [B, 3], the second dimension contains (F, H, W)
337
+ freqs(Tensor): Rope freqs, shape [1024, C / num_heads / 2]
338
+ """
339
+ # assert e.dtype == torch.float32
340
+ if e.dim() == 3:
341
+ modulation = self.modulation
342
+ # with amp.autocast(dtype=torch.float32):
343
+ e = (self.modulation + e).chunk(6, dim=1)
344
+ elif e.dim() == 4:
345
+ modulation = self.modulation.unsqueeze(2) # 1, 6, 1, dim
346
+ # with amp.autocast("cuda", dtype=torch.float32):
347
+ e = (modulation + e).chunk(6, dim=1)
348
+ e = [ei.squeeze(1) for ei in e]
349
+ # assert e[0].dtype == torch.float32
350
+
351
+ # self-attention
352
+ y = self.self_attn(
353
+ self.norm1(x) * (1 + e[1]) + e[0], seq_lens, grid_sizes, freqs
354
+ )
355
+ # with amp.autocast(dtype=torch.float32):
356
+ x = x + y * e[2]
357
+
358
+ # cross-attention & ffn function
359
+ def cross_attn_ffn(x, context, e, mouse_cond, keyboard_cond):
360
+ dtype = context.dtype
361
+ x = x + self.cross_attn(self.norm3(x.to(dtype)), context)
362
+ if self.action_model is not None:
363
+ assert mouse_cond is not None or keyboard_cond is not None
364
+ x = self.action_model(
365
+ x.to(dtype),
366
+ grid_sizes[0],
367
+ grid_sizes[1],
368
+ grid_sizes[2],
369
+ mouse_cond,
370
+ keyboard_cond,
371
+ )
372
+ y = self.ffn(self.norm2(x) * (1 + e[4]) + e[3])
373
+ # with amp.autocast(dtype=torch.float32):
374
+ x = x + y * e[5]
375
+ return x
376
+
377
+ x = cross_attn_ffn(x, context, e, mouse_cond, keyboard_cond)
378
+ return x
379
+
380
+
381
+ class Head(nn.Module):
382
+ def __init__(self, dim, out_dim, patch_size, eps=1e-6):
383
+ super().__init__()
384
+ self.dim = dim
385
+ self.out_dim = out_dim
386
+ self.patch_size = patch_size
387
+ self.eps = eps
388
+
389
+ # layers
390
+ out_dim = math.prod(patch_size) * out_dim
391
+ self.norm = MatrixGameWanLayerNorm(dim, eps)
392
+ self.head = nn.Linear(dim, out_dim)
393
+
394
+ # modulation
395
+ self.modulation = nn.Parameter(torch.randn(1, 2, dim) / dim**0.5)
396
+
397
+ def forward(self, x, e):
398
+ r"""
399
+ Args:
400
+ x(Tensor): Shape [B, L1, C]
401
+ e(Tensor): Shape [B, C]
402
+ """
403
+ # assert e.dtype == torch.float32
404
+ # with amp.autocast(dtype=torch.float32):
405
+ if e.dim() == 2:
406
+ modulation = self.modulation # 1, 2, dim
407
+ e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
408
+ elif e.dim() == 3:
409
+ modulation = self.modulation.unsqueeze(2) # 1, 2, seq, dim
410
+ e = (modulation + e.unsqueeze(1)).chunk(2, dim=1)
411
+ e = [ei.squeeze(1) for ei in e]
412
+ x = self.head(self.norm(x) * (1 + e[1]) + e[0])
413
+ return x
414
+
415
+
416
+ class MLPProj(torch.nn.Module):
417
+ def __init__(self, in_dim, out_dim):
418
+ super().__init__()
419
+
420
+ self.proj = torch.nn.Sequential(
421
+ torch.nn.LayerNorm(in_dim),
422
+ torch.nn.Linear(in_dim, in_dim),
423
+ torch.nn.GELU(),
424
+ torch.nn.Linear(in_dim, out_dim),
425
+ torch.nn.LayerNorm(out_dim),
426
+ )
427
+
428
+ def forward(self, image_embeds):
429
+ clip_extra_context_tokens = self.proj(image_embeds)
430
+ return clip_extra_context_tokens
431
+
432
+
433
+ # class RegisterTokens(nn.Module):
434
+ # def __init__(self, num_registers: int, dim: int):
435
+ # super().__init__()
436
+ # self.register_tokens = nn.Parameter(torch.randn(num_registers, dim) * 0.02)
437
+ # self.rms_norm = MatrixGameWanRMSNorm(dim, eps=1e-6)
438
+
439
+ # def forward(self):
440
+ # return self.rms_norm(self.register_tokens)
441
+
442
+ # def reset_parameters(self):
443
+ # nn.init.normal_(self.register_tokens, std=0.02)
444
+
445
+
446
+ class MatrixGameWanModel(ModelMixin, ConfigMixin, FromOriginalModelMixin, PeftAdapterMixin):
447
+ r"""
448
+ MatrixGameWan diffusion backbone supporting both text-to-video and image-to-video.
449
+ """
450
+
451
+ ignore_for_config = [
452
+ "patch_size",
453
+ "cross_attn_norm",
454
+ "qk_norm",
455
+ "text_dim",
456
+ "window_size",
457
+ ]
458
+ _no_split_modules = ["MatrixGameWanAttentionBlock"]
459
+ _supports_gradient_checkpointing = True
460
+
461
+ @register_to_config
462
+ def __init__(
463
+ self,
464
+ model_type="i2v",
465
+ patch_size=(1, 2, 2),
466
+ text_len=512,
467
+ in_dim=36,
468
+ dim=1536,
469
+ ffn_dim=8960,
470
+ freq_dim=256,
471
+ text_dim=4096,
472
+ out_dim=16,
473
+ num_heads=12,
474
+ num_layers=30,
475
+ window_size=(-1, -1),
476
+ qk_norm=True,
477
+ cross_attn_norm=True,
478
+ inject_sample_info=False,
479
+ action_config={},
480
+ eps=1e-6,
481
+ ):
482
+ r"""
483
+ Initialize the diffusion model backbone.
484
+
485
+ Args:
486
+ model_type (`str`, *optional*, defaults to 't2v'):
487
+ Model variant - 't2v' (text-to-video) or 'i2v' (image-to-video)
488
+ patch_size (`tuple`, *optional*, defaults to (1, 2, 2)):
489
+ 3D patch dimensions for video embedding (t_patch, h_patch, w_patch)
490
+ text_len (`int`, *optional*, defaults to 512):
491
+ Fixed length for text embeddings
492
+ in_dim (`int`, *optional*, defaults to 16):
493
+ Input video channels (C_in)
494
+ dim (`int`, *optional*, defaults to 2048):
495
+ Hidden dimension of the transformer
496
+ ffn_dim (`int`, *optional*, defaults to 8192):
497
+ Intermediate dimension in feed-forward network
498
+ freq_dim (`int`, *optional*, defaults to 256):
499
+ Dimension for sinusoidal time embeddings
500
+ text_dim (`int`, *optional*, defaults to 4096):
501
+ Input dimension for text embeddings
502
+ out_dim (`int`, *optional*, defaults to 16):
503
+ Output video channels (C_out)
504
+ num_heads (`int`, *optional*, defaults to 16):
505
+ Number of attention heads
506
+ num_layers (`int`, *optional*, defaults to 32):
507
+ Number of transformer blocks
508
+ window_size (`tuple`, *optional*, defaults to (-1, -1)):
509
+ Window size for local attention (-1 indicates global attention)
510
+ qk_norm (`bool`, *optional*, defaults to True):
511
+ Enable query/key normalization
512
+ cross_attn_norm (`bool`, *optional*, defaults to False):
513
+ Enable cross-attention normalization
514
+ eps (`float`, *optional*, defaults to 1e-6):
515
+ Epsilon value for normalization layers
516
+ """
517
+
518
+ super().__init__()
519
+
520
+ assert model_type in ["i2v"]
521
+ self.model_type = model_type
522
+ self.use_action_module = len(action_config) > 0
523
+ assert self.use_action_module == True
524
+ self.patch_size = patch_size
525
+ self.text_len = text_len
526
+ self.in_dim = in_dim
527
+ self.dim = dim
528
+ self.ffn_dim = ffn_dim
529
+ self.freq_dim = freq_dim
530
+ self.text_dim = text_dim
531
+ self.out_dim = out_dim
532
+ self.num_heads = num_heads
533
+ self.num_layers = num_layers
534
+ self.window_size = window_size
535
+ self.qk_norm = qk_norm
536
+ self.cross_attn_norm = cross_attn_norm
537
+ self.eps = eps
538
+ self.local_attn_size = -1
539
+
540
+ # embeddings
541
+ self.patch_embedding = nn.Conv3d(
542
+ in_dim, dim, kernel_size=patch_size, stride=patch_size
543
+ )
544
+ # self.text_embedding = nn.Sequential(
545
+ # nn.Linear(text_dim, dim), nn.GELU(approximate='tanh'),
546
+ # nn.Linear(dim, dim))
547
+
548
+ self.time_embedding = nn.Sequential(
549
+ nn.Linear(freq_dim, dim), nn.SiLU(), nn.Linear(dim, dim)
550
+ )
551
+ self.time_projection = nn.Sequential(nn.SiLU(), nn.Linear(dim, dim * 6))
552
+
553
+ # blocks
554
+ cross_attn_type = "i2v_cross_attn"
555
+ self.blocks = nn.ModuleList(
556
+ [
557
+ MatrixGameWanAttentionBlock(
558
+ cross_attn_type,
559
+ dim,
560
+ ffn_dim,
561
+ num_heads,
562
+ window_size,
563
+ qk_norm,
564
+ cross_attn_norm,
565
+ eps=eps,
566
+ action_config=action_config,
567
+ )
568
+ for _ in range(num_layers)
569
+ ]
570
+ )
571
+
572
+ # head
573
+ self.head = Head(dim, out_dim, patch_size, eps)
574
+
575
+ # buffers (don't use register_buffer otherwise dtype will be changed in to())
576
+ assert (dim % num_heads) == 0 and (dim // num_heads) % 2 == 0
577
+ d = dim // num_heads
578
+ self.freqs = torch.cat(
579
+ [
580
+ rope_params(1024, d - 4 * (d // 6)),
581
+ rope_params(1024, 2 * (d // 6)),
582
+ rope_params(1024, 2 * (d // 6)),
583
+ ],
584
+ dim=1,
585
+ )
586
+
587
+ if model_type == "i2v":
588
+ self.img_emb = MLPProj(1280, dim)
589
+
590
+ # initialize weights
591
+ self.init_weights()
592
+
593
+ self.gradient_checkpointing = False
594
+
595
+ def _set_gradient_checkpointing(self, module, value=False):
596
+ self.gradient_checkpointing = value
597
+
598
+ def forward(self, *args, **kwargs):
599
+ # if kwargs.get('classify_mode', False) is True:
600
+ # kwargs.pop('classify_mode')
601
+ # return self._forward_classify(*args, **kwargs)
602
+ # else:
603
+ return self._forward(*args, **kwargs)
604
+
605
+ def _forward(
606
+ self,
607
+ x,
608
+ t,
609
+ visual_context,
610
+ cond_concat,
611
+ mouse_cond=None,
612
+ keyboard_cond=None,
613
+ fps=None,
614
+ # seq_len,
615
+ # classify_mode=False,
616
+ # concat_time_embeddings=False,
617
+ # register_tokens=None,
618
+ # cls_pred_branch=None,
619
+ # gan_ca_blocks=None,
620
+ # clip_fea=None,
621
+ # y=None,
622
+ ):
623
+ r"""
624
+ Forward pass through the diffusion model
625
+
626
+ Args:
627
+ x (List[Tensor]):
628
+ List of input video tensors, each with shape [C_in, F, H, W]
629
+ t (Tensor):
630
+ Diffusion timesteps tensor of shape [B]
631
+ context (List[Tensor]):
632
+ List of text embeddings each with shape [L, C]
633
+ seq_len (`int`):
634
+ Maximum sequence length for positional encoding
635
+ clip_fea (Tensor, *optional*):
636
+ CLIP image features for image-to-video mode
637
+ y (List[Tensor], *optional*):
638
+ Conditional video inputs for image-to-video mode, same shape as x
639
+
640
+ Returns:
641
+ List[Tensor]:
642
+ List of denoised video tensors with original input shapes [C_out, F, H / 8, W / 8]
643
+ """
644
+ # params
645
+ if mouse_cond is not None or keyboard_cond is not None:
646
+ assert self.use_action_module == True
647
+ device = self.patch_embedding.weight.device
648
+ if self.freqs.device != device:
649
+ self.freqs = self.freqs.to(device)
650
+
651
+ x = torch.cat([x, cond_concat], dim=1)
652
+ # embeddings
653
+ x = self.patch_embedding(x)
654
+ grid_sizes = torch.tensor(x.shape[2:], dtype=torch.long)
655
+ x = x.flatten(2).transpose(1, 2)
656
+ seq_lens = torch.tensor([u.size(0) for u in x], dtype=torch.long)
657
+ # seq_len = seq_lens.max()
658
+ # # assert seq_lens.max() <= seq_len
659
+ # x = torch.cat([
660
+ # torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))],
661
+ # dim=1) for u in x
662
+ # ])
663
+
664
+ # time embeddings
665
+ # with amp.autocast(dtype=torch.float32):
666
+ # assert t.ndim == 1
667
+ e = self.time_embedding(
668
+ sinusoidal_embedding_1d(self.freq_dim, t).type_as(x)
669
+ ) # TODO: check if t ndim == 1
670
+
671
+ e0 = self.time_projection(e).unflatten(1, (6, self.dim))
672
+ # assert e.dtype == torch.float32 and e0.dtype == torch.float32
673
+
674
+ # context
675
+ context_lens = None
676
+ # context = self.text_embedding(
677
+ # torch.stack([
678
+ # torch.cat(
679
+ # [u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
680
+ # for u in context
681
+ # ]))
682
+
683
+ # if clip_fea is not None:
684
+ # context_clip = self.img_emb(clip_fea) # bs x 257 x dim
685
+ context = self.img_emb(visual_context)
686
+
687
+ # arguments
688
+ # kwargs = dict(
689
+ # e=e0,
690
+ # seq_lens=seq_lens,
691
+ # grid_sizes=grid_sizes,
692
+ # freqs=self.freqs,
693
+ # context=context,
694
+ # context_lens=context_lens)
695
+ kwargs = dict(
696
+ e=e0,
697
+ grid_sizes=grid_sizes,
698
+ seq_lens=seq_lens,
699
+ freqs=self.freqs,
700
+ context=context,
701
+ mouse_cond=mouse_cond,
702
+ # context_lens=context_lens,
703
+ keyboard_cond=keyboard_cond,
704
+ )
705
+
706
+ def create_custom_forward(module):
707
+ def custom_forward(*inputs, **kwargs):
708
+ return module(*inputs, **kwargs)
709
+
710
+ return custom_forward
711
+
712
+ for ii, block in enumerate(self.blocks):
713
+ if torch.is_grad_enabled() and self.gradient_checkpointing:
714
+ x = torch.utils.checkpoint.checkpoint(
715
+ create_custom_forward(block),
716
+ x,
717
+ **kwargs,
718
+ use_reentrant=False,
719
+ )
720
+ else:
721
+ x = block(x, **kwargs)
722
+
723
+ # head
724
+ x = self.head(x, e)
725
+
726
+ # unpatchify
727
+ x = self.unpatchify(x, grid_sizes)
728
+
729
+ return x.float()
730
+
731
+ def unpatchify(self, x, grid_sizes): # TODO check grid sizes
732
+ r"""
733
+ Reconstruct video tensors from patch embeddings.
734
+
735
+ Args:
736
+ x (List[Tensor]):
737
+ List of patchified features, each with shape [L, C_out * prod(patch_size)]
738
+ grid_sizes (Tensor):
739
+ Original spatial-temporal grid dimensions before patching,
740
+ shape [3] (3 dimensions correspond to F_patches, H_patches, W_patches)
741
+
742
+ Returns:
743
+ List[Tensor]:
744
+ Reconstructed video tensors with shape [C_out, F, H / 8, W / 8]
745
+ """
746
+
747
+ c = self.out_dim
748
+ bs = x.shape[0]
749
+ x = x.view(bs, *grid_sizes, *self.patch_size, c)
750
+ x = torch.einsum("bfhwpqrc->bcfphqwr", x)
751
+ x = x.reshape(bs, c, *[i * j for i, j in zip(grid_sizes, self.patch_size)])
752
+ return x
753
+
754
+ def init_weights(self):
755
+ r"""
756
+ Initialize model parameters using Xavier initialization.
757
+ """
758
+
759
+ # basic init
760
+ for m in self.modules():
761
+ if isinstance(m, nn.Linear):
762
+ nn.init.xavier_uniform_(m.weight)
763
+ if m.bias is not None:
764
+ nn.init.zeros_(m.bias)
765
+
766
+ # init embeddings
767
+ nn.init.xavier_uniform_(self.patch_embedding.weight.flatten(1))
768
+ for m in self.time_embedding.modules():
769
+ if isinstance(m, nn.Linear):
770
+ nn.init.normal_(m.weight, std=0.02)
771
+
772
+ # init output layer
773
+ nn.init.zeros_(self.head.head.weight)
774
+ if self.use_action_module == True:
775
+ for m in self.blocks:
776
+ nn.init.zeros_(m.action_model.proj_mouse.weight)
777
+ if m.action_model.proj_mouse.bias is not None:
778
+ nn.init.zeros_(m.action_model.proj_mouse.bias)
779
+ nn.init.zeros_(m.action_model.proj_keyboard.weight)
780
+ if m.action_model.proj_keyboard.bias is not None:
781
+ nn.init.zeros_(m.action_model.proj_keyboard.bias)
transformer/posemb_layers.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Union, Tuple, List
3
+
4
+
5
+ def _to_tuple(x, dim=2):
6
+ if isinstance(x, int):
7
+ return (x,) * dim
8
+ elif len(x) == dim:
9
+ return x
10
+ else:
11
+ raise ValueError(f"Expected length {dim} or int, but got {x}")
12
+
13
+
14
+ def get_meshgrid_nd(start, *args, dim=2):
15
+ """
16
+ Get n-D meshgrid with start, stop and num.
17
+
18
+ Args:
19
+ start (int or tuple): If len(args) == 0, start is num; If len(args) == 1, start is start, args[0] is stop,
20
+ step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num. For n-dim, start/stop/num
21
+ should be int or n-tuple. If n-tuple is provided, the meshgrid will be stacked following the dim order in
22
+ n-tuples.
23
+ *args: See above.
24
+ dim (int): Dimension of the meshgrid. Defaults to 2.
25
+
26
+ Returns:
27
+ grid (np.ndarray): [dim, ...]
28
+ """
29
+ if len(args) == 0:
30
+ # start is grid_size
31
+ num = _to_tuple(start, dim=dim)
32
+ start = (0,) * dim
33
+ stop = num
34
+ elif len(args) == 1:
35
+ # start is start, args[0] is stop, step is 1
36
+ start = _to_tuple(start, dim=dim)
37
+ stop = _to_tuple(args[0], dim=dim)
38
+ num = [stop[i] - start[i] for i in range(dim)]
39
+ elif len(args) == 2:
40
+ # start is start, args[0] is stop, args[1] is num
41
+ start = _to_tuple(start, dim=dim) # Left-Top eg: 12,0
42
+ stop = _to_tuple(args[0], dim=dim) # Right-Bottom eg: 20,32
43
+ num = _to_tuple(args[1], dim=dim) # Target Size eg: 32,124
44
+ else:
45
+ raise ValueError(f"len(args) should be 0, 1 or 2, but got {len(args)}")
46
+
47
+ # PyTorch implement of np.linspace(start[i], stop[i], num[i], endpoint=False)
48
+ axis_grid = []
49
+ for i in range(dim):
50
+ a, b, n = start[i], stop[i], num[i]
51
+ g = torch.linspace(a, b, n + 1, dtype=torch.float32, device=torch.cuda.current_device())[:n]
52
+ axis_grid.append(g)
53
+ grid = torch.meshgrid(*axis_grid, indexing="ij") # dim x [W, H, D]
54
+ grid = torch.stack(grid, dim=0) # [dim, W, H, D]
55
+
56
+ return grid
57
+
58
+
59
+ #################################################################################
60
+ # Rotary Positional Embedding Functions #
61
+ #################################################################################
62
+ # https://github.com/meta-llama/llama/blob/be327c427cc5e89cc1d3ab3d3fec4484df771245/llama/model.py#L80
63
+
64
+
65
+ def reshape_for_broadcast(
66
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
67
+ x: torch.Tensor,
68
+ head_first=False,
69
+ ):
70
+ """
71
+ Reshape frequency tensor for broadcasting it with another tensor.
72
+
73
+ This function reshapes the frequency tensor to have the same shape as the target tensor 'x'
74
+ for the purpose of broadcasting the frequency tensor during element-wise operations.
75
+
76
+ Notes:
77
+ When using FlashMHAModified, head_first should be False.
78
+ When using Attention, head_first should be True.
79
+
80
+ Args:
81
+ freqs_cis (Union[torch.Tensor, Tuple[torch.Tensor]]): Frequency tensor to be reshaped.
82
+ x (torch.Tensor): Target tensor for broadcasting compatibility.
83
+ head_first (bool): head dimension first (except batch dim) or not.
84
+
85
+ Returns:
86
+ torch.Tensor: Reshaped frequency tensor.
87
+
88
+ Raises:
89
+ AssertionError: If the frequency tensor doesn't match the expected shape.
90
+ AssertionError: If the target tensor 'x' doesn't have the expected number of dimensions.
91
+ """
92
+ ndim = x.ndim
93
+ assert 0 <= 1 < ndim
94
+
95
+ if isinstance(freqs_cis, tuple):
96
+ # freqs_cis: (cos, sin) in real space
97
+ if head_first:
98
+ assert freqs_cis[0].shape == (
99
+ x.shape[-2],
100
+ x.shape[-1],
101
+ ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
102
+ shape = [
103
+ d if i == ndim - 2 or i == ndim - 1 else 1
104
+ for i, d in enumerate(x.shape)
105
+ ]
106
+ else:
107
+ # assert freqs_cis[0].shape == (
108
+ # x.shape[1],
109
+ # x.shape[-1],
110
+ # ), f"freqs_cis shape {freqs_cis[0].shape} does not match x shape {x.shape}"
111
+ # shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
112
+ shape = [1, freqs_cis[0].shape[0], 1, freqs_cis[0].shape[1]]
113
+ return freqs_cis[0].view(*shape), freqs_cis[1].view(*shape)
114
+ else:
115
+ # freqs_cis: values in complex space
116
+ if head_first:
117
+ assert freqs_cis.shape == (
118
+ x.shape[-2],
119
+ x.shape[-1],
120
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
121
+ shape = [
122
+ d if i == ndim - 2 or i == ndim - 1 else 1
123
+ for i, d in enumerate(x.shape)
124
+ ]
125
+ else:
126
+ assert freqs_cis.shape == (
127
+ x.shape[1],
128
+ x.shape[-1],
129
+ ), f"freqs_cis shape {freqs_cis.shape} does not match x shape {x.shape}"
130
+ shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
131
+ return freqs_cis.view(*shape)
132
+
133
+
134
+ def rotate_half(x):
135
+ x_real, x_imag = (
136
+ x.float().reshape(*x.shape[:-1], -1, 2).unbind(-1)
137
+ ) # [B, S, H, D//2]
138
+ return torch.stack([-x_imag, x_real], dim=-1).flatten(3)
139
+
140
+
141
+ def apply_rotary_emb(
142
+ xq: torch.Tensor,
143
+ xk: torch.Tensor,
144
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]],
145
+ head_first: bool = False,
146
+ start_offset: int = 0,
147
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
148
+ """
149
+ Apply rotary embeddings to input tensors using the given frequency tensor.
150
+
151
+ This function applies rotary embeddings to the given query 'xq' and key 'xk' tensors using the provided
152
+ frequency tensor 'freqs_cis'. The input tensors are reshaped as complex numbers, and the frequency tensor
153
+ is reshaped for broadcasting compatibility. The resulting tensors contain rotary embeddings and are
154
+ returned as real tensors.
155
+
156
+ Args:
157
+ xq (torch.Tensor): Query tensor to apply rotary embeddings. [B, S, H, D]
158
+ xk (torch.Tensor): Key tensor to apply rotary embeddings. [B, S, H, D]
159
+ freqs_cis (torch.Tensor or tuple): Precomputed frequency tensor for complex exponential.
160
+ head_first (bool): head dimension first (except batch dim) or not.
161
+
162
+ Returns:
163
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
164
+
165
+ """
166
+ # print(freqs_cis[0].shape, xq.shape, xk.shape)
167
+ xk_out = None
168
+ assert isinstance(freqs_cis, tuple)
169
+ if isinstance(freqs_cis, tuple):
170
+ cos, sin = reshape_for_broadcast(freqs_cis, xq, head_first) # [S, D]
171
+ cos, sin = cos.to(xq.device), sin.to(xq.device)
172
+ # real * cos - imag * sin
173
+ # imag * cos + real * sin
174
+ xq_out = (xq.float() * cos[:, start_offset:start_offset + xq.shape[1], :, :] + rotate_half(xq.float()) * sin[:, start_offset:start_offset + xq.shape[1], :, :]).type_as(xq)
175
+ xk_out = (xk.float() * cos[:, start_offset:start_offset + xk.shape[1], :, :] + rotate_half(xk.float()) * sin[:, start_offset:start_offset + xk.shape[1], :, :]).type_as(xk)
176
+ else:
177
+ # view_as_complex will pack [..., D/2, 2](real) to [..., D/2](complex)
178
+ xq_ = torch.view_as_complex(
179
+ xq.float().reshape(*xq.shape[:-1], -1, 2)
180
+ ) # [B, S, H, D//2]
181
+ freqs_cis = reshape_for_broadcast(freqs_cis, xq_, head_first).to(
182
+ xq.device
183
+ ) # [S, D//2] --> [1, S, 1, D//2]
184
+ # (real, imag) * (cos, sin) = (real * cos - imag * sin, imag * cos + real * sin)
185
+ # view_as_real will expand [..., D/2](complex) to [..., D/2, 2](real)
186
+ xq_out = torch.view_as_real(xq_ * freqs_cis).flatten(3).type_as(xq)
187
+ xk_ = torch.view_as_complex(
188
+ xk.float().reshape(*xk.shape[:-1], -1, 2)
189
+ ) # [B, S, H, D//2]
190
+ xk_out = torch.view_as_real(xk_ * freqs_cis).flatten(3).type_as(xk)
191
+
192
+ return xq_out, xk_out
193
+
194
+
195
+ def get_nd_rotary_pos_embed(
196
+ rope_dim_list,
197
+ start,
198
+ *args,
199
+ theta=10000.0,
200
+ use_real=False,
201
+ theta_rescale_factor: Union[float, List[float]] = 1.0,
202
+ interpolation_factor: Union[float, List[float]] = 1.0,
203
+ ):
204
+ """
205
+ This is a n-d version of precompute_freqs_cis, which is a RoPE for tokens with n-d structure.
206
+
207
+ Args:
208
+ rope_dim_list (list of int): Dimension of each rope. len(rope_dim_list) should equal to n.
209
+ sum(rope_dim_list) should equal to head_dim of attention layer.
210
+ start (int | tuple of int | list of int): If len(args) == 0, start is num; If len(args) == 1, start is start,
211
+ args[0] is stop, step is 1; If len(args) == 2, start is start, args[0] is stop, args[1] is num.
212
+ *args: See above.
213
+ theta (float): Scaling factor for frequency computation. Defaults to 10000.0.
214
+ use_real (bool): If True, return real part and imaginary part separately. Otherwise, return complex numbers.
215
+ Some libraries such as TensorRT does not support complex64 data type. So it is useful to provide a real
216
+ part and an imaginary part separately.
217
+ theta_rescale_factor (float): Rescale factor for theta. Defaults to 1.0.
218
+
219
+ Returns:
220
+ pos_embed (torch.Tensor): [HW, D/2]
221
+ """
222
+
223
+ grid = get_meshgrid_nd(
224
+ start, *args, dim=len(rope_dim_list)
225
+ ) # [3, W, H, D] / [2, W, H]
226
+
227
+ if isinstance(theta_rescale_factor, int) or isinstance(theta_rescale_factor, float):
228
+ theta_rescale_factor = [theta_rescale_factor] * len(rope_dim_list)
229
+ elif isinstance(theta_rescale_factor, list) and len(theta_rescale_factor) == 1:
230
+ theta_rescale_factor = [theta_rescale_factor[0]] * len(rope_dim_list)
231
+ assert len(theta_rescale_factor) == len(
232
+ rope_dim_list
233
+ ), "len(theta_rescale_factor) should equal to len(rope_dim_list)"
234
+
235
+ if isinstance(interpolation_factor, int) or isinstance(interpolation_factor, float):
236
+ interpolation_factor = [interpolation_factor] * len(rope_dim_list)
237
+ elif isinstance(interpolation_factor, list) and len(interpolation_factor) == 1:
238
+ interpolation_factor = [interpolation_factor[0]] * len(rope_dim_list)
239
+ assert len(interpolation_factor) == len(
240
+ rope_dim_list
241
+ ), "len(interpolation_factor) should equal to len(rope_dim_list)"
242
+
243
+ # use 1/ndim of dimensions to encode grid_axis
244
+ embs = []
245
+ for i in range(len(rope_dim_list)):
246
+ emb = get_1d_rotary_pos_embed(
247
+ rope_dim_list[i],
248
+ grid[i].reshape(-1),
249
+ theta,
250
+ use_real=use_real,
251
+ theta_rescale_factor=theta_rescale_factor[i],
252
+ interpolation_factor=interpolation_factor[i],
253
+ ) # 2 x [WHD, rope_dim_list[i]]
254
+ embs.append(emb)
255
+
256
+ if use_real:
257
+ cos = torch.cat([emb[0] for emb in embs], dim=1) # (WHD, D/2)
258
+ sin = torch.cat([emb[1] for emb in embs], dim=1) # (WHD, D/2)
259
+ return cos, sin
260
+ else:
261
+ emb = torch.cat(embs, dim=1) # (WHD, D/2)
262
+ return emb
263
+
264
+
265
+ def get_1d_rotary_pos_embed(
266
+ dim: int,
267
+ pos: Union[torch.FloatTensor, int],
268
+ theta: float = 10000.0,
269
+ use_real: bool = False,
270
+ theta_rescale_factor: float = 1.0,
271
+ interpolation_factor: float = 1.0,
272
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
273
+ """
274
+ Precompute the frequency tensor for complex exponential (cis) with given dimensions.
275
+ (Note: `cis` means `cos + i * sin`, where i is the imaginary unit.)
276
+
277
+ This function calculates a frequency tensor with complex exponential using the given dimension 'dim'
278
+ and the end index 'end'. The 'theta' parameter scales the frequencies.
279
+ The returned tensor contains complex values in complex64 data type.
280
+
281
+ Args:
282
+ dim (int): Dimension of the frequency tensor.
283
+ pos (int or torch.FloatTensor): Position indices for the frequency tensor. [S] or scalar
284
+ theta (float, optional): Scaling factor for frequency computation. Defaults to 10000.0.
285
+ use_real (bool, optional): If True, return real part and imaginary part separately.
286
+ Otherwise, return complex numbers.
287
+ theta_rescale_factor (float, optional): Rescale factor for theta. Defaults to 1.0.
288
+
289
+ Returns:
290
+ freqs_cis: Precomputed frequency tensor with complex exponential. [S, D/2]
291
+ freqs_cos, freqs_sin: Precomputed frequency tensor with real and imaginary parts separately. [S, D]
292
+ """
293
+ if isinstance(pos, int):
294
+ pos = torch.arange(pos, device=torch.cuda.current_device()).float()
295
+
296
+ # proposed by reddit user bloc97, to rescale rotary embeddings to longer sequence length without fine-tuning
297
+ # has some connection to NTK literature
298
+ if theta_rescale_factor != 1.0:
299
+ theta *= theta_rescale_factor ** (dim / (dim - 2))
300
+
301
+ freqs = 1.0 / (
302
+ theta ** (torch.arange(0, dim, 2, device=torch.cuda.current_device())[: (dim // 2)].float() / dim)
303
+ ) # [D/2]
304
+ # assert interpolation_factor == 1.0, f"interpolation_factor: {interpolation_factor}"
305
+ freqs = torch.outer(pos * interpolation_factor, freqs) # [S, D/2]
306
+ if use_real:
307
+ freqs_cos = freqs.cos().repeat_interleave(2, dim=1) # [S, D]
308
+ freqs_sin = freqs.sin().repeat_interleave(2, dim=1) # [S, D]
309
+ return freqs_cos, freqs_sin
310
+ else:
311
+ freqs_cis = torch.polar(
312
+ torch.ones_like(freqs), freqs
313
+ ) # complex64 # [S, D/2]
314
+ return freqs_cis