Spaces:
Running
on
Zero
Running
on
Zero
| # This code is based on https://github.com/Mathux/ACTOR.git | |
| import torch | |
| import numpy as np | |
| import torch.nn.functional as F | |
| from . import rotation_conversions as geometry | |
| from .smpl import SMPL | |
| JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"] | |
| class Rotation2xyz: | |
| def __init__(self, device, dataset="amass"): | |
| self.device = device | |
| self.dataset = dataset | |
| self.smpl_model = SMPL().eval().to(device) | |
| def __call__( | |
| self, | |
| x, | |
| mask, | |
| pose_rep, | |
| translation, | |
| glob, | |
| jointstype, | |
| vertstrans, | |
| betas=None, | |
| beta=0, | |
| glob_rot=None, | |
| get_rotations_back=False, | |
| **kwargs | |
| ): | |
| if pose_rep == "xyz": | |
| return x | |
| if mask is None: | |
| mask = torch.ones((x.shape[-1]), dtype=bool, device=x.device) | |
| if not glob and glob_rot is None: | |
| raise TypeError("You must specify global rotation if glob is False") | |
| if jointstype not in JOINTSTYPES: | |
| raise NotImplementedError("This jointstype is not implemented.") | |
| if translation: | |
| x_translations = x[-1, :3] | |
| x_rotations = x[:-1] | |
| else: | |
| x_rotations = x | |
| x_rotations = x_rotations.permute(2, 0, 1) | |
| time, njoints, feats = x_rotations.shape | |
| # Compute rotations (convert only masked sequences output) | |
| if pose_rep == "rotvec": | |
| rotations = geometry.axis_angle_to_matrix(x_rotations[mask]) | |
| elif pose_rep == "rotmat": | |
| rotations = x_rotations[mask].view(-1, njoints, 3, 3) | |
| elif pose_rep == "rotquat": | |
| rotations = geometry.quaternion_to_matrix(x_rotations[mask]) | |
| elif pose_rep == "rot6d": | |
| rotations = geometry.rotation_6d_to_matrix(x_rotations[mask]) | |
| else: | |
| raise NotImplementedError("No geometry for this one.") | |
| if not glob: | |
| global_orient = torch.tensor(glob_rot, device=x.device) | |
| global_orient = geometry.axis_angle_to_matrix(global_orient).view( | |
| 1, 1, 3, 3 | |
| ) | |
| global_orient = global_orient.repeat(len(rotations), 1, 1, 1) | |
| else: | |
| global_orient = rotations[:, 0] | |
| rotations = rotations[:, 1:] | |
| if betas is None: | |
| betas = torch.zeros( | |
| [rotations.shape[0], self.smpl_model.num_betas], | |
| dtype=rotations.dtype, | |
| device=rotations.device, | |
| ) | |
| betas[:, 1] = beta | |
| print(betas) | |
| out = self.smpl_model( | |
| body_pose=rotations, global_orient=global_orient, betas=betas | |
| ) | |
| # get the desirable joints | |
| joints = out[jointstype] | |
| x_xyz = torch.empty(time, joints.shape[1], 3, device=x.device, dtype=x.dtype) | |
| x_xyz[~mask] = 0 | |
| x_xyz[mask] = joints | |
| x_xyz = x_xyz.permute(1, 2, 0).contiguous() | |
| if translation and vertstrans: | |
| # the first translation root at the origin | |
| x_translations = x_translations - x_translations[:, [0]] | |
| # add the translation to all the joints | |
| x_xyz = x_xyz + x_translations[None, :, :] | |
| if get_rotations_back: | |
| return x_xyz, rotations, global_orient | |
| else: | |
| return x_xyz | |