|
import torch |
|
import copy |
|
import numpy as np |
|
|
|
from typing import OrderedDict |
|
from scipy.ndimage import gaussian_filter1d |
|
|
|
from transformers import PreTrainedModel |
|
from in2in.utils.configs import get_config |
|
from in2in.models.in2in import in2IN |
|
from in2in.utils.preprocess import MotionNormalizer |
|
|
|
from .config import in2INConfig |
|
|
|
class in2INModel(PreTrainedModel): |
|
|
|
config_class = in2INConfig |
|
|
|
def __init__(self, config): |
|
super().__init__(config) |
|
self.mode = config.MODE |
|
self.model = in2IN(config, mode=config.MODE) |
|
self.normalizer = MotionNormalizer() |
|
|
|
def forward(self, prompt_interaction, prompt_individual1, prompt_individual2): |
|
self.model.eval() |
|
batch = OrderedDict({}) |
|
|
|
batch["motion_lens"] = torch.zeros(1,1).long() |
|
batch["prompt_interaction"] = prompt_interaction |
|
|
|
if self.mode != "individual": |
|
batch["prompt_individual1"] = prompt_individual1 |
|
batch["prompt_individual2"] = prompt_individual2 |
|
|
|
window_size = 210 |
|
motion_output = self.generate_loop(batch, window_size) |
|
return motion_output |
|
|
|
def generate_loop(self, batch, window_size): |
|
prompt_interaction = batch["prompt_interaction"] |
|
|
|
if self.mode != "individual": |
|
prompt_individual1 = batch["prompt_individual1"] |
|
prompt_individual2 = batch["prompt_individual2"] |
|
|
|
batch = copy.deepcopy(batch) |
|
batch["motion_lens"][:] = window_size |
|
|
|
batch["text"] = [prompt_interaction] |
|
if self.mode != "individual": |
|
batch["text_individual1"] = [prompt_individual1] |
|
batch["text_individual2"] = [prompt_individual2] |
|
|
|
batch = self.model.forward_test(batch) |
|
motion_output_both = batch["output"][0].reshape(batch["output"][0].shape[0], 2, -1) |
|
motion_output_both = self.normalizer.backward(motion_output_both.cpu().detach().numpy()) |
|
|
|
sequences = [[], []] |
|
for j in range(2): |
|
motion_output = motion_output_both[:,j] |
|
joints3d = motion_output[:,:22*3].reshape(-1,22,3) |
|
joints3d = gaussian_filter1d(joints3d, 1, axis=0, mode='nearest') |
|
sequences[j].append(joints3d) |
|
|
|
sequences[0] = np.concatenate(sequences[0], axis=0) |
|
sequences[1] = np.concatenate(sequences[1], axis=0) |
|
return sequences |
|
|