File size: 3,665 Bytes
352b049
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch
import numpy as np
import os

from mesh.smpl2mesh import SMPL2Mesh

from skeleton import SkeletonAMASS, convert2humanml
from mesh.io import save_obj, to_mesh
from skeleton2smpl.skeleton2smpl import Skeleton2Obj
import json


def get_humanml_motion(npy_file, skeleton, remove_global_translation=False):
    
    motion = torch.from_numpy(np.load(npy_file, allow_pickle=True))

    if remove_global_translation:
        #remove hip motion
        motion = motion - motion[..., 0:1, :]

    humanml_motion = convert2humanml(
        motion, skeleton.LANDMARKS, skeleton.TO_HUMANML_NAMES
    )
    return humanml_motion

def save_mesh(vertices, faces, npy_file):
    def npy_path_to_obj_path(npy_path: str) -> str:
        return os.path.join(os.path.dirname(npy_path) ,  f"{npy_path}_obj")
    results_dir = npy_path_to_obj_path(npy_file)
    os.makedirs(results_dir, exist_ok=True)
    # create obs_obj and pred_obj folders
    obs_obj_dir = os.path.join(results_dir, "obs_obj")
    pred_obj_dir = os.path.join(results_dir, "pred_obj")
    os.makedirs(obs_obj_dir, exist_ok=True)
    os.makedirs(pred_obj_dir, exist_ok=True)
    for frame_i in range(vertices.shape[-1]):
        # first 30 frames save to obs_obj/
        if frame_i < 30:
            file_path = os.path.join(results_dir, f"obs_obj/frame{frame_i:03d}.obj")
        else:
            file_path = os.path.join(results_dir, f"pred_obj/frame{frame_i:03d}.obj")
        mesh = to_mesh(vertices[..., frame_i], faces)
        save_obj(mesh, file_path)
    print(f"Saved obj files to [{results_dir}]")


def main():
    test_directory = '/usr/wiss/curreli/work/my_exps/final_predictions_storage/hmp/visuals_50samples/amass/SkeletonDiffusion/test_optimization'
    num_smplify_iters = 20
    device = "cuda"

    # Load the dictionary of arrays from the npz file
    output_file = "src_joints2smpl_demo/obs_data.npz"
    loaded_data = np.load(output_file, allow_pickle=True)
    rot_motions_obs = loaded_data["rot_motions_obs"]
    smpl_dict_obs = loaded_data['smpl_dict_obs'].item()
    smpl_dict_obs = {k: torch.from_numpy(v).to(device) for k,v in smpl_dict_obs.items()}
    print("Loaded observation data from npz file.")

    skeleton = SkeletonAMASS
    skeleton2obj = Skeleton2Obj(
        device=device, num_smplify_iters=num_smplify_iters, 
        smpl_model_dir="./models/body_models/", #path to smpl body models
        gmm_model_dir="./models/joint2smpl_models/", #path to gmm model
    )

    # get all the npy files in the directory
    pred_files = ['pred_closest_GT.npy']
    pred_motions = torch.cat([get_humanml_motion(npy_file, skeleton=skeleton, remove_global_translation=True) for npy_file in pred_files], dim=0)
    pred_motions = pred_motions.view(-1, 22, 3).to(device)

    init_params = {}
    init_params["betas"] = smpl_dict_obs["betas"][-1].unsqueeze(0).expand(pred_motions.shape[0], -1).to(device)
    init_params["pose"] = smpl_dict_obs["pose"][-1].unsqueeze(0).expand(pred_motions.shape[0], -1, -1).view(pred_motions.shape[0], -1).to(device)
    init_params["cam"] = smpl_dict_obs["cam"][-1].unsqueeze(0).expand(pred_motions.shape[0], -1, -1).to(device)

    rot_motions, smpl_dict = skeleton2obj.convert_motion_2smpl(pred_motions, hmp=True, init_params=init_params, fix_betas=True)

    smpl2mesh = SMPL2Mesh(device)
    vertices, faces = smpl2mesh.convert_smpl_to_mesh(rot_motions, pred_motions)

    pred_files = [('pred_closest_GT.npy')]
    vertices = vertices.reshape(*vertices.shape[:2], len(pred_files), -1)
    for v, npy_file in zip(np.moveaxis(vertices, 2, 0), pred_files):
        save_mesh(v, faces, npy_file)
 


if __name__ == "__main__":
    main()