# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # http://www.apache.org/licenses/LICENSE-2.0 # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import torch from einops import rearrange from common.distributed import get_device from common.distributed.advanced import ( get_unified_parallel_world_size, get_unified_parallel_group, pad_tensor, Slice, gather_outputs, gather_seq_scatter_heads_qkv, gather_seq_scatter_double_head, gather_heads_scatter_seq, unpad_tensor ) from humo.models.wan_modules.attention import flash_attention from humo.models.wan_modules.model_humo import rope_apply, sinusoidal_embedding_1d def ulysses_dit_forward( self, x, t, context, seq_len, audio=None, y=None ): """ x: A list of videos each with shape [C, T, H, W]. t: [B]. context: A list of text embeddings each with shape [L, C]. """ if self.model_type == 'i2v': # assert clip_fea is not None and y is not None assert y is not None # params device = self.patch_embedding.weight.device if self.freqs.device != device: self.freqs = self.freqs.to(device) if y is not None: x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)] # embeddings x = [self.patch_embedding(u.unsqueeze(0)) for u in x] grid_sizes = torch.stack( [torch.tensor(u.shape[2:], dtype=torch.long) for u in x]) x = [u.flatten(2).transpose(1, 2) for u in x] seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long, device=device) assert seq_lens.max() <= seq_len x = torch.cat([ torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1) for u in x ]) # time embeddings with torch.amp.autocast('cuda', dtype=torch.float32): e = self.time_embedding( sinusoidal_embedding_1d(self.freq_dim, t).float()).float() e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float() assert e.dtype == torch.float32 and e0.dtype == torch.float32 # context context_lens = None context = self.text_embedding( torch.stack([ torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))]) for u in context ])) if self.insert_audio: audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio] audio_seq_len = torch.tensor(max([au.shape[2] for au in audio]) * audio[0].shape[3], device=get_device()) audio = [au.flatten(2).transpose(1, 2) for au in audio] # [1, t*32, 1536] audio_seq_lens = torch.tensor([au.size(1) for au in audio], dtype=torch.long, device=device) audio = torch.cat([ torch.cat([au, au.new_zeros(1, audio_seq_len - au.size(1), au.size(2))], dim=1) for au in audio ]) else: audio = None audio_seq_len = None audio_seq_lens = None # ulysses support sp_world = get_unified_parallel_world_size() group = get_unified_parallel_group() if seq_len % sp_world: padding_size = sp_world - (seq_len % sp_world) x = pad_tensor(x, dim=1, padding_size=padding_size) if self.insert_audio: audio_padding_size = sp_world - (audio_seq_len % sp_world) audio = pad_tensor(audio, dim=1, padding_size=audio_padding_size) x = Slice.apply(group, x, 1, True) if self.insert_audio: audio = Slice.apply(group, audio, 1, True) # arguments kwargs = dict( e=e0, seq_lens=seq_lens, grid_sizes=grid_sizes, freqs=self.freqs, context=context, context_lens=context_lens, audio=audio, audio_seq_len=audio_seq_len) for block in self.blocks: x = block(x, **kwargs) # head x = self.head(x, e) # ulysses support x = gather_outputs(x, gather_dim=1, padding_dim=1, unpad_dim_size=seq_len, scale_grad=True) # unpatchify x = self.unpatchify(x, grid_sizes) return [u.float() for u in x] def ulysses_attn_forward( self, x, seq_lens, grid_sizes, freqs, dtype=torch.bfloat16 ): b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim seq_len = seq_lens.max() half_dtypes = (torch.float16, torch.bfloat16) def half(x): return x if x.dtype in half_dtypes else x.to(dtype) # query, key, value function def qkv_fn(x): q = self.norm_q(self.q(x)) k = self.norm_k(self.k(x)) v = self.v(x) return q, k, v q, k, v = qkv_fn(x) # ulysses support sp_size = get_unified_parallel_world_size() if n % sp_size: pad_size = sp_size - (n % sp_size) pad_size = pad_size * d pad_inner_dim = n * d + pad_size q = pad_tensor(q, dim=2, padding_size=pad_size) k = pad_tensor(k, dim=2, padding_size=pad_size) v = pad_tensor(v, dim=2, padding_size=pad_size) else: pad_inner_dim = n * d qkv = torch.cat([q, k, v], dim=2) qkv = gather_seq_scatter_heads_qkv(qkv, seq_dim=1, unpadded_dim_size=seq_len) q, k, v = qkv.split(pad_inner_dim // sp_size, dim=2) pad_n = pad_inner_dim // d pad_split_n = pad_n // sp_size q = q.view(b, seq_len, pad_split_n, d) k = k.view(b, seq_len, pad_split_n, d) v = v.view(b, seq_len, pad_split_n, d) q = rope_apply(q, grid_sizes, freqs) k = rope_apply(k, grid_sizes, freqs) x = flash_attention( q=half(q), k=half(k), v=half(v), k_lens=seq_lens, window_size=self.window_size ) # ulysses support x = x.flatten(2) x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1) if n % sp_size: x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len) x = self.o(x) return x def ulysses_audio_cross_attn_forward( self, x, audio, seq_lens, grid_sizes, freqs, audio_seq_len, dtype=torch.bfloat16 ): b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim seq_len = seq_lens.max() q = self.norm_q(self.q(x)) k = self.norm_k(self.k(audio)) v = self.v(audio) # ulysses support sp_size = get_unified_parallel_world_size() if n % sp_size: pad_size = sp_size - (n % sp_size) pad_size = pad_size * d pad_inner_dim = n * d + pad_size q = pad_tensor(q, dim=2, padding_size=pad_size) k = pad_tensor(k, dim=2, padding_size=pad_size) v = pad_tensor(v, dim=2, padding_size=pad_size) else: pad_inner_dim = n * d qq = torch.cat([q, q], dim=2) kv = torch.cat([k, v], dim=2) qq = gather_seq_scatter_double_head(qq, seq_dim=1, unpadded_dim_size=seq_len) kv = gather_seq_scatter_double_head(kv, seq_dim=1, unpadded_dim_size=audio_seq_len) q, _ = qq.split(pad_inner_dim // sp_size, dim=2) k, v = kv.split(pad_inner_dim // sp_size, dim=2) pad_n = pad_inner_dim // d pad_split_n = pad_n // sp_size q = q.view(b, seq_len, pad_split_n, d) k = k.view(b, audio_seq_len, pad_split_n, d) v = v.view(b, audio_seq_len, pad_split_n, d) hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2]) assert hlen_wlen == 1560 or hlen_wlen == 3600 q = q.reshape(-1, hlen_wlen, pad_split_n, d) k = k.reshape(-1, 16, pad_split_n, d) v = v.reshape(-1, 16, pad_split_n, d) x = flash_attention( q=q, k=k, v=v, k_lens=None, ) x = x.view(b, -1, pad_split_n, d) # ulysses support x = x.flatten(2) x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1) if n % sp_size: x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len) x = self.o(x) return x