File size: 3,338 Bytes
352b049
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
# This code is based on https://github.com/Mathux/ACTOR.git
import torch
import numpy as np
import torch.nn.functional as F
from . import rotation_conversions as geometry
from .smpl import SMPL

JOINTSTYPES = ["a2m", "a2mpl", "smpl", "vibe", "vertices"]


class Rotation2xyz:
    def __init__(self, device, dataset="amass"):
        self.device = device
        self.dataset = dataset
        self.smpl_model = SMPL().eval().to(device)

    def __call__(
        self,
        x,
        mask,
        pose_rep,
        translation,
        glob,
        jointstype,
        vertstrans,
        betas=None,
        beta=0,
        glob_rot=None,
        get_rotations_back=False,
        **kwargs
    ):
        if pose_rep == "xyz":
            return x

        if mask is None:
            mask = torch.ones((x.shape[-1]), dtype=bool, device=x.device)

        if not glob and glob_rot is None:
            raise TypeError("You must specify global rotation if glob is False")

        if jointstype not in JOINTSTYPES:
            raise NotImplementedError("This jointstype is not implemented.")

        if translation:
            x_translations = x[-1, :3]
            x_rotations = x[:-1]
        else:
            x_rotations = x

        x_rotations = x_rotations.permute(2, 0, 1)
        time, njoints, feats = x_rotations.shape

        # Compute rotations (convert only masked sequences output)
        if pose_rep == "rotvec":
            rotations = geometry.axis_angle_to_matrix(x_rotations[mask])
        elif pose_rep == "rotmat":
            rotations = x_rotations[mask].view(-1, njoints, 3, 3)
        elif pose_rep == "rotquat":
            rotations = geometry.quaternion_to_matrix(x_rotations[mask])
        elif pose_rep == "rot6d":
            rotations = geometry.rotation_6d_to_matrix(x_rotations[mask])
        else:
            raise NotImplementedError("No geometry for this one.")

        if not glob:
            global_orient = torch.tensor(glob_rot, device=x.device)
            global_orient = geometry.axis_angle_to_matrix(global_orient).view(
                1, 1, 3, 3
            )
            global_orient = global_orient.repeat(len(rotations), 1, 1, 1)
        else:
            global_orient = rotations[:, 0]
            rotations = rotations[:, 1:]

        if betas is None:
            betas = torch.zeros(
                [rotations.shape[0], self.smpl_model.num_betas],
                dtype=rotations.dtype,
                device=rotations.device,
            )
            betas[:, 1] = beta
        print(betas)
        out = self.smpl_model(
            body_pose=rotations, global_orient=global_orient, betas=betas
        )

        # get the desirable joints
        joints = out[jointstype]

        x_xyz = torch.empty(time, joints.shape[1], 3, device=x.device, dtype=x.dtype)
        x_xyz[~mask] = 0
        x_xyz[mask] = joints

        x_xyz = x_xyz.permute(1, 2, 0).contiguous()

        if translation and vertstrans:
            # the first translation root at the origin
            x_translations = x_translations - x_translations[:, [0]]

            # add the translation to all the joints
            x_xyz = x_xyz + x_translations[None, :, :]

        if get_rotations_back:
            return x_xyz, rotations, global_orient
        else:
            return x_xyz