HuMo_local / humo /models /distributed /dit_ulysses_sequence_parallel.py
alexnasa's picture
Upload 54 files
295978e verified
# Copyright (c) 2025 Bytedance Ltd. and/or its affiliates
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from einops import rearrange
from common.distributed import get_device
from common.distributed.advanced import (
get_unified_parallel_world_size,
get_unified_parallel_group,
pad_tensor,
Slice,
gather_outputs,
gather_seq_scatter_heads_qkv,
gather_seq_scatter_double_head,
gather_heads_scatter_seq,
unpad_tensor
)
from humo.models.wan_modules.attention import flash_attention
from humo.models.wan_modules.model_humo import rope_apply, sinusoidal_embedding_1d
def ulysses_dit_forward(
self,
x,
t,
context,
seq_len,
audio=None,
y=None
):
"""
x: A list of videos each with shape [C, T, H, W].
t: [B].
context: A list of text embeddings each with shape [L, C].
"""
if self.model_type == 'i2v':
# assert clip_fea is not None and y is not None
assert y is not None
# params
device = self.patch_embedding.weight.device
if self.freqs.device != device:
self.freqs = self.freqs.to(device)
if y is not None:
x = [torch.cat([u, v], dim=0) for u, v in zip(x, y)]
# embeddings
x = [self.patch_embedding(u.unsqueeze(0)) for u in x]
grid_sizes = torch.stack(
[torch.tensor(u.shape[2:], dtype=torch.long) for u in x])
x = [u.flatten(2).transpose(1, 2) for u in x]
seq_lens = torch.tensor([u.size(1) for u in x], dtype=torch.long, device=device)
assert seq_lens.max() <= seq_len
x = torch.cat([
torch.cat([u, u.new_zeros(1, seq_len - u.size(1), u.size(2))], dim=1)
for u in x
])
# time embeddings
with torch.amp.autocast('cuda', dtype=torch.float32):
e = self.time_embedding(
sinusoidal_embedding_1d(self.freq_dim, t).float()).float()
e0 = self.time_projection(e).unflatten(1, (6, self.dim)).float()
assert e.dtype == torch.float32 and e0.dtype == torch.float32
# context
context_lens = None
context = self.text_embedding(
torch.stack([
torch.cat([u, u.new_zeros(self.text_len - u.size(0), u.size(1))])
for u in context
]))
if self.insert_audio:
audio = [self.audio_proj(au.unsqueeze(0)).permute(0, 3, 1, 2) for au in audio]
audio_seq_len = torch.tensor(max([au.shape[2] for au in audio]) * audio[0].shape[3], device=get_device())
audio = [au.flatten(2).transpose(1, 2) for au in audio] # [1, t*32, 1536]
audio_seq_lens = torch.tensor([au.size(1) for au in audio], dtype=torch.long, device=device)
audio = torch.cat([
torch.cat([au, au.new_zeros(1, audio_seq_len - au.size(1), au.size(2))],
dim=1) for au in audio
])
else:
audio = None
audio_seq_len = None
audio_seq_lens = None
# ulysses support
sp_world = get_unified_parallel_world_size()
group = get_unified_parallel_group()
if seq_len % sp_world:
padding_size = sp_world - (seq_len % sp_world)
x = pad_tensor(x, dim=1, padding_size=padding_size)
if self.insert_audio:
audio_padding_size = sp_world - (audio_seq_len % sp_world)
audio = pad_tensor(audio, dim=1, padding_size=audio_padding_size)
x = Slice.apply(group, x, 1, True)
if self.insert_audio:
audio = Slice.apply(group, audio, 1, True)
# arguments
kwargs = dict(
e=e0,
seq_lens=seq_lens,
grid_sizes=grid_sizes,
freqs=self.freqs,
context=context,
context_lens=context_lens,
audio=audio,
audio_seq_len=audio_seq_len)
for block in self.blocks:
x = block(x, **kwargs)
# head
x = self.head(x, e)
# ulysses support
x = gather_outputs(x, gather_dim=1, padding_dim=1, unpad_dim_size=seq_len, scale_grad=True)
# unpatchify
x = self.unpatchify(x, grid_sizes)
return [u.float() for u in x]
def ulysses_attn_forward(
self,
x,
seq_lens,
grid_sizes,
freqs,
dtype=torch.bfloat16
):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
seq_len = seq_lens.max()
half_dtypes = (torch.float16, torch.bfloat16)
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# query, key, value function
def qkv_fn(x):
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(x))
v = self.v(x)
return q, k, v
q, k, v = qkv_fn(x)
# ulysses support
sp_size = get_unified_parallel_world_size()
if n % sp_size:
pad_size = sp_size - (n % sp_size)
pad_size = pad_size * d
pad_inner_dim = n * d + pad_size
q = pad_tensor(q, dim=2, padding_size=pad_size)
k = pad_tensor(k, dim=2, padding_size=pad_size)
v = pad_tensor(v, dim=2, padding_size=pad_size)
else:
pad_inner_dim = n * d
qkv = torch.cat([q, k, v], dim=2)
qkv = gather_seq_scatter_heads_qkv(qkv, seq_dim=1, unpadded_dim_size=seq_len)
q, k, v = qkv.split(pad_inner_dim // sp_size, dim=2)
pad_n = pad_inner_dim // d
pad_split_n = pad_n // sp_size
q = q.view(b, seq_len, pad_split_n, d)
k = k.view(b, seq_len, pad_split_n, d)
v = v.view(b, seq_len, pad_split_n, d)
q = rope_apply(q, grid_sizes, freqs)
k = rope_apply(k, grid_sizes, freqs)
x = flash_attention(
q=half(q),
k=half(k),
v=half(v),
k_lens=seq_lens,
window_size=self.window_size
)
# ulysses support
x = x.flatten(2)
x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1)
if n % sp_size:
x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len)
x = self.o(x)
return x
def ulysses_audio_cross_attn_forward(
self,
x,
audio,
seq_lens,
grid_sizes,
freqs,
audio_seq_len,
dtype=torch.bfloat16
):
b, s, n, d = *x.shape[:2], self.num_heads, self.head_dim
seq_len = seq_lens.max()
q = self.norm_q(self.q(x))
k = self.norm_k(self.k(audio))
v = self.v(audio)
# ulysses support
sp_size = get_unified_parallel_world_size()
if n % sp_size:
pad_size = sp_size - (n % sp_size)
pad_size = pad_size * d
pad_inner_dim = n * d + pad_size
q = pad_tensor(q, dim=2, padding_size=pad_size)
k = pad_tensor(k, dim=2, padding_size=pad_size)
v = pad_tensor(v, dim=2, padding_size=pad_size)
else:
pad_inner_dim = n * d
qq = torch.cat([q, q], dim=2)
kv = torch.cat([k, v], dim=2)
qq = gather_seq_scatter_double_head(qq, seq_dim=1, unpadded_dim_size=seq_len)
kv = gather_seq_scatter_double_head(kv, seq_dim=1, unpadded_dim_size=audio_seq_len)
q, _ = qq.split(pad_inner_dim // sp_size, dim=2)
k, v = kv.split(pad_inner_dim // sp_size, dim=2)
pad_n = pad_inner_dim // d
pad_split_n = pad_n // sp_size
q = q.view(b, seq_len, pad_split_n, d)
k = k.view(b, audio_seq_len, pad_split_n, d)
v = v.view(b, audio_seq_len, pad_split_n, d)
hlen_wlen = int(grid_sizes[0][1] * grid_sizes[0][2])
assert hlen_wlen == 1560 or hlen_wlen == 3600
q = q.reshape(-1, hlen_wlen, pad_split_n, d)
k = k.reshape(-1, 16, pad_split_n, d)
v = v.reshape(-1, 16, pad_split_n, d)
x = flash_attention(
q=q,
k=k,
v=v,
k_lens=None,
)
x = x.view(b, -1, pad_split_n, d)
# ulysses support
x = x.flatten(2)
x = gather_heads_scatter_seq(x, head_dim=2, seq_dim=1)
if n % sp_size:
x = unpad_tensor(x, dim=2, unpad_dim_size=seq_len)
x = self.o(x)
return x