Spaces:
Running
Running
import os | |
import orjson | |
import torch | |
import numpy as np | |
from model import TMR_textencoder | |
EMBS = "data/unit_motion_embs" | |
def load_json(path): | |
with open(path, "rb") as ff: | |
return orjson.loads(ff.read()) | |
def load_keyids(split): | |
path = os.path.join(EMBS, f"{split}.keyids") | |
with open(path) as ff: | |
keyids = np.array([x.strip() for x in ff.readlines()]) | |
return keyids | |
def load_keyids_splits(splits): | |
return {split: load_keyids(split) for split in splits} | |
def load_unit_motion_embs(split, device): | |
path = os.path.join(EMBS, f"{split}_motion_embs_unit.npy") | |
tensor = torch.from_numpy(np.load(path)).to(device) | |
return tensor | |
def load_unit_motion_embs_splits(splits, device): | |
return {split: load_unit_motion_embs(split, device) for split in splits} | |
def load_model(device): | |
text_params = { | |
"latent_dim": 256, | |
"ff_size": 1024, | |
"num_layers": 6, | |
"num_heads": 4, | |
"activation": "gelu", | |
"modelpath": "distilbert-base-uncased", | |
} | |
"unit_motion_embs" | |
model = TMR_textencoder(**text_params) | |
state_dict = torch.load("data/textencoder.pt", map_location=device) | |
# load values for the transformer only | |
model.load_state_dict(state_dict, strict=False) | |
model = model.eval() | |
return model.to(device) | |