in2IN / model.py
pabloruizponce's picture
Update model.py
8b7bf6e verified
import torch
import copy
import numpy as np
from typing import OrderedDict
from scipy.ndimage import gaussian_filter1d
from transformers import PreTrainedModel
from in2in.utils.configs import get_config
from in2in.models.in2in import in2IN
from in2in.utils.preprocess import MotionNormalizer
from .config import in2INConfig
class in2INModel(PreTrainedModel):
config_class = in2INConfig
def __init__(self, config):
super().__init__(config)
self.mode = config.MODE
self.model = in2IN(config, mode=config.MODE)
self.normalizer = MotionNormalizer()
def forward(self, prompt_interaction, prompt_individual1, prompt_individual2):
self.model.eval()
batch = OrderedDict({})
batch["motion_lens"] = torch.zeros(1,1).long()
batch["prompt_interaction"] = prompt_interaction
if self.mode != "individual":
batch["prompt_individual1"] = prompt_individual1
batch["prompt_individual2"] = prompt_individual2
window_size = 210
motion_output = self.generate_loop(batch, window_size)
return motion_output
def generate_loop(self, batch, window_size):
prompt_interaction = batch["prompt_interaction"]
if self.mode != "individual":
prompt_individual1 = batch["prompt_individual1"]
prompt_individual2 = batch["prompt_individual2"]
batch = copy.deepcopy(batch)
batch["motion_lens"][:] = window_size
batch["text"] = [prompt_interaction]
if self.mode != "individual":
batch["text_individual1"] = [prompt_individual1]
batch["text_individual2"] = [prompt_individual2]
batch = self.model.forward_test(batch)
motion_output_both = batch["output"][0].reshape(batch["output"][0].shape[0], 2, -1)
motion_output_both = self.normalizer.backward(motion_output_both.cpu().detach().numpy())
sequences = [[], []]
for j in range(2):
motion_output = motion_output_both[:,j]
joints3d = motion_output[:,:22*3].reshape(-1,22,3)
joints3d = gaussian_filter1d(joints3d, 1, axis=0, mode='nearest')
sequences[j].append(joints3d)
sequences[0] = np.concatenate(sequences[0], axis=0)
sequences[1] = np.concatenate(sequences[1], axis=0)
return sequences