Spaces:
Running
Running
| from cv2 import norm | |
| import torch | |
| from torch import layer_norm, nn | |
| from mmcv.runner import BaseModule | |
| import numpy as np | |
| from ..builder import SUBMODULES | |
| from .position_encoding import SinusoidalPositionalEncoding, LearnedPositionalEncoding | |
| import math | |
| class ACTOREncoder(BaseModule): | |
| def __init__(self, | |
| max_seq_len=16, | |
| njoints=None, | |
| nfeats=None, | |
| input_feats=None, | |
| latent_dim=256, | |
| output_dim=256, | |
| condition_dim=None, | |
| num_heads=4, | |
| ff_size=1024, | |
| num_layers=8, | |
| activation='gelu', | |
| dropout=0.1, | |
| use_condition=False, | |
| num_class=None, | |
| use_final_proj=False, | |
| output_var=False, | |
| pos_embedding='sinusoidal', | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| self.njoints = njoints | |
| self.nfeats = nfeats | |
| if input_feats is None: | |
| assert self.njoints is not None and self.nfeats is not None | |
| self.input_feats = njoints * nfeats | |
| else: | |
| self.input_feats = input_feats | |
| self.max_seq_len = max_seq_len | |
| self.latent_dim = latent_dim | |
| self.condition_dim = condition_dim | |
| self.use_condition = use_condition | |
| self.num_class = num_class | |
| self.use_final_proj = use_final_proj | |
| self.output_var = output_var | |
| self.skelEmbedding = nn.Linear(self.input_feats, self.latent_dim) | |
| if self.use_condition: | |
| if num_class is None: | |
| self.mu_layer = build_MLP(self.condition_dim, self.latent_dim) | |
| if self.output_var: | |
| self.sigma_layer = build_MLP(self.condition_dim, self.latent_dim) | |
| else: | |
| self.mu_layer = nn.Parameter(torch.randn(num_class, self.latent_dim)) | |
| if self.output_var: | |
| self.sigma_layer = nn.Parameter(torch.randn(num_class, self.latent_dim)) | |
| else: | |
| if self.output_var: | |
| self.query = nn.Parameter(torch.randn(2, self.latent_dim)) | |
| else: | |
| self.query = nn.Parameter(torch.randn(1, self.latent_dim)) | |
| if pos_embedding == 'sinusoidal': | |
| self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout) | |
| else: | |
| self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len + 2) | |
| seqTransEncoderLayer = nn.TransformerEncoderLayer( | |
| d_model=self.latent_dim, | |
| nhead=num_heads, | |
| dim_feedforward=ff_size, | |
| dropout=dropout, | |
| activation=activation) | |
| self.seqTransEncoder = nn.TransformerEncoder( | |
| seqTransEncoderLayer, | |
| num_layers=num_layers) | |
| def forward(self, motion, motion_mask=None, condition=None): | |
| B, T = motion.shape[:2] | |
| motion = motion.view(B, T, -1) | |
| feature = self.skelEmbedding(motion) | |
| if self.use_condition: | |
| if self.output_var: | |
| if self.num_class is None: | |
| sigma_query = self.sigma_layer(condition).view(B, 1, -1) | |
| else: | |
| sigma_query = self.sigma_layer[condition.long()].view(B, 1, -1) | |
| feature = torch.cat((sigma_query, feature), dim=1) | |
| if self.num_class is None: | |
| mu_query = self.mu_layer(condition).view(B, 1, -1) | |
| else: | |
| mu_query = self.mu_layer[condition.long()].view(B, 1, -1) | |
| feature = torch.cat((mu_query, feature), dim=1) | |
| else: | |
| query = self.query.view(1, -1, self.latent_dim).repeat(B, 1, 1) | |
| feature = torch.cat((query, feature), dim=1) | |
| if self.output_var: | |
| motion_mask = torch.cat((torch.zeros(B, 2).to(motion.device), 1 - motion_mask), dim=1).bool() | |
| else: | |
| motion_mask = torch.cat((torch.zeros(B, 1).to(motion.device), 1 - motion_mask), dim=1).bool() | |
| feature = feature.permute(1, 0, 2).contiguous() | |
| feature = self.pos_encoder(feature) | |
| feature = self.seqTransEncoder(feature, src_key_padding_mask=motion_mask) | |
| if self.use_final_proj: | |
| mu = self.final_mu(feature[0]) | |
| if self.output_var: | |
| sigma = self.final_sigma(feature[1]) | |
| return mu, sigma | |
| return mu | |
| else: | |
| if self.output_var: | |
| return feature[0], feature[1] | |
| else: | |
| return feature[0] | |
| class ACTORDecoder(BaseModule): | |
| def __init__(self, | |
| max_seq_len=16, | |
| njoints=None, | |
| nfeats=None, | |
| input_feats=None, | |
| input_dim=256, | |
| latent_dim=256, | |
| condition_dim=None, | |
| num_heads=4, | |
| ff_size=1024, | |
| num_layers=8, | |
| activation='gelu', | |
| dropout=0.1, | |
| use_condition=False, | |
| num_class=None, | |
| pos_embedding='sinusoidal', | |
| init_cfg=None): | |
| super().__init__(init_cfg=init_cfg) | |
| if input_dim != latent_dim: | |
| self.linear = nn.Linear(input_dim, latent_dim) | |
| else: | |
| self.linear = nn.Identity() | |
| self.njoints = njoints | |
| self.nfeats = nfeats | |
| if input_feats is None: | |
| assert self.njoints is not None and self.nfeats is not None | |
| self.input_feats = njoints * nfeats | |
| else: | |
| self.input_feats = input_feats | |
| self.max_seq_len = max_seq_len | |
| self.input_dim = input_dim | |
| self.latent_dim = latent_dim | |
| self.condition_dim = condition_dim | |
| self.use_condition = use_condition | |
| self.num_class = num_class | |
| if self.use_condition: | |
| if num_class is None: | |
| self.condition_bias = build_MLP(condition_dim, latent_dim) | |
| else: | |
| self.condition_bias = nn.Parameter(torch.randn(num_class, latent_dim)) | |
| if pos_embedding == 'sinusoidal': | |
| self.pos_encoder = SinusoidalPositionalEncoding(latent_dim, dropout) | |
| else: | |
| self.pos_encoder = LearnedPositionalEncoding(latent_dim, dropout, max_len=max_seq_len) | |
| seqTransDecoderLayer = nn.TransformerDecoderLayer( | |
| d_model=self.latent_dim, | |
| nhead=num_heads, | |
| dim_feedforward=ff_size, | |
| dropout=dropout, | |
| activation=activation) | |
| self.seqTransDecoder = nn.TransformerDecoder( | |
| seqTransDecoderLayer, | |
| num_layers=num_layers) | |
| self.final = nn.Linear(self.latent_dim, self.input_feats) | |
| def forward(self, input, motion_mask=None, condition=None): | |
| B = input.shape[0] | |
| T = self.max_seq_len | |
| input = self.linear(input) | |
| if self.use_condition: | |
| if self.num_class is None: | |
| condition = self.condition_bias(condition) | |
| else: | |
| condition = self.condition_bias[condition.long()].squeeze(1) | |
| input = input + condition | |
| query = self.pos_encoder.pe[:T, :].view(T, 1, -1).repeat(1, B, 1) | |
| input = input.view(1, B, -1) | |
| feature = self.seqTransDecoder(tgt=query, memory=input, tgt_key_padding_mask=(1 - motion_mask).bool()) | |
| pose = self.final(feature).permute(1, 0, 2).contiguous() | |
| return pose | |