File size: 1,147 Bytes
2b67076
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
# Copyright 2024-2025 The Alibaba Wan Team Authors. All rights reserved.
import math
import types
from copy import deepcopy
from einops import  rearrange
from typing import List
import numpy as np
import torch
import torch.cuda.amp as amp
import torch.nn as nn

def after_patch_embedding(self, x: List[torch.Tensor], pose_latents, face_pixel_values):
    pose_latents = self.pose_patch_embedding(pose_latents)
    x[:, :, 1:] += pose_latents
    
    b,c,T,h,w = face_pixel_values.shape
    face_pixel_values = rearrange(face_pixel_values, "b c t h w -> (b t) c h w")
    encode_bs = 8
    face_pixel_values_tmp = []
    for i in range(math.ceil(face_pixel_values.shape[0]/encode_bs)):
        face_pixel_values_tmp.append(self.motion_encoder.get_motion(face_pixel_values[i*encode_bs:(i+1)*encode_bs]))

    motion_vec = torch.cat(face_pixel_values_tmp)
    
    motion_vec = rearrange(motion_vec, "(b t) c -> b t c", t=T)
    motion_vec = self.face_encoder(motion_vec)

    B, L, H, C = motion_vec.shape
    pad_face = torch.zeros(B, 1, H, C).type_as(motion_vec)
    motion_vec = torch.cat([pad_face, motion_vec], dim=1)
    return x, motion_vec