alexnasa commited on
Commit
c5d62d8
·
verified ·
1 Parent(s): ff7fb3a

Update humo/generate.py

Browse files
Files changed (1) hide show
  1. humo/generate.py +973 -983
humo/generate.py CHANGED
@@ -1,984 +1,974 @@
1
- # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
- # Licensed under the Apache License, Version 2.0 (the "License");
3
- # you may not use this file except in compliance with the License.
4
- # You may obtain a copy of the License at
5
- # http://www.apache.org/licenses/LICENSE-2.0
6
- # Unless required by applicable law or agreed to in writing, software
7
- # distributed under the License is distributed on an "AS IS" BASIS,
8
- # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
- # See the License for the specific language governing permissions and
10
- # limitations under the License.
11
-
12
- # Inference codes adapted from [SeedVR]
13
- # https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
14
-
15
- import math
16
- import os
17
- import gc
18
- import random
19
- import sys
20
- import mediapy
21
- import numpy as np
22
- import torch
23
- import torch.distributed as dist
24
- from omegaconf import DictConfig, ListConfig, OmegaConf
25
- from einops import rearrange
26
- from omegaconf import OmegaConf
27
- from PIL import Image, ImageOps
28
- from torchvision.transforms import ToTensor
29
- from tqdm import tqdm
30
- from torch.distributed.device_mesh import init_device_mesh
31
- from torch.distributed.fsdp import (
32
- BackwardPrefetch,
33
- FullyShardedDataParallel,
34
- MixedPrecision,
35
- ShardingStrategy,
36
- )
37
- from common.distributed import (
38
- get_device,
39
- get_global_rank,
40
- get_local_rank,
41
- meta_param_init_fn,
42
- meta_non_persistent_buffer_init_fn,
43
- init_torch,
44
- )
45
- from common.distributed.advanced import (
46
- init_unified_parallel,
47
- get_unified_parallel_world_size,
48
- get_sequence_parallel_rank,
49
- init_model_shard_cpu_group,
50
- )
51
- from common.logger import get_logger
52
- from common.config import create_object
53
- from common.distributed import get_device, get_global_rank
54
- from torchvision.transforms import Compose, Normalize, ToTensor
55
- from humo.models.wan_modules.t5 import T5EncoderModel
56
- from humo.models.wan_modules.vae import WanVAE
57
- from humo.models.utils.utils import tensor_to_video, prepare_json_dataset
58
- from contextlib import contextmanager
59
- import torch.cuda.amp as amp
60
- from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
61
- from humo.utils.audio_processor_whisper import AudioProcessor
62
- from humo.utils.wav2vec import linear_interpolation_fps
63
- from torchao.quantization import quantize_
64
-
65
- import torch._dynamo as dynamo
66
- dynamo.config.capture_scalar_outputs = True
67
- torch.set_float32_matmul_precision("high")
68
-
69
- import torch
70
- import torch.nn as nn
71
- import transformer_engine.pytorch as te
72
-
73
- image_transform = Compose([
74
- ToTensor(),
75
- Normalize(mean=0.5, std=0.5),
76
- ])
77
-
78
- SIZE_CONFIGS = {
79
- '720*1280': (720, 1280),
80
- '1280*720': (1280, 720),
81
- '480*832': (480, 832),
82
- '832*480': (832, 480),
83
- '1024*1024': (1024, 1024),
84
- }
85
-
86
- def clever_format(nums, format="%.2f"):
87
- from typing import Iterable
88
- if not isinstance(nums, Iterable):
89
- nums = [nums]
90
- clever_nums = []
91
- for num in nums:
92
- if num > 1e12:
93
- clever_nums.append(format % (num / 1e12) + "T")
94
- elif num > 1e9:
95
- clever_nums.append(format % (num / 1e9) + "G")
96
- elif num > 1e6:
97
- clever_nums.append(format % (num / 1e6) + "M")
98
- elif num > 1e3:
99
- clever_nums.append(format % (num / 1e3) + "K")
100
- else:
101
- clever_nums.append(format % num + "B")
102
-
103
- clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
104
-
105
- return clever_nums
106
-
107
-
108
-
109
- # --- put near your imports ---
110
- import torch
111
- import torch.nn as nn
112
- import contextlib
113
- import transformer_engine.pytorch as te
114
-
115
- # FP8 autocast compatibility for different TE versions
116
- try:
117
- # Preferred modern API
118
- from transformer_engine.pytorch import fp8_autocast
119
- try:
120
- # Newer TE: use recipe-based API
121
- from transformer_engine.common.recipe import DelayedScaling, Format
122
- def make_fp8_ctx(enabled: bool = True):
123
- if not enabled:
124
- return contextlib.nullcontext()
125
- fp8_recipe = DelayedScaling(fp8_format=Format.E4M3) # E4M3 format
126
- return fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)
127
- except Exception:
128
- # Very old variant that might still accept fp8_format directly
129
- def make_fp8_ctx(enabled: bool = True):
130
- # If TE doesn't have FP8Format, just no-op
131
- if not hasattr(te, "FP8Format"):
132
- return contextlib.nullcontext()
133
- return te.fp8_autocast(enabled=enabled, fp8_format=te.FP8Format.E4M3)
134
- except Exception:
135
- # TE not present or totally incompatible — no-op
136
- def make_fp8_ctx(enabled: bool = True):
137
- return contextlib.nullcontext()
138
-
139
-
140
- # TE sometimes exposes Linear at different paths; this normalizes it.
141
- try:
142
- TELinear = te.Linear
143
- except AttributeError: # very old layouts
144
- from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
145
-
146
- # --- near imports ---
147
- import torch
148
- import torch.nn as nn
149
- import transformer_engine.pytorch as te
150
-
151
- try:
152
- TELinear = te.Linear
153
- except AttributeError:
154
- from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
155
-
156
- import torch
157
- import torch.nn as nn
158
- import transformer_engine.pytorch as te
159
-
160
- try:
161
- TELinear = te.Linear
162
- except AttributeError:
163
- from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
164
-
165
- def _default_te_allow(fullname: str, lin: nn.Linear) -> bool:
166
- """
167
- Allow TE only where it's shape-safe & beneficial.
168
- Skip small/special layers (time/timestep/pos embeds, heads).
169
- Enforce multiples of 16 for in/out features (FP8 kernel friendly).
170
- Also skip very small projections likely to see M=1.
171
- """
172
- blocked_keywords = (
173
- "time_embedding", "timestep", "time_embed",
174
- "time_projection", "pos_embedding", "pos_embed",
175
- "to_logits", "logits", "final_proj", "proj_out", "output_projection",
176
- )
177
- if any(k in fullname for k in blocked_keywords):
178
- return False
179
-
180
- # TE FP8 kernels like K, N divisible by 16
181
- if lin.in_features % 16 != 0 or lin.out_features % 16 != 0:
182
- return False
183
-
184
- # Heuristic: avoid tiny layers; keeps attention/MLP, skips small MLPs
185
- if lin.in_features < 512 or lin.out_features < 512:
186
- return False
187
-
188
- # Whitelist: only convert inside transformer blocks if you know their prefix
189
- # This further reduces risk of catching special heads elsewhere.
190
- allowed_context = ("blocks", "layers", "transformer", "attn", "mlp", "ffn")
191
- if not any(tok in fullname for tok in allowed_context):
192
- return False
193
-
194
- return True
195
-
196
- @torch.no_grad()
197
- def convert_linears_to_te_fp8(module: nn.Module, allow_pred=_default_te_allow, _prefix=""):
198
- for name, child in list(module.named_children()):
199
- full = f"{_prefix}.{name}" if _prefix else name
200
- convert_linears_to_te_fp8(child, allow_pred, full)
201
-
202
- if isinstance(child, nn.Linear):
203
- if allow_pred is not None and not allow_pred(full, child):
204
- continue
205
-
206
- te_lin = TELinear(
207
- in_features=child.in_features,
208
- out_features=child.out_features,
209
- bias=(child.bias is not None),
210
- params_dtype=torch.bfloat16,
211
- ).to(child.weight.device)
212
-
213
- te_lin.weight.copy_(child.weight.to(te_lin.weight.dtype))
214
- if child.bias is not None:
215
- te_lin.bias.copy_(child.bias.to(te_lin.bias.dtype))
216
-
217
- setattr(module, name, te_lin)
218
- return module
219
-
220
- class Generator():
221
- def __init__(self, config: DictConfig):
222
- self.config = config.copy()
223
- OmegaConf.set_readonly(self.config, True)
224
- self.logger = get_logger(self.__class__.__name__)
225
-
226
- # init_torch(cudnn_benchmark=False)
227
- self.configure_models()
228
-
229
- def entrypoint(self):
230
-
231
- self.inference_loop()
232
-
233
- def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
234
- device_mesh = None
235
- fsdp_strategy = ShardingStrategy[sharding_strategy]
236
- if (
237
- fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD]
238
- and device_mesh_config is not None
239
- ):
240
- device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
241
- return device_mesh, fsdp_strategy
242
-
243
-
244
- def configure_models(self):
245
- self.configure_dit_model(device="cuda")
246
-
247
- self.dit.eval().to("cuda")
248
- convert_linears_to_te_fp8(self.dit)
249
-
250
- self.dit = torch.compile(self.dit, )
251
-
252
-
253
- self.configure_vae_model(device="cuda")
254
- if self.config.generation.get('extract_audio_feat', False):
255
- self.configure_wav2vec(device="cpu")
256
- self.configure_text_model(device="cuda")
257
-
258
- # # Initialize fsdp.
259
- # self.configure_dit_fsdp_model()
260
- # self.configure_text_fsdp_model()
261
-
262
- # quantize_(self.text_encoder, Int8WeightOnlyConfig())
263
- # quantize_(self.dit, Float8DynamicActivationFloat8WeightConfig())
264
-
265
-
266
- def configure_dit_model(self, device=get_device()):
267
-
268
- init_unified_parallel(self.config.dit.sp_size)
269
- self.sp_size = get_unified_parallel_world_size()
270
-
271
- # Create DiT model on meta, then mark dtype as bfloat16 (no real allocation yet).
272
- init_device = "meta"
273
- with torch.device(init_device):
274
- self.dit = create_object(self.config.dit.model)
275
- self.dit = self.dit.to(dtype=torch.bfloat16) # or: self.dit.bfloat16()
276
- self.logger.info(f"Load DiT model on {init_device}.")
277
- self.dit.eval().requires_grad_(False)
278
-
279
- # Load dit checkpoint.
280
- path = self.config.dit.checkpoint_dir
281
-
282
- def _cast_state_dict_to_bf16(state):
283
- for k, v in state.items():
284
- if isinstance(v, torch.Tensor) and v.is_floating_point():
285
- state[k] = v.to(dtype=torch.bfloat16, copy=False)
286
- return state
287
-
288
- if path.endswith(".pth"):
289
- # Load to CPU first; we’ll move the model later.
290
- state = torch.load(path, map_location="cpu", mmap=True)
291
- state = _cast_state_dict_to_bf16(state)
292
- missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
293
- self.logger.info(
294
- f"dit loaded from {path}. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}"
295
- )
296
- else:
297
- from safetensors.torch import load_file
298
- import json
299
- def load_custom_sharded_weights(model_dir, base_name):
300
- index_path = f"{model_dir}/{base_name}.safetensors.index.json"
301
- with open(index_path, "r") as f:
302
- index = json.load(f)
303
- weight_map = index["weight_map"]
304
- shard_files = set(weight_map.values())
305
- state_dict = {}
306
- for shard_file in shard_files:
307
- shard_path = f"{model_dir}/{shard_file}"
308
- # Load on CPU, then cast to bf16; we’ll move the whole module later.
309
- shard_state = load_file(shard_path, device="cpu")
310
- shard_state = {k: (v.to(dtype=torch.bfloat16, copy=False) if v.is_floating_point() else v)
311
- for k, v in shard_state.items()}
312
- state_dict.update(shard_state)
313
- return state_dict
314
-
315
- state = load_custom_sharded_weights(path, 'humo')
316
- self.dit.load_state_dict(state, strict=False, assign=True)
317
-
318
- self.dit = meta_non_persistent_buffer_init_fn(self.dit)
319
-
320
- target_device = get_device() if device in [get_device(), "cuda"] else device
321
- self.dit.to(target_device) # dtype already bf16
322
-
323
- # Print model size.
324
- params = sum(p.numel() for p in self.dit.parameters())
325
- self.logger.info(
326
- f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
327
- )
328
-
329
-
330
- def configure_vae_model(self, device=get_device()):
331
- self.vae_stride = self.config.vae.vae_stride
332
- self.vae = WanVAE(
333
- vae_pth=self.config.vae.checkpoint,
334
- device=device)
335
-
336
- if self.config.generation.height == 480:
337
- self.zero_vae = torch.load(self.config.dit.zero_vae_path)
338
- elif self.config.generation.height == 720:
339
- self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path)
340
- else:
341
- raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.")
342
-
343
- def configure_wav2vec(self, device=get_device()):
344
- audio_separator_model_file = self.config.audio.vocal_separator
345
- wav2vec_model_path = self.config.audio.wav2vec_model
346
-
347
- self.audio_processor = AudioProcessor(
348
- 16000,
349
- 25,
350
- wav2vec_model_path,
351
- "all",
352
- audio_separator_model_file,
353
- None, # not seperate
354
- os.path.join(self.config.generation.output.dir, "vocals"),
355
- device=device,
356
- )
357
-
358
- def configure_text_model(self, device=get_device()):
359
- self.text_encoder = T5EncoderModel(
360
- text_len=self.config.dit.model.text_len,
361
- dtype=torch.bfloat16,
362
- device=device,
363
- checkpoint_path=self.config.text.t5_checkpoint,
364
- tokenizer_path=self.config.text.t5_tokenizer,
365
- )
366
-
367
-
368
- def configure_dit_fsdp_model(self):
369
- from humo.models.wan_modules.model_humo import WanAttentionBlock
370
-
371
- dit_blocks = (WanAttentionBlock,)
372
-
373
- # Init model_shard_cpu_group for saving checkpoint with sharded state_dict.
374
- init_model_shard_cpu_group(
375
- self.config.dit.fsdp.sharding_strategy,
376
- self.config.dit.fsdp.get("device_mesh", None),
377
- )
378
-
379
- # Assert that dit has wrappable blocks.
380
- assert any(isinstance(m, dit_blocks) for m in self.dit.modules())
381
-
382
- # Define wrap policy on all dit blocks.
383
- def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
384
- return recurse or isinstance(module, dit_blocks)
385
-
386
- # Configure FSDP settings.
387
- device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
388
- self.config.dit.fsdp.sharding_strategy,
389
- self.config.dit.fsdp.get("device_mesh", None),
390
- )
391
- settings = dict(
392
- auto_wrap_policy=custom_auto_wrap_policy,
393
- sharding_strategy=fsdp_strategy,
394
- backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
395
- device_id=get_local_rank(),
396
- use_orig_params=False,
397
- sync_module_states=True,
398
- forward_prefetch=True,
399
- limit_all_gathers=False, # False for ZERO2.
400
- mixed_precision=MixedPrecision(
401
- param_dtype=torch.bfloat16,
402
- reduce_dtype=torch.float32,
403
- buffer_dtype=torch.float32,
404
- ),
405
- device_mesh=device_mesh,
406
- param_init_fn=meta_param_init_fn,
407
- )
408
-
409
- # Apply FSDP.
410
- self.dit = FullyShardedDataParallel(self.dit, **settings)
411
- # self.dit.to(get_device())
412
-
413
-
414
- def configure_text_fsdp_model(self):
415
- # If FSDP is not enabled, put text_encoder to GPU and return.
416
- if not self.config.text.fsdp.enabled:
417
- self.text_encoder.to(get_device())
418
- return
419
-
420
- # from transformers.models.t5.modeling_t5 import T5Block
421
- from humo.models.wan_modules.t5 import T5SelfAttention
422
-
423
- text_blocks = (torch.nn.Embedding, T5SelfAttention)
424
- # text_blocks_names = ("QWenBlock", "QWenModel") # QWen cannot be imported. Use str.
425
-
426
- def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
427
- return (
428
- recurse
429
- or isinstance(module, text_blocks)
430
- )
431
-
432
- # Apply FSDP.
433
- text_encoder_dtype = getattr(torch, self.config.text.dtype)
434
- device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
435
- self.config.text.fsdp.sharding_strategy,
436
- self.config.text.fsdp.get("device_mesh", None),
437
- )
438
- self.text_encoder = FullyShardedDataParallel(
439
- module=self.text_encoder,
440
- auto_wrap_policy=custom_auto_wrap_policy,
441
- sharding_strategy=fsdp_strategy,
442
- backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
443
- device_id=get_local_rank(),
444
- use_orig_params=False,
445
- sync_module_states=False,
446
- forward_prefetch=True,
447
- limit_all_gathers=True,
448
- mixed_precision=MixedPrecision(
449
- param_dtype=text_encoder_dtype,
450
- reduce_dtype=text_encoder_dtype,
451
- buffer_dtype=text_encoder_dtype,
452
- ),
453
- device_mesh=device_mesh,
454
- )
455
- self.text_encoder.to(get_device()).requires_grad_(False)
456
-
457
-
458
- def load_image_latent_ref_id(self, path: str, size, device):
459
- # Load size.
460
- h, w = size[1], size[0]
461
-
462
- # Load image.
463
- if len(path) > 1 and not isinstance(path, str):
464
- ref_vae_latents = []
465
- for image_path in path:
466
- with Image.open(image_path) as img:
467
- img = img.convert("RGB")
468
-
469
- # Calculate the required size to keep aspect ratio and fill the rest with padding.
470
- img_ratio = img.width / img.height
471
- target_ratio = w / h
472
-
473
- if img_ratio > target_ratio: # Image is wider than target
474
- new_width = w
475
- new_height = int(new_width / img_ratio)
476
- else: # Image is taller than target
477
- new_height = h
478
- new_width = int(new_height * img_ratio)
479
-
480
- # img = img.resize((new_width, new_height), Image.ANTIALIAS)
481
- img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
482
-
483
- # Create a new image with the target size and place the resized image in the center
484
- delta_w = w - img.size[0]
485
- delta_h = h - img.size[1]
486
- padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
487
- new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
488
-
489
- # Transform to tensor and normalize.
490
- transform = Compose(
491
- [
492
- ToTensor(),
493
- Normalize(0.5, 0.5),
494
- ]
495
- )
496
- new_img = transform(new_img)
497
- # img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0]
498
- img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
499
- ref_vae_latents.append(img_vae_latent[0])
500
-
501
- return [torch.cat(ref_vae_latents, dim=1)]
502
- else:
503
- if not isinstance(path, str):
504
- path = path[0]
505
- with Image.open(path) as img:
506
- img = img.convert("RGB")
507
-
508
- # Calculate the required size to keep aspect ratio and fill the rest with padding.
509
- img_ratio = img.width / img.height
510
- target_ratio = w / h
511
-
512
- if img_ratio > target_ratio: # Image is wider than target
513
- new_width = w
514
- new_height = int(new_width / img_ratio)
515
- else: # Image is taller than target
516
- new_height = h
517
- new_width = int(new_height * img_ratio)
518
-
519
- # img = img.resize((new_width, new_height), Image.ANTIALIAS)
520
- img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
521
-
522
- # Create a new image with the target size and place the resized image in the center
523
- delta_w = w - img.size[0]
524
- delta_h = h - img.size[1]
525
- padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
526
- new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
527
-
528
- # Transform to tensor and normalize.
529
- transform = Compose(
530
- [
531
- ToTensor(),
532
- Normalize(0.5, 0.5),
533
- ]
534
- )
535
- new_img = transform(new_img)
536
- img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
537
-
538
- # Vae encode.
539
- return img_vae_latent
540
-
541
- def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
542
- zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
543
- zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
544
- iter_ = 1 + (frame_num - 1) // 4
545
- audio_emb_wind = []
546
- for lt_i in range(iter_):
547
- if lt_i == 0:
548
- st = frame0_idx + lt_i - 2
549
- ed = frame0_idx + lt_i + 3
550
- wind_feat = torch.stack([
551
- audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
552
- for i in range(st, ed)
553
- ], dim=0)
554
- wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
555
- else:
556
- st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
557
- ed = frame0_idx + 1 + 4 * lt_i + audio_shift
558
- wind_feat = torch.stack([
559
- audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
560
- for i in range(st, ed)
561
- ], dim=0)
562
- audio_emb_wind.append(wind_feat)
563
- audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
564
-
565
- return audio_emb_wind, ed - audio_shift
566
-
567
- def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
568
- if wav_enc_type == "wav2vec":
569
- feat_merge = audio_emb
570
- elif wav_enc_type == "whisper":
571
- feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
572
- feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
573
- feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
574
- feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
575
- feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
576
- feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0]
577
- else:
578
- raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
579
-
580
- return feat_merge
581
-
582
- def parse_output(self, output):
583
- latent = output[0]
584
- mask = None
585
- return latent, mask
586
-
587
- def forward_tia(self, latents, timestep, t, step_change, arg_tia, arg_ti, arg_i, arg_null):
588
- pos_tia, _ = self.parse_output(self.dit(
589
- latents, t=timestep, **arg_tia
590
- ))
591
- torch.cuda.empty_cache()
592
-
593
- pos_ti, _ = self.parse_output(self.dit(
594
- latents, t=timestep, **arg_ti
595
- ))
596
- torch.cuda.empty_cache()
597
-
598
- if t > step_change:
599
- neg, _ = self.parse_output(self.dit(
600
- latents, t=timestep, **arg_i
601
- )) # img included in null, same with official Wan-2.1
602
- torch.cuda.empty_cache()
603
-
604
- noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
605
- self.config.generation.scale_t * (pos_ti - neg) + \
606
- neg
607
- else:
608
- neg, _ = self.parse_output(self.dit(
609
- latents, t=timestep, **arg_null
610
- )) # img not included in null
611
- torch.cuda.empty_cache()
612
-
613
- noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
614
- (self.config.generation.scale_t - 2.0) * (pos_ti - neg) + \
615
- neg
616
- return noise_pred
617
-
618
- def forward_ti(self, latents, timestep, t, step_change, arg_ti, arg_t, arg_i, arg_null):
619
- # Positive with text+image (no audio)
620
- pos_ti, _ = self.parse_output(self.dit(
621
- latents, t=timestep, **arg_ti
622
- ))
623
- torch.cuda.empty_cache()
624
-
625
- # Positive with text only (no image, no audio)
626
- pos_t, _ = self.parse_output(self.dit(
627
- latents, t=timestep, **arg_t
628
- ))
629
- torch.cuda.empty_cache()
630
-
631
- # Negative branch: before step_change, don't include image in null; after, include image (like Wan-2.1)
632
- if t > step_change:
633
- neg, _ = self.parse_output(self.dit(
634
- latents, t=timestep, **arg_i
635
- )) # img included in null
636
- else:
637
- neg, _ = self.parse_output(self.dit(
638
- latents, t=timestep, **arg_null
639
- )) # img NOT included in null
640
- torch.cuda.empty_cache()
641
-
642
- # Guidance blend: replace "scale_a" below with "scale_i" if you add a separate image scale in config
643
- noise_pred = self.config.generation.scale_a * (pos_ti - pos_t) + \
644
- self.config.generation.scale_t * (pos_t - neg) + \
645
- neg
646
- return noise_pred
647
-
648
- def forward_ta(self, latents, timestep, arg_ta, arg_t, arg_null):
649
- pos_ta, _ = self.parse_output(self.dit(
650
- latents, t=timestep, **arg_ta
651
- ))
652
- torch.cuda.empty_cache()
653
-
654
- pos_t, _ = self.parse_output(self.dit(
655
- latents, t=timestep, **arg_t
656
- ))
657
- torch.cuda.empty_cache()
658
-
659
- neg, _ = self.parse_output(self.dit(
660
- latents, t=timestep, **arg_null
661
- ))
662
- torch.cuda.empty_cache()
663
-
664
- noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \
665
- self.config.generation.scale_t * (pos_t - neg) + \
666
- neg
667
- return noise_pred
668
-
669
- @torch.no_grad()
670
- def inference(self,
671
- input_prompt,
672
- img_path,
673
- audio_path,
674
- size=(1280, 720),
675
- frame_num=81,
676
- shift=5.0,
677
- sample_solver='unipc',
678
- inference_mode='TIA',
679
- sampling_steps=50,
680
- n_prompt="",
681
- seed=-1,
682
- tea_cache_l1_thresh = 0.0,
683
- device = get_device(),
684
- ):
685
-
686
- print("inference started")
687
-
688
- # self.vae.model.to(device=device)
689
- if img_path is not None:
690
- latents_ref = self.load_image_latent_ref_id(img_path, size, device)
691
- else:
692
- latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
693
-
694
- # self.vae.model.to(device="cpu")
695
-
696
- print("vae finished")
697
-
698
- latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
699
-
700
- # audio
701
- if audio_path is not None:
702
- if self.config.generation.extract_audio_feat:
703
- self.audio_processor.whisper.to(device=device)
704
- audio_emb, audio_length = self.audio_processor.preprocess(audio_path)
705
- self.audio_processor.whisper.to(device='cpu')
706
- else:
707
- audio_emb_path = audio_path.replace(".wav", ".pt")
708
- audio_emb = torch.load(audio_emb_path).to(device=device)
709
- audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper")
710
- self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path)
711
- else:
712
- audio_emb = torch.zeros(frame_num, 5, 1280).to(device)
713
-
714
- frame_num = frame_num if frame_num != -1 else audio_length
715
- frame_num = 4 * ((frame_num - 1) // 4) + 1
716
- audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0)
717
- zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device)
718
- audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
719
- audio_emb = [audio_emb.to(device)]
720
- audio_emb_neg = [torch.zeros_like(audio_emb[0])]
721
-
722
- # preprocess
723
- self.patch_size = self.config.dit.model.patch_size
724
- F = frame_num
725
- target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1],
726
- size[1] // self.vae_stride[1],
727
- size[0] // self.vae_stride[2])
728
-
729
- seq_len = math.ceil((target_shape[2] * target_shape[3]) /
730
- (self.patch_size[1] * self.patch_size[2]) *
731
- target_shape[1] / self.sp_size) * self.sp_size
732
-
733
- if n_prompt == "":
734
- n_prompt = self.config.generation.sample_neg_prompt
735
- seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
736
- seed_g = torch.Generator(device=device)
737
- seed_g.manual_seed(seed)
738
-
739
- # self.text_encoder.model.to(device)
740
- context = self.text_encoder([input_prompt], device)
741
- context_null = self.text_encoder([n_prompt], device)
742
- # self.text_encoder.model.cpu()
743
-
744
- print("text encoder finished")
745
-
746
- noise = [
747
- torch.randn(
748
- target_shape[0],
749
- target_shape[1], # - latents_ref[0].shape[1],
750
- target_shape[2],
751
- target_shape[3],
752
- dtype=torch.float32,
753
- device=device,
754
- generator=seed_g)
755
- ]
756
-
757
- @contextmanager
758
- def noop_no_sync():
759
- yield
760
-
761
- no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
762
- step_change = self.config.generation.step_change # 980
763
-
764
- # evaluation mode
765
- with make_fp8_ctx(True), torch.autocast('cuda', dtype=torch.bfloat16), torch.no_grad(), no_sync():
766
-
767
- if sample_solver == 'unipc':
768
- sample_scheduler = FlowUniPCMultistepScheduler(
769
- num_train_timesteps=1000,
770
- shift=1,
771
- use_dynamic_shifting=False)
772
- sample_scheduler.set_timesteps(
773
- sampling_steps, device=device, shift=shift)
774
- timesteps = sample_scheduler.timesteps
775
-
776
- # sample videos
777
- latents = noise
778
-
779
- msk = torch.ones(4, target_shape[1], target_shape[2], target_shape[3], device=get_device())
780
- msk[:,:-latents_ref[0].shape[1]] = 0
781
-
782
- zero_vae = self.zero_vae[:, :(target_shape[1]-latents_ref[0].shape[1])].to(
783
- device=get_device(), dtype=latents_ref[0].dtype)
784
- y_c = torch.cat([
785
- zero_vae,
786
- latents_ref[0]
787
- ], dim=1)
788
- y_c = [torch.concat([msk, y_c])]
789
-
790
- y_null = self.zero_vae[:, :target_shape[1]].to(
791
- device=get_device(), dtype=latents_ref[0].dtype)
792
- y_null = [torch.concat([msk, y_null])]
793
-
794
- tea_cache_l1_thresh = tea_cache_l1_thresh
795
- tea_cache_model_id = "Wan2.1-T2V-14B"
796
-
797
- arg_null = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
798
- arg_t = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
799
- arg_i = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
800
- arg_ti = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
801
- arg_ta = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
802
- arg_tia = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
803
-
804
- torch.cuda.empty_cache()
805
- # self.dit.to(device=get_device())
806
- for _, t in enumerate(tqdm(timesteps)):
807
- timestep = [t]
808
- timestep = torch.stack(timestep)
809
-
810
- if inference_mode == "TIA":
811
- noise_pred = self.forward_tia(latents, timestep, t, step_change,
812
- arg_tia, arg_ti, arg_i, arg_null)
813
- elif inference_mode == "TA":
814
- noise_pred = self.forward_ta(latents, timestep, arg_ta, arg_t, arg_null)
815
- elif inference_mode == "TI":
816
- noise_pred = self.forward_ti(latents, timestep, t, step_change,
817
- arg_ti, arg_t, arg_i, arg_null)
818
- else:
819
- raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}")
820
-
821
- temp_x0 = sample_scheduler.step(
822
- noise_pred.unsqueeze(0),
823
- t,
824
- latents[0].unsqueeze(0),
825
- return_dict=False,
826
- generator=seed_g)[0]
827
- latents = [temp_x0.squeeze(0)]
828
-
829
- del timestep
830
- torch.cuda.empty_cache()
831
-
832
- x0 = latents
833
- x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
834
-
835
- # if offload_model:
836
- # self.dit.cpu()
837
-
838
- print("dit finished")
839
-
840
- torch.cuda.empty_cache()
841
- # if get_local_rank() == 0:
842
- # self.vae.model.to(device=device)
843
- videos = self.vae.decode(x0)
844
- # self.vae.model.to(device="cpu")
845
-
846
- print("vae 2 finished")
847
-
848
- del noise, latents, noise_pred
849
- del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
850
- del x0, temp_x0
851
- del sample_scheduler
852
- torch.cuda.empty_cache()
853
- gc.collect()
854
- torch.cuda.synchronize()
855
- if dist.is_initialized():
856
- dist.barrier()
857
-
858
- return videos[0] # if get_local_rank() == 0 else None
859
-
860
-
861
- def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, inference_mode = "TIA", width = 832, height = 480, steps=50, frames = 97, tea_cache_l1_thresh = 0.0, seed = 0):
862
-
863
- video = self.inference(
864
- prompt,
865
- ref_img_path,
866
- audio_path,
867
- size=SIZE_CONFIGS[f"{width}*{height}"],
868
- frame_num=frames,
869
- shift=self.config.diffusion.timesteps.sampling.shift,
870
- sample_solver='unipc',
871
- sampling_steps=steps,
872
- inference_mode = inference_mode,
873
- tea_cache_l1_thresh = tea_cache_l1_thresh,
874
- seed=seed
875
- )
876
-
877
- torch.cuda.empty_cache()
878
- gc.collect()
879
-
880
- # Save samples.
881
- if get_sequence_parallel_rank() == 0:
882
- pathname = self.save_sample(
883
- sample=video,
884
- audio_path=audio_path,
885
- output_dir = output_dir,
886
- filename=filename,
887
- )
888
- self.logger.info(f"Finished {filename}, saved to {pathname}.")
889
-
890
- del video, prompt
891
- torch.cuda.empty_cache()
892
- gc.collect()
893
-
894
-
895
- def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
896
- gen_config = self.config.generation
897
- # Prepare file path.
898
- extension = ".mp4" if sample.ndim == 4 else ".png"
899
- filename += extension
900
- pathname = os.path.join(output_dir, filename)
901
- # Convert sample.
902
- sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8)
903
- sample = rearrange(sample, "c t h w -> t h w c")
904
- # Save file.
905
- if sample.ndim == 4:
906
- if audio_path is not None:
907
- tensor_to_video(
908
- sample.numpy(),
909
- pathname,
910
- audio_path,
911
- fps=gen_config.fps)
912
- else:
913
- mediapy.write_video(
914
- path=pathname,
915
- images=sample.numpy(),
916
- fps=gen_config.fps,
917
- )
918
- else:
919
- raise ValueError
920
- return pathname
921
-
922
-
923
- def prepare_positive_prompts(self):
924
- pos_prompts = self.config.generation.positive_prompt
925
- if pos_prompts.endswith(".json"):
926
- pos_prompts = prepare_json_dataset(pos_prompts)
927
- else:
928
- raise NotImplementedError
929
- assert isinstance(pos_prompts, ListConfig)
930
-
931
- return pos_prompts
932
-
933
- class TeaCache:
934
- def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
935
- self.num_inference_steps = num_inference_steps
936
- self.step = 0
937
- self.accumulated_rel_l1_distance = 0
938
- self.previous_modulated_input = None
939
- self.rel_l1_thresh = rel_l1_thresh
940
- self.previous_residual = None
941
- self.previous_hidden_states = None
942
-
943
- self.coefficients_dict = {
944
- "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
945
- "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
946
- "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
947
- "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
948
- }
949
- if model_id not in self.coefficients_dict:
950
- supported_model_ids = ", ".join([i for i in self.coefficients_dict])
951
- raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
952
- self.coefficients = self.coefficients_dict[model_id]
953
-
954
- def check(self, dit, x, t_mod):
955
- modulated_inp = t_mod.clone()
956
- if self.step == 0 or self.step == self.num_inference_steps - 1:
957
- should_calc = True
958
- self.accumulated_rel_l1_distance = 0
959
- else:
960
- coefficients = self.coefficients
961
- rescale_func = np.poly1d(coefficients)
962
- self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
963
- if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
964
- should_calc = False
965
- else:
966
- should_calc = True
967
- self.accumulated_rel_l1_distance = 0
968
- self.previous_modulated_input = modulated_inp
969
- self.step += 1
970
- if self.step == self.num_inference_steps:
971
- self.step = 0
972
- if should_calc:
973
- self.previous_hidden_states = x.clone()
974
- return not should_calc
975
-
976
- def store(self, hidden_states):
977
- if self.previous_hidden_states is None:
978
- return
979
- self.previous_residual = hidden_states - self.previous_hidden_states
980
- self.previous_hidden_states = None
981
-
982
- def update(self, hidden_states):
983
- hidden_states = hidden_states + self.previous_residual
984
  return hidden_states
 
