Spaces:
Running
on
Zero
Running
on
Zero
| # Copyright (c) 2025 Bytedance Ltd. and/or its affiliates | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # Inference codes adapted from [SeedVR] | |
| # https://github.com/ByteDance-Seed/SeedVR/blob/main/projects/inference_seedvr2_7b.py | |
| import math | |
| import os | |
| import gc | |
| import random | |
| import sys | |
| import mediapy | |
| import numpy as np | |
| import torch | |
| import torch.distributed as dist | |
| from omegaconf import DictConfig, ListConfig, OmegaConf | |
| from einops import rearrange | |
| from omegaconf import OmegaConf | |
| from PIL import Image, ImageOps | |
| from torchvision.transforms import ToTensor | |
| from tqdm import tqdm | |
| from torch.distributed.device_mesh import init_device_mesh | |
| from torch.distributed.fsdp import ( | |
| BackwardPrefetch, | |
| FullyShardedDataParallel, | |
| MixedPrecision, | |
| ShardingStrategy, | |
| ) | |
| from common.distributed import ( | |
| get_device, | |
| get_global_rank, | |
| get_local_rank, | |
| meta_param_init_fn, | |
| meta_non_persistent_buffer_init_fn, | |
| init_torch, | |
| ) | |
| from common.distributed.advanced import ( | |
| init_unified_parallel, | |
| get_unified_parallel_world_size, | |
| get_sequence_parallel_rank, | |
| init_model_shard_cpu_group, | |
| ) | |
| from common.logger import get_logger | |
| from common.config import create_object | |
| from common.distributed import get_device, get_global_rank | |
| from torchvision.transforms import Compose, Normalize, ToTensor | |
| from humo.models.wan_modules.t5 import T5EncoderModel | |
| from humo.models.wan_modules.vae import WanVAE | |
| from humo.models.utils.utils import tensor_to_video, prepare_json_dataset | |
| from contextlib import contextmanager | |
| import torch.cuda.amp as amp | |
| from humo.models.utils.fm_solvers_unipc import FlowUniPCMultistepScheduler | |
| from humo.utils.audio_processor_whisper import AudioProcessor | |
| from humo.utils.wav2vec import linear_interpolation_fps | |
| from torchao.quantization import quantize_ | |
| import torch._dynamo as dynamo | |
| dynamo.config.capture_scalar_outputs = True | |
| torch.set_float32_matmul_precision("high") | |
| import torch | |
| import torch.nn as nn | |
| import transformer_engine.pytorch as te | |
| image_transform = Compose([ | |
| ToTensor(), | |
| Normalize(mean=0.5, std=0.5), | |
| ]) | |
| SIZE_CONFIGS = { | |
| '720*1280': (720, 1280), | |
| '1280*720': (1280, 720), | |
| '480*832': (480, 832), | |
| '832*480': (832, 480), | |
| '1024*1024': (1024, 1024), | |
| } | |
| def clever_format(nums, format="%.2f"): | |
| from typing import Iterable | |
| if not isinstance(nums, Iterable): | |
| nums = [nums] | |
| clever_nums = [] | |
| for num in nums: | |
| if num > 1e12: | |
| clever_nums.append(format % (num / 1e12) + "T") | |
| elif num > 1e9: | |
| clever_nums.append(format % (num / 1e9) + "G") | |
| elif num > 1e6: | |
| clever_nums.append(format % (num / 1e6) + "M") | |
| elif num > 1e3: | |
| clever_nums.append(format % (num / 1e3) + "K") | |
| else: | |
| clever_nums.append(format % num + "B") | |
| clever_nums = clever_nums[0] if len(clever_nums) == 1 else (*clever_nums,) | |
| return clever_nums | |
| # --- put near your imports --- | |
| import torch | |
| import torch.nn as nn | |
| import contextlib | |
| import transformer_engine.pytorch as te | |
| # FP8 autocast compatibility for different TE versions | |
| try: | |
| # Preferred modern API | |
| from transformer_engine.pytorch import fp8_autocast | |
| try: | |
| # Newer TE: use recipe-based API | |
| from transformer_engine.common.recipe import DelayedScaling, Format | |
| def make_fp8_ctx(enabled: bool = True): | |
| if not enabled: | |
| return contextlib.nullcontext() | |
| fp8_recipe = DelayedScaling(fp8_format=Format.E4M3) # E4M3 format | |
| return fp8_autocast(enabled=True, fp8_recipe=fp8_recipe) | |
| except Exception: | |
| # Very old variant that might still accept fp8_format directly | |
| def make_fp8_ctx(enabled: bool = True): | |
| # If TE doesn't have FP8Format, just no-op | |
| if not hasattr(te, "FP8Format"): | |
| return contextlib.nullcontext() | |
| return te.fp8_autocast(enabled=enabled, fp8_format=te.FP8Format.E4M3) | |
| except Exception: | |
| # TE not present or totally incompatible — no-op | |
| def make_fp8_ctx(enabled: bool = True): | |
| return contextlib.nullcontext() | |
| # TE sometimes exposes Linear at different paths; this normalizes it. | |
| try: | |
| TELinear = te.Linear | |
| except AttributeError: # very old layouts | |
| from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore | |
| # --- near imports --- | |
| import torch | |
| import torch.nn as nn | |
| import transformer_engine.pytorch as te | |
| try: | |
| TELinear = te.Linear | |
| except AttributeError: | |
| from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore | |
| import torch | |
| import torch.nn as nn | |
| import transformer_engine.pytorch as te | |
| try: | |
| TELinear = te.Linear | |
| except AttributeError: | |
| from transformer_engine.pytorch.modules.linear import Linear as TELinear # type: ignore | |
| def _default_te_allow(fullname: str, lin: nn.Linear) -> bool: | |
| """ | |
| Allow TE only where it's shape-safe & beneficial. | |
| Skip small/special layers (time/timestep/pos embeds, heads). | |
| Enforce multiples of 16 for in/out features (FP8 kernel friendly). | |
| Also skip very small projections likely to see M=1. | |
| """ | |
| blocked_keywords = ( | |
| "time_embedding", "timestep", "time_embed", | |
| "time_projection", "pos_embedding", "pos_embed", | |
| "to_logits", "logits", "final_proj", "proj_out", "output_projection", | |
| ) | |
| if any(k in fullname for k in blocked_keywords): | |
| return False | |
| # TE FP8 kernels like K, N divisible by 16 | |
| if lin.in_features % 16 != 0 or lin.out_features % 16 != 0: | |
| return False | |
| # Heuristic: avoid tiny layers; keeps attention/MLP, skips small MLPs | |
| if lin.in_features < 512 or lin.out_features < 512: | |
| return False | |
| # Whitelist: only convert inside transformer blocks if you know their prefix | |
| # This further reduces risk of catching special heads elsewhere. | |
| allowed_context = ("blocks", "layers", "transformer", "attn", "mlp", "ffn") | |
| if not any(tok in fullname for tok in allowed_context): | |
| return False | |
| return True | |
| def convert_linears_to_te_fp8(module: nn.Module, allow_pred=_default_te_allow, _prefix=""): | |
| for name, child in list(module.named_children()): | |
| full = f"{_prefix}.{name}" if _prefix else name | |
| convert_linears_to_te_fp8(child, allow_pred, full) | |
| if isinstance(child, nn.Linear): | |
| if allow_pred is not None and not allow_pred(full, child): | |
| continue | |
| te_lin = TELinear( | |
| in_features=child.in_features, | |
| out_features=child.out_features, | |
| bias=(child.bias is not None), | |
| params_dtype=torch.bfloat16, | |
| ).to(child.weight.device) | |
| te_lin.weight.copy_(child.weight.to(te_lin.weight.dtype)) | |
| if child.bias is not None: | |
| te_lin.bias.copy_(child.bias.to(te_lin.bias.dtype)) | |
| setattr(module, name, te_lin) | |
| return module | |
| class Generator(): | |
| def __init__(self, config: DictConfig): | |
| self.config = config.copy() | |
| OmegaConf.set_readonly(self.config, True) | |
| self.logger = get_logger(self.__class__.__name__) | |
| # init_torch(cudnn_benchmark=False) | |
| self.configure_models() | |
| def entrypoint(self): | |
| self.inference_loop() | |
| def get_fsdp_sharding_config(self, sharding_strategy, device_mesh_config): | |
| device_mesh = None | |
| fsdp_strategy = ShardingStrategy[sharding_strategy] | |
| if ( | |
| fsdp_strategy in [ShardingStrategy._HYBRID_SHARD_ZERO2, ShardingStrategy.HYBRID_SHARD] | |
| and device_mesh_config is not None | |
| ): | |
| device_mesh = init_device_mesh("cuda", tuple(device_mesh_config)) | |
| return device_mesh, fsdp_strategy | |
| def configure_models(self): | |
| self.configure_dit_model(device="cuda") | |
| self.dit.eval().to("cuda") | |
| convert_linears_to_te_fp8(self.dit) | |
| self.dit = torch.compile(self.dit, ) | |
| self.configure_vae_model(device="cuda") | |
| if self.config.generation.get('extract_audio_feat', False): | |
| self.configure_wav2vec(device="cpu") | |
| self.configure_text_model(device="cuda") | |
| # # Initialize fsdp. | |
| # self.configure_dit_fsdp_model() | |
| # self.configure_text_fsdp_model() | |
| # quantize_(self.text_encoder, Int8WeightOnlyConfig()) | |
| # quantize_(self.dit, Float8DynamicActivationFloat8WeightConfig()) | |
| def configure_dit_model(self, device=get_device()): | |
| init_unified_parallel(self.config.dit.sp_size) | |
| self.sp_size = get_unified_parallel_world_size() | |
| # Create DiT model on meta, then mark dtype as bfloat16 (no real allocation yet). | |
| init_device = "meta" | |
| with torch.device(init_device): | |
| self.dit = create_object(self.config.dit.model) | |
| self.dit = self.dit.to(dtype=torch.bfloat16) # or: self.dit.bfloat16() | |
| self.logger.info(f"Load DiT model on {init_device}.") | |
| self.dit.eval().requires_grad_(False) | |
| # Load dit checkpoint. | |
| path = self.config.dit.checkpoint_dir | |
| def _cast_state_dict_to_bf16(state): | |
| for k, v in state.items(): | |
| if isinstance(v, torch.Tensor) and v.is_floating_point(): | |
| state[k] = v.to(dtype=torch.bfloat16, copy=False) | |
| return state | |
| if path.endswith(".pth"): | |
| # Load to CPU first; we’ll move the model later. | |
| state = torch.load(path, map_location="cpu", mmap=True) | |
| state = _cast_state_dict_to_bf16(state) | |
| missing_keys, unexpected_keys = self.dit.load_state_dict(state, strict=False, assign=True) | |
| self.logger.info( | |
| f"dit loaded from {path}. Missing keys: {len(missing_keys)}, Unexpected keys: {len(unexpected_keys)}" | |
| ) | |
| else: | |
| from safetensors.torch import load_file | |
| import json | |
| def load_custom_sharded_weights(model_dir, base_name): | |
| index_path = f"{model_dir}/{base_name}.safetensors.index.json" | |
| with open(index_path, "r") as f: | |
| index = json.load(f) | |
| weight_map = index["weight_map"] | |
| shard_files = set(weight_map.values()) | |
| state_dict = {} | |
| for shard_file in shard_files: | |
| shard_path = f"{model_dir}/{shard_file}" | |
| # Load on CPU, then cast to bf16; we’ll move the whole module later. | |
| shard_state = load_file(shard_path, device="cpu") | |
| shard_state = {k: (v.to(dtype=torch.bfloat16, copy=False) if v.is_floating_point() else v) | |
| for k, v in shard_state.items()} | |
| state_dict.update(shard_state) | |
| return state_dict | |
| state = load_custom_sharded_weights(path, 'humo') | |
| self.dit.load_state_dict(state, strict=False, assign=True) | |
| self.dit = meta_non_persistent_buffer_init_fn(self.dit) | |
| target_device = get_device() if device in [get_device(), "cuda"] else device | |
| self.dit.to(target_device) # dtype already bf16 | |
| # Print model size. | |
| params = sum(p.numel() for p in self.dit.parameters()) | |
| self.logger.info( | |
| f"[RANK:{get_global_rank()}] DiT Parameters: {clever_format(params, '%.3f')}" | |
| ) | |
| def configure_vae_model(self, device=get_device()): | |
| self.vae_stride = self.config.vae.vae_stride | |
| self.vae = WanVAE( | |
| vae_pth=self.config.vae.checkpoint, | |
| device=device) | |
| if self.config.generation.height == 480: | |
| self.zero_vae = torch.load(self.config.dit.zero_vae_path) | |
| elif self.config.generation.height == 720: | |
| self.zero_vae = torch.load(self.config.dit.zero_vae_720p_path) | |
| else: | |
| raise ValueError(f"Unsupported height {self.config.generation.height} for zero-vae.") | |
| def configure_wav2vec(self, device=get_device()): | |
| audio_separator_model_file = self.config.audio.vocal_separator | |
| wav2vec_model_path = self.config.audio.wav2vec_model | |
| self.audio_processor = AudioProcessor( | |
| 16000, | |
| 25, | |
| wav2vec_model_path, | |
| "all", | |
| audio_separator_model_file, | |
| None, # not seperate | |
| os.path.join(self.config.generation.output.dir, "vocals"), | |
| device=device, | |
| ) | |
| def configure_text_model(self, device=get_device()): | |
| self.text_encoder = T5EncoderModel( | |
| text_len=self.config.dit.model.text_len, | |
| dtype=torch.bfloat16, | |
| device=device, | |
| checkpoint_path=self.config.text.t5_checkpoint, | |
| tokenizer_path=self.config.text.t5_tokenizer, | |
| ) | |
| def configure_dit_fsdp_model(self): | |
| from humo.models.wan_modules.model_humo import WanAttentionBlock | |
| dit_blocks = (WanAttentionBlock,) | |
| # Init model_shard_cpu_group for saving checkpoint with sharded state_dict. | |
| init_model_shard_cpu_group( | |
| self.config.dit.fsdp.sharding_strategy, | |
| self.config.dit.fsdp.get("device_mesh", None), | |
| ) | |
| # Assert that dit has wrappable blocks. | |
| assert any(isinstance(m, dit_blocks) for m in self.dit.modules()) | |
| # Define wrap policy on all dit blocks. | |
| def custom_auto_wrap_policy(module, recurse, *args, **kwargs): | |
| return recurse or isinstance(module, dit_blocks) | |
| # Configure FSDP settings. | |
| device_mesh, fsdp_strategy = self.get_fsdp_sharding_config( | |
| self.config.dit.fsdp.sharding_strategy, | |
| self.config.dit.fsdp.get("device_mesh", None), | |
| ) | |
| settings = dict( | |
| auto_wrap_policy=custom_auto_wrap_policy, | |
| sharding_strategy=fsdp_strategy, | |
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | |
| device_id=get_local_rank(), | |
| use_orig_params=False, | |
| sync_module_states=True, | |
| forward_prefetch=True, | |
| limit_all_gathers=False, # False for ZERO2. | |
| mixed_precision=MixedPrecision( | |
| param_dtype=torch.bfloat16, | |
| reduce_dtype=torch.float32, | |
| buffer_dtype=torch.float32, | |
| ), | |
| device_mesh=device_mesh, | |
| param_init_fn=meta_param_init_fn, | |
| ) | |
| # Apply FSDP. | |
| self.dit = FullyShardedDataParallel(self.dit, **settings) | |
| # self.dit.to(get_device()) | |
| def configure_text_fsdp_model(self): | |
| # If FSDP is not enabled, put text_encoder to GPU and return. | |
| if not self.config.text.fsdp.enabled: | |
| self.text_encoder.to(get_device()) | |
| return | |
| # from transformers.models.t5.modeling_t5 import T5Block | |
| from humo.models.wan_modules.t5 import T5SelfAttention | |
| text_blocks = (torch.nn.Embedding, T5SelfAttention) | |
| # text_blocks_names = ("QWenBlock", "QWenModel") # QWen cannot be imported. Use str. | |
| def custom_auto_wrap_policy(module, recurse, *args, **kwargs): | |
| return ( | |
| recurse | |
| or isinstance(module, text_blocks) | |
| ) | |
| # Apply FSDP. | |
| text_encoder_dtype = getattr(torch, self.config.text.dtype) | |
| device_mesh, fsdp_strategy = self.get_fsdp_sharding_config( | |
| self.config.text.fsdp.sharding_strategy, | |
| self.config.text.fsdp.get("device_mesh", None), | |
| ) | |
| self.text_encoder = FullyShardedDataParallel( | |
| module=self.text_encoder, | |
| auto_wrap_policy=custom_auto_wrap_policy, | |
| sharding_strategy=fsdp_strategy, | |
| backward_prefetch=BackwardPrefetch.BACKWARD_PRE, | |
| device_id=get_local_rank(), | |
| use_orig_params=False, | |
| sync_module_states=False, | |
| forward_prefetch=True, | |
| limit_all_gathers=True, | |
| mixed_precision=MixedPrecision( | |
| param_dtype=text_encoder_dtype, | |
| reduce_dtype=text_encoder_dtype, | |
| buffer_dtype=text_encoder_dtype, | |
| ), | |
| device_mesh=device_mesh, | |
| ) | |
| self.text_encoder.to(get_device()).requires_grad_(False) | |
| def load_image_latent_ref_id(self, path: str, size, device): | |
| # Load size. | |
| h, w = size[1], size[0] | |
| # Load image. | |
| if len(path) > 1 and not isinstance(path, str): | |
| ref_vae_latents = [] | |
| for image_path in path: | |
| with Image.open(image_path) as img: | |
| img = img.convert("RGB") | |
| # Calculate the required size to keep aspect ratio and fill the rest with padding. | |
| img_ratio = img.width / img.height | |
| target_ratio = w / h | |
| if img_ratio > target_ratio: # Image is wider than target | |
| new_width = w | |
| new_height = int(new_width / img_ratio) | |
| else: # Image is taller than target | |
| new_height = h | |
| new_width = int(new_height * img_ratio) | |
| # img = img.resize((new_width, new_height), Image.ANTIALIAS) | |
| img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| # Create a new image with the target size and place the resized image in the center | |
| delta_w = w - img.size[0] | |
| delta_h = h - img.size[1] | |
| padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) | |
| new_img = ImageOps.expand(img, padding, fill=(255, 255, 255)) | |
| # Transform to tensor and normalize. | |
| transform = Compose( | |
| [ | |
| ToTensor(), | |
| Normalize(0.5, 0.5), | |
| ] | |
| ) | |
| new_img = transform(new_img) | |
| # img_vae_latent = self.vae_encode([new_img.unsqueeze(1)])[0] | |
| img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device) | |
| ref_vae_latents.append(img_vae_latent[0]) | |
| return [torch.cat(ref_vae_latents, dim=1)] | |
| else: | |
| if not isinstance(path, str): | |
| path = path[0] | |
| with Image.open(path) as img: | |
| img = img.convert("RGB") | |
| # Calculate the required size to keep aspect ratio and fill the rest with padding. | |
| img_ratio = img.width / img.height | |
| target_ratio = w / h | |
| if img_ratio > target_ratio: # Image is wider than target | |
| new_width = w | |
| new_height = int(new_width / img_ratio) | |
| else: # Image is taller than target | |
| new_height = h | |
| new_width = int(new_height * img_ratio) | |
| # img = img.resize((new_width, new_height), Image.ANTIALIAS) | |
| img = img.resize((new_width, new_height), Image.Resampling.LANCZOS) | |
| # Create a new image with the target size and place the resized image in the center | |
| delta_w = w - img.size[0] | |
| delta_h = h - img.size[1] | |
| padding = (delta_w // 2, delta_h // 2, delta_w - (delta_w // 2), delta_h - (delta_h // 2)) | |
| new_img = ImageOps.expand(img, padding, fill=(255, 255, 255)) | |
| # Transform to tensor and normalize. | |
| transform = Compose( | |
| [ | |
| ToTensor(), | |
| Normalize(0.5, 0.5), | |
| ] | |
| ) | |
| new_img = transform(new_img) | |
| img_vae_latent = self.vae.encode([new_img.unsqueeze(1)], device) | |
| # Vae encode. | |
| return img_vae_latent | |
| def get_audio_emb_window(self, audio_emb, frame_num, frame0_idx, audio_shift=2): | |
| zero_audio_embed = torch.zeros((audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) | |
| zero_audio_embed_3 = torch.zeros((3, audio_emb.shape[1], audio_emb.shape[2]), dtype=audio_emb.dtype, device=audio_emb.device) # device=audio_emb.device | |
| iter_ = 1 + (frame_num - 1) // 4 | |
| audio_emb_wind = [] | |
| for lt_i in range(iter_): | |
| if lt_i == 0: | |
| st = frame0_idx + lt_i - 2 | |
| ed = frame0_idx + lt_i + 3 | |
| wind_feat = torch.stack([ | |
| audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed | |
| for i in range(st, ed) | |
| ], dim=0) | |
| wind_feat = torch.cat((zero_audio_embed_3, wind_feat), dim=0) | |
| else: | |
| st = frame0_idx + 1 + 4 * (lt_i - 1) - audio_shift | |
| ed = frame0_idx + 1 + 4 * lt_i + audio_shift | |
| wind_feat = torch.stack([ | |
| audio_emb[i] if (0 <= i < audio_emb.shape[0]) else zero_audio_embed | |
| for i in range(st, ed) | |
| ], dim=0) | |
| audio_emb_wind.append(wind_feat) | |
| audio_emb_wind = torch.stack(audio_emb_wind, dim=0) | |
| return audio_emb_wind, ed - audio_shift | |
| def audio_emb_enc(self, audio_emb, wav_enc_type="whisper"): | |
| if wav_enc_type == "wav2vec": | |
| feat_merge = audio_emb | |
| elif wav_enc_type == "whisper": | |
| feat0 = linear_interpolation_fps(audio_emb[:, :, 0: 8].mean(dim=2), 50, 25) | |
| feat1 = linear_interpolation_fps(audio_emb[:, :, 8: 16].mean(dim=2), 50, 25) | |
| feat2 = linear_interpolation_fps(audio_emb[:, :, 16: 24].mean(dim=2), 50, 25) | |
| feat3 = linear_interpolation_fps(audio_emb[:, :, 24: 32].mean(dim=2), 50, 25) | |
| feat4 = linear_interpolation_fps(audio_emb[:, :, 32], 50, 25) | |
| feat_merge = torch.stack([feat0, feat1, feat2, feat3, feat4], dim=2)[0] | |
| else: | |
| raise ValueError(f"Unsupported wav_enc_type: {wav_enc_type}") | |
| return feat_merge | |
| def forward_tia(self, latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null): | |
| neg = self.dit( | |
| [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null | |
| )[0] | |
| pos_t = self.dit( | |
| [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t | |
| )[0] | |
| pos_ta = self.dit( | |
| [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta | |
| )[0] | |
| pos_tia = self.dit( | |
| [torch.cat([latent[:,:-latent_ref.shape[1]], latent_ref], dim=1) for latent, latent_ref in zip(latents, latents_ref)], t=timestep, **arg_ta | |
| )[0] | |
| noise_pred = self.config.generation.scale_i * (pos_tia - pos_ta) + \ | |
| self.config.generation.scale_a * (pos_ta - pos_t) + \ | |
| self.config.generation.scale_t * (pos_t - neg) + \ | |
| neg | |
| return noise_pred | |
| def forward_ta(self, latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null): | |
| neg = self.dit( | |
| [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_null | |
| )[0] | |
| pos_t = self.dit( | |
| [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_t | |
| )[0] | |
| pos_ta = self.dit( | |
| [torch.cat([latent[:,:-latent_ref_neg.shape[1]], latent_ref_neg], dim=1) for latent, latent_ref_neg in zip(latents, latents_ref_neg)], t=timestep, **arg_ta | |
| )[0] | |
| noise_pred = self.config.generation.scale_a * (pos_ta - pos_t) + \ | |
| self.config.generation.scale_t * (pos_t - neg) + \ | |
| neg | |
| return noise_pred | |
| def inference(self, | |
| input_prompt, | |
| img_path, | |
| audio_path, | |
| size=(1280, 720), | |
| frame_num=81, | |
| shift=5.0, | |
| sample_solver='unipc', | |
| inference_mode='TIA', | |
| sampling_steps=50, | |
| n_prompt="", | |
| seed=-1, | |
| tea_cache_l1_thresh = 0.0, | |
| device = get_device(), | |
| ): | |
| # self.vae.model.to(device=device) | |
| if img_path is not None: | |
| latents_ref = self.load_image_latent_ref_id(img_path, size, device) | |
| else: | |
| latents_ref = [torch.zeros(16, 1, size[1]//8, size[0]//8).to(device)] | |
| # self.vae.model.to(device="cpu") | |
| latents_ref_neg = [torch.zeros_like(latent_ref) for latent_ref in latents_ref] | |
| # audio | |
| if audio_path is not None: | |
| if self.config.generation.extract_audio_feat: | |
| self.audio_processor.whisper.to(device=device) | |
| audio_emb, audio_length = self.audio_processor.preprocess(audio_path) | |
| self.audio_processor.whisper.to(device='cpu') | |
| else: | |
| audio_emb_path = audio_path.replace(".wav", ".pt") | |
| audio_emb = torch.load(audio_emb_path).to(device=device) | |
| audio_emb = self.audio_emb_enc(audio_emb, wav_enc_type="whisper") | |
| self.logger.info("使用预先提取好的音频特征: %s", audio_emb_path) | |
| else: | |
| audio_emb = torch.zeros(frame_num, 5, 1280).to(device) | |
| frame_num = frame_num if frame_num != -1 else audio_length | |
| frame_num = 4 * ((frame_num - 1) // 4) + 1 | |
| audio_emb, _ = self.get_audio_emb_window(audio_emb, frame_num, frame0_idx=0) | |
| zero_audio_pad = torch.zeros(latents_ref[0].shape[1], *audio_emb.shape[1:]).to(audio_emb.device) | |
| audio_emb = torch.cat([audio_emb, zero_audio_pad], dim=0) | |
| audio_emb = [audio_emb.to(device)] | |
| audio_emb_neg = [torch.zeros_like(audio_emb[0])] | |
| # preprocess | |
| self.patch_size = self.config.dit.model.patch_size | |
| F = frame_num | |
| target_shape = (self.vae.model.z_dim, (F - 1) // self.vae_stride[0] + 1 + latents_ref[0].shape[1], | |
| size[1] // self.vae_stride[1], | |
| size[0] // self.vae_stride[2]) | |
| seq_len = math.ceil((target_shape[2] * target_shape[3]) / | |
| (self.patch_size[1] * self.patch_size[2]) * | |
| target_shape[1] / self.sp_size) * self.sp_size | |
| if n_prompt == "": | |
| n_prompt = self.config.generation.sample_neg_prompt | |
| seed = seed if seed >= 0 else random.randint(0, sys.maxsize) | |
| seed_g = torch.Generator(device=device) | |
| seed_g.manual_seed(seed) | |
| # self.text_encoder.model.to(device) | |
| context = self.text_encoder([input_prompt], device) | |
| context_null = self.text_encoder([n_prompt], device) | |
| # self.text_encoder.model.cpu() | |
| noise = [ | |
| torch.randn( | |
| target_shape[0], | |
| target_shape[1], # - latents_ref[0].shape[1], | |
| target_shape[2], | |
| target_shape[3], | |
| dtype=torch.float32, | |
| device=device, | |
| generator=seed_g) | |
| ] | |
| def noop_no_sync(): | |
| yield | |
| no_sync = getattr(self.dit, 'no_sync', noop_no_sync) | |
| # evaluation mode | |
| with make_fp8_ctx(True), torch.autocast('cuda', dtype=torch.bfloat16), torch.no_grad(), no_sync(): | |
| if sample_solver == 'unipc': | |
| sample_scheduler = FlowUniPCMultistepScheduler( | |
| num_train_timesteps=1000, | |
| shift=1, | |
| use_dynamic_shifting=False) | |
| sample_scheduler.set_timesteps( | |
| sampling_steps, device=device, shift=shift) | |
| timesteps = sample_scheduler.timesteps | |
| # sample videos | |
| latents = noise | |
| # referene image在下面的输入中手动指定, 不在arg中指定 | |
| arg_ta = {'context': context, 'seq_len': seq_len, 'audio': audio_emb} | |
| arg_t = {'context': context, 'seq_len': seq_len, 'audio': audio_emb_neg} | |
| arg_null = {'context': context_null, 'seq_len': seq_len, 'audio': audio_emb_neg} | |
| torch.cuda.empty_cache() | |
| for _, t in enumerate(tqdm(timesteps)): | |
| timestep = [t] | |
| timestep = torch.stack(timestep) | |
| if self.config.generation.mode == "TIA": | |
| noise_pred = self.forward_tia(latents, latents_ref, latents_ref_neg, timestep, arg_t, arg_ta, arg_null) | |
| elif self.config.generation.mode == "TA": | |
| noise_pred = self.forward_ta(latents, latents_ref_neg, timestep, arg_t, arg_ta, arg_null) | |
| else: | |
| raise ValueError(f"Unsupported generation mode: {self.config.generation.mode}") | |
| temp_x0 = sample_scheduler.step( | |
| noise_pred.unsqueeze(0), | |
| t, | |
| latents[0].unsqueeze(0), | |
| return_dict=False, | |
| generator=seed_g)[0] | |
| latents = [temp_x0.squeeze(0)] | |
| del timestep | |
| torch.cuda.empty_cache() | |
| x0 = latents | |
| x0 = [x0_[:,:-latents_ref[0].shape[1]] for x0_ in x0] | |
| # if offload_model: | |
| # self.dit.cpu() | |
| torch.cuda.empty_cache() | |
| # if get_local_rank() == 0: | |
| # self.vae.model.to(device=device) | |
| videos = self.vae.decode(x0) | |
| # self.vae.model.to(device="cpu") | |
| del noise, latents, noise_pred | |
| del audio_emb, audio_emb_neg, latents_ref, latents_ref_neg, context, context_null | |
| del x0, temp_x0 | |
| del sample_scheduler | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| torch.cuda.synchronize() | |
| if dist.is_initialized(): | |
| dist.barrier() | |
| return videos[0] # if get_local_rank() == 0 else None | |
| def inference_loop(self, prompt, ref_img_path, audio_path, output_dir, filename, inference_mode = "TIA", width = 832, height = 480, steps=50, frames = 97, tea_cache_l1_thresh = 0.0, seed = 0): | |
| video = self.inference( | |
| prompt, | |
| ref_img_path, | |
| audio_path, | |
| size=SIZE_CONFIGS[f"{width}*{height}"], | |
| frame_num=frames, | |
| shift=self.config.diffusion.timesteps.sampling.shift, | |
| sample_solver='unipc', | |
| sampling_steps=steps, | |
| inference_mode = inference_mode, | |
| tea_cache_l1_thresh = tea_cache_l1_thresh, | |
| seed=seed | |
| ) | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| # Save samples. | |
| if get_sequence_parallel_rank() == 0: | |
| pathname = self.save_sample( | |
| sample=video, | |
| audio_path=audio_path, | |
| output_dir = output_dir, | |
| filename=filename, | |
| ) | |
| self.logger.info(f"Finished {filename}, saved to {pathname}.") | |
| del video, prompt | |
| torch.cuda.empty_cache() | |
| gc.collect() | |
| def save_sample(self, *, sample: torch.Tensor, audio_path: str, output_dir: str, filename: str): | |
| gen_config = self.config.generation | |
| # Prepare file path. | |
| extension = ".mp4" if sample.ndim == 4 else ".png" | |
| filename += extension | |
| pathname = os.path.join(output_dir, filename) | |
| # Convert sample. | |
| sample = sample.clip(-1, 1).mul_(0.5).add_(0.5).mul_(255).to("cpu", torch.uint8) | |
| sample = rearrange(sample, "c t h w -> t h w c") | |
| # Save file. | |
| if sample.ndim == 4: | |
| if audio_path is not None: | |
| tensor_to_video( | |
| sample.numpy(), | |
| pathname, | |
| audio_path, | |
| fps=gen_config.fps) | |
| else: | |
| mediapy.write_video( | |
| path=pathname, | |
| images=sample.numpy(), | |
| fps=gen_config.fps, | |
| ) | |
| else: | |
| raise ValueError | |
| return pathname | |
| def prepare_positive_prompts(self): | |
| pos_prompts = self.config.generation.positive_prompt | |
| if pos_prompts.endswith(".json"): | |
| pos_prompts = prepare_json_dataset(pos_prompts) | |
| else: | |
| raise NotImplementedError | |
| assert isinstance(pos_prompts, ListConfig) | |
| return pos_prompts | |
| class TeaCache: | |
| def __init__(self, num_inference_steps, rel_l1_thresh, model_id): | |
| self.num_inference_steps = num_inference_steps | |
| self.step = 0 | |
| self.accumulated_rel_l1_distance = 0 | |
| self.previous_modulated_input = None | |
| self.rel_l1_thresh = rel_l1_thresh | |
| self.previous_residual = None | |
| self.previous_hidden_states = None | |
| self.coefficients_dict = { | |
| "Wan2.1-T2V-1.3B": [-5.21862437e+04, 9.23041404e+03, -5.28275948e+02, 1.36987616e+01, -4.99875664e-02], | |
| "Wan2.1-T2V-14B": [-3.03318725e+05, 4.90537029e+04, -2.65530556e+03, 5.87365115e+01, -3.15583525e-01], | |
| "Wan2.1-I2V-14B-480P": [2.57151496e+05, -3.54229917e+04, 1.40286849e+03, -1.35890334e+01, 1.32517977e-01], | |
| "Wan2.1-I2V-14B-720P": [ 8.10705460e+03, 2.13393892e+03, -3.72934672e+02, 1.66203073e+01, -4.17769401e-02], | |
| } | |
| if model_id not in self.coefficients_dict: | |
| supported_model_ids = ", ".join([i for i in self.coefficients_dict]) | |
| raise ValueError(f"{model_id} is not a supported TeaCache model id. Please choose a valid model id in ({supported_model_ids}).") | |
| self.coefficients = self.coefficients_dict[model_id] | |
| def check(self, dit, x, t_mod): | |
| modulated_inp = t_mod.clone() | |
| if self.step == 0 or self.step == self.num_inference_steps - 1: | |
| should_calc = True | |
| self.accumulated_rel_l1_distance = 0 | |
| else: | |
| coefficients = self.coefficients | |
| rescale_func = np.poly1d(coefficients) | |
| self.accumulated_rel_l1_distance += rescale_func(((modulated_inp-self.previous_modulated_input).abs().mean() / self.previous_modulated_input.abs().mean()).cpu().item()) | |
| if self.accumulated_rel_l1_distance < self.rel_l1_thresh: | |
| should_calc = False | |
| else: | |
| should_calc = True | |
| self.accumulated_rel_l1_distance = 0 | |
| self.previous_modulated_input = modulated_inp | |
| self.step += 1 | |
| if self.step == self.num_inference_steps: | |
| self.step = 0 | |
| if should_calc: | |
| self.previous_hidden_states = x.clone() | |
| return not should_calc | |
| def store(self, hidden_states): | |
| if self.previous_hidden_states is None: | |
| return | |
| self.previous_residual = hidden_states - self.previous_hidden_states | |
| self.previous_hidden_states = None | |
| def update(self, hidden_states): | |
| hidden_states = hidden_states + self.previous_residual | |
| return hidden_states |