Spaces:
Running
on
Zero
Running
on
Zero
import torch | |
from utils.word_vectorizer import WordVectorizer | |
from torch.utils.data import Dataset, DataLoader | |
from os.path import join as pjoin | |
from tqdm import tqdm | |
import numpy as np | |
from eval.evaluator_modules import * | |
from torch.utils.data._utils.collate import default_collate | |
class GeneratedDataset(Dataset): | |
""" | |
opt.dataset_name | |
opt.max_motion_length | |
opt.unit_length | |
""" | |
def __init__( | |
self, opt, pipeline, dataset, w_vectorizer, mm_num_samples, mm_num_repeats | |
): | |
assert mm_num_samples < len(dataset) | |
self.dataset = dataset | |
dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True) | |
generated_motion = [] | |
min_mov_length = 10 if opt.dataset_name == "t2m" else 6 | |
# Pre-process all target captions | |
mm_generated_motions = [] | |
if mm_num_samples > 0: | |
mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False) | |
mm_idxs = np.sort(mm_idxs) | |
all_caption = [] | |
all_m_lens = [] | |
all_data = [] | |
with torch.no_grad(): | |
for i, data in tqdm(enumerate(dataloader)): | |
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data | |
all_data.append(data) | |
tokens = tokens[0].split("_") | |
mm_num_now = len(mm_generated_motions) | |
is_mm = ( | |
True | |
if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) | |
else False | |
) | |
repeat_times = mm_num_repeats if is_mm else 1 | |
m_lens = max( | |
torch.div(m_lens, opt.unit_length, rounding_mode="trunc") | |
* opt.unit_length, | |
min_mov_length * opt.unit_length, | |
) | |
m_lens = min(m_lens, opt.max_motion_length) | |
if isinstance(m_lens, int): | |
m_lens = torch.LongTensor([m_lens]).to(opt.device) | |
else: | |
m_lens = m_lens.to(opt.device) | |
for t in range(repeat_times): | |
all_m_lens.append(m_lens) | |
all_caption.extend(caption) | |
if is_mm: | |
mm_generated_motions.append(0) | |
all_m_lens = torch.stack(all_m_lens) | |
# Generate all sequences | |
with torch.no_grad(): | |
all_pred_motions, t_eval = pipeline.generate(all_caption, all_m_lens) | |
self.eval_generate_time = t_eval | |
cur_idx = 0 | |
mm_generated_motions = [] | |
with torch.no_grad(): | |
for i, data_dummy in tqdm(enumerate(dataloader)): | |
data = all_data[i] | |
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data | |
tokens = tokens[0].split("_") | |
mm_num_now = len(mm_generated_motions) | |
is_mm = ( | |
True | |
if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) | |
else False | |
) | |
repeat_times = mm_num_repeats if is_mm else 1 | |
mm_motions = [] | |
for t in range(repeat_times): | |
pred_motions = all_pred_motions[cur_idx] | |
cur_idx += 1 | |
if t == 0: | |
sub_dict = { | |
"motion": pred_motions.cpu().numpy(), | |
"length": pred_motions.shape[0], # m_lens[0].item(), # | |
"caption": caption[0], | |
"cap_len": cap_lens[0].item(), | |
"tokens": tokens, | |
} | |
generated_motion.append(sub_dict) | |
if is_mm: | |
mm_motions.append( | |
{ | |
"motion": pred_motions.cpu().numpy(), | |
"length": pred_motions.shape[ | |
0 | |
], # m_lens[0].item(), #m_lens[0].item() | |
} | |
) | |
if is_mm: | |
mm_generated_motions.append( | |
{ | |
"caption": caption[0], | |
"tokens": tokens, | |
"cap_len": cap_lens[0].item(), | |
"mm_motions": mm_motions, | |
} | |
) | |
self.generated_motion = generated_motion | |
self.mm_generated_motion = mm_generated_motions | |
self.opt = opt | |
self.w_vectorizer = w_vectorizer | |
def __len__(self): | |
return len(self.generated_motion) | |
def __getitem__(self, item): | |
data = self.generated_motion[item] | |
motion, m_length, caption, tokens = ( | |
data["motion"], | |
data["length"], | |
data["caption"], | |
data["tokens"], | |
) | |
sent_len = data["cap_len"] | |
# This step is needed because T2M evaluators expect their norm convention | |
normed_motion = motion | |
denormed_motion = self.dataset.inv_transform(normed_motion) | |
renormed_motion = ( | |
denormed_motion - self.dataset.mean_for_eval | |
) / self.dataset.std_for_eval # according to T2M norms | |
motion = renormed_motion | |
pos_one_hots = [] | |
word_embeddings = [] | |
for token in tokens: | |
word_emb, pos_oh = self.w_vectorizer[token] | |
pos_one_hots.append(pos_oh[None, :]) | |
word_embeddings.append(word_emb[None, :]) | |
pos_one_hots = np.concatenate(pos_one_hots, axis=0) | |
word_embeddings = np.concatenate(word_embeddings, axis=0) | |
length = len(motion) | |
if length < self.opt.max_motion_length: | |
motion = np.concatenate( | |
[ | |
motion, | |
np.zeros((self.opt.max_motion_length - length, motion.shape[1])), | |
], | |
axis=0, | |
) | |
return ( | |
word_embeddings, | |
pos_one_hots, | |
caption, | |
sent_len, | |
motion, | |
m_length, | |
"_".join(tokens), | |
) | |
def collate_fn(batch): | |
batch.sort(key=lambda x: x[3], reverse=True) | |
return default_collate(batch) | |
class MMGeneratedDataset(Dataset): | |
def __init__(self, opt, motion_dataset, w_vectorizer): | |
self.opt = opt | |
self.dataset = motion_dataset.mm_generated_motion | |
self.w_vectorizer = w_vectorizer | |
def __len__(self): | |
return len(self.dataset) | |
def __getitem__(self, item): | |
data = self.dataset[item] | |
mm_motions = data["mm_motions"] | |
m_lens = [] | |
motions = [] | |
for mm_motion in mm_motions: | |
m_lens.append(mm_motion["length"]) | |
motion = mm_motion["motion"] | |
if len(motion) < self.opt.max_motion_length: | |
motion = np.concatenate( | |
[ | |
motion, | |
np.zeros( | |
(self.opt.max_motion_length - len(motion), motion.shape[1]) | |
), | |
], | |
axis=0, | |
) | |
motion = motion[None, :] | |
motions.append(motion) | |
m_lens = np.array(m_lens, dtype=np.int32) | |
motions = np.concatenate(motions, axis=0) | |
sort_indx = np.argsort(m_lens)[::-1].copy() | |
m_lens = m_lens[sort_indx] | |
motions = motions[sort_indx] | |
return motions, m_lens | |
def get_motion_loader( | |
opt, batch_size, pipeline, ground_truth_dataset, mm_num_samples, mm_num_repeats | |
): | |
# Currently the configurations of two datasets are almost the same | |
if opt.dataset_name == "t2m" or opt.dataset_name == "kit": | |
w_vectorizer = WordVectorizer(opt.glove_dir, "our_vab") | |
else: | |
raise KeyError("Dataset not recognized!!") | |
dataset = GeneratedDataset( | |
opt, | |
pipeline, | |
ground_truth_dataset, | |
w_vectorizer, | |
mm_num_samples, | |
mm_num_repeats, | |
) | |
mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer) | |
motion_loader = DataLoader( | |
dataset, | |
batch_size=batch_size, | |
collate_fn=collate_fn, | |
drop_last=True, | |
num_workers=4, | |
) | |
mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) | |
return motion_loader, mm_motion_loader, dataset.eval_generate_time | |