1
+ # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ # http://www.apache.org/licenses/LICENSE-2.0
6
+ # Unless required by applicable law or agreed to in writing, software
7
+ # distributed under the License is distributed on an "AS IS" BASIS,
8
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
9
+ # See the License for the specific language governing permissions and
10
+ # limitations under the License.
11
+
12
+ # Inference codes adapted from [SeedVR]
13
+ # https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py
14
+
15
+ import math
16
+ import os
17
+ import gc
18
+ import random
19
+ import sys
20
+ import mediapy
21
+ import numpy as np
22
+ import torch
23
+ import torch.distributed as dist
24
+ from omegaconf import DictConfig, ListConfig, OmegaConf
25
+ from einops import rearrange
26
+ from omegaconf import OmegaConf
27
+ from PIL import Image, ImageOps
28
+ from torchvision.transforms import ToTensor
29
+ from tqdm import tqdm
30
+ from torch.distributed.device_mesh import init_device_mesh
31
+ from torch.distributed.fsdp import (
32
+ BackwardPrefetch,
33
+ FullyShardedDataParallel,
34
+ MixedPrecision,
35
+ ShardingStrategy,
36
+ )
37
+ from common.distributed import (
38
+ get_device,
39
+ get_global_rank,
40
+ get_local_rank,
41
+ meta_param_init_fn,
42
+ meta_non_persistent_buffer_init_fn,
43
+ init_torch,
44
+ )
45
+ from common.distributed.advanced import (
46
+ init_unified_parallel,
47
+ get_unified_parallel_world_size,
48
+ get_sequence_parallel_rank,
49
+ init_model_shard_cpu_group,
50
+ )
51
+ from common.logger import get_logger
52
+ from common.config import create_object
53
+ from common.distributed import get_device, get_global_rank
54
+ from torchvision.transforms import Compose, Normalize, ToTensor
55
+ from humo.models.wan_modules.t5 import T5EncoderModel
56
+ from humo.models.wan_modules.vae import WanVAE
57
+ from humo.models.utils.utils import tensor_to_video, prepare_json_dataset
58
+ from contextlib import contextmanager
59
+ import torch.cuda.amp as amp
60
+ from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler
61
+ from humo.utils.audio_processor_whisper import AudioProcessor
62
+ from humo.utils.wav2vec import linear_interpolation_fps
63
+ from torchao.quantization import quantize_
64
+
65
+ import torch._dynamo as dynamo
66
+ dynamo.config.capture_scalar_outputs = True
67
+ torch.set_float32_matmul_precision("high")
68
+
69
+ import torch
70
+ import torch.nn as nn
71
+ import transformer_engine.pytorch as te
72
+
73
+ image_transform = Compose([
74
+ ToTensor(),
75
+ Normalize(mean=0.5, std=0.5),
76
+ ])
77
+
78
+ SIZE_CONFIGS = {
79
+ '720*1280': (720, 1280),
80
+ '1280*720': (1280, 720),
81
+ '480*832': (480, 832),
82
+ '832*480': (832, 480),
83
+ '1024*1024': (1024, 1024),
84
+ }
85
+
86
+ def clever_format(nums, format="%.2f"):
87
+ from typing import Iterable
88
+ if not isinstance(nums, Iterable):
89
+ nums = [nums]
90
+ clever_nums = []
91
+ for num in nums:
92
+ if num > 1e12:
93
+ clever_nums.append(format % (num / 1e12) + "T")
94
+ elif num > 1e9:
95
+ clever_nums.append(format % (num / 1e9) + "G")
96
+ elif num > 1e6:
97
+ clever_nums.append(format % (num / 1e6) + "M")
98
+ elif num > 1e3:
99
+ clever_nums.append(format % (num / 1e3) + "K")
100
+ else:
101
+ clever_nums.append(format % num + "B")
102
+
103
+ clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,)
104
+
105
+ return clever_nums
106
+
107
+
108
+
109
+ # --- put near your imports ---
110
+ import torch
111
+ import torch.nn as nn
112
+ import contextlib
113
+ import transformer_engine.pytorch as te
114
+
115
+ # FP8 autocast compatibility for different TE versions
116
+ try:
117
+ # Preferred modern API
118
+ from transformer_engine.pytorch import fp8_autocast
119
+ try:
120
+ # Newer TE: use recipe-based API
121
+ from transformer_engine.common.recipe import DelayedScaling, Format
122
+ def make_fp8_ctx(enabled: bool = True):
123
+ if not enabled:
124
+ return contextlib.nullcontext()
125
+ fp8_recipe = DelayedScaling(fp8_format=Format.E4M3) # E4M3 format
126
+ return fp8_autocast(enabled=True, fp8_recipe=fp8_recipe)
127
+ except Exception:
128
+ # Very old variant that might still accept fp8_format directly
129
+ def make_fp8_ctx(enabled: bool = True):
130
+ # If TE doesn't have FP8Format, just no-op
131
+ if not hasattr(te, "FP8Format"):
132
+ return contextlib.nullcontext()
133
+ return te.fp8_autocast(enabled=enabled, fp8_format=te.FP8Format.E4M3)
134
+ except Exception:
135
+ # TE not present or totally incompatible — no-op
136
+ def make_fp8_ctx(enabled: bool = True):
137
+ return contextlib.nullcontext()
138
+
139
+
140
+ # TE sometimes exposes Linear at different paths; this normalizes it.
141
+ try:
142
+ TELinear = te.Linear
143
+ except AttributeError: # very old layouts
144
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
145
+
146
+ # --- near imports ---
147
+ import torch
148
+ import torch.nn as nn
149
+ import transformer_engine.pytorch as te
150
+
151
+ try:
152
+ TELinear = te.Linear
153
+ except AttributeError:
154
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
155
+
156
+ import torch
157
+ import torch.nn as nn
158
+ import transformer_engine.pytorch as te
159
+
160
+ try:
161
+ TELinear = te.Linear
162
+ except AttributeError:
163
+ from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore
164
+
165
+ def _default_te_allow(fullname: str, lin: nn.Linear) -> bool:
166
+ """
167
+ Allow TE only where it's shape-safe & beneficial.
168
+ Skip small/special layers (time/timestep/pos embeds, heads).
169
+ Enforce multiples of 16 for in/out features (FP8 kernel friendly).
170
+ Also skip very small projections likely to see M=1.
171
+ """
172
+ blocked_keywords = (
173
+ "time_embedding", "timestep", "time_embed",
174
+ "time_projection", "pos_embedding", "pos_embed",
175
+ "to_logits", "logits", "final_proj", "proj_out", "output_projection",
176
+ )
177
+ if any(k in fullname for k in blocked_keywords):
178
+ return False
179
+
180
+ # TE FP8 kernels like K, N divisible by 16
181
+ if lin.in_features % 16 != 0 or lin.out_features % 16 != 0:
182
+ return False
183
+
184
+ # Heuristic: avoid tiny layers; keeps attention/MLP, skips small MLPs
185
+ if lin.in_features < 512 or lin.out_features < 512:
186
+ return False
187
+
188
+ # Whitelist: only convert inside transformer blocks if you know their prefix
189
+ # This further reduces risk of catching special heads elsewhere.
190
+ allowed_context = ("blocks", "layers", "transformer", "attn", "mlp", "ffn")
191
+ if not any(tok in fullname for tok in allowed_context):
192
+ return False
193
+
194
+ return True
195
+
196
+ @torch.no_grad()
197
+ def convert_linears_to_te_fp8(module: nn.Module, allow_pred=_default_te_allow, _prefix=""):
198
+ for name, child in list(module.named_children()):
199
+ full = f"{_prefix}.{name}" if _prefix else name
200
+ convert_linears_to_te_fp8(child, allow_pred, full)
201
+
202
+ if isinstance(child, nn.Linear):
203
+ if allow_pred is not None and not allow_pred(full, child):
204
+ continue
205
+
206
+ te_lin = TELinear(
207
+ in_features=child.in_features,
208
+ out_features=child.out_features,
209
+ bias=(child.bias is not None),
210
+ params_dtype=torch.bfloat16,
211
+ ).to(child.weight.device)
212
+
213
+ te_lin.weight.copy_(child.weight.to(te_lin.weight.dtype))
214
+ if child.bias is not None:
215
+ te_lin.bias.copy_(child.bias.to(te_lin.bias.dtype))
216
+
217
+ setattr(module, name, te_lin)
218
+ return module
219
+
220
+ class Generator():
221
+ def __init__(self, config: DictConfig):
222
+ self.config = config.copy()
223
+ OmegaConf.set_readonly(self.config, True)
224
+ self.logger = get_logger(self.__class__.__name__)
225
+
226
+ # init_torch(cudnn_benchmark=False)
227
+ self.configure_models()
228
+
229
+ def entrypoint(self):
230
+
231
+ self.inference_loop()
232
+
233
+ def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config):
234
+ device_mesh = None
235
+ fsdp_strategy = ShardingStrategy[sharding_strategy]
236
+ if (
237
+ fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD]
238
+ and device_mesh_config is not None
239
+ ):
240
+ device_mesh = init_device_mesh("cuda", tuple(device_mesh_config))
241
+ return device_mesh, fsdp_strategy
242
+
243
+
244
+ def configure_models(self):
245
+ self.configure_dit_model(device="cuda")
246
+
247
+ self.dit.eval().to("cuda")
248
+ convert_linears_to_te_fp8(self.dit)
249
+
250
+ self.dit = torch.compile(self.dit, )
251
+
252
+
253
+ self.configure_vae_model(device="cuda")
254
+ if self.config.generation.get('extract_audio_feat', False):
255
+ self.configure_wav2vec(device="cpu")
256
+ self.configure_text_model(device="cuda")
257
+
258
+ # # Initialize fsdp.
259
+ # self.configure_dit_fsdp_model()
260
+ # self.configure_text_fsdp_model()
261
+
262
+ # quantize_(self.text_encoder, Int8WeightOnlyConfig())
263
+ # quantize_(self.dit, Float8DynamicActivationFloat8WeightConfig())
264
+
265
+
266
+ def configure_dit_model(self, device=get_device()):
267
+
268
+ init_unified_parallel(self.config.dit.sp_size)
269
+ self.sp_size = get_unified_parallel_world_size()
270
+
271
+ # Create DiT model on meta, then mark dtype as bfloat16 (no real allocation yet).
272
+ init_device = "meta"
273
+ with torch.device(init_device):
274
+ self.dit = create_object(self.config.dit.model)
275
+ self.dit = self.dit.to(dtype=torch.bfloat16) # or: self.dit.bfloat16()
276
+ self.logger.info(f"Load DiT model on {init_device}.")
277
+ self.dit.eval().requires_grad_(False)
278
+
279
+ # Load dit checkpoint.
280
+ path = self.config.dit.checkpoint_dir
281
+
282
+ def _cast_state_dict_to_bf16(state):
283
+ for k, v in state.items():
284
+ if isinstance(v, torch.Tensor) and v.is_floating_point():
285
+ state[k] = v.to(dtype=torch.bfloat16, copy=False)
286
+ return state
287
+
288
+ if path.endswith(".pth"):
289
+ # Load to CPU first; we’ll move the model later.
290
+ state = torch.load(path, map_location="cpu", mmap=True)
291
+ state = _cast_state_dict_to_bf16(state)
292
+ missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True)
293
+ self.logger.info(
294
+ f"dit loaded from {path}. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}"
295
+ )
296
+ else:
297
+ from safetensors.torch import load_file
298
+ import json
299
+ def load_custom_sharded_weights(model_dir, base_name):
300
+ index_path = f"{model_dir}/{base_name}.safetensors.index.json"
301
+ with open(index_path, "r") as f:
302
+ index = json.load(f)
303
+ weight_map = index["weight_map"]
304
+ shard_files = set(weight_map.values())
305
+ state_dict = {}
306
+ for shard_file in shard_files:
307
+ shard_path = f"{model_dir}/{shard_file}"
308
+ # Load on CPU, then cast to bf16; we’ll move the whole module later.
309
+ shard_state = load_file(shard_path, device="cpu")
310
+ shard_state = {k: (v.to(dtype=torch.bfloat16, copy=False) if v.is_floating_point() else v)
311
+ for k, v in shard_state.items()}
312
+ state_dict.update(shard_state)
313
+ return state_dict
314
+
315
+ state = load_custom_sharded_weights(path, 'humo')
316
+ self.dit.load_state_dict(state, strict=False, assign=True)
317
+
318
+ self.dit = meta_non_persistent_buffer_init_fn(self.dit)
319
+
320
+ target_device = get_device() if device in [get_device(), "cuda"] else device
321
+ self.dit.to(target_device) # dtype already bf16
322
+
323
+ # Print model size.
324
+ params = sum(p.numel() for p in self.dit.parameters())
325
+ self.logger.info(
326
+ f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}"
327
+ )
328
+
329
+
330
+ def configure_vae_model(self, device=get_device()):
331
+ self.vae_stride = self.config.vae.vae_stride
332
+ self.vae = WanVAE(
333
+ vae_pth=self.config.vae.checkpoint,
334
+ device=device)
335
+
336
+ if self.config.generation.height == 480:
337
+ self.zero_vae = torch.load(self.config.dit.zero_vae_path)
338
+ elif self.config.generation.height == 720:
339
+ self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path)
340
+ else:
341
+ raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.")
342
+
343
+ def configure_wav2vec(self, device=get_device()):
344
+ audio_separator_model_file = self.config.audio.vocal_separator
345
+ wav2vec_model_path = self.config.audio.wav2vec_model
346
+
347
+ self.audio_processor = AudioProcessor(
348
+ 16000,
349
+ 25,
350
+ wav2vec_model_path,
351
+ "all",
352
+ audio_separator_model_file,
353
+ None, # not seperate
354
+ os.path.join(self.config.generation.output.dir, "vocals"),
355
+ device=device,
356
+ )
357
+
358
+ def configure_text_model(self, device=get_device()):
359
+ self.text_encoder = T5EncoderModel(
360
+ text_len=self.config.dit.model.text_len,
361
+ dtype=torch.bfloat16,
362
+ device=device,
363
+ checkpoint_path=self.config.text.t5_checkpoint,
364
+ tokenizer_path=self.config.text.t5_tokenizer,
365
+ )
366
+
367
+
368
+ def configure_dit_fsdp_model(self):
369
+ from humo.models.wan_modules.model_humo import WanAttentionBlock
370
+
371
+ dit_blocks = (WanAttentionBlock,)
372
+
373
+ # Init model_shard_cpu_group for saving checkpoint with sharded state_dict.
374
+ init_model_shard_cpu_group(
375
+ self.config.dit.fsdp.sharding_strategy,
376
+ self.config.dit.fsdp.get("device_mesh", None),
377
+ )
378
+
379
+ # Assert that dit has wrappable blocks.
380
+ assert any(isinstance(m, dit_blocks) for m in self.dit.modules())
381
+
382
+ # Define wrap policy on all dit blocks.
383
+ def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
384
+ return recurse or isinstance(module, dit_blocks)
385
+
386
+ # Configure FSDP settings.
387
+ device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
388
+ self.config.dit.fsdp.sharding_strategy,
389
+ self.config.dit.fsdp.get("device_mesh", None),
390
+ )
391
+ settings = dict(
392
+ auto_wrap_policy=custom_auto_wrap_policy,
393
+ sharding_strategy=fsdp_strategy,
394
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
395
+ device_id=get_local_rank(),
396
+ use_orig_params=False,
397
+ sync_module_states=True,
398
+ forward_prefetch=True,
399
+ limit_all_gathers=False, # False for ZERO2.
400
+ mixed_precision=MixedPrecision(
401
+ param_dtype=torch.bfloat16,
402
+ reduce_dtype=torch.float32,
403
+ buffer_dtype=torch.float32,
404
+ ),
405
+ device_mesh=device_mesh,
406
+ param_init_fn=meta_param_init_fn,
407
+ )
408
+
409
+ # Apply FSDP.
410
+ self.dit = FullyShardedDataParallel(self.dit, **settings)
411
+ # self.dit.to(get_device())
412
+
413
+
414
+ def configure_text_fsdp_model(self):
415
+ # If FSDP is not enabled, put text_encoder to GPU and return.
416
+ if not self.config.text.fsdp.enabled:
417
+ self.text_encoder.to(get_device())
418
+ return
419
+
420
+ # from transformers.models.t5.modeling_t5 import T5Block
421
+ from humo.models.wan_modules.t5 import T5SelfAttention
422
+
423
+ text_blocks = (torch.nn.Embedding, T5SelfAttention)
424
+ # text_blocks_names = ("QWenBlock", "QWenModel") # QWen cannot be imported. Use str.
425
+
426
+ def custom_auto_wrap_policy(module, recurse, *args, **kwargs):
427
+ return (
428
+ recurse
429
+ or isinstance(module, text_blocks)
430
+ )
431
+
432
+ # Apply FSDP.
433
+ text_encoder_dtype = getattr(torch, self.config.text.dtype)
434
+ device_mesh, fsdp_strategy = self.get_fsdp_sharding_config(
435
+ self.config.text.fsdp.sharding_strategy,
436
+ self.config.text.fsdp.get("device_mesh", None),
437
+ )
438
+ self.text_encoder = FullyShardedDataParallel(
439
+ module=self.text_encoder,
440
+ auto_wrap_policy=custom_auto_wrap_policy,
441
+ sharding_strategy=fsdp_strategy,
442
+ backward_prefetch=BackwardPrefetch.BACKWARD_PRE,
443
+ device_id=get_local_rank(),
444
+ use_orig_params=False,
445
+ sync_module_states=False,
446
+ forward_prefetch=True,
447
+ limit_all_gathers=True,
448
+ mixed_precision=MixedPrecision(
449
+ param_dtype=text_encoder_dtype,
450
+ reduce_dtype=text_encoder_dtype,
451
+ buffer_dtype=text_encoder_dtype,
452
+ ),
453
+ device_mesh=device_mesh,
454
+ )
455
+ self.text_encoder.to(get_device()).requires_grad_(False)
456
+
457
+
458
+ def load_image_latent_ref_id(self, path: str, size, device):
459
+ # Load size.
460
+ h, w = size[1], size[0]
461
+
462
+ # Load image.
463
+ if len(path) > 1 and not isinstance(path, str):
464
+ ref_vae_latents = []
465
+ for image_path in path:
466
+ with Image.open(image_path) as img:
467
+ img = img.convert("RGB")
468
+
469
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
470
+ img_ratio = img.width / img.height
471
+ target_ratio = w / h
472
+
473
+ if img_ratio > target_ratio: # Image is wider than target
474
+ new_width = w
475
+ new_height = int(new_width / img_ratio)
476
+ else: # Image is taller than target
477
+ new_height = h
478
+ new_width = int(new_height * img_ratio)
479
+
480
+ # img = img.resize((new_width, new_height), Image.ANTIALIAS)
481
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
482
+
483
+ # Create a new image with the target size and place the resized image in the center
484
+ delta_w = w - img.size[0]
485
+ delta_h = h - img.size[1]
486
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
487
+ new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
488
+
489
+ # Transform to tensor and normalize.
490
+ transform = Compose(
491
+ [
492
+ ToTensor(),
493
+ Normalize(0.5, 0.5),
494
+ ]
495
+ )
496
+ new_img = transform(new_img)
497
+ # img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0]
498
+ img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
499
+ ref_vae_latents.append(img_vae_latent[0])
500
+
501
+ return [torch.cat(ref_vae_latents, dim=1)]
502
+ else:
503
+ if not isinstance(path, str):
504
+ path = path[0]
505
+ with Image.open(path) as img:
506
+ img = img.convert("RGB")
507
+
508
+ # Calculate the required size to keep aspect ratio and fill the rest with padding.
509
+ img_ratio = img.width / img.height
510
+ target_ratio = w / h
511
+
512
+ if img_ratio > target_ratio: # Image is wider than target
513
+ new_width = w
514
+ new_height = int(new_width / img_ratio)
515
+ else: # Image is taller than target
516
+ new_height = h
517
+ new_width = int(new_height * img_ratio)
518
+
519
+ # img = img.resize((new_width, new_height), Image.ANTIALIAS)
520
+ img = img.resize((new_width, new_height), Image.Resampling.LANCZOS)
521
+
522
+ # Create a new image with the target size and place the resized image in the center
523
+ delta_w = w - img.size[0]
524
+ delta_h = h - img.size[1]
525
+ padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2))
526
+ new_img = ImageOps.expand(img, padding, fill=(255, 255, 255))
527
+
528
+ # Transform to tensor and normalize.
529
+ transform = Compose(
530
+ [
531
+ ToTensor(),
532
+ Normalize(0.5, 0.5),
533
+ ]
534
+ )
535
+ new_img = transform(new_img)
536
+ img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device)
537
+
538
+ # Vae encode.
539
+ return img_vae_latent
540
+
541
+ def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2):
542
+ zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device)
543
+ zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device
544
+ iter_ = 1 + (frame_num - 1) // 4
545
+ audio_emb_wind = []
546
+ for lt_i in range(iter_):
547
+ if lt_i == 0:
548
+ st = frame0_idx + lt_i - 2
549
+ ed = frame0_idx + lt_i + 3
550
+ wind_feat = torch.stack([
551
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
552
+ for i in range(st, ed)
553
+ ], dim=0)
554
+ wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0)
555
+ else:
556
+ st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift
557
+ ed = frame0_idx + 1 + 4 * lt_i + audio_shift
558
+ wind_feat = torch.stack([
559
+ audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed
560
+ for i in range(st, ed)
561
+ ], dim=0)
562
+ audio_emb_wind.append(wind_feat)
563
+ audio_emb_wind = torch.stack(audio_emb_wind, dim=0)
564
+
565
+ return audio_emb_wind, ed - audio_shift
566
+
567
+ def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"):
568
+ if wav_enc_type == "wav2vec":
569
+ feat_merge = audio_emb
570
+ elif wav_enc_type == "whisper":
571
+ feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25)
572
+ feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25)
573
+ feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25)
574
+ feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25)
575
+ feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25)
576
+ feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0]
577
+ else:
578
+ raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}")
579
+
580
+ return feat_merge
581
+
582
+ def parse_output(self, output):
583
+ latent = output[0]
584
+ mask = None
585
+ return latent, mask
586
+
587
+ def forward_tia(self, latents, timestep, t, step_change, arg_tia, arg_ti, arg_i, arg_null):
588
+ pos_tia, _ = self.parse_output(self.dit(
589
+ latents, t=timestep, **arg_tia
590
+ ))
591
+ torch.cuda.empty_cache()
592
+
593
+ pos_ti, _ = self.parse_output(self.dit(
594
+ latents, t=timestep, **arg_ti
595
+ ))
596
+ torch.cuda.empty_cache()
597
+
598
+ if t > step_change:
599
+ neg, _ = self.parse_output(self.dit(
600
+ latents, t=timestep, **arg_i
601
+ )) # img included in null, same with official Wan-2.1
602
+ torch.cuda.empty_cache()
603
+
604
+ noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
605
+ self.config.generation.scale_t * (pos_ti - neg) + \
606
+ neg
607
+ else:
608
+ neg, _ = self.parse_output(self.dit(
609
+ latents, t=timestep, **arg_null
610
+ )) # img not included in null
611
+ torch.cuda.empty_cache()
612
+
613
+ noise_pred = self.config.generation.scale_a * (pos_tia - pos_ti) + \
614
+ (self.config.generation.scale_t - 2.0) * (pos_ti - neg) + \
615
+ neg
616
+ return noise_pred
617
+
618
+ def forward_ti(self, latents, timestep, t, step_change, arg_ti, arg_t, arg_i, arg_null):
619
+ # Positive with text+image (no audio)
620
+ pos_ti, _ = self.parse_output(self.dit(
621
+ latents, t=timestep, **arg_ti
622
+ ))
623
+ torch.cuda.empty_cache()
624
+
625
+ # Positive with text only (no image, no audio)
626
+ pos_t, _ = self.parse_output(self.dit(
627
+ latents, t=timestep, **arg_t
628
+ ))
629
+ torch.cuda.empty_cache()
630
+
631
+ # Negative branch: before step_change, don't include image in null; after, include image (like Wan-2.1)
632
+ if t > step_change:
633
+ neg, _ = self.parse_output(self.dit(
634
+ latents, t=timestep, **arg_i
635
+ )) # img included in null
636
+ else:
637
+ neg, _ = self.parse_output(self.dit(
638
+ latents, t=timestep, **arg_null
639
+ )) # img NOT included in null
640
+ torch.cuda.empty_cache()
641
+
642
+ # Guidance blend: replace "scale_a" below with "scale_i" if you add a separate image scale in config
643
+ noise_pred = self.config.generation.scale_a * (pos_ti - pos_t) + \
644
+ self.config.generation.scale_t * (pos_t - neg) + \
645
+ neg
646
+ return noise_pred
647
+
648
+ def forward_ta(self, latents, timestep, arg_ta, arg_t, arg_null):
649
+ pos_ta, _ = self.parse_output(self.dit(
650
+ latents, t=timestep, **arg_ta
651
+ ))
652
+ torch.cuda.empty_cache()
653
+
654
+ pos_t, _ = self.parse_output(self.dit(
655
+ latents, t=timestep, **arg_t
656
+ ))
657
+ torch.cuda.empty_cache()
658
+
659
+ neg, _ = self.parse_output(self.dit(
660
+ latents, t=timestep, **arg_null
661
+ ))
662
+ torch.cuda.empty_cache()
663
+
664
+ noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \
665
+ self.config.generation.scale_t * (pos_t - neg) + \
666
+ neg
667
+ return noise_pred
668
+
669
+ @torch.no_grad()
670
+ def inference(self,
671
+ input_prompt,
672
+ img_path,
673
+ audio_path,
674
+ size=(1280, 720),
675
+ frame_num=81,
676
+ shift=5.0,
677
+ sample_solver='unipc',
678
+ inference_mode='TIA',
679
+ sampling_steps=50,
680
+ n_prompt="",
681
+ seed=-1,
682
+ tea_cache_l1_thresh = 0.0,
683
+ device = get_device(),
684
+ ):
685
+
686
+ # self.vae.model.to(device=device)
687
+ if img_path is not None:
688
+ latents_ref = self.load_image_latent_ref_id(img_path, size, device)
689
+ else:
690
+ latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)]
691
+
692
+ # self.vae.model.to(device="cpu")
693
+
694
+ latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref]
695
+
696
+ # audio
697
+ if audio_path is not None:
698
+ if self.config.generation.extract_audio_feat:
699
+ self.audio_processor.whisper.to(device=device)
700
+ audio_emb, audio_length = self.audio_processor.preprocess(audio_path)
701
+ self.audio_processor.whisper.to(device='cpu')
702
+ else:
703
+ audio_emb_path = audio_path.replace(".wav", ".pt")
704
+ audio_emb = torch.load(audio_emb_path).to(device=device)
705
+ audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper")
706
+ self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path)
707
+ else:
708
+ audio_emb = torch.zeros(frame_num, 5, 1280).to(device)
709
+
710
+ frame_num = frame_num if frame_num != -1 else audio_length
711
+ frame_num = 4 * ((frame_num - 1) // 4) + 1
712
+ audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0)
713
+ zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device)
714
+ audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0)
715
+ audio_emb = [audio_emb.to(device)]
716
+ audio_emb_neg = [torch.zeros_like(audio_emb[0])]
717
+
718
+ # preprocess
719
+ self.patch_size = self.config.dit.model.patch_size
720
+ F = frame_num
721
+ target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1],
722
+ size[1] // self.vae_stride[1],
723
+ size[0] // self.vae_stride[2])
724
+
725
+ seq_len = math.ceil((target_shape[2] * target_shape[3]) /
726
+ (self.patch_size[1] * self.patch_size[2]) *
727
+ target_shape[1] / self.sp_size) * self.sp_size
728
+
729
+ if n_prompt == "":
730
+ n_prompt = self.config.generation.sample_neg_prompt
731
+ seed = seed if seed >= 0 else random.randint(0, sys.maxsize)
732
+ seed_g = torch.Generator(device=device)
733
+ seed_g.manual_seed(seed)
734
+
735
+ # self.text_encoder.model.to(device)
736
+ context = self.text_encoder([input_prompt], device)
737
+ context_null = self.text_encoder([n_prompt], device)
738
+ # self.text_encoder.model.cpu()
739
+
740
+ noise = [
741
+ torch.randn(
742
+ target_shape[0],
743
+ target_shape[1], # - latents_ref[0].shape[1],
744
+ target_shape[2],
745
+ target_shape[3],
746
+ dtype=torch.float32,
747
+ device=device,
748
+ generator=seed_g)
749
+ ]
750
+
751
+ @contextmanager
752
+ def noop_no_sync():
753
+ yield
754
+
755
+ no_sync = getattr(self.dit, 'no_sync', noop_no_sync)
756
+ step_change = self.config.generation.step_change # 980
757
+
758
+ # evaluation mode
759
+ with make_fp8_ctx(True), torch.autocast('cuda', dtype=torch.bfloat16), torch.no_grad(), no_sync():
760
+
761
+ if sample_solver == 'unipc':
762
+ sample_scheduler = FlowUniPCMultistepScheduler(
763
+ num_train_timesteps=1000,
764
+ shift=1,
765
+ use_dynamic_shifting=False)
766
+ sample_scheduler.set_timesteps(
767
+ sampling_steps, device=device, shift=shift)
768
+ timesteps = sample_scheduler.timesteps
769
+
770
+ # sample videos
771
+ latents = noise
772
+
773
+ msk = torch.ones(4, target_shape[1], target_shape[2], target_shape[3], device=get_device())
774
+ msk[:,:-latents_ref[0].shape[1]] = 0
775
+
776
+ zero_vae = self.zero_vae[:, :(target_shape[1]-latents_ref[0].shape[1])].to(
777
+ device=get_device(), dtype=latents_ref[0].dtype)
778
+ y_c = torch.cat([
779
+ zero_vae,
780
+ latents_ref[0]
781
+ ], dim=1)
782
+ y_c = [torch.concat([msk, y_c])]
783
+
784
+ y_null = self.zero_vae[:, :target_shape[1]].to(
785
+ device=get_device(), dtype=latents_ref[0].dtype)
786
+ y_null = [torch.concat([msk, y_null])]
787
+
788
+ tea_cache_l1_thresh = tea_cache_l1_thresh
789
+ tea_cache_model_id = "Wan2.1-T2V-14B"
790
+
791
+ arg_null = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
792
+ arg_t = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
793
+ arg_i = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context_null, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
794
+ arg_ti = {'seq_len': seq_len, 'audio': audio_emb_neg, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
795
+ arg_ta = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_null, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
796
+ arg_tia = {'seq_len': seq_len, 'audio': audio_emb, 'y': y_c, 'context': context, "tea_cache": TeaCache(sampling_steps, rel_l1_thresh=tea_cache_l1_thresh, model_id=tea_cache_model_id) if tea_cache_l1_thresh is not None and tea_cache_l1_thresh > 0 else None}
797
+
798
+ torch.cuda.empty_cache()
799
+ # self.dit.to(device=get_device())
800
+ for _, t in enumerate(tqdm(timesteps)):
801
+ timestep = [t]
802
+ timestep = torch.stack(timestep)
803
+
804
+ if inference_mode == "TIA":
805
+ noise_pred = self.forward_tia(latents, timestep, t, step_change,
806
+ arg_tia, arg_ti, arg_i, arg_null)
807
+ elif inference_mode == "TA":
808
+ noise_pred = self.forward_ta(latents, timestep, arg_ta, arg_t, arg_null)
809
+ elif inference_mode == "TI":
810
+ noise_pred = self.forward_ti(latents, timestep, t, step_change,
811
+ arg_ti, arg_t, arg_i, arg_null)
812
+ else:
813
+ raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}")
814
+
815
+ temp_x0 = sample_scheduler.step(
816
+ noise_pred.unsqueeze(0),
817
+ t,
818
+ latents[0].unsqueeze(0),
819
+ return_dict=False,
820
+ generator=seed_g)[0]
821
+ latents = [temp_x0.squeeze(0)]
822
+
823
+ del timestep
824
+ torch.cuda.empty_cache()
825
+
826
+ x0 = latents
827
+ x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0]
828
+
829
+ # if offload_model:
830
+ # self.dit.cpu()
831
+
832
+ torch.cuda.empty_cache()
833
+ # if get_local_rank() == 0:
834
+ # self.vae.model.to(device=device)
835
+ videos = self.vae.decode(x0)
836
+ # self.vae.model.to(device="cpu")
837
+
838
+ del noise, latents, noise_pred
839
+ del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null
840
+ del x0, temp_x0
841
+ del sample_scheduler
842
+ torch.cuda.empty_cache()
843
+ gc.collect()
844
+ torch.cuda.synchronize()
845
+ if dist.is_initialized():
846
+ dist.barrier()
847
+
848
+ return videos[0] # if get_local_rank() == 0 else None
849
+
850
+
851
+ def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, inference_mode = "TIA", width = 832, height = 480, steps=50, frames = 97, tea_cache_l1_thresh = 0.0, seed = 0):
852
+
853
+ video = self.inference(
854
+ prompt,
855
+ ref_img_path,
856
+ audio_path,
857
+ size=SIZE_CONFIGS[f"{width}*{height}"],
858
+ frame_num=frames,
859
+ shift=self.config.diffusion.timesteps.sampling.shift,
860
+ sample_solver='unipc',
861
+ sampling_steps=steps,
862
+ inference_mode = inference_mode,
863
+ tea_cache_l1_thresh = tea_cache_l1_thresh,
864
+ seed=seed
865
+ )
866
+
867
+ torch.cuda.empty_cache()
868
+ gc.collect()
869
+
870
+ # Save samples.
871
+ if get_sequence_parallel_rank() == 0:
872
+ pathname = self.save_sample(
873
+ sample=video,
874
+ audio_path=audio_path,
875
+ output_dir = output_dir,
876
+ filename=filename,
877
+ )
878
+ self.logger.info(f"Finished {filename}, saved to {pathname}.")
879
+
880
+ del video, prompt
881
+ torch.cuda.empty_cache()
882
+ gc.collect()
883
+
884
+
885
+ def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str):
886
+ gen_config = self.config.generation
887
+ # Prepare file path.
888
+ extension = ".mp4" if sample.ndim == 4 else ".png"
889
+ filename += extension
890
+ pathname = os.path.join(output_dir, filename)
891
+ # Convert sample.
892
+ sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8)
893
+ sample = rearrange(sample, "c t h w -> t h w c")
894
+ # Save file.
895
+ if sample.ndim == 4:
896
+ if audio_path is not None:
897
+ tensor_to_video(
898
+ sample.numpy(),
899
+ pathname,
900
+ audio_path,
901
+ fps=gen_config.fps)
902
+ else:
903
+ mediapy.write_video(
904
+ path=pathname,
905
+ images=sample.numpy(),
906
+ fps=gen_config.fps,
907
+ )
908
+ else:
909
+ raise ValueError
910
+ return pathname
911
+
912
+
913
+ def prepare_positive_prompts(self):
914
+ pos_prompts = self.config.generation.positive_prompt
915
+ if pos_prompts.endswith(".json"):
916
+ pos_prompts = prepare_json_dataset(pos_prompts)
917
+ else:
918
+ raise NotImplementedError
919
+ assert isinstance(pos_prompts, ListConfig)
920
+
921
+ return pos_prompts
922
+
923
+ class TeaCache:
924
+ def __init__(self, num_inference_steps, rel_l1_thresh, model_id):
925
+ self.num_inference_steps = num_inference_steps
926
+ self.step = 0
927
+ self.accumulated_rel_l1_distance = 0
928
+ self.previous_modulated_input = None
929
+ self.rel_l1_thresh = rel_l1_thresh
930
+ self.previous_residual = None
931
+ self.previous_hidden_states = None
932
+
933
+ self.coefficients_dict = {
934
+ "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02],
935
+ "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01],
936
+ "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01],
937
+ "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02],
938
+ }
939
+ if model_id not in self.coefficients_dict:
940
+ supported_model_ids = ", ".join([i for i in self.coefficients_dict])
941
+ raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).")
942
+ self.coefficients = self.coefficients_dict[model_id]
943
+
944
+ def check(self, dit, x, t_mod):
945
+ modulated_inp = t_mod.clone()
946
+ if self.step == 0 or self.step == self.num_inference_steps - 1:
947
+ should_calc = True
948
+ self.accumulated_rel_l1_distance = 0
949
+ else:
950
+ coefficients = self.coefficients
951
+ rescale_func = np.poly1d(coefficients)
952
+ self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item())
953
+ if self.accumulated_rel_l1_distance < self.rel_l1_thresh:
954
+ should_calc = False
955
+ else:
956
+ should_calc = True
957
+ self.accumulated_rel_l1_distance = 0
958
+ self.previous_modulated_input = modulated_inp
959
+ self.step += 1
960
+ if self.step == self.num_inference_steps:
961
+ self.step = 0
962
+ if should_calc:
963
+ self.previous_hidden_states = x.clone()
964
+ return not should_calc
965
+
966
+ def store(self, hidden_states):
967
+ if self.previous_hidden_states is None:
968
+ return
969
+ self.previous_residual = hidden_states - self.previous_hidden_states
970
+ self.previous_hidden_states = None
971
+
972
+ def update(self, hidden_states):
973
+ hidden_states = hidden_states + self.previous_residual
 
 
 
 
 
 
 
 
 
 
974
  return hidden_states