|
|
|
|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
from ovi.modules.model import WanLayerNorm, WanModel, WanRMSNorm, gradient_checkpointing, rope_apply
|
|
|
from ovi.modules.attention import flash_attention
|
|
|
from ovi.distributed_comms.communications import all_gather, all_to_all_4D
|
|
|
from ovi.distributed_comms.parallel_states import nccl_info, get_sequence_parallel_state
|
|
|
|
|
|
class FusionModel(nn.Module):
|
|
|
def __init__(self, video_config=None, audio_config=None):
|
|
|
super().__init__()
|
|
|
has_video = True
|
|
|
has_audio = True
|
|
|
if video_config is not None:
|
|
|
self.video_model = WanModel(**video_config)
|
|
|
else:
|
|
|
has_video = False
|
|
|
self.video_model = None
|
|
|
print("Warning: No video model is provided!")
|
|
|
|
|
|
if audio_config is not None:
|
|
|
self.audio_model = WanModel(**audio_config)
|
|
|
else:
|
|
|
has_audio = False
|
|
|
self.audio_model = None
|
|
|
print("Warning: No audio model is provided!")
|
|
|
|
|
|
if has_video and has_audio:
|
|
|
assert len(self.video_model.blocks) == len(self.audio_model.blocks)
|
|
|
self.num_blocks = len(self.video_model.blocks)
|
|
|
|
|
|
self.use_sp = get_sequence_parallel_state()
|
|
|
if self.use_sp:
|
|
|
self.sp_size = nccl_info.sp_size
|
|
|
self.sp_rank = nccl_info.rank_within_group
|
|
|
self.inject_cross_attention_kv_projections()
|
|
|
|
|
|
self.init_weights()
|
|
|
|
|
|
def inject_cross_attention_kv_projections(self):
|
|
|
for vid_block in self.video_model.blocks:
|
|
|
vid_block.cross_attn.k_fusion = nn.Linear(vid_block.dim, vid_block.dim)
|
|
|
vid_block.cross_attn.v_fusion = nn.Linear(vid_block.dim, vid_block.dim)
|
|
|
vid_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(vid_block.dim, elementwise_affine=True)
|
|
|
vid_block.cross_attn.norm_k_fusion = WanRMSNorm(vid_block.dim, eps=1e-6) if vid_block.qk_norm else nn.Identity()
|
|
|
|
|
|
|
|
|
for audio_block in self.audio_model.blocks:
|
|
|
audio_block.cross_attn.k_fusion = nn.Linear(audio_block.dim, audio_block.dim)
|
|
|
audio_block.cross_attn.v_fusion = nn.Linear(audio_block.dim, audio_block.dim)
|
|
|
audio_block.cross_attn.pre_attn_norm_fusion = WanLayerNorm(audio_block.dim, elementwise_affine=True)
|
|
|
audio_block.cross_attn.norm_k_fusion = WanRMSNorm(audio_block.dim, eps=1e-6) if audio_block.qk_norm else nn.Identity()
|
|
|
|
|
|
|
|
|
def merge_kwargs(self, vid_kwargs, audio_kwargs):
|
|
|
"""
|
|
|
keys in each kwarg:
|
|
|
e
|
|
|
seq_lens
|
|
|
grid_sizes
|
|
|
freqs
|
|
|
context
|
|
|
context_lens
|
|
|
"""
|
|
|
merged_kwargs = {}
|
|
|
for key in vid_kwargs:
|
|
|
merged_kwargs[f"vid_{key}"] = vid_kwargs[key]
|
|
|
for key in audio_kwargs:
|
|
|
merged_kwargs[f"audio_{key}"] = audio_kwargs[key]
|
|
|
return merged_kwargs
|
|
|
|
|
|
def single_fusion_cross_attention_forward(self,
|
|
|
cross_attn_block,
|
|
|
src_seq,
|
|
|
src_grid_sizes,
|
|
|
src_freqs,
|
|
|
target_seq,
|
|
|
target_seq_lens,
|
|
|
target_grid_sizes,
|
|
|
target_freqs,
|
|
|
context,
|
|
|
context_lens
|
|
|
):
|
|
|
b, n, d = src_seq.size(0), cross_attn_block.num_heads, cross_attn_block.head_dim
|
|
|
if hasattr(cross_attn_block, "k_img"):
|
|
|
|
|
|
q, k, v, k_img, v_img = cross_attn_block.qkv_fn(src_seq, context)
|
|
|
else:
|
|
|
|
|
|
q, k, v = cross_attn_block.qkv_fn(src_seq, context)
|
|
|
k_img = v_img = None
|
|
|
|
|
|
|
|
|
if self.use_sp:
|
|
|
q = all_to_all_4D(q, scatter_dim=2, gather_dim=1)
|
|
|
k = torch.chunk(k, self.sp_size, dim=2)[self.sp_rank]
|
|
|
v = torch.chunk(v, self.sp_size, dim=2)[self.sp_rank]
|
|
|
if k_img is not None:
|
|
|
k_img = torch.chunk(k_img, self.sp_size, dim=2)[self.sp_rank]
|
|
|
if v_img is not None:
|
|
|
v_img = torch.chunk(v_img, self.sp_size, dim=2)[self.sp_rank]
|
|
|
|
|
|
x = flash_attention(q, k, v, k_lens=context_lens)
|
|
|
|
|
|
if k_img is not None:
|
|
|
img_x = flash_attention(q, k_img, v_img, k_lens=None)
|
|
|
x = x + img_x
|
|
|
|
|
|
is_vid = src_grid_sizes.shape[1] > 1
|
|
|
|
|
|
target_seq = cross_attn_block.pre_attn_norm_fusion(target_seq)
|
|
|
k_target = cross_attn_block.norm_k_fusion(cross_attn_block.k_fusion(target_seq)).view(b, -1, n, d)
|
|
|
v_target = cross_attn_block.v_fusion(target_seq).view(b, -1, n, d)
|
|
|
if self.use_sp:
|
|
|
k_target = all_to_all_4D(k_target, scatter_dim=2, gather_dim=1)
|
|
|
v_target = all_to_all_4D(v_target, scatter_dim=2, gather_dim=1)
|
|
|
|
|
|
q = rope_apply(q, src_grid_sizes, src_freqs)
|
|
|
k_target = rope_apply(k_target, target_grid_sizes, target_freqs)
|
|
|
|
|
|
target_x = flash_attention(q, k_target, v_target, k_lens=target_seq_lens)
|
|
|
|
|
|
x = x + target_x
|
|
|
if self.use_sp:
|
|
|
x = all_to_all_4D(x, scatter_dim=1, gather_dim=2)
|
|
|
|
|
|
x = x.flatten(2)
|
|
|
|
|
|
x = cross_attn_block.o(x)
|
|
|
return x
|
|
|
|
|
|
def single_fusion_cross_attention_ffn_forward(self,
|
|
|
attn_block,
|
|
|
src_seq,
|
|
|
src_grid_sizes,
|
|
|
src_freqs,
|
|
|
target_seq,
|
|
|
target_seq_lens,
|
|
|
target_grid_sizes,
|
|
|
target_freqs,
|
|
|
context,
|
|
|
context_lens,
|
|
|
src_e):
|
|
|
|
|
|
src_seq = src_seq + self.single_fusion_cross_attention_forward(attn_block.cross_attn,
|
|
|
attn_block.norm3(src_seq),
|
|
|
src_grid_sizes=src_grid_sizes,
|
|
|
src_freqs=src_freqs,
|
|
|
target_seq=target_seq,
|
|
|
target_seq_lens=target_seq_lens,
|
|
|
target_grid_sizes=target_grid_sizes,
|
|
|
target_freqs=target_freqs,
|
|
|
context=context,
|
|
|
context_lens=context_lens
|
|
|
)
|
|
|
y = attn_block.ffn(attn_block.norm2(src_seq).bfloat16() * (1 + src_e[4].squeeze(2)) + src_e[3].squeeze(2))
|
|
|
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
|
|
src_seq = src_seq + y * src_e[5].squeeze(2)
|
|
|
return src_seq
|
|
|
|
|
|
def single_fusion_block_forward(self,
|
|
|
vid_block,
|
|
|
audio_block,
|
|
|
vid,
|
|
|
audio,
|
|
|
vid_e,
|
|
|
vid_seq_lens,
|
|
|
vid_grid_sizes,
|
|
|
vid_freqs,
|
|
|
vid_context,
|
|
|
vid_context_lens,
|
|
|
audio_e,
|
|
|
audio_seq_lens,
|
|
|
audio_grid_sizes,
|
|
|
audio_freqs,
|
|
|
audio_context,
|
|
|
audio_context_lens
|
|
|
):
|
|
|
|
|
|
assert audio_e.dtype == torch.bfloat16
|
|
|
assert len(audio_e.shape) == 4 and audio_e.size(2) == 6 and audio_e.shape[1] == audio.shape[1], f"{audio_e.shape}, {audio.shape}"
|
|
|
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
|
|
audio_e = audio_block.modulation(audio_e).chunk(6, dim=2)
|
|
|
assert audio_e[0].dtype == torch.bfloat16
|
|
|
|
|
|
|
|
|
audio_y = audio_block.self_attn(
|
|
|
audio_block.norm1(audio).bfloat16() * (1 + audio_e[1].squeeze(2)) + audio_e[0].squeeze(2), audio_seq_lens, audio_grid_sizes,
|
|
|
audio_freqs)
|
|
|
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
|
|
audio = audio + audio_y * audio_e[2].squeeze(2)
|
|
|
|
|
|
|
|
|
assert len(vid_e.shape) == 4 and vid_e.size(2) == 6 and vid_e.shape[1] == vid.shape[1], f"{vid_e.shape}, {vid.shape}"
|
|
|
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
|
|
vid_e = vid_block.modulation(vid_e).chunk(6, dim=2)
|
|
|
|
|
|
|
|
|
vid_y = vid_block.self_attn(
|
|
|
vid_block.norm1(vid).bfloat16() * (1 + vid_e[1].squeeze(2)) + vid_e[0].squeeze(2), vid_seq_lens, vid_grid_sizes,
|
|
|
vid_freqs)
|
|
|
|
|
|
with torch.amp.autocast('cuda', dtype=torch.bfloat16):
|
|
|
vid = vid + vid_y * vid_e[2].squeeze(2)
|
|
|
|
|
|
og_audio = audio
|
|
|
|
|
|
|
|
|
audio = self.single_fusion_cross_attention_ffn_forward(
|
|
|
audio_block,
|
|
|
audio,
|
|
|
audio_grid_sizes,
|
|
|
audio_freqs,
|
|
|
vid,
|
|
|
vid_seq_lens,
|
|
|
vid_grid_sizes,
|
|
|
vid_freqs,
|
|
|
audio_context,
|
|
|
audio_context_lens,
|
|
|
audio_e
|
|
|
)
|
|
|
|
|
|
assert not torch.equal(og_audio, audio), "Audio should be changed after cross-attention!"
|
|
|
|
|
|
|
|
|
vid = self.single_fusion_cross_attention_ffn_forward(
|
|
|
vid_block,
|
|
|
vid,
|
|
|
vid_grid_sizes,
|
|
|
vid_freqs,
|
|
|
og_audio,
|
|
|
audio_seq_lens,
|
|
|
audio_grid_sizes,
|
|
|
audio_freqs,
|
|
|
vid_context,
|
|
|
vid_context_lens,
|
|
|
vid_e
|
|
|
)
|
|
|
|
|
|
return vid, audio
|
|
|
|
|
|
def forward(
|
|
|
self,
|
|
|
vid,
|
|
|
audio,
|
|
|
t,
|
|
|
vid_context,
|
|
|
audio_context,
|
|
|
vid_seq_len,
|
|
|
audio_seq_len,
|
|
|
clip_fea=None,
|
|
|
clip_fea_audio=None,
|
|
|
y=None,
|
|
|
first_frame_is_clean=False,
|
|
|
slg_layer=False
|
|
|
):
|
|
|
|
|
|
assert clip_fea is None
|
|
|
assert y is None
|
|
|
|
|
|
if vid is None or all([x is None for x in vid]):
|
|
|
assert vid_context is None
|
|
|
assert vid_seq_len is None
|
|
|
assert self.audio_model is not None
|
|
|
|
|
|
return None, self.audio_model(x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None)
|
|
|
|
|
|
if audio is None or all([x is None for x in audio]):
|
|
|
assert clip_fea_audio is None
|
|
|
assert audio_context is None
|
|
|
assert audio_seq_len is None
|
|
|
assert self.video_model is not None
|
|
|
|
|
|
return self.video_model(x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean), None
|
|
|
|
|
|
vid, vid_e, vid_kwargs = self.video_model.prepare_transformer_block_kwargs(
|
|
|
x=vid, t=t, context=vid_context, seq_len=vid_seq_len, clip_fea=clip_fea, y=y, first_frame_is_clean=first_frame_is_clean
|
|
|
)
|
|
|
|
|
|
audio, audio_e, audio_kwargs = self.audio_model.prepare_transformer_block_kwargs(
|
|
|
x=audio, t=t, context=audio_context, seq_len=audio_seq_len, clip_fea=clip_fea_audio, y=None, first_frame_is_clean=False
|
|
|
)
|
|
|
|
|
|
kwargs = self.merge_kwargs(vid_kwargs, audio_kwargs)
|
|
|
|
|
|
for i in range(self.num_blocks):
|
|
|
"""
|
|
|
1 fusion block refers to 1 audio block with 1 video block.
|
|
|
"""
|
|
|
if slg_layer > 0 and i == slg_layer:
|
|
|
continue
|
|
|
vid_block = self.video_model.blocks[i]
|
|
|
audio_block = self.audio_model.blocks[i]
|
|
|
vid, audio = gradient_checkpointing(
|
|
|
enabled=(self.training and self.gradient_checkpointing),
|
|
|
module=self.single_fusion_block_forward,
|
|
|
vid_block=vid_block,
|
|
|
audio_block=audio_block,
|
|
|
vid=vid,
|
|
|
audio=audio,
|
|
|
**kwargs
|
|
|
)
|
|
|
|
|
|
vid = self.video_model.post_transformer_block_out(vid, vid_kwargs['grid_sizes'], vid_e)
|
|
|
audio = self.audio_model.post_transformer_block_out(audio, audio_kwargs['grid_sizes'], audio_e)
|
|
|
|
|
|
return vid, audio
|
|
|
|
|
|
def init_weights(self):
|
|
|
if self.audio_model is not None:
|
|
|
self.audio_model.init_weights()
|
|
|
|
|
|
if self.video_model is not None:
|
|
|
self.video_model.init_weights()
|
|
|
|
|
|
for name, mod in self.video_model.named_modules():
|
|
|
if "fusion" in name and isinstance(mod, nn.Linear):
|
|
|
with torch.no_grad():
|
|
|
mod.weight.div_(10.0)
|
|
|
|
|
|
|
|
|
def set_rope_params(self):
|
|
|
self.video_model.set_rope_params()
|
|
|
self.audio_model.set_rope_params() |