import numpy as np import smplx import torch from .optimize.smplify import SMPLify3D class joints2smpl: def __init__( self, device, num_smplify_iters=150, SMPL_MODEL_DIR= "./body_models/", GMM_MODEL_DIR="./joint2smpl_models/", ): self.device = torch.device(device) # self.batch_size = num_frames self.num_joints = 22 # for HumanML3D self.joint_category = "AMASS" self.num_smplify_iters = num_smplify_iters self.fix_foot = False smplmodel = smplx.create( SMPL_MODEL_DIR, model_type="smpl", gender="neutral", ext="pkl").to(self.device) # smplmodel = smplx.create( # SMPL_MODEL_DIR, model_type="smpl", gender="female", ext="pkl" # ).to(self.device) # # #-------------initialize SMPLify # self.smplify = SMPLify3D(smplxmodel=smplmodel, batch_size=self.batch_size, joints_category=self.joint_category, num_iters=self.num_smplify_iters, device=self.device) self.smplify = SMPLify3D( smplxmodel=smplmodel, # batch_size=self.batch_size, joints_category=self.joint_category, num_iters=self.num_smplify_iters, device=self.device, GMM_MODEL_DIR=GMM_MODEL_DIR ) def joint2smpl(self, input_joints, init_params=None, hmp=True, fix_betas=False): batch_size = input_joints.shape[0] _smplify = self.smplify # if init_params is None else self.smplify_fast pred_pose = torch.zeros(batch_size, 72).to(self.device) pred_betas = torch.zeros(1, 10).expand(batch_size, 10).to(self.device) pred_cam_t = torch.zeros(1, 3).expand(batch_size, 3).to(self.device) keypoints_3d = torch.Tensor(input_joints).to(self.device).float() if init_params is None: # assert 0, "Not implemented. Missing init pose." pred_betas = ( torch.tensor([ 1.47646511e+00, 4.79749501e-01, -6.36047006e-01, -1.52980864e+00, -1.11884427e+00, -5.40487289e-01, 3.93005997e-01, -1.88832569e+00, -2.78680950e-01, -5.49529344e-02]) .unsqueeze(0) .repeat(batch_size, 1) .float() .to(self.device) ) pred_pose = (torch.tensor([ 0.4531, 0.3044, 0.2968, -0.2239, 0.0174, 0.0925, -0.2378, -0.0465, -0.0786, 0.2782, 0.0141, 0.0138, 0.4328, -0.0629, -0.0961, 0.5043, 0.0035, 0.0610, 0.0230, -0.0317, 0.0058, 0.0070, 0.1317, -0.0544, -0.0589, -0.1752, 0.1355, 0.0134, -0.0037, 0.0089, -0.2093, 0.1600, 0.1092, -0.0387, 0.0824, -0.2041, -0.0056, -0.0075, -0.0035, -0.0237, -0.1248, -0.2736, -0.0459, 0.1991, 0.2373, 0.0667, -0.0405, 0.0329, 0.0536, -0.2914, -0.6969, 0.0559, 0.2858, 0.6525, 0.1222, -0.9116, 0.2383, -0.0366, 0.9237, -0.2554, -0.0657, -0.1045, 0.0501, -0.0388, 0.0909, -0.0707, -0.1437, -0.0590, -0.1801, -0.0875, 0.1093, 0.2009]) .unsqueeze(0) .repeat(batch_size, 1) .float().to(self.device) ) pred_cam_t = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0).to(self.device) else: pred_betas = init_params["betas"] pred_pose = init_params["pose"] pred_cam_t = init_params["cam"] if self.joint_category == "AMASS": confidence_input = torch.ones(self.num_joints) # make sure the foot and ankle if self.fix_foot == True: confidence_input[7] = 1.5 confidence_input[8] = 1.5 confidence_input[10] = 1.5 confidence_input[11] = 1.5 else: print("Such category not settle down!") ( new_opt_vertices, new_opt_joints, new_opt_pose, new_opt_betas, new_opt_cam_t, ) = _smplify( pred_pose.detach(), pred_betas.detach(), pred_cam_t.detach(), keypoints_3d, conf_3d=confidence_input.to(self.device), if_simple_hmp_optimizes=hmp, fix_betas=fix_betas, ) # thetas = new_opt_pose.reshape(batch_size, 24, 3) # thetas = matrix_to_rotation_6d( # axis_angle_to_matrix(thetas) # ) # [bs, 24, 6] # # root_loc = torch.tensor(keypoints_3d[:, 0]) # [bs, 3] # root_loc = keypoints_3d[:, 0].clone() # [bs, 3] # root_loc = torch.cat([root_loc, torch.zeros_like(root_loc)], dim=-1).unsqueeze( # 1 # ) # [bs, 1, 6] # thetas = torch.cat([thetas, root_loc], dim=1).permute(1, 2, 0) # [25, 6, 196] # thetas = thetas.clone().detach() return new_opt_pose, { "pose": new_opt_joints[..., :24, :].clone().detach(), "betas": new_opt_betas.clone().detach(), "cam": new_opt_cam_t.clone().detach(), } class Skeleton2Obj: def __init__( self, device: str = "cpu", num_smplify_iters=150, smpl_model_dir="./body_models/", gmm_model_dir="./joint2smpl_models/", ): self.simplify = joints2smpl( device=device, num_smplify_iters=num_smplify_iters, SMPL_MODEL_DIR=smpl_model_dir, GMM_MODEL_DIR=gmm_model_dir, ) def convert_motion_2smpl(self, motion, hmp=True, init_params=None, fix_betas=False) -> tuple[np.ndarray, np.ndarray]: new_opt_rot, smpl_dict = self.simplify.joint2smpl(motion, hmp=hmp, init_params=init_params, fix_betas=fix_betas) return new_opt_rot, smpl_dict