Spaces:
Running
Running
| import os | |
| import orjson | |
| import torch | |
| import numpy as np | |
| from model import TMR_textencoder | |
| EMBS = "data/unit_motion_embs" | |
| def load_json(path): | |
| with open(path, "rb") as ff: | |
| return orjson.loads(ff.read()) | |
| def load_keyids(split): | |
| path = os.path.join(EMBS, f"{split}.keyids") | |
| with open(path) as ff: | |
| keyids = np.array([x.strip() for x in ff.readlines()]) | |
| return keyids | |
| def load_keyids_splits(splits): | |
| return {split: load_keyids(split) for split in splits} | |
| def load_unit_motion_embs(split, device): | |
| path = os.path.join(EMBS, f"{split}_motion_embs_unit.npy") | |
| tensor = torch.from_numpy(np.load(path)).to(device) | |
| return tensor | |
| def load_unit_motion_embs_splits(splits, device): | |
| return {split: load_unit_motion_embs(split, device) for split in splits} | |
| def load_model(device): | |
| text_params = { | |
| "latent_dim": 256, | |
| "ff_size": 1024, | |
| "num_layers": 6, | |
| "num_heads": 4, | |
| "activation": "gelu", | |
| "modelpath": "distilbert-base-uncased", | |
| } | |
| "unit_motion_embs" | |
| model = TMR_textencoder(**text_params) | |
| state_dict = torch.load("data/textencoder.pt", map_location=device) | |
| # load values for the transformer only | |
| model.load_state_dict(state_dict, strict=False) | |
| model = model.eval() | |
| return model.to(device) | |