import torch from utils.word_vectorizer import WordVectorizer from torch.utils.data import Dataset, DataLoader from os.path import join as pjoin from tqdm import tqdm import numpy as np from eval.evaluator_modules import * from torch.utils.data._utils.collate import default_collate class GeneratedDataset(Dataset): """ opt.dataset_name opt.max_motion_length opt.unit_length """ def __init__( self, opt, pipeline, dataset, w_vectorizer, mm_num_samples, mm_num_repeats ): assert mm_num_samples < len(dataset) self.dataset = dataset dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True) generated_motion = [] min_mov_length = 10 if opt.dataset_name == "t2m" else 6 # Pre-process all target captions mm_generated_motions = [] if mm_num_samples > 0: mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False) mm_idxs = np.sort(mm_idxs) all_caption = [] all_m_lens = [] all_data = [] with torch.no_grad(): for i, data in tqdm(enumerate(dataloader)): word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data all_data.append(data) tokens = tokens[0].split("_") mm_num_now = len(mm_generated_motions) is_mm = ( True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False ) repeat_times = mm_num_repeats if is_mm else 1 m_lens = max( torch.div(m_lens, opt.unit_length, rounding_mode="trunc") * opt.unit_length, min_mov_length * opt.unit_length, ) m_lens = min(m_lens, opt.max_motion_length) if isinstance(m_lens, int): m_lens = torch.LongTensor([m_lens]).to(opt.device) else: m_lens = m_lens.to(opt.device) for t in range(repeat_times): all_m_lens.append(m_lens) all_caption.extend(caption) if is_mm: mm_generated_motions.append(0) all_m_lens = torch.stack(all_m_lens) # Generate all sequences with torch.no_grad(): all_pred_motions, t_eval = pipeline.generate(all_caption, all_m_lens) self.eval_generate_time = t_eval cur_idx = 0 mm_generated_motions = [] with torch.no_grad(): for i, data_dummy in tqdm(enumerate(dataloader)): data = all_data[i] word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data tokens = tokens[0].split("_") mm_num_now = len(mm_generated_motions) is_mm = ( True if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) else False ) repeat_times = mm_num_repeats if is_mm else 1 mm_motions = [] for t in range(repeat_times): pred_motions = all_pred_motions[cur_idx] cur_idx += 1 if t == 0: sub_dict = { "motion": pred_motions.cpu().numpy(), "length": pred_motions.shape[0], # m_lens[0].item(), # "caption": caption[0], "cap_len": cap_lens[0].item(), "tokens": tokens, } generated_motion.append(sub_dict) if is_mm: mm_motions.append( { "motion": pred_motions.cpu().numpy(), "length": pred_motions.shape[ 0 ], # m_lens[0].item(), #m_lens[0].item() } ) if is_mm: mm_generated_motions.append( { "caption": caption[0], "tokens": tokens, "cap_len": cap_lens[0].item(), "mm_motions": mm_motions, } ) self.generated_motion = generated_motion self.mm_generated_motion = mm_generated_motions self.opt = opt self.w_vectorizer = w_vectorizer def __len__(self): return len(self.generated_motion) def __getitem__(self, item): data = self.generated_motion[item] motion, m_length, caption, tokens = ( data["motion"], data["length"], data["caption"], data["tokens"], ) sent_len = data["cap_len"] # This step is needed because T2M evaluators expect their norm convention normed_motion = motion denormed_motion = self.dataset.inv_transform(normed_motion) renormed_motion = ( denormed_motion - self.dataset.mean_for_eval ) / self.dataset.std_for_eval # according to T2M norms motion = renormed_motion pos_one_hots = [] word_embeddings = [] for token in tokens: word_emb, pos_oh = self.w_vectorizer[token] pos_one_hots.append(pos_oh[None, :]) word_embeddings.append(word_emb[None, :]) pos_one_hots = np.concatenate(pos_one_hots, axis=0) word_embeddings = np.concatenate(word_embeddings, axis=0) length = len(motion) if length < self.opt.max_motion_length: motion = np.concatenate( [ motion, np.zeros((self.opt.max_motion_length - length, motion.shape[1])), ], axis=0, ) return ( word_embeddings, pos_one_hots, caption, sent_len, motion, m_length, "_".join(tokens), ) def collate_fn(batch): batch.sort(key=lambda x: x[3], reverse=True) return default_collate(batch) class MMGeneratedDataset(Dataset): def __init__(self, opt, motion_dataset, w_vectorizer): self.opt = opt self.dataset = motion_dataset.mm_generated_motion self.w_vectorizer = w_vectorizer def __len__(self): return len(self.dataset) def __getitem__(self, item): data = self.dataset[item] mm_motions = data["mm_motions"] m_lens = [] motions = [] for mm_motion in mm_motions: m_lens.append(mm_motion["length"]) motion = mm_motion["motion"] if len(motion) < self.opt.max_motion_length: motion = np.concatenate( [ motion, np.zeros( (self.opt.max_motion_length - len(motion), motion.shape[1]) ), ], axis=0, ) motion = motion[None, :] motions.append(motion) m_lens = np.array(m_lens, dtype=np.int32) motions = np.concatenate(motions, axis=0) sort_indx = np.argsort(m_lens)[::-1].copy() m_lens = m_lens[sort_indx] motions = motions[sort_indx] return motions, m_lens def get_motion_loader( opt, batch_size, pipeline, ground_truth_dataset, mm_num_samples, mm_num_repeats ): # Currently the configurations of two datasets are almost the same if opt.dataset_name == "t2m" or opt.dataset_name == "kit": w_vectorizer = WordVectorizer(opt.glove_dir, "our_vab") else: raise KeyError("Dataset not recognized!!") dataset = GeneratedDataset( opt, pipeline, ground_truth_dataset, w_vectorizer, mm_num_samples, mm_num_repeats, ) mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer) motion_loader = DataLoader( dataset, batch_size=batch_size, collate_fn=collate_fn, drop_last=True, num_workers=4, ) mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) return motion_loader, mm_motion_loader, dataset.eval_generate_time