Spaces:
Running
on
Zero
Running
on
Zero
init demo
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- 1gpu.yaml +16 -0
- LICENSE +57 -0
- app.py +540 -0
- assets/motion_lens.txt +1 -0
- assets/prompts.txt +1 -0
- config/diffuser_params.yaml +26 -0
- config/evaluator.yaml +14 -0
- datasets/__init__.py +22 -0
- datasets/t2m_dataset.py +304 -0
- eval/__init__.py +2 -0
- eval/eval_t2m.py +222 -0
- eval/evaluator_modules.py +436 -0
- eval/evaluator_wrapper.py +78 -0
- models/__init__.py +26 -0
- models/gaussian_diffusion.py +134 -0
- models/unet.py +1073 -0
- motion_loader/__init__.py +2 -0
- motion_loader/dataset_motion_loaders.py +31 -0
- motion_loader/model_motion_loaders.py +243 -0
- options/__init__.py +0 -0
- options/edit.yaml +40 -0
- options/evaluate_options.py +53 -0
- options/generate_options.py +50 -0
- options/get_opt.py +74 -0
- options/noedit.yaml +21 -0
- options/train_options.py +126 -0
- prepare/download_glove.sh +12 -0
- prepare/download_t2m_evaluators.sh +17 -0
- prepare/prepare_clip.sh +5 -0
- requirements.txt +20 -0
- scripts/__init__.py +0 -0
- scripts/evaluation.py +107 -0
- scripts/generate.py +174 -0
- scripts/train.py +66 -0
- trainers/__init__.py +4 -0
- trainers/ddpm_trainer.py +245 -0
- utils/__init__.py +0 -0
- utils/constants.py +176 -0
- utils/ema.py +13 -0
- utils/eval_humanml.py +298 -0
- utils/kinematics.py +350 -0
- utils/metrics.py +142 -0
- utils/model_load.py +31 -0
- utils/motion_process.py +483 -0
- utils/paramUtil.py +152 -0
- utils/plot_script.py +117 -0
- utils/quaternion.py +423 -0
- utils/skeleton.py +199 -0
- utils/smpl.py +343 -0
- utils/transforms.py +600 -0
1gpu.yaml
ADDED
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
compute_environment: LOCAL_MACHINE
|
2 |
+
debug: false
|
3 |
+
distributed_type: 'NO'
|
4 |
+
downcast_bf16: 'no'
|
5 |
+
machine_rank: 0
|
6 |
+
main_training_function: main
|
7 |
+
mixed_precision: no
|
8 |
+
num_machines: 1
|
9 |
+
num_processes: 1
|
10 |
+
rdzv_backend: static
|
11 |
+
same_network: true
|
12 |
+
tpu_env: []
|
13 |
+
tpu_use_cluster: false
|
14 |
+
tpu_use_sudo: false
|
15 |
+
use_cpu: false
|
16 |
+
main_process_port: 21000
|
LICENSE
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#--------------------------------------------
|
2 |
+
# MotionCLR
|
3 |
+
# Copyright (c) 2024 IDEA. All Rights Reserved.
|
4 |
+
# Licensed under the IDEA License, Version 1.0 [see LICENSE for details]
|
5 |
+
#--------------------------------------------
|
6 |
+
|
7 |
+
IDEA License 1.0
|
8 |
+
|
9 |
+
This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and the International Digital Economy Academy (“IDEA” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by IDEA under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by IDEA related to the Software (“Documentation”).
|
10 |
+
|
11 |
+
By downloading the Software or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to IDEA that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity.
|
12 |
+
|
13 |
+
1. LICENSE GRANT
|
14 |
+
|
15 |
+
a. You are granted a non-exclusive, worldwide, transferable, sublicensable, irrevocable, royalty free and limited license under IDEA’s copyright interests to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Software solely for your non-commercial research purposes.
|
16 |
+
|
17 |
+
b. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. IDEA and its licensors reserve all rights not expressly granted by this License.
|
18 |
+
|
19 |
+
c. If you intend to use the Software Products for any commercial purposes, you must request a license from IDEA, which IDEA may grant to you in its sole discretion.
|
20 |
+
|
21 |
+
2. REDISTRIBUTION AND USE
|
22 |
+
|
23 |
+
a. If you distribute or make the Software Products, or any derivative works thereof, available to a third party, you shall provide a copy of this Agreement to such third party.
|
24 |
+
|
25 |
+
b. You must retain in all copies of the Software Products that you distribute the following attribution notice: "MotionCLR is licensed under the IDEA License 1.0, Copyright (c) IDEA. All Rights Reserved."
|
26 |
+
|
27 |
+
d. Your use of the Software Products must comply with applicable laws and regulations (including trade compliance laws and regulations).
|
28 |
+
|
29 |
+
e. You will not, and will not permit, assist or cause any third party to use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for in any manner that infringes, misappropriates, or otherwise violates any third-party rights.
|
30 |
+
|
31 |
+
3. DISCLAIMER OF WARRANTY
|
32 |
+
|
33 |
+
UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
|
34 |
+
|
35 |
+
4. LIMITATION OF LIABILITY
|
36 |
+
|
37 |
+
IN NO EVENT WILL IDEA OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF IDEA OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
|
38 |
+
|
39 |
+
5. INDEMNIFICATION
|
40 |
+
|
41 |
+
You will indemnify, defend and hold harmless IDEA and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “IDEA Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any IDEA Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the IDEA Parties of any such Claims, and cooperate with IDEA Parties in defending such Claims. You will also grant the IDEA Parties sole control of the defense or settlement, at IDEA’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and IDEA or the other IDEA Parties.
|
42 |
+
|
43 |
+
6. TERMINATION; SURVIVAL
|
44 |
+
|
45 |
+
a. This License will automatically terminate upon any breach by you of the terms of this License.
|
46 |
+
|
47 |
+
b. If you institute litigation or other proceedings against IDEA or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted.
|
48 |
+
|
49 |
+
c. The following sections survive termination of this License: 2 (Redistribution and use), 3 (Disclaimers of Warranty), 4 (Limitation of Liability), 5 (Indemnification), 6 (Termination; Survival), 7 (Trademarks) and 8 (Applicable Law; Dispute Resolution).
|
50 |
+
|
51 |
+
7. TRADEMARKS
|
52 |
+
|
53 |
+
Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with IDEA without the prior written permission of IDEA, except to the extent necessary to make the reference required by the attribution notice of this Agreement.
|
54 |
+
|
55 |
+
8. APPLICABLE LAW; DISPUTE RESOLUTION
|
56 |
+
|
57 |
+
This License will be governed and construed under the laws of the People’s Republic of China without regard to conflicts of law provisions. The parties expressly agree that the United Nations Convention on Contracts for the International Sale of Goods will not apply. Any suit or proceeding arising out of or relating to this License will be brought in the courts, as applicable, in Shenzhen, Guangdong, and each party irrevocably submits to the jurisdiction and venue of such courts.
|
app.py
ADDED
@@ -0,0 +1,540 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import spaces
|
2 |
+
import gradio as gr
|
3 |
+
import sys
|
4 |
+
import os
|
5 |
+
import torch
|
6 |
+
import numpy as np
|
7 |
+
from os.path import join as pjoin
|
8 |
+
import utils.paramUtil as paramUtil
|
9 |
+
from utils.plot_script import *
|
10 |
+
from utils.utils import *
|
11 |
+
from utils.motion_process import recover_from_ric
|
12 |
+
from accelerate.utils import set_seed
|
13 |
+
from models.gaussian_diffusion import DiffusePipeline
|
14 |
+
from options.generate_options import GenerateOptions
|
15 |
+
from utils.model_load import load_model_weights
|
16 |
+
from motion_loader import get_dataset_loader
|
17 |
+
from models import build_models
|
18 |
+
import yaml
|
19 |
+
import time
|
20 |
+
from box import Box
|
21 |
+
import hashlib
|
22 |
+
from huggingface_hub import hf_hub_download
|
23 |
+
|
24 |
+
ckptdir = './checkpoints/t2m/release'
|
25 |
+
os.makedirs(ckptdir, exist_ok=True)
|
26 |
+
|
27 |
+
|
28 |
+
mean_path = hf_hub_download(
|
29 |
+
repo_id="EvanTHU/MotionCLR",
|
30 |
+
filename="meta/mean.npy",
|
31 |
+
local_dir=ckptdir,
|
32 |
+
local_dir_use_symlinks=False
|
33 |
+
)
|
34 |
+
|
35 |
+
std_path = hf_hub_download(
|
36 |
+
repo_id="EvanTHU/MotionCLR",
|
37 |
+
filename="meta/std.npy",
|
38 |
+
local_dir=ckptdir,
|
39 |
+
local_dir_use_symlinks=False
|
40 |
+
)
|
41 |
+
|
42 |
+
model_path = hf_hub_download(
|
43 |
+
repo_id="EvanTHU/MotionCLR",
|
44 |
+
filename="model/latest.tar",
|
45 |
+
local_dir=ckptdir,
|
46 |
+
local_dir_use_symlinks=False
|
47 |
+
)
|
48 |
+
|
49 |
+
opt_path = hf_hub_download(
|
50 |
+
repo_id="EvanTHU/MotionCLR",
|
51 |
+
filename="opt.txt",
|
52 |
+
local_dir=ckptdir,
|
53 |
+
local_dir_use_symlinks=False
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
|
58 |
+
os.makedirs("tmp", exist_ok=True)
|
59 |
+
os.environ['GRADIO_TEMP_DIR'] = './tmp'
|
60 |
+
|
61 |
+
def generate_md5(input_string):
|
62 |
+
# Encode the string and compute the MD5 hash
|
63 |
+
md5_hash = hashlib.md5(input_string.encode())
|
64 |
+
# Return the hexadecimal representation of the hash
|
65 |
+
return md5_hash.hexdigest()
|
66 |
+
|
67 |
+
def set_all_use_to_false(data):
|
68 |
+
for key, value in data.items():
|
69 |
+
if isinstance(value, Box):
|
70 |
+
set_all_use_to_false(value)
|
71 |
+
elif key == 'use':
|
72 |
+
data[key] = False
|
73 |
+
return data
|
74 |
+
|
75 |
+
def yaml_to_box(yaml_file):
|
76 |
+
with open(yaml_file, 'r') as file:
|
77 |
+
yaml_data = yaml.safe_load(file)
|
78 |
+
|
79 |
+
return Box(yaml_data)
|
80 |
+
|
81 |
+
HEAD = """<div class="embed_hidden">
|
82 |
+
<h1 style='text-align: center'> MotionCLR User Interaction Demo </h1>
|
83 |
+
"""
|
84 |
+
|
85 |
+
edit_config = yaml_to_box('options/edit.yaml')
|
86 |
+
os.environ['GRADIO_TEMP_DIR'] = './tmp'
|
87 |
+
CSS = """
|
88 |
+
.retrieved_video {
|
89 |
+
position: relative;
|
90 |
+
margin: 0;
|
91 |
+
box-shadow: var(--block-shadow);
|
92 |
+
border-width: var(--block-border-width);
|
93 |
+
border-color: #000000;
|
94 |
+
border-radius: var(--block-radius);
|
95 |
+
background: var(--block-background-fill);
|
96 |
+
width: 100%;
|
97 |
+
line-height: var(--line-sm);
|
98 |
+
}
|
99 |
+
.contour_video {
|
100 |
+
display: flex;
|
101 |
+
flex-direction: column;
|
102 |
+
justify-content: center;
|
103 |
+
align-items: center;
|
104 |
+
z-index: var(--layer-5);
|
105 |
+
border-radius: var(--block-radius);
|
106 |
+
background: var(--background-fill-primary);
|
107 |
+
padding: 0 var(--size-6);
|
108 |
+
max-height: var(--size-screen-h);
|
109 |
+
overflow: hidden;
|
110 |
+
}
|
111 |
+
"""
|
112 |
+
|
113 |
+
def generate_video_from_text(text, opt, pipeline):
|
114 |
+
width = 500
|
115 |
+
height = 500
|
116 |
+
texts = [text]
|
117 |
+
motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
|
118 |
+
|
119 |
+
save_dir = './tmp/gen/'
|
120 |
+
filename = generate_md5(str(time.time())) + ".mp4"
|
121 |
+
save_path = pjoin(save_dir, str(filename))
|
122 |
+
os.makedirs(save_dir, exist_ok=True)
|
123 |
+
|
124 |
+
start_time = time.perf_counter()
|
125 |
+
gr.Info("Generating motion...", duration = 3)
|
126 |
+
pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
|
127 |
+
end_time = time.perf_counter()
|
128 |
+
exc = end_time - start_time
|
129 |
+
gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
|
130 |
+
start_time = time.perf_counter()
|
131 |
+
mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
|
132 |
+
std = np.load(pjoin(opt.meta_dir, 'std.npy'))
|
133 |
+
|
134 |
+
|
135 |
+
samples = []
|
136 |
+
|
137 |
+
root_list = []
|
138 |
+
for i, motion in enumerate(pred_motions):
|
139 |
+
motion = motion.cpu().numpy() * std + mean
|
140 |
+
# 1. recover 3d joints representation by ik
|
141 |
+
motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
|
142 |
+
# 2. put on Floor (Y axis)
|
143 |
+
floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
|
144 |
+
motion[:, :, 1] -= floor_height
|
145 |
+
motion = motion.numpy()
|
146 |
+
# 3. remove jitter
|
147 |
+
motion = motion_temporal_filter(motion, sigma=1)
|
148 |
+
|
149 |
+
samples.append(motion)
|
150 |
+
|
151 |
+
i = 0
|
152 |
+
title = texts[i]
|
153 |
+
motion = samples[i]
|
154 |
+
kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
|
155 |
+
plot_3d_motion(save_path, kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)
|
156 |
+
|
157 |
+
gr.Info("Rendered motion...", duration = 3)
|
158 |
+
end_time = time.perf_counter()
|
159 |
+
exc = end_time - start_time
|
160 |
+
gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
|
161 |
+
|
162 |
+
video_dis = f'<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_path}"></video>'
|
163 |
+
style_dis = video_dis + """<br> <p align="center"> Content Reference </p>"""
|
164 |
+
global edit_config
|
165 |
+
edit_config = set_all_use_to_false(edit_config)
|
166 |
+
return video_dis, video_dis, video_dis, video_dis, style_dis, video_dis, gr.update(visible=True)
|
167 |
+
|
168 |
+
def reweighting(text, idx, weight, opt, pipeline):
|
169 |
+
global edit_config
|
170 |
+
edit_config.reweighting_attn.use = True
|
171 |
+
edit_config.reweighting_attn.idx = idx
|
172 |
+
edit_config.reweighting_attn.reweighting_attn_weight = weight
|
173 |
+
|
174 |
+
|
175 |
+
gr.Info("Loading Configurations...", duration = 3)
|
176 |
+
model = build_models(opt, edit_config=edit_config)
|
177 |
+
ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')
|
178 |
+
niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)
|
179 |
+
|
180 |
+
pipeline = DiffusePipeline(
|
181 |
+
opt = opt,
|
182 |
+
model = model,
|
183 |
+
diffuser_name = opt.diffuser_name,
|
184 |
+
device=opt.device,
|
185 |
+
num_inference_steps=opt.num_inference_steps,
|
186 |
+
torch_dtype=torch.float16,
|
187 |
+
)
|
188 |
+
|
189 |
+
print(edit_config)
|
190 |
+
|
191 |
+
width = 500
|
192 |
+
height = 500
|
193 |
+
texts = [text, text]
|
194 |
+
motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
|
195 |
+
|
196 |
+
save_dir = './tmp/gen/'
|
197 |
+
filenames = [generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4"]
|
198 |
+
save_paths = [pjoin(save_dir, str(filenames[0])), pjoin(save_dir, str(filenames[1]))]
|
199 |
+
os.makedirs(save_dir, exist_ok=True)
|
200 |
+
|
201 |
+
start_time = time.perf_counter()
|
202 |
+
gr.Info("Generating motion...", duration = 3)
|
203 |
+
pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
|
204 |
+
end_time = time.perf_counter()
|
205 |
+
exc = end_time - start_time
|
206 |
+
gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
|
207 |
+
start_time = time.perf_counter()
|
208 |
+
mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
|
209 |
+
std = np.load(pjoin(opt.meta_dir, 'std.npy'))
|
210 |
+
|
211 |
+
|
212 |
+
samples = []
|
213 |
+
|
214 |
+
root_list = []
|
215 |
+
for i, motion in enumerate(pred_motions):
|
216 |
+
motion = motion.cpu().numpy() * std + mean
|
217 |
+
# 1. recover 3d joints representation by ik
|
218 |
+
motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
|
219 |
+
# 2. put on Floor (Y axis)
|
220 |
+
floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
|
221 |
+
motion[:, :, 1] -= floor_height
|
222 |
+
motion = motion.numpy()
|
223 |
+
# 3. remove jitter
|
224 |
+
motion = motion_temporal_filter(motion, sigma=1)
|
225 |
+
|
226 |
+
samples.append(motion)
|
227 |
+
|
228 |
+
i = 1
|
229 |
+
title = texts[i]
|
230 |
+
motion = samples[i]
|
231 |
+
kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
|
232 |
+
plot_3d_motion(save_paths[1], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)
|
233 |
+
|
234 |
+
|
235 |
+
gr.Info("Rendered motion...", duration = 3)
|
236 |
+
end_time = time.perf_counter()
|
237 |
+
exc = end_time - start_time
|
238 |
+
gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
|
239 |
+
|
240 |
+
video_dis = f'<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[1]}"></video>'
|
241 |
+
|
242 |
+
|
243 |
+
edit_config = set_all_use_to_false(edit_config)
|
244 |
+
return video_dis
|
245 |
+
|
246 |
+
def generate_example_based_motion(text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion, opt, pipeline):
|
247 |
+
global edit_config
|
248 |
+
edit_config.example_based.use = True
|
249 |
+
edit_config.example_based.chunk_size = chunk_size
|
250 |
+
edit_config.example_based.example_based_steps_end = example_based_steps_end
|
251 |
+
edit_config.example_based.temp_seed = temp_seed
|
252 |
+
edit_config.example_based.temp_seed_bar = temp_seed_bar
|
253 |
+
|
254 |
+
|
255 |
+
gr.Info("Loading Configurations...", duration = 3)
|
256 |
+
model = build_models(opt, edit_config=edit_config)
|
257 |
+
ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')
|
258 |
+
niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)
|
259 |
+
|
260 |
+
pipeline = DiffusePipeline(
|
261 |
+
opt = opt,
|
262 |
+
model = model,
|
263 |
+
diffuser_name = opt.diffuser_name,
|
264 |
+
device=opt.device,
|
265 |
+
num_inference_steps=opt.num_inference_steps,
|
266 |
+
torch_dtype=torch.float16,
|
267 |
+
)
|
268 |
+
|
269 |
+
width = 500
|
270 |
+
height = 500
|
271 |
+
texts = [text for _ in range(num_motion)]
|
272 |
+
motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
|
273 |
+
|
274 |
+
save_dir = './tmp/gen/'
|
275 |
+
filenames = [generate_md5(str(time.time())) + ".mp4" for _ in range(num_motion)]
|
276 |
+
save_paths = [pjoin(save_dir, str(filenames[i])) for i in range(num_motion)]
|
277 |
+
os.makedirs(save_dir, exist_ok=True)
|
278 |
+
|
279 |
+
start_time = time.perf_counter()
|
280 |
+
gr.Info("Generating motion...", duration = 3)
|
281 |
+
pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
|
282 |
+
end_time = time.perf_counter()
|
283 |
+
exc = end_time - start_time
|
284 |
+
gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
|
285 |
+
start_time = time.perf_counter()
|
286 |
+
mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
|
287 |
+
std = np.load(pjoin(opt.meta_dir, 'std.npy'))
|
288 |
+
|
289 |
+
|
290 |
+
samples = []
|
291 |
+
|
292 |
+
root_list = []
|
293 |
+
progress=gr.Progress()
|
294 |
+
progress(0, desc="Starting...")
|
295 |
+
for i, motion in enumerate(pred_motions):
|
296 |
+
motion = motion.cpu().numpy() * std + mean
|
297 |
+
# 1. recover 3d joints representation by ik
|
298 |
+
motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
|
299 |
+
# 2. put on Floor (Y axis)
|
300 |
+
floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
|
301 |
+
motion[:, :, 1] -= floor_height
|
302 |
+
motion = motion.numpy()
|
303 |
+
# 3. remove jitter
|
304 |
+
motion = motion_temporal_filter(motion, sigma=1)
|
305 |
+
|
306 |
+
samples.append(motion)
|
307 |
+
|
308 |
+
video_dis = []
|
309 |
+
i = 0
|
310 |
+
for title in progress.tqdm(texts):
|
311 |
+
print(save_paths[i])
|
312 |
+
title = texts[i]
|
313 |
+
motion = samples[i]
|
314 |
+
kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
|
315 |
+
plot_3d_motion(save_paths[i], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)
|
316 |
+
video_html = f'''
|
317 |
+
<video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()" autoplay loop disablepictureinpicture src="./file={save_paths[i]}"> </video>
|
318 |
+
'''
|
319 |
+
video_dis.append(video_html)
|
320 |
+
i += 1
|
321 |
+
|
322 |
+
for _ in range(24 - num_motion):
|
323 |
+
video_dis.append(None)
|
324 |
+
gr.Info("Rendered motion...", duration = 3)
|
325 |
+
end_time = time.perf_counter()
|
326 |
+
exc = end_time - start_time
|
327 |
+
gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
|
328 |
+
|
329 |
+
edit_config = set_all_use_to_false(edit_config)
|
330 |
+
return video_dis
|
331 |
+
|
332 |
+
def transfer_style(text, style_text, style_transfer_steps_end, opt, pipeline):
|
333 |
+
global edit_config
|
334 |
+
edit_config.style_tranfer.use = True
|
335 |
+
edit_config.style_tranfer.style_transfer_steps_end = style_transfer_steps_end
|
336 |
+
|
337 |
+
gr.Info("Loading Configurations...", duration = 3)
|
338 |
+
model = build_models(opt, edit_config=edit_config)
|
339 |
+
ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')
|
340 |
+
niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)
|
341 |
+
|
342 |
+
pipeline = DiffusePipeline(
|
343 |
+
opt = opt,
|
344 |
+
model = model,
|
345 |
+
diffuser_name = opt.diffuser_name,
|
346 |
+
device=opt.device,
|
347 |
+
num_inference_steps=opt.num_inference_steps,
|
348 |
+
torch_dtype=torch.float16,
|
349 |
+
)
|
350 |
+
|
351 |
+
print(edit_config)
|
352 |
+
|
353 |
+
width = 500
|
354 |
+
height = 500
|
355 |
+
texts = [style_text, text, text]
|
356 |
+
motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
|
357 |
+
|
358 |
+
save_dir = './tmp/gen/'
|
359 |
+
filenames = [generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4"]
|
360 |
+
save_paths = [pjoin(save_dir, str(filenames[0])), pjoin(save_dir, str(filenames[1])), pjoin(save_dir, str(filenames[2]))]
|
361 |
+
os.makedirs(save_dir, exist_ok=True)
|
362 |
+
|
363 |
+
start_time = time.perf_counter()
|
364 |
+
gr.Info("Generating motion...", duration = 3)
|
365 |
+
pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
|
366 |
+
end_time = time.perf_counter()
|
367 |
+
exc = end_time - start_time
|
368 |
+
gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
|
369 |
+
start_time = time.perf_counter()
|
370 |
+
mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
|
371 |
+
std = np.load(pjoin(opt.meta_dir, 'std.npy'))
|
372 |
+
|
373 |
+
samples = []
|
374 |
+
|
375 |
+
root_list = []
|
376 |
+
for i, motion in enumerate(pred_motions):
|
377 |
+
motion = motion.cpu().numpy() * std + mean
|
378 |
+
# 1. recover 3d joints representation by ik
|
379 |
+
motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
|
380 |
+
# 2. put on Floor (Y axis)
|
381 |
+
floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
|
382 |
+
motion[:, :, 1] -= floor_height
|
383 |
+
motion = motion.numpy()
|
384 |
+
# 3. remove jitter
|
385 |
+
motion = motion_temporal_filter(motion, sigma=1)
|
386 |
+
|
387 |
+
samples.append(motion)
|
388 |
+
|
389 |
+
for i,title in enumerate(texts):
|
390 |
+
title = texts[i]
|
391 |
+
motion = samples[i]
|
392 |
+
kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
|
393 |
+
plot_3d_motion(save_paths[i], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)
|
394 |
+
|
395 |
+
gr.Info("Rendered motion...", duration = 3)
|
396 |
+
end_time = time.perf_counter()
|
397 |
+
exc = end_time - start_time
|
398 |
+
gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
|
399 |
+
|
400 |
+
video_dis0 = f"""<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[0]}"></video> <br> <p align="center"> Style Reference </p>"""
|
401 |
+
video_dis1 = f"""<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[2]}"></video> <br> <p align="center"> Content Reference </p>"""
|
402 |
+
video_dis2 = f"""<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[1]}"></video> <br> <p align="center"> Transfered Result </p>"""
|
403 |
+
|
404 |
+
edit_config = set_all_use_to_false(edit_config)
|
405 |
+
return video_dis0, video_dis2
|
406 |
+
|
407 |
+
|
408 |
+
@spaces.GPU
|
409 |
+
def main():
|
410 |
+
parser = GenerateOptions()
|
411 |
+
opt = parser.parse_app()
|
412 |
+
set_seed(opt.seed)
|
413 |
+
device_id = opt.gpu_id
|
414 |
+
device = torch.device('cuda:%d' % device_id if torch.cuda.is_available() else 'cpu')
|
415 |
+
opt.device = device
|
416 |
+
|
417 |
+
|
418 |
+
# load model
|
419 |
+
model = build_models(opt, edit_config=edit_config)
|
420 |
+
ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')
|
421 |
+
niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)
|
422 |
+
|
423 |
+
pipeline = DiffusePipeline(
|
424 |
+
opt = opt,
|
425 |
+
model = model,
|
426 |
+
diffuser_name = opt.diffuser_name,
|
427 |
+
device=device,
|
428 |
+
num_inference_steps=opt.num_inference_steps,
|
429 |
+
torch_dtype=torch.float16,
|
430 |
+
)
|
431 |
+
|
432 |
+
with gr.Blocks() as demo:
|
433 |
+
gr.Markdown(HEAD)
|
434 |
+
with gr.Row():
|
435 |
+
with gr.Column(scale=7):
|
436 |
+
text_input = gr.Textbox(label="Input the text prompt to generate motion...")
|
437 |
+
with gr.Column(scale=3):
|
438 |
+
sequence_length = gr.Slider(minimum=1, maximum=9.6, step=0.1, label="Motion length", value=8)
|
439 |
+
with gr.Row():
|
440 |
+
generate_button = gr.Button("Generate motion")
|
441 |
+
|
442 |
+
with gr.Row():
|
443 |
+
video_display = gr.HTML(label="生成的视频", visible=True)
|
444 |
+
|
445 |
+
|
446 |
+
tabs = gr.Tabs(visible=True)
|
447 |
+
with tabs:
|
448 |
+
with gr.Tab("Motion (de-)emphasizing"):
|
449 |
+
with gr.Row():
|
450 |
+
int_input = gr.Number(label="Editing word index", minimum=0, maximum=70)
|
451 |
+
weight_input = gr.Slider(minimum=-1, maximum=1, step=0.01, label="Input weight for (de-)emphasizing [-1, 1]", value=0)
|
452 |
+
|
453 |
+
trim_button = gr.Button("Edit reweighting")
|
454 |
+
|
455 |
+
with gr.Row():
|
456 |
+
original_video1 = gr.HTML(label="before editing", visible=False)
|
457 |
+
edited_video = gr.HTML(label="after editing")
|
458 |
+
|
459 |
+
trim_button.click(
|
460 |
+
fn=lambda x, int_input, weight_input : reweighting(x, int_input, weight_input, opt, pipeline),
|
461 |
+
inputs=[text_input, int_input, weight_input],
|
462 |
+
outputs=edited_video,
|
463 |
+
)
|
464 |
+
|
465 |
+
with gr.Tab("Example-based motion genration"):
|
466 |
+
with gr.Row():
|
467 |
+
with gr.Column(scale=4):
|
468 |
+
chunk_size = gr.Number(minimum=10, maximum=20, step=10,label="Chunk size (#frames)", value=20)
|
469 |
+
example_based_steps_end = gr.Number(minimum=0, maximum=9,label="Ending step of manipulation", value=6)
|
470 |
+
with gr.Column(scale=3):
|
471 |
+
temp_seed = gr.Number(label="Seed for random", value=200, minimum=0)
|
472 |
+
temp_seed_bar = gr.Slider(minimum=0, maximum=100, step=1, label="Seed for random bar", value=15)
|
473 |
+
with gr.Column(scale=3):
|
474 |
+
num_motion = gr.Radio(choices=[4, 8, 12, 16, 24], value=8, label="Select number of motions")
|
475 |
+
|
476 |
+
gen_button = gr.Button("Generate example-based motion")
|
477 |
+
|
478 |
+
|
479 |
+
example_video_display = []
|
480 |
+
for _ in range(6):
|
481 |
+
with gr.Row():
|
482 |
+
for _ in range(4):
|
483 |
+
video = gr.HTML(label="Example-based motion", visible=True)
|
484 |
+
example_video_display.append(video)
|
485 |
+
|
486 |
+
gen_button.click(
|
487 |
+
fn=lambda text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion: generate_example_based_motion(text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion, opt, pipeline),
|
488 |
+
inputs=[text_input, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion],
|
489 |
+
outputs=example_video_display
|
490 |
+
)
|
491 |
+
|
492 |
+
with gr.Tab("Style transfer"):
|
493 |
+
with gr.Row():
|
494 |
+
style_text = gr.Textbox(label="Reference prompt (e.g. 'a man walks.')", value="a man walks.")
|
495 |
+
style_transfer_steps_end = gr.Number(label="The end step of diffusion (0~9)", minimum=0, maximum=9, value=5)
|
496 |
+
|
497 |
+
style_transfer_button = gr.Button("Transfer style")
|
498 |
+
|
499 |
+
with gr.Row():
|
500 |
+
style_reference = gr.HTML(label="style reference")
|
501 |
+
original_video4 = gr.HTML(label="before style transfer", visible=False)
|
502 |
+
styled_video = gr.HTML(label="after style transfer")
|
503 |
+
|
504 |
+
style_transfer_button.click(
|
505 |
+
fn=lambda text, style_text, style_transfer_steps_end: transfer_style(text, style_text, style_transfer_steps_end, opt, pipeline),
|
506 |
+
inputs=[text_input, style_text, style_transfer_steps_end],
|
507 |
+
outputs=[style_reference, styled_video],
|
508 |
+
)
|
509 |
+
|
510 |
+
def update_motion_length(sequence_length):
|
511 |
+
opt.motion_length = sequence_length
|
512 |
+
|
513 |
+
def on_generate(text, length, pipeline):
|
514 |
+
update_motion_length(length)
|
515 |
+
return generate_video_from_text(text, opt, pipeline)
|
516 |
+
|
517 |
+
|
518 |
+
generate_button.click(
|
519 |
+
fn=lambda text, length: on_generate(text, length, pipeline),
|
520 |
+
inputs=[text_input, sequence_length],
|
521 |
+
outputs=[
|
522 |
+
video_display,
|
523 |
+
original_video1,
|
524 |
+
original_video4,
|
525 |
+
tabs,
|
526 |
+
],
|
527 |
+
show_progress=True
|
528 |
+
)
|
529 |
+
|
530 |
+
generate_button.click(
|
531 |
+
fn=lambda: [gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)],
|
532 |
+
inputs=None,
|
533 |
+
outputs=[video_display, original_video1, original_video4]
|
534 |
+
)
|
535 |
+
|
536 |
+
demo.launch()
|
537 |
+
|
538 |
+
|
539 |
+
if __name__ == '__main__':
|
540 |
+
main()
|
assets/motion_lens.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
160
|
assets/prompts.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
a man jumps.
|
config/diffuser_params.yaml
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dpmsolver:
|
2 |
+
scheduler_class: DPMSolverMultistepScheduler
|
3 |
+
additional_params:
|
4 |
+
algorithm_type: sde-dpmsolver++
|
5 |
+
use_karras_sigmas: true
|
6 |
+
|
7 |
+
ddpm:
|
8 |
+
scheduler_class: DDPMScheduler
|
9 |
+
additional_params:
|
10 |
+
variance_type: fixed_small
|
11 |
+
clip_sample: false
|
12 |
+
|
13 |
+
ddim:
|
14 |
+
scheduler_class: DDIMScheduler
|
15 |
+
additional_params:
|
16 |
+
clip_sample: false
|
17 |
+
|
18 |
+
deis:
|
19 |
+
scheduler_class: DEISMultistepScheduler
|
20 |
+
additional_params:
|
21 |
+
num_train_timesteps: 1000
|
22 |
+
|
23 |
+
pndm:
|
24 |
+
scheduler_class: PNDMScheduler
|
25 |
+
additional_params:
|
26 |
+
num_train_timesteps: 1000
|
config/evaluator.yaml
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
unit_length: 4
|
2 |
+
max_text_len: 20
|
3 |
+
text_enc_mod: bigru
|
4 |
+
estimator_mod: bigru
|
5 |
+
dim_text_hidden: 512
|
6 |
+
dim_att_vec: 512
|
7 |
+
dim_z: 128
|
8 |
+
dim_movement_enc_hidden: 512
|
9 |
+
dim_movement_dec_hidden: 512
|
10 |
+
dim_movement_latent: 512
|
11 |
+
dim_word: 300
|
12 |
+
dim_pos_ohot: 15
|
13 |
+
dim_motion_hidden: 1024
|
14 |
+
dim_coemb_hidden: 512
|
datasets/__init__.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from .t2m_dataset import HumanML3D,KIT
|
3 |
+
|
4 |
+
from os.path import join as pjoin
|
5 |
+
__all__ = [
|
6 |
+
'HumanML3D', 'KIT', 'get_dataset',]
|
7 |
+
|
8 |
+
def get_dataset(opt, split='train', mode='train', accelerator=None):
|
9 |
+
if opt.dataset_name == 't2m' :
|
10 |
+
dataset = HumanML3D(opt, split, mode, accelerator)
|
11 |
+
elif opt.dataset_name == 'kit' :
|
12 |
+
dataset = KIT(opt,split, mode, accelerator)
|
13 |
+
else:
|
14 |
+
raise KeyError('Dataset Does Not Exist')
|
15 |
+
|
16 |
+
if accelerator:
|
17 |
+
accelerator.print('Completing loading %s dataset' % opt.dataset_name)
|
18 |
+
else:
|
19 |
+
print('Completing loading %s dataset' % opt.dataset_name)
|
20 |
+
|
21 |
+
return dataset
|
22 |
+
|
datasets/t2m_dataset.py
ADDED
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils import data
|
3 |
+
import numpy as np
|
4 |
+
from os.path import join as pjoin
|
5 |
+
import random
|
6 |
+
import codecs as cs
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
from utils.word_vectorizer import WordVectorizer, POS_enumerator
|
9 |
+
from utils.motion_process import recover_from_ric
|
10 |
+
|
11 |
+
|
12 |
+
class Text2MotionDataset(data.Dataset):
|
13 |
+
"""
|
14 |
+
Dataset for Text2Motion generation task.
|
15 |
+
"""
|
16 |
+
|
17 |
+
data_root = ""
|
18 |
+
min_motion_len = 40
|
19 |
+
joints_num = None
|
20 |
+
dim_pose = None
|
21 |
+
max_motion_length = 196
|
22 |
+
|
23 |
+
def __init__(self, opt, split, mode="train", accelerator=None):
|
24 |
+
self.max_text_len = getattr(opt, "max_text_len", 20)
|
25 |
+
self.unit_length = getattr(opt, "unit_length", 4)
|
26 |
+
self.mode = mode
|
27 |
+
motion_dir = pjoin(self.data_root, "new_joint_vecs")
|
28 |
+
text_dir = pjoin(self.data_root, "texts")
|
29 |
+
|
30 |
+
if mode not in ["train", "eval", "gt_eval", "xyz_gt", "hml_gt"]:
|
31 |
+
raise ValueError(
|
32 |
+
f"Mode '{mode}' is not supported. Please use one of: 'train', 'eval', 'gt_eval', 'xyz_gt','hml_gt'."
|
33 |
+
)
|
34 |
+
|
35 |
+
mean, std = None, None
|
36 |
+
if mode == "gt_eval":
|
37 |
+
print(pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_std.npy"))
|
38 |
+
# used by T2M models (including evaluators)
|
39 |
+
mean = np.load(pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_mean.npy"))
|
40 |
+
std = np.load(pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_std.npy"))
|
41 |
+
elif mode in ["eval"]:
|
42 |
+
print(pjoin(opt.meta_dir, "std.npy"))
|
43 |
+
# used by our models during inference
|
44 |
+
mean = np.load(pjoin(opt.meta_dir, "mean.npy"))
|
45 |
+
std = np.load(pjoin(opt.meta_dir, "std.npy"))
|
46 |
+
else:
|
47 |
+
# used by our models during train
|
48 |
+
mean = np.load(pjoin(self.data_root, "Mean.npy"))
|
49 |
+
std = np.load(pjoin(self.data_root, "Std.npy"))
|
50 |
+
|
51 |
+
if mode == "eval":
|
52 |
+
# used by T2M models (including evaluators)
|
53 |
+
# this is to translate ours norms to theirs
|
54 |
+
self.mean_for_eval = np.load(
|
55 |
+
pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_mean.npy")
|
56 |
+
)
|
57 |
+
self.std_for_eval = np.load(
|
58 |
+
pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_std.npy")
|
59 |
+
)
|
60 |
+
if mode in ["gt_eval", "eval"]:
|
61 |
+
self.w_vectorizer = WordVectorizer(opt.glove_dir, "our_vab")
|
62 |
+
|
63 |
+
data_dict = {}
|
64 |
+
id_list = []
|
65 |
+
split_file = pjoin(self.data_root, f"{split}.txt")
|
66 |
+
with cs.open(split_file, "r") as f:
|
67 |
+
for line in f.readlines():
|
68 |
+
id_list.append(line.strip())
|
69 |
+
|
70 |
+
if opt.debug == True:
|
71 |
+
id_list = id_list[:1000]
|
72 |
+
|
73 |
+
new_name_list = []
|
74 |
+
length_list = []
|
75 |
+
for name in tqdm(
|
76 |
+
id_list,
|
77 |
+
disable=(
|
78 |
+
not accelerator.is_local_main_process
|
79 |
+
if accelerator is not None
|
80 |
+
else False
|
81 |
+
),
|
82 |
+
):
|
83 |
+
motion = np.load(pjoin(motion_dir, name + ".npy"))
|
84 |
+
if (len(motion)) < self.min_motion_len or (len(motion) >= 200):
|
85 |
+
continue
|
86 |
+
text_data = []
|
87 |
+
flag = False
|
88 |
+
with cs.open(pjoin(text_dir, name + ".txt")) as f:
|
89 |
+
for line in f.readlines():
|
90 |
+
text_dict = {}
|
91 |
+
line_split = line.strip().split("#")
|
92 |
+
caption = line_split[0]
|
93 |
+
try:
|
94 |
+
tokens = line_split[1].split(" ")
|
95 |
+
f_tag = float(line_split[2])
|
96 |
+
to_tag = float(line_split[3])
|
97 |
+
f_tag = 0.0 if np.isnan(f_tag) else f_tag
|
98 |
+
to_tag = 0.0 if np.isnan(to_tag) else to_tag
|
99 |
+
except:
|
100 |
+
tokens = ["a/NUM", "a/NUM"]
|
101 |
+
f_tag = 0.0
|
102 |
+
to_tag = 8.0
|
103 |
+
text_dict["caption"] = caption
|
104 |
+
text_dict["tokens"] = tokens
|
105 |
+
if f_tag == 0.0 and to_tag == 0.0:
|
106 |
+
flag = True
|
107 |
+
text_data.append(text_dict)
|
108 |
+
else:
|
109 |
+
n_motion = motion[int(f_tag * 20) : int(to_tag * 20)]
|
110 |
+
if (len(n_motion)) < self.min_motion_len or (
|
111 |
+
len(n_motion) >= 200
|
112 |
+
):
|
113 |
+
continue
|
114 |
+
new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name
|
115 |
+
while new_name in data_dict:
|
116 |
+
new_name = (
|
117 |
+
random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name
|
118 |
+
)
|
119 |
+
data_dict[new_name] = {
|
120 |
+
"motion": n_motion,
|
121 |
+
"length": len(n_motion),
|
122 |
+
"text": [text_dict],
|
123 |
+
}
|
124 |
+
new_name_list.append(new_name)
|
125 |
+
length_list.append(len(n_motion))
|
126 |
+
if flag:
|
127 |
+
data_dict[name] = {
|
128 |
+
"motion": motion,
|
129 |
+
"length": len(motion),
|
130 |
+
"text": text_data,
|
131 |
+
}
|
132 |
+
new_name_list.append(name)
|
133 |
+
length_list.append(len(motion))
|
134 |
+
|
135 |
+
name_list, length_list = zip(
|
136 |
+
*sorted(zip(new_name_list, length_list), key=lambda x: x[1])
|
137 |
+
)
|
138 |
+
|
139 |
+
if mode == "train":
|
140 |
+
if opt.dataset_name != "amass":
|
141 |
+
joints_num = self.joints_num
|
142 |
+
# root_rot_velocity (B, seq_len, 1)
|
143 |
+
std[0:1] = std[0:1] / opt.feat_bias
|
144 |
+
# root_linear_velocity (B, seq_len, 2)
|
145 |
+
std[1:3] = std[1:3] / opt.feat_bias
|
146 |
+
# root_y (B, seq_len, 1)
|
147 |
+
std[3:4] = std[3:4] / opt.feat_bias
|
148 |
+
# ric_data (B, seq_len, (joint_num - 1)*3)
|
149 |
+
std[4 : 4 + (joints_num - 1) * 3] = (
|
150 |
+
std[4 : 4 + (joints_num - 1) * 3] / 1.0
|
151 |
+
)
|
152 |
+
# rot_data (B, seq_len, (joint_num - 1)*6)
|
153 |
+
std[4 + (joints_num - 1) * 3 : 4 + (joints_num - 1) * 9] = (
|
154 |
+
std[4 + (joints_num - 1) * 3 : 4 + (joints_num - 1) * 9] / 1.0
|
155 |
+
)
|
156 |
+
# local_velocity (B, seq_len, joint_num*3)
|
157 |
+
std[
|
158 |
+
4 + (joints_num - 1) * 9 : 4 + (joints_num - 1) * 9 + joints_num * 3
|
159 |
+
] = (
|
160 |
+
std[
|
161 |
+
4
|
162 |
+
+ (joints_num - 1) * 9 : 4
|
163 |
+
+ (joints_num - 1) * 9
|
164 |
+
+ joints_num * 3
|
165 |
+
]
|
166 |
+
/ 1.0
|
167 |
+
)
|
168 |
+
# foot contact (B, seq_len, 4)
|
169 |
+
std[4 + (joints_num - 1) * 9 + joints_num * 3 :] = (
|
170 |
+
std[4 + (joints_num - 1) * 9 + joints_num * 3 :] / opt.feat_bias
|
171 |
+
)
|
172 |
+
|
173 |
+
assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1]
|
174 |
+
|
175 |
+
if accelerator is not None and accelerator.is_main_process:
|
176 |
+
np.save(pjoin(opt.meta_dir, "mean.npy"), mean)
|
177 |
+
np.save(pjoin(opt.meta_dir, "std.npy"), std)
|
178 |
+
|
179 |
+
self.mean = mean
|
180 |
+
self.std = std
|
181 |
+
self.data_dict = data_dict
|
182 |
+
self.name_list = name_list
|
183 |
+
|
184 |
+
def inv_transform(self, data):
|
185 |
+
return data * self.std + self.mean
|
186 |
+
|
187 |
+
def __len__(self):
|
188 |
+
return len(self.data_dict)
|
189 |
+
|
190 |
+
def __getitem__(self, idx):
|
191 |
+
data = self.data_dict[self.name_list[idx]]
|
192 |
+
motion, m_length, text_list = data["motion"], data["length"], data["text"]
|
193 |
+
|
194 |
+
# Randomly select a caption
|
195 |
+
text_data = random.choice(text_list)
|
196 |
+
caption = text_data["caption"]
|
197 |
+
|
198 |
+
"Z Normalization"
|
199 |
+
if self.mode not in ["xyz_gt", "hml_gt"]:
|
200 |
+
motion = (motion - self.mean) / self.std
|
201 |
+
|
202 |
+
"crop motion"
|
203 |
+
if self.mode in ["eval", "gt_eval"]:
|
204 |
+
# Crop the motions in to times of 4, and introduce small variations
|
205 |
+
if self.unit_length < 10:
|
206 |
+
coin2 = np.random.choice(["single", "single", "double"])
|
207 |
+
else:
|
208 |
+
coin2 = "single"
|
209 |
+
if coin2 == "double":
|
210 |
+
m_length = (m_length // self.unit_length - 1) * self.unit_length
|
211 |
+
elif coin2 == "single":
|
212 |
+
m_length = (m_length // self.unit_length) * self.unit_length
|
213 |
+
idx = random.randint(0, len(motion) - m_length)
|
214 |
+
motion = motion[idx : idx + m_length]
|
215 |
+
elif m_length >= self.max_motion_length:
|
216 |
+
idx = random.randint(0, len(motion) - self.max_motion_length)
|
217 |
+
motion = motion[idx : idx + self.max_motion_length]
|
218 |
+
m_length = self.max_motion_length
|
219 |
+
|
220 |
+
"pad motion"
|
221 |
+
if m_length < self.max_motion_length:
|
222 |
+
motion = np.concatenate(
|
223 |
+
[
|
224 |
+
motion,
|
225 |
+
np.zeros((self.max_motion_length - m_length, motion.shape[1])),
|
226 |
+
],
|
227 |
+
axis=0,
|
228 |
+
)
|
229 |
+
assert len(motion) == self.max_motion_length
|
230 |
+
|
231 |
+
if self.mode in ["gt_eval", "eval"]:
|
232 |
+
"word embedding for text-to-motion evaluation"
|
233 |
+
tokens = text_data["tokens"]
|
234 |
+
if len(tokens) < self.max_text_len:
|
235 |
+
# pad with "unk"
|
236 |
+
tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
|
237 |
+
sent_len = len(tokens)
|
238 |
+
tokens = tokens + ["unk/OTHER"] * (self.max_text_len + 2 - sent_len)
|
239 |
+
else:
|
240 |
+
# crop
|
241 |
+
tokens = tokens[: self.max_text_len]
|
242 |
+
tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
|
243 |
+
sent_len = len(tokens)
|
244 |
+
pos_one_hots = []
|
245 |
+
word_embeddings = []
|
246 |
+
for token in tokens:
|
247 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
248 |
+
pos_one_hots.append(pos_oh[None, :])
|
249 |
+
word_embeddings.append(word_emb[None, :])
|
250 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
251 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
252 |
+
return (
|
253 |
+
word_embeddings,
|
254 |
+
pos_one_hots,
|
255 |
+
caption,
|
256 |
+
sent_len,
|
257 |
+
motion,
|
258 |
+
m_length,
|
259 |
+
"_".join(tokens),
|
260 |
+
)
|
261 |
+
elif self.mode in ["xyz_gt"]:
|
262 |
+
"Convert motion hml representation to skeleton points xyz"
|
263 |
+
# 1. Use kn to get the keypoints position (the padding position after kn is all zero)
|
264 |
+
motion = torch.from_numpy(motion).float()
|
265 |
+
pred_joints = recover_from_ric(
|
266 |
+
motion, self.joints_num
|
267 |
+
) # (nframe, njoints, 3)
|
268 |
+
|
269 |
+
# 2. Put on Floor (Y axis)
|
270 |
+
floor_height = pred_joints.min(dim=0)[0].min(dim=0)[0][1]
|
271 |
+
pred_joints[:, :, 1] -= floor_height
|
272 |
+
return pred_joints
|
273 |
+
|
274 |
+
return caption, motion, m_length
|
275 |
+
|
276 |
+
|
277 |
+
class HumanML3D(Text2MotionDataset):
|
278 |
+
def __init__(self, opt, split="train", mode="train", accelerator=None):
|
279 |
+
self.data_root = "./data/HumanML3D"
|
280 |
+
self.min_motion_len = 40
|
281 |
+
self.joints_num = 22
|
282 |
+
self.dim_pose = 263
|
283 |
+
self.max_motion_length = 196
|
284 |
+
if accelerator:
|
285 |
+
accelerator.print(
|
286 |
+
"\n Loading %s mode HumanML3D %s dataset ..." % (mode, split)
|
287 |
+
)
|
288 |
+
else:
|
289 |
+
print("\n Loading %s mode HumanML3D dataset ..." % mode)
|
290 |
+
super(HumanML3D, self).__init__(opt, split, mode, accelerator)
|
291 |
+
|
292 |
+
|
293 |
+
class KIT(Text2MotionDataset):
|
294 |
+
def __init__(self, opt, split="train", mode="train", accelerator=None):
|
295 |
+
self.data_root = "./data/KIT-ML"
|
296 |
+
self.min_motion_len = 24
|
297 |
+
self.joints_num = 21
|
298 |
+
self.dim_pose = 251
|
299 |
+
self.max_motion_length = 196
|
300 |
+
if accelerator:
|
301 |
+
accelerator.print("\n Loading %s mode KIT %s dataset ..." % (mode, split))
|
302 |
+
else:
|
303 |
+
print("\n Loading %s mode KIT dataset ..." % mode)
|
304 |
+
super(KIT, self).__init__(opt, split, mode, accelerator)
|
eval/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .evaluator_wrapper import EvaluatorModelWrapper
|
2 |
+
from .eval_t2m import evaluation
|
eval/eval_t2m.py
ADDED
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file code from T2M(https://github.com/EricGuo5513/text-to-motion), licensed under the https://github.com/EricGuo5513/text-to-motion/blob/main/LICENSE.
|
2 |
+
# Copyright (c) 2022 Chuan Guo
|
3 |
+
from datetime import datetime
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from utils.metrics import *
|
7 |
+
from collections import OrderedDict
|
8 |
+
|
9 |
+
|
10 |
+
def evaluate_matching_score(eval_wrapper,motion_loaders, file):
|
11 |
+
match_score_dict = OrderedDict({})
|
12 |
+
R_precision_dict = OrderedDict({})
|
13 |
+
activation_dict = OrderedDict({})
|
14 |
+
# print(motion_loaders.keys())
|
15 |
+
print('========== Evaluating Matching Score ==========')
|
16 |
+
for motion_loader_name, motion_loader in motion_loaders.items():
|
17 |
+
all_motion_embeddings = []
|
18 |
+
score_list = []
|
19 |
+
all_size = 0
|
20 |
+
matching_score_sum = 0
|
21 |
+
top_k_count = 0
|
22 |
+
# print(motion_loader_name)
|
23 |
+
with torch.no_grad():
|
24 |
+
for idx, batch in enumerate(motion_loader):
|
25 |
+
word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch
|
26 |
+
text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(
|
27 |
+
word_embs=word_embeddings,
|
28 |
+
pos_ohot=pos_one_hots,
|
29 |
+
cap_lens=sent_lens,
|
30 |
+
motions=motions,
|
31 |
+
m_lens=m_lens
|
32 |
+
)
|
33 |
+
dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(),
|
34 |
+
motion_embeddings.cpu().numpy())
|
35 |
+
matching_score_sum += dist_mat.trace()
|
36 |
+
# import pdb;pdb.set_trace()
|
37 |
+
|
38 |
+
argsmax = np.argsort(dist_mat, axis=1)
|
39 |
+
top_k_mat = calculate_top_k(argsmax, top_k=3)
|
40 |
+
top_k_count += top_k_mat.sum(axis=0)
|
41 |
+
|
42 |
+
all_size += text_embeddings.shape[0]
|
43 |
+
|
44 |
+
all_motion_embeddings.append(motion_embeddings.cpu().numpy())
|
45 |
+
|
46 |
+
all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
|
47 |
+
# import pdb;pdb.set_trace()
|
48 |
+
matching_score = matching_score_sum / all_size
|
49 |
+
R_precision = top_k_count / all_size
|
50 |
+
match_score_dict[motion_loader_name] = matching_score
|
51 |
+
R_precision_dict[motion_loader_name] = R_precision
|
52 |
+
activation_dict[motion_loader_name] = all_motion_embeddings
|
53 |
+
|
54 |
+
print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}')
|
55 |
+
print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True)
|
56 |
+
|
57 |
+
line = f'---> [{motion_loader_name}] R_precision: '
|
58 |
+
for i in range(len(R_precision)):
|
59 |
+
line += '(top %d): %.4f ' % (i+1, R_precision[i])
|
60 |
+
print(line)
|
61 |
+
print(line, file=file, flush=True)
|
62 |
+
|
63 |
+
return match_score_dict, R_precision_dict, activation_dict
|
64 |
+
|
65 |
+
|
66 |
+
def evaluate_fid(eval_wrapper,groundtruth_loader, activation_dict, file):
|
67 |
+
eval_dict = OrderedDict({})
|
68 |
+
gt_motion_embeddings = []
|
69 |
+
print('========== Evaluating FID ==========')
|
70 |
+
with torch.no_grad():
|
71 |
+
for idx, batch in enumerate(groundtruth_loader):
|
72 |
+
_, _, _, sent_lens, motions, m_lens, _ = batch
|
73 |
+
motion_embeddings = eval_wrapper.get_motion_embeddings(
|
74 |
+
motions=motions,
|
75 |
+
m_lens=m_lens
|
76 |
+
)
|
77 |
+
gt_motion_embeddings.append(motion_embeddings.cpu().numpy())
|
78 |
+
gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0)
|
79 |
+
gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings)
|
80 |
+
|
81 |
+
for model_name, motion_embeddings in activation_dict.items():
|
82 |
+
mu, cov = calculate_activation_statistics(motion_embeddings)
|
83 |
+
# print(mu)
|
84 |
+
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
|
85 |
+
print(f'---> [{model_name}] FID: {fid:.4f}')
|
86 |
+
print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True)
|
87 |
+
eval_dict[model_name] = fid
|
88 |
+
return eval_dict
|
89 |
+
|
90 |
+
|
91 |
+
def evaluate_diversity(activation_dict, file, diversity_times):
|
92 |
+
eval_dict = OrderedDict({})
|
93 |
+
print('========== Evaluating Diversity ==========')
|
94 |
+
for model_name, motion_embeddings in activation_dict.items():
|
95 |
+
diversity = calculate_diversity(motion_embeddings, diversity_times)
|
96 |
+
eval_dict[model_name] = diversity
|
97 |
+
print(f'---> [{model_name}] Diversity: {diversity:.4f}')
|
98 |
+
print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True)
|
99 |
+
return eval_dict
|
100 |
+
|
101 |
+
|
102 |
+
def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times):
|
103 |
+
eval_dict = OrderedDict({})
|
104 |
+
print('========== Evaluating MultiModality ==========')
|
105 |
+
for model_name, mm_motion_loader in mm_motion_loaders.items():
|
106 |
+
mm_motion_embeddings = []
|
107 |
+
with torch.no_grad():
|
108 |
+
for idx, batch in enumerate(mm_motion_loader):
|
109 |
+
# (1, mm_replications, dim_pos)
|
110 |
+
motions, m_lens = batch
|
111 |
+
motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0])
|
112 |
+
mm_motion_embeddings.append(motion_embedings.unsqueeze(0))
|
113 |
+
if len(mm_motion_embeddings) == 0:
|
114 |
+
multimodality = 0
|
115 |
+
else:
|
116 |
+
mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy()
|
117 |
+
multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times)
|
118 |
+
print(f'---> [{model_name}] Multimodality: {multimodality:.4f}')
|
119 |
+
print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True)
|
120 |
+
eval_dict[model_name] = multimodality
|
121 |
+
return eval_dict
|
122 |
+
|
123 |
+
|
124 |
+
def get_metric_statistics(values, replication_times):
|
125 |
+
mean = np.mean(values, axis=0)
|
126 |
+
std = np.std(values, axis=0)
|
127 |
+
conf_interval = 1.96 * std / np.sqrt(replication_times)
|
128 |
+
return mean, conf_interval
|
129 |
+
|
130 |
+
|
131 |
+
def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=False):
|
132 |
+
with open(log_file, 'a') as f:
|
133 |
+
all_metrics = OrderedDict({'Matching Score': OrderedDict({}),
|
134 |
+
'R_precision': OrderedDict({}),
|
135 |
+
'FID': OrderedDict({}),
|
136 |
+
'Diversity': OrderedDict({}),
|
137 |
+
'MultiModality': OrderedDict({})})
|
138 |
+
|
139 |
+
for replication in range(replication_times):
|
140 |
+
print(f'Time: {datetime.now()}')
|
141 |
+
print(f'Time: {datetime.now()}', file=f, flush=True)
|
142 |
+
motion_loaders = {}
|
143 |
+
motion_loaders['ground truth'] = gt_loader
|
144 |
+
mm_motion_loaders = {}
|
145 |
+
# motion_loaders['ground truth'] = gt_loader
|
146 |
+
for motion_loader_name, motion_loader_getter in eval_motion_loaders.items():
|
147 |
+
motion_loader, mm_motion_loader,eval_generate_time = motion_loader_getter()
|
148 |
+
print(f'---> [{motion_loader_name}] batch_generate_time: {eval_generate_time}s', file=f, flush=True)
|
149 |
+
motion_loaders[motion_loader_name] = motion_loader
|
150 |
+
mm_motion_loaders[motion_loader_name] = mm_motion_loader
|
151 |
+
|
152 |
+
if replication_times>1:
|
153 |
+
print(f'==================== Replication {replication} ====================')
|
154 |
+
print(f'==================== Replication {replication} ====================', file=f, flush=True)
|
155 |
+
|
156 |
+
|
157 |
+
mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f)
|
158 |
+
|
159 |
+
fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f)
|
160 |
+
|
161 |
+
div_score_dict = evaluate_diversity(acti_dict, f, diversity_times)
|
162 |
+
|
163 |
+
if run_mm:
|
164 |
+
mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times)
|
165 |
+
|
166 |
+
print(f'!!! DONE !!!')
|
167 |
+
print(f'!!! DONE !!!', file=f, flush=True)
|
168 |
+
|
169 |
+
for key, item in mat_score_dict.items():
|
170 |
+
if key not in all_metrics['Matching Score']:
|
171 |
+
all_metrics['Matching Score'][key] = [item]
|
172 |
+
else:
|
173 |
+
all_metrics['Matching Score'][key] += [item]
|
174 |
+
|
175 |
+
for key, item in R_precision_dict.items():
|
176 |
+
if key not in all_metrics['R_precision']:
|
177 |
+
all_metrics['R_precision'][key] = [item]
|
178 |
+
else:
|
179 |
+
all_metrics['R_precision'][key] += [item]
|
180 |
+
|
181 |
+
for key, item in fid_score_dict.items():
|
182 |
+
if key not in all_metrics['FID']:
|
183 |
+
all_metrics['FID'][key] = [item]
|
184 |
+
else:
|
185 |
+
all_metrics['FID'][key] += [item]
|
186 |
+
|
187 |
+
for key, item in div_score_dict.items():
|
188 |
+
if key not in all_metrics['Diversity']:
|
189 |
+
all_metrics['Diversity'][key] = [item]
|
190 |
+
else:
|
191 |
+
all_metrics['Diversity'][key] += [item]
|
192 |
+
|
193 |
+
for key, item in mm_score_dict.items():
|
194 |
+
if key not in all_metrics['MultiModality']:
|
195 |
+
all_metrics['MultiModality'][key] = [item]
|
196 |
+
else:
|
197 |
+
all_metrics['MultiModality'][key] += [item]
|
198 |
+
|
199 |
+
|
200 |
+
mean_dict = {}
|
201 |
+
if replication_times>1:
|
202 |
+
for metric_name, metric_dict in all_metrics.items():
|
203 |
+
print('========== %s Summary ==========' % metric_name)
|
204 |
+
print('========== %s Summary ==========' % metric_name, file=f, flush=True)
|
205 |
+
|
206 |
+
for model_name, values in metric_dict.items():
|
207 |
+
# print(metric_name, model_name)
|
208 |
+
mean, conf_interval = get_metric_statistics(np.array(values),replication_times)
|
209 |
+
mean_dict[metric_name + '_' + model_name] = mean
|
210 |
+
# print(mean, mean.dtype)
|
211 |
+
if isinstance(mean, np.float64) or isinstance(mean, np.float32):
|
212 |
+
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
|
213 |
+
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
|
214 |
+
elif isinstance(mean, np.ndarray):
|
215 |
+
line = f'---> [{model_name}]'
|
216 |
+
for i in range(len(mean)):
|
217 |
+
line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
|
218 |
+
print(line)
|
219 |
+
print(line, file=f, flush=True)
|
220 |
+
return mean_dict
|
221 |
+
else:
|
222 |
+
return all_metrics
|
eval/evaluator_modules.py
ADDED
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file code from T2M(https://github.com/EricGuo5513/text-to-motion), licensed under the https://github.com/EricGuo5513/text-to-motion/blob/main/LICENSE.
|
2 |
+
# Copyright (c) 2022 Chuan Guo
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import numpy as np
|
6 |
+
import time
|
7 |
+
import math
|
8 |
+
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
9 |
+
import torch.nn.functional as F
|
10 |
+
|
11 |
+
|
12 |
+
class ContrastiveLoss(torch.nn.Module):
|
13 |
+
"""
|
14 |
+
Contrastive loss function.
|
15 |
+
Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
|
16 |
+
"""
|
17 |
+
def __init__(self, margin=3.0):
|
18 |
+
super(ContrastiveLoss, self).__init__()
|
19 |
+
self.margin = margin
|
20 |
+
|
21 |
+
def forward(self, output1, output2, label):
|
22 |
+
euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
|
23 |
+
loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
|
24 |
+
(label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
|
25 |
+
return loss_contrastive
|
26 |
+
|
27 |
+
|
28 |
+
def init_weight(m):
|
29 |
+
if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
|
30 |
+
nn.init.xavier_normal_(m.weight)
|
31 |
+
# m.bias.data.fill_(0.01)
|
32 |
+
if m.bias is not None:
|
33 |
+
nn.init.constant_(m.bias, 0)
|
34 |
+
|
35 |
+
|
36 |
+
def reparameterize(mu, logvar):
|
37 |
+
s_var = logvar.mul(0.5).exp_()
|
38 |
+
eps = s_var.data.new(s_var.size()).normal_()
|
39 |
+
return eps.mul(s_var).add_(mu)
|
40 |
+
|
41 |
+
|
42 |
+
# batch_size, dimension and position
|
43 |
+
# output: (batch_size, dim)
|
44 |
+
def positional_encoding(batch_size, dim, pos):
|
45 |
+
assert batch_size == pos.shape[0]
|
46 |
+
positions_enc = np.array([
|
47 |
+
[pos[j] / np.power(10000, (i-i%2)/dim) for i in range(dim)]
|
48 |
+
for j in range(batch_size)
|
49 |
+
], dtype=np.float32)
|
50 |
+
positions_enc[:, 0::2] = np.sin(positions_enc[:, 0::2])
|
51 |
+
positions_enc[:, 1::2] = np.cos(positions_enc[:, 1::2])
|
52 |
+
return torch.from_numpy(positions_enc).float()
|
53 |
+
|
54 |
+
|
55 |
+
def get_padding_mask(batch_size, seq_len, cap_lens):
|
56 |
+
cap_lens = cap_lens.data.tolist()
|
57 |
+
mask_2d = torch.ones((batch_size, seq_len, seq_len), dtype=torch.float32)
|
58 |
+
for i, cap_len in enumerate(cap_lens):
|
59 |
+
mask_2d[i, :, :cap_len] = 0
|
60 |
+
return mask_2d.bool(), 1 - mask_2d[:, :, 0].clone()
|
61 |
+
|
62 |
+
|
63 |
+
class PositionalEncoding(nn.Module):
|
64 |
+
|
65 |
+
def __init__(self, d_model, max_len=300):
|
66 |
+
super(PositionalEncoding, self).__init__()
|
67 |
+
|
68 |
+
pe = torch.zeros(max_len, d_model)
|
69 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
70 |
+
div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
|
71 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
72 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
73 |
+
# pe = pe.unsqueeze(0).transpose(0, 1)
|
74 |
+
self.register_buffer('pe', pe)
|
75 |
+
|
76 |
+
def forward(self, pos):
|
77 |
+
return self.pe[pos]
|
78 |
+
|
79 |
+
|
80 |
+
class MovementConvEncoder(nn.Module):
|
81 |
+
def __init__(self, input_size, hidden_size, output_size):
|
82 |
+
super(MovementConvEncoder, self).__init__()
|
83 |
+
self.main = nn.Sequential(
|
84 |
+
nn.Conv1d(input_size, hidden_size, 4, 2, 1),
|
85 |
+
nn.Dropout(0.2, inplace=True),
|
86 |
+
nn.LeakyReLU(0.2, inplace=True),
|
87 |
+
nn.Conv1d(hidden_size, output_size, 4, 2, 1),
|
88 |
+
nn.Dropout(0.2, inplace=True),
|
89 |
+
nn.LeakyReLU(0.2, inplace=True),
|
90 |
+
)
|
91 |
+
self.out_net = nn.Linear(output_size, output_size)
|
92 |
+
self.main.apply(init_weight)
|
93 |
+
self.out_net.apply(init_weight)
|
94 |
+
|
95 |
+
def forward(self, inputs):
|
96 |
+
inputs = inputs.permute(0, 2, 1)
|
97 |
+
outputs = self.main(inputs).permute(0, 2, 1)
|
98 |
+
# print(outputs.shape)
|
99 |
+
return self.out_net(outputs)
|
100 |
+
|
101 |
+
|
102 |
+
class MovementConvDecoder(nn.Module):
|
103 |
+
def __init__(self, input_size, hidden_size, output_size):
|
104 |
+
super(MovementConvDecoder, self).__init__()
|
105 |
+
self.main = nn.Sequential(
|
106 |
+
nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1),
|
107 |
+
# nn.Dropout(0.2, inplace=True),
|
108 |
+
nn.LeakyReLU(0.2, inplace=True),
|
109 |
+
nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1),
|
110 |
+
# nn.Dropout(0.2, inplace=True),
|
111 |
+
nn.LeakyReLU(0.2, inplace=True),
|
112 |
+
)
|
113 |
+
self.out_net = nn.Linear(output_size, output_size)
|
114 |
+
|
115 |
+
self.main.apply(init_weight)
|
116 |
+
self.out_net.apply(init_weight)
|
117 |
+
|
118 |
+
def forward(self, inputs):
|
119 |
+
inputs = inputs.permute(0, 2, 1)
|
120 |
+
outputs = self.main(inputs).permute(0, 2, 1)
|
121 |
+
return self.out_net(outputs)
|
122 |
+
|
123 |
+
|
124 |
+
class TextVAEDecoder(nn.Module):
|
125 |
+
def __init__(self, text_size, input_size, output_size, hidden_size, n_layers):
|
126 |
+
super(TextVAEDecoder, self).__init__()
|
127 |
+
self.input_size = input_size
|
128 |
+
self.output_size = output_size
|
129 |
+
self.hidden_size = hidden_size
|
130 |
+
self.n_layers = n_layers
|
131 |
+
self.emb = nn.Sequential(
|
132 |
+
nn.Linear(input_size, hidden_size),
|
133 |
+
nn.LayerNorm(hidden_size),
|
134 |
+
nn.LeakyReLU(0.2, inplace=True))
|
135 |
+
|
136 |
+
self.z2init = nn.Linear(text_size, hidden_size * n_layers)
|
137 |
+
self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)])
|
138 |
+
self.positional_encoder = PositionalEncoding(hidden_size)
|
139 |
+
|
140 |
+
|
141 |
+
self.output = nn.Sequential(
|
142 |
+
nn.Linear(hidden_size, hidden_size),
|
143 |
+
nn.LayerNorm(hidden_size),
|
144 |
+
nn.LeakyReLU(0.2, inplace=True),
|
145 |
+
nn.Linear(hidden_size, output_size)
|
146 |
+
)
|
147 |
+
|
148 |
+
#
|
149 |
+
# self.output = nn.Sequential(
|
150 |
+
# nn.Linear(hidden_size, hidden_size),
|
151 |
+
# nn.LayerNorm(hidden_size),
|
152 |
+
# nn.LeakyReLU(0.2, inplace=True),
|
153 |
+
# nn.Linear(hidden_size, output_size-4)
|
154 |
+
# )
|
155 |
+
|
156 |
+
# self.contact_net = nn.Sequential(
|
157 |
+
# nn.Linear(output_size-4, 64),
|
158 |
+
# nn.LayerNorm(64),
|
159 |
+
# nn.LeakyReLU(0.2, inplace=True),
|
160 |
+
# nn.Linear(64, 4)
|
161 |
+
# )
|
162 |
+
|
163 |
+
self.output.apply(init_weight)
|
164 |
+
self.emb.apply(init_weight)
|
165 |
+
self.z2init.apply(init_weight)
|
166 |
+
# self.contact_net.apply(init_weight)
|
167 |
+
|
168 |
+
def get_init_hidden(self, latent):
|
169 |
+
hidden = self.z2init(latent)
|
170 |
+
hidden = torch.split(hidden, self.hidden_size, dim=-1)
|
171 |
+
return list(hidden)
|
172 |
+
|
173 |
+
def forward(self, inputs, last_pred, hidden, p):
|
174 |
+
h_in = self.emb(inputs)
|
175 |
+
pos_enc = self.positional_encoder(p).to(inputs.device).detach()
|
176 |
+
h_in = h_in + pos_enc
|
177 |
+
for i in range(self.n_layers):
|
178 |
+
# print(h_in.shape)
|
179 |
+
hidden[i] = self.gru[i](h_in, hidden[i])
|
180 |
+
h_in = hidden[i]
|
181 |
+
pose_pred = self.output(h_in)
|
182 |
+
# pose_pred = self.output(h_in) + last_pred.detach()
|
183 |
+
# contact = self.contact_net(pose_pred)
|
184 |
+
# return torch.cat([pose_pred, contact], dim=-1), hidden
|
185 |
+
return pose_pred, hidden
|
186 |
+
|
187 |
+
|
188 |
+
class TextDecoder(nn.Module):
|
189 |
+
def __init__(self, text_size, input_size, output_size, hidden_size, n_layers):
|
190 |
+
super(TextDecoder, self).__init__()
|
191 |
+
self.input_size = input_size
|
192 |
+
self.output_size = output_size
|
193 |
+
self.hidden_size = hidden_size
|
194 |
+
self.n_layers = n_layers
|
195 |
+
self.emb = nn.Sequential(
|
196 |
+
nn.Linear(input_size, hidden_size),
|
197 |
+
nn.LayerNorm(hidden_size),
|
198 |
+
nn.LeakyReLU(0.2, inplace=True))
|
199 |
+
|
200 |
+
self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)])
|
201 |
+
self.z2init = nn.Linear(text_size, hidden_size * n_layers)
|
202 |
+
self.positional_encoder = PositionalEncoding(hidden_size)
|
203 |
+
|
204 |
+
self.mu_net = nn.Linear(hidden_size, output_size)
|
205 |
+
self.logvar_net = nn.Linear(hidden_size, output_size)
|
206 |
+
|
207 |
+
self.emb.apply(init_weight)
|
208 |
+
self.z2init.apply(init_weight)
|
209 |
+
self.mu_net.apply(init_weight)
|
210 |
+
self.logvar_net.apply(init_weight)
|
211 |
+
|
212 |
+
def get_init_hidden(self, latent):
|
213 |
+
|
214 |
+
hidden = self.z2init(latent)
|
215 |
+
hidden = torch.split(hidden, self.hidden_size, dim=-1)
|
216 |
+
|
217 |
+
return list(hidden)
|
218 |
+
|
219 |
+
def forward(self, inputs, hidden, p):
|
220 |
+
# print(inputs.shape)
|
221 |
+
x_in = self.emb(inputs)
|
222 |
+
pos_enc = self.positional_encoder(p).to(inputs.device).detach()
|
223 |
+
x_in = x_in + pos_enc
|
224 |
+
|
225 |
+
for i in range(self.n_layers):
|
226 |
+
hidden[i] = self.gru[i](x_in, hidden[i])
|
227 |
+
h_in = hidden[i]
|
228 |
+
mu = self.mu_net(h_in)
|
229 |
+
logvar = self.logvar_net(h_in)
|
230 |
+
z = reparameterize(mu, logvar)
|
231 |
+
return z, mu, logvar, hidden
|
232 |
+
|
233 |
+
class AttLayer(nn.Module):
|
234 |
+
def __init__(self, query_dim, key_dim, value_dim):
|
235 |
+
super(AttLayer, self).__init__()
|
236 |
+
self.W_q = nn.Linear(query_dim, value_dim)
|
237 |
+
self.W_k = nn.Linear(key_dim, value_dim, bias=False)
|
238 |
+
self.W_v = nn.Linear(key_dim, value_dim)
|
239 |
+
|
240 |
+
self.softmax = nn.Softmax(dim=1)
|
241 |
+
self.dim = value_dim
|
242 |
+
|
243 |
+
self.W_q.apply(init_weight)
|
244 |
+
self.W_k.apply(init_weight)
|
245 |
+
self.W_v.apply(init_weight)
|
246 |
+
|
247 |
+
def forward(self, query, key_mat):
|
248 |
+
'''
|
249 |
+
query (batch, query_dim)
|
250 |
+
key (batch, seq_len, key_dim)
|
251 |
+
'''
|
252 |
+
# print(query.shape)
|
253 |
+
query_vec = self.W_q(query).unsqueeze(-1) # (batch, value_dim, 1)
|
254 |
+
val_set = self.W_v(key_mat) # (batch, seq_len, value_dim)
|
255 |
+
key_set = self.W_k(key_mat) # (batch, seq_len, value_dim)
|
256 |
+
|
257 |
+
weights = torch.matmul(key_set, query_vec) / np.sqrt(self.dim)
|
258 |
+
|
259 |
+
co_weights = self.softmax(weights) # (batch, seq_len, 1)
|
260 |
+
values = val_set * co_weights # (batch, seq_len, value_dim)
|
261 |
+
pred = values.sum(dim=1) # (batch, value_dim)
|
262 |
+
return pred, co_weights
|
263 |
+
|
264 |
+
def short_cut(self, querys, keys):
|
265 |
+
return self.W_q(querys), self.W_k(keys)
|
266 |
+
|
267 |
+
|
268 |
+
class TextEncoderBiGRU(nn.Module):
|
269 |
+
def __init__(self, word_size, pos_size, hidden_size, device):
|
270 |
+
super(TextEncoderBiGRU, self).__init__()
|
271 |
+
self.device = device
|
272 |
+
|
273 |
+
self.pos_emb = nn.Linear(pos_size, word_size)
|
274 |
+
self.input_emb = nn.Linear(word_size, hidden_size)
|
275 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
|
276 |
+
# self.linear2 = nn.Linear(hidden_size, output_size)
|
277 |
+
|
278 |
+
self.input_emb.apply(init_weight)
|
279 |
+
self.pos_emb.apply(init_weight)
|
280 |
+
# self.linear2.apply(init_weight)
|
281 |
+
# self.batch_size = batch_size
|
282 |
+
self.hidden_size = hidden_size
|
283 |
+
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
|
284 |
+
|
285 |
+
# input(batch_size, seq_len, dim)
|
286 |
+
def forward(self, word_embs, pos_onehot, cap_lens):
|
287 |
+
num_samples = word_embs.shape[0]
|
288 |
+
|
289 |
+
pos_embs = self.pos_emb(pos_onehot)
|
290 |
+
inputs = word_embs + pos_embs
|
291 |
+
input_embs = self.input_emb(inputs)
|
292 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
293 |
+
|
294 |
+
cap_lens = cap_lens.data.tolist()
|
295 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
|
296 |
+
|
297 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
298 |
+
|
299 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
300 |
+
gru_seq = pad_packed_sequence(gru_seq, batch_first=True)[0]
|
301 |
+
forward_seq = gru_seq[..., :self.hidden_size]
|
302 |
+
backward_seq = gru_seq[..., self.hidden_size:].clone()
|
303 |
+
|
304 |
+
# Concate the forward and backward word embeddings
|
305 |
+
for i, length in enumerate(cap_lens):
|
306 |
+
backward_seq[i:i+1, :length] = torch.flip(backward_seq[i:i+1, :length].clone(), dims=[1])
|
307 |
+
gru_seq = torch.cat([forward_seq, backward_seq], dim=-1)
|
308 |
+
|
309 |
+
return gru_seq, gru_last
|
310 |
+
|
311 |
+
|
312 |
+
class TextEncoderBiGRUCo(nn.Module):
|
313 |
+
def __init__(self, word_size, pos_size, hidden_size, output_size, device):
|
314 |
+
super(TextEncoderBiGRUCo, self).__init__()
|
315 |
+
self.device = device
|
316 |
+
|
317 |
+
self.pos_emb = nn.Linear(pos_size, word_size)
|
318 |
+
self.input_emb = nn.Linear(word_size, hidden_size)
|
319 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
|
320 |
+
self.output_net = nn.Sequential(
|
321 |
+
nn.Linear(hidden_size * 2, hidden_size),
|
322 |
+
nn.LayerNorm(hidden_size),
|
323 |
+
nn.LeakyReLU(0.2, inplace=True),
|
324 |
+
nn.Linear(hidden_size, output_size)
|
325 |
+
)
|
326 |
+
|
327 |
+
self.input_emb.apply(init_weight)
|
328 |
+
self.pos_emb.apply(init_weight)
|
329 |
+
self.output_net.apply(init_weight)
|
330 |
+
# self.linear2.apply(init_weight)
|
331 |
+
# self.batch_size = batch_size
|
332 |
+
self.hidden_size = hidden_size
|
333 |
+
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
|
334 |
+
|
335 |
+
# input(batch_size, seq_len, dim)
|
336 |
+
def forward(self, word_embs, pos_onehot, cap_lens):
|
337 |
+
num_samples = word_embs.shape[0]
|
338 |
+
|
339 |
+
pos_embs = self.pos_emb(pos_onehot)
|
340 |
+
inputs = word_embs + pos_embs
|
341 |
+
input_embs = self.input_emb(inputs)
|
342 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
343 |
+
|
344 |
+
cap_lens = cap_lens.data.tolist()
|
345 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
|
346 |
+
|
347 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
348 |
+
|
349 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
350 |
+
|
351 |
+
return self.output_net(gru_last)
|
352 |
+
|
353 |
+
|
354 |
+
class MotionEncoderBiGRUCo(nn.Module):
|
355 |
+
def __init__(self, input_size, hidden_size, output_size, device):
|
356 |
+
super(MotionEncoderBiGRUCo, self).__init__()
|
357 |
+
self.device = device
|
358 |
+
|
359 |
+
self.input_emb = nn.Linear(input_size, hidden_size)
|
360 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
|
361 |
+
self.output_net = nn.Sequential(
|
362 |
+
nn.Linear(hidden_size*2, hidden_size),
|
363 |
+
nn.LayerNorm(hidden_size),
|
364 |
+
nn.LeakyReLU(0.2, inplace=True),
|
365 |
+
nn.Linear(hidden_size, output_size)
|
366 |
+
)
|
367 |
+
|
368 |
+
self.input_emb.apply(init_weight)
|
369 |
+
self.output_net.apply(init_weight)
|
370 |
+
self.hidden_size = hidden_size
|
371 |
+
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
|
372 |
+
|
373 |
+
# input(batch_size, seq_len, dim)
|
374 |
+
def forward(self, inputs, m_lens):
|
375 |
+
num_samples = inputs.shape[0]
|
376 |
+
|
377 |
+
input_embs = self.input_emb(inputs)
|
378 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
379 |
+
|
380 |
+
cap_lens = m_lens.data.tolist()
|
381 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
|
382 |
+
|
383 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
384 |
+
|
385 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
386 |
+
|
387 |
+
return self.output_net(gru_last)
|
388 |
+
|
389 |
+
|
390 |
+
class MotionLenEstimatorBiGRU(nn.Module):
|
391 |
+
def __init__(self, word_size, pos_size, hidden_size, output_size):
|
392 |
+
super(MotionLenEstimatorBiGRU, self).__init__()
|
393 |
+
|
394 |
+
self.pos_emb = nn.Linear(pos_size, word_size)
|
395 |
+
self.input_emb = nn.Linear(word_size, hidden_size)
|
396 |
+
self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
|
397 |
+
nd = 512
|
398 |
+
self.output = nn.Sequential(
|
399 |
+
nn.Linear(hidden_size*2, nd),
|
400 |
+
nn.LayerNorm(nd),
|
401 |
+
nn.LeakyReLU(0.2, inplace=True),
|
402 |
+
|
403 |
+
nn.Linear(nd, nd // 2),
|
404 |
+
nn.LayerNorm(nd // 2),
|
405 |
+
nn.LeakyReLU(0.2, inplace=True),
|
406 |
+
|
407 |
+
nn.Linear(nd // 2, nd // 4),
|
408 |
+
nn.LayerNorm(nd // 4),
|
409 |
+
nn.LeakyReLU(0.2, inplace=True),
|
410 |
+
|
411 |
+
nn.Linear(nd // 4, output_size)
|
412 |
+
)
|
413 |
+
|
414 |
+
self.input_emb.apply(init_weight)
|
415 |
+
self.pos_emb.apply(init_weight)
|
416 |
+
self.output.apply(init_weight)
|
417 |
+
self.hidden_size = hidden_size
|
418 |
+
self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
|
419 |
+
|
420 |
+
# input(batch_size, seq_len, dim)
|
421 |
+
def forward(self, word_embs, pos_onehot, cap_lens):
|
422 |
+
num_samples = word_embs.shape[0]
|
423 |
+
|
424 |
+
pos_embs = self.pos_emb(pos_onehot)
|
425 |
+
inputs = word_embs + pos_embs
|
426 |
+
input_embs = self.input_emb(inputs)
|
427 |
+
hidden = self.hidden.repeat(1, num_samples, 1)
|
428 |
+
|
429 |
+
cap_lens = cap_lens.data.tolist()
|
430 |
+
emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
|
431 |
+
|
432 |
+
gru_seq, gru_last = self.gru(emb, hidden)
|
433 |
+
|
434 |
+
gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
|
435 |
+
|
436 |
+
return self.output(gru_last)
|
eval/evaluator_wrapper.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file code from T2M(https://github.com/EricGuo5513/text-to-motion), licensed under the https://github.com/EricGuo5513/text-to-motion/blob/main/LICENSE.
|
2 |
+
# Copyright (c) 2022 Chuan Guo
|
3 |
+
import torch
|
4 |
+
from os.path import join as pjoin
|
5 |
+
import numpy as np
|
6 |
+
from .evaluator_modules import *
|
7 |
+
|
8 |
+
|
9 |
+
def build_models(opt):
|
10 |
+
movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
|
11 |
+
text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
|
12 |
+
pos_size=opt.dim_pos_ohot,
|
13 |
+
hidden_size=opt.dim_text_hidden,
|
14 |
+
output_size=opt.dim_coemb_hidden,
|
15 |
+
device=opt.device)
|
16 |
+
|
17 |
+
motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
|
18 |
+
hidden_size=opt.dim_motion_hidden,
|
19 |
+
output_size=opt.dim_coemb_hidden,
|
20 |
+
device=opt.device)
|
21 |
+
checkpoint = torch.load(pjoin(opt.evaluator_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
|
22 |
+
map_location=opt.device)
|
23 |
+
movement_enc.load_state_dict(checkpoint['movement_encoder'])
|
24 |
+
text_enc.load_state_dict(checkpoint['text_encoder'])
|
25 |
+
motion_enc.load_state_dict(checkpoint['motion_encoder'])
|
26 |
+
print('\nLoading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
|
27 |
+
return text_enc, motion_enc, movement_enc
|
28 |
+
|
29 |
+
class EvaluatorModelWrapper(object):
|
30 |
+
|
31 |
+
def __init__(self, opt):
|
32 |
+
self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
|
33 |
+
self.opt = opt
|
34 |
+
self.device = opt.device
|
35 |
+
|
36 |
+
self.text_encoder.to(opt.device)
|
37 |
+
self.motion_encoder.to(opt.device)
|
38 |
+
self.movement_encoder.to(opt.device)
|
39 |
+
|
40 |
+
self.text_encoder.eval()
|
41 |
+
self.motion_encoder.eval()
|
42 |
+
self.movement_encoder.eval()
|
43 |
+
|
44 |
+
# Please note that the results does not following the order of inputs
|
45 |
+
def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
|
46 |
+
with torch.no_grad():
|
47 |
+
word_embs = word_embs.detach().to(self.device).float()
|
48 |
+
pos_ohot = pos_ohot.detach().to(self.device).float()
|
49 |
+
motions = motions.detach().to(self.device).float()
|
50 |
+
|
51 |
+
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
|
52 |
+
motions = motions[align_idx]
|
53 |
+
m_lens = m_lens[align_idx]
|
54 |
+
|
55 |
+
'''Movement Encoding'''
|
56 |
+
movements = self.movement_encoder(motions[..., :-4]).detach()
|
57 |
+
m_lens = torch.div(m_lens, self.opt.unit_length, rounding_mode='trunc')
|
58 |
+
motion_embedding = self.motion_encoder(movements, m_lens)
|
59 |
+
|
60 |
+
'''Text Encoding'''
|
61 |
+
text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
|
62 |
+
text_embedding = text_embedding[align_idx]
|
63 |
+
return text_embedding, motion_embedding
|
64 |
+
|
65 |
+
# Please note that the results does not following the order of inputs
|
66 |
+
def get_motion_embeddings(self, motions, m_lens):
|
67 |
+
with torch.no_grad():
|
68 |
+
motions = motions.detach().to(self.device).float()
|
69 |
+
|
70 |
+
align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
|
71 |
+
motions = motions[align_idx]
|
72 |
+
m_lens = m_lens[align_idx]
|
73 |
+
|
74 |
+
'''Movement Encoding'''
|
75 |
+
movements = self.movement_encoder(motions[..., :-4]).detach()
|
76 |
+
m_lens = torch.div(m_lens, self.opt.unit_length, rounding_mode='trunc')
|
77 |
+
motion_embedding = self.motion_encoder(movements, m_lens)
|
78 |
+
return motion_embedding
|
models/__init__.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .unet import MotionCLR
|
2 |
+
|
3 |
+
|
4 |
+
__all__ = ["MotionCLR"]
|
5 |
+
|
6 |
+
|
7 |
+
def build_models(opt, edit_config=None, out_path=None):
|
8 |
+
print("\nInitializing model ...")
|
9 |
+
model = MotionCLR(
|
10 |
+
input_feats=opt.dim_pose,
|
11 |
+
text_latent_dim=opt.text_latent_dim,
|
12 |
+
base_dim=opt.base_dim,
|
13 |
+
dim_mults=opt.dim_mults,
|
14 |
+
time_dim=opt.time_dim,
|
15 |
+
adagn=not opt.no_adagn,
|
16 |
+
zero=True,
|
17 |
+
dropout=opt.dropout,
|
18 |
+
no_eff=opt.no_eff,
|
19 |
+
cond_mask_prob=getattr(opt, "cond_mask_prob", 0.0),
|
20 |
+
self_attention=opt.self_attention,
|
21 |
+
vis_attn=opt.vis_attn,
|
22 |
+
edit_config=edit_config,
|
23 |
+
out_path=out_path,
|
24 |
+
)
|
25 |
+
|
26 |
+
return model
|
models/gaussian_diffusion.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from diffusers import (
|
2 |
+
DPMSolverMultistepScheduler,
|
3 |
+
DDPMScheduler,
|
4 |
+
DDIMScheduler,
|
5 |
+
PNDMScheduler,
|
6 |
+
DEISMultistepScheduler,
|
7 |
+
)
|
8 |
+
import torch
|
9 |
+
import yaml
|
10 |
+
import math
|
11 |
+
import tqdm
|
12 |
+
import time
|
13 |
+
|
14 |
+
|
15 |
+
class DiffusePipeline(object):
|
16 |
+
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
opt,
|
20 |
+
model,
|
21 |
+
diffuser_name,
|
22 |
+
num_inference_steps,
|
23 |
+
device,
|
24 |
+
torch_dtype=torch.float16,
|
25 |
+
):
|
26 |
+
self.device = device
|
27 |
+
self.torch_dtype = torch_dtype
|
28 |
+
self.diffuser_name = diffuser_name
|
29 |
+
self.num_inference_steps = num_inference_steps
|
30 |
+
if self.torch_dtype == torch.float16:
|
31 |
+
model = model.half()
|
32 |
+
self.model = model.to(device)
|
33 |
+
self.opt = opt
|
34 |
+
|
35 |
+
# Load parameters from YAML file
|
36 |
+
with open("config/diffuser_params.yaml", "r") as yaml_file:
|
37 |
+
diffuser_params = yaml.safe_load(yaml_file)
|
38 |
+
|
39 |
+
# Select diffusion'parameters based on diffuser_name
|
40 |
+
if diffuser_name in diffuser_params:
|
41 |
+
params = diffuser_params[diffuser_name]
|
42 |
+
scheduler_class_name = params["scheduler_class"]
|
43 |
+
additional_params = params["additional_params"]
|
44 |
+
|
45 |
+
# align training parameters
|
46 |
+
additional_params["num_train_timesteps"] = opt.diffusion_steps
|
47 |
+
additional_params["beta_schedule"] = opt.beta_schedule
|
48 |
+
additional_params["prediction_type"] = opt.prediction_type
|
49 |
+
|
50 |
+
try:
|
51 |
+
scheduler_class = globals()[scheduler_class_name]
|
52 |
+
except KeyError:
|
53 |
+
raise ValueError(f"Class '{scheduler_class_name}' not found.")
|
54 |
+
|
55 |
+
self.scheduler = scheduler_class(**additional_params)
|
56 |
+
else:
|
57 |
+
raise ValueError(f"Unsupported diffuser_name: {diffuser_name}")
|
58 |
+
|
59 |
+
def generate_batch(self, caption, m_lens):
|
60 |
+
B = len(caption)
|
61 |
+
T = m_lens.max()
|
62 |
+
shape = (B, T, self.model.input_feats)
|
63 |
+
|
64 |
+
# random sampling noise x_T
|
65 |
+
sample = torch.randn(shape, device=self.device, dtype=self.torch_dtype)
|
66 |
+
|
67 |
+
# set timesteps
|
68 |
+
self.scheduler.set_timesteps(self.num_inference_steps, self.device)
|
69 |
+
timesteps = [
|
70 |
+
torch.tensor([t] * B, device=self.device).long()
|
71 |
+
for t in self.scheduler.timesteps
|
72 |
+
]
|
73 |
+
|
74 |
+
# cache text_embedded
|
75 |
+
enc_text = self.model.encode_text(caption, self.device)
|
76 |
+
|
77 |
+
for i, t in enumerate(timesteps):
|
78 |
+
# 1. model predict
|
79 |
+
with torch.no_grad():
|
80 |
+
if getattr(self.model, "cond_mask_prob", 0) > 0:
|
81 |
+
predict = self.model.forward_with_cfg(sample, t, enc_text=enc_text)
|
82 |
+
else:
|
83 |
+
|
84 |
+
predict = self.model(sample, t, enc_text=enc_text)
|
85 |
+
|
86 |
+
# 2. compute less noisy motion and set x_t -> x_t-1
|
87 |
+
sample = self.scheduler.step(predict, t[0], sample).prev_sample
|
88 |
+
|
89 |
+
return sample
|
90 |
+
|
91 |
+
def generate(self, caption, m_lens, batch_size=32):
|
92 |
+
N = len(caption)
|
93 |
+
infer_mode = ""
|
94 |
+
if getattr(self.model, "cond_mask_prob", 0) > 0:
|
95 |
+
infer_mode = "classifier-free-guidance"
|
96 |
+
print(
|
97 |
+
f"\nUsing {self.diffuser_name} diffusion scheduler to {infer_mode} generate {N} motions, sampling {self.num_inference_steps} steps."
|
98 |
+
)
|
99 |
+
self.model.eval()
|
100 |
+
|
101 |
+
all_output = []
|
102 |
+
t_sum = 0
|
103 |
+
cur_idx = 0
|
104 |
+
for bacth_idx in tqdm.tqdm(range(math.ceil(N / batch_size))):
|
105 |
+
if cur_idx + batch_size >= N:
|
106 |
+
batch_caption = caption[cur_idx:]
|
107 |
+
batch_m_lens = m_lens[cur_idx:]
|
108 |
+
else:
|
109 |
+
batch_caption = caption[cur_idx : cur_idx + batch_size]
|
110 |
+
batch_m_lens = m_lens[cur_idx : cur_idx + batch_size]
|
111 |
+
torch.cuda.synchronize()
|
112 |
+
start_time = time.time()
|
113 |
+
output = self.generate_batch(batch_caption, batch_m_lens)
|
114 |
+
torch.cuda.synchronize()
|
115 |
+
now_time = time.time()
|
116 |
+
|
117 |
+
# The average inference time is calculated after GPU warm-up in the first 50 steps.
|
118 |
+
if (bacth_idx + 1) * self.num_inference_steps >= 50:
|
119 |
+
t_sum += now_time - start_time
|
120 |
+
|
121 |
+
# Crop motion with gt/predicted motion length
|
122 |
+
B = output.shape[0]
|
123 |
+
for i in range(B):
|
124 |
+
all_output.append(output[i, : batch_m_lens[i]])
|
125 |
+
|
126 |
+
cur_idx += batch_size
|
127 |
+
|
128 |
+
# calcalate average inference time
|
129 |
+
t_eval = t_sum / (bacth_idx - 1)
|
130 |
+
print(
|
131 |
+
"The average generation time of a batch motion (bs=%d) is %f seconds"
|
132 |
+
% (batch_size, t_eval)
|
133 |
+
)
|
134 |
+
return all_output, t_eval
|
models/unet.py
ADDED
@@ -0,0 +1,1073 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import clip
|
2 |
+
import math
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from torch import nn
|
7 |
+
import numpy as np
|
8 |
+
from einops.layers.torch import Rearrange
|
9 |
+
from einops import rearrange
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import os
|
12 |
+
|
13 |
+
|
14 |
+
MONITOR_ATTN = []
|
15 |
+
SELF_ATTN = []
|
16 |
+
|
17 |
+
|
18 |
+
def vis_attn(att, out_path, step, layer, shape, type_="self", lines=True):
|
19 |
+
if lines:
|
20 |
+
plt.figure(figsize=(10, 3))
|
21 |
+
for token_index in range(att.shape[1]):
|
22 |
+
plt.plot(att[:, token_index], label=f"Token {token_index}")
|
23 |
+
|
24 |
+
plt.title("Attention Values for Each Token")
|
25 |
+
plt.xlabel("time")
|
26 |
+
plt.ylabel("Attention Value")
|
27 |
+
plt.legend(loc="upper right", bbox_to_anchor=(1.15, 1))
|
28 |
+
|
29 |
+
# save image
|
30 |
+
savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_lines_{shape}.png")
|
31 |
+
os.makedirs(os.path.dirname(savepath), exist_ok=True)
|
32 |
+
plt.savefig(savepath, bbox_inches="tight")
|
33 |
+
np.save(savepath.replace(".png", ".npy"), att)
|
34 |
+
else:
|
35 |
+
plt.figure(figsize=(10, 10))
|
36 |
+
plt.imshow(att.transpose(), cmap="viridis", aspect="auto")
|
37 |
+
plt.colorbar()
|
38 |
+
plt.title("Attention Matrix Heatmap")
|
39 |
+
plt.ylabel("time")
|
40 |
+
plt.xlabel("time")
|
41 |
+
|
42 |
+
# save image
|
43 |
+
savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_heatmap_{shape}.png")
|
44 |
+
os.makedirs(os.path.dirname(savepath), exist_ok=True)
|
45 |
+
plt.savefig(savepath, bbox_inches="tight")
|
46 |
+
np.save(savepath.replace(".png", ".npy"), att)
|
47 |
+
|
48 |
+
|
49 |
+
def zero_module(module):
|
50 |
+
"""
|
51 |
+
Zero out the parameters of a module and return it.
|
52 |
+
"""
|
53 |
+
for p in module.parameters():
|
54 |
+
p.detach().zero_()
|
55 |
+
return module
|
56 |
+
|
57 |
+
|
58 |
+
class FFN(nn.Module):
|
59 |
+
|
60 |
+
def __init__(self, latent_dim, ffn_dim, dropout):
|
61 |
+
super().__init__()
|
62 |
+
self.linear1 = nn.Linear(latent_dim, ffn_dim)
|
63 |
+
self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim))
|
64 |
+
self.activation = nn.GELU()
|
65 |
+
self.dropout = nn.Dropout(dropout)
|
66 |
+
|
67 |
+
def forward(self, x):
|
68 |
+
y = self.linear2(self.dropout(self.activation(self.linear1(x))))
|
69 |
+
y = x + y
|
70 |
+
return y
|
71 |
+
|
72 |
+
|
73 |
+
class Conv1dAdaGNBlock(nn.Module):
|
74 |
+
"""
|
75 |
+
Conv1d --> GroupNorm --> scale,shift --> Mish
|
76 |
+
"""
|
77 |
+
|
78 |
+
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4):
|
79 |
+
super().__init__()
|
80 |
+
self.out_channels = out_channels
|
81 |
+
self.block = nn.Conv1d(
|
82 |
+
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
|
83 |
+
)
|
84 |
+
self.group_norm = nn.GroupNorm(n_groups, out_channels)
|
85 |
+
self.avtication = nn.Mish()
|
86 |
+
|
87 |
+
def forward(self, x, scale, shift):
|
88 |
+
"""
|
89 |
+
Args:
|
90 |
+
x: [bs, nfeat, nframes]
|
91 |
+
scale: [bs, out_feat, 1]
|
92 |
+
shift: [bs, out_feat, 1]
|
93 |
+
"""
|
94 |
+
x = self.block(x)
|
95 |
+
|
96 |
+
batch_size, channels, horizon = x.size()
|
97 |
+
x = rearrange(
|
98 |
+
x, "batch channels horizon -> (batch horizon) channels"
|
99 |
+
) # [bs*seq, nfeats]
|
100 |
+
x = self.group_norm(x)
|
101 |
+
x = rearrange(
|
102 |
+
x.reshape(batch_size, horizon, channels),
|
103 |
+
"batch horizon channels -> batch channels horizon",
|
104 |
+
)
|
105 |
+
x = ada_shift_scale(x, shift, scale)
|
106 |
+
|
107 |
+
return self.avtication(x)
|
108 |
+
|
109 |
+
|
110 |
+
class SelfAttention(nn.Module):
|
111 |
+
|
112 |
+
def __init__(
|
113 |
+
self,
|
114 |
+
latent_dim,
|
115 |
+
text_latent_dim,
|
116 |
+
num_heads: int = 8,
|
117 |
+
dropout: float = 0.0,
|
118 |
+
log_attn=False,
|
119 |
+
edit_config=None,
|
120 |
+
):
|
121 |
+
super().__init__()
|
122 |
+
self.num_head = num_heads
|
123 |
+
self.norm = nn.LayerNorm(latent_dim)
|
124 |
+
self.query = nn.Linear(latent_dim, latent_dim)
|
125 |
+
self.key = nn.Linear(latent_dim, latent_dim)
|
126 |
+
self.value = nn.Linear(latent_dim, latent_dim)
|
127 |
+
self.dropout = nn.Dropout(dropout)
|
128 |
+
|
129 |
+
self.edit_config = edit_config
|
130 |
+
self.log_attn = log_attn
|
131 |
+
|
132 |
+
def forward(self, x):
|
133 |
+
"""
|
134 |
+
x: B, T, D
|
135 |
+
xf: B, N, L
|
136 |
+
"""
|
137 |
+
B, T, D = x.shape
|
138 |
+
N = x.shape[1]
|
139 |
+
assert N == T
|
140 |
+
H = self.num_head
|
141 |
+
|
142 |
+
# B, T, 1, D
|
143 |
+
query = self.query(self.norm(x)).unsqueeze(2)
|
144 |
+
# B, 1, N, D
|
145 |
+
key = self.key(self.norm(x)).unsqueeze(1)
|
146 |
+
query = query.view(B, T, H, -1)
|
147 |
+
key = key.view(B, N, H, -1)
|
148 |
+
|
149 |
+
# style transfer motion editing
|
150 |
+
style_tranfer = self.edit_config.style_tranfer.use
|
151 |
+
if style_tranfer:
|
152 |
+
if (
|
153 |
+
len(SELF_ATTN)
|
154 |
+
<= self.edit_config.style_tranfer.style_transfer_steps_end
|
155 |
+
):
|
156 |
+
query[1] = query[0]
|
157 |
+
|
158 |
+
# example based motion generation
|
159 |
+
example_based = self.edit_config.example_based.use
|
160 |
+
if example_based:
|
161 |
+
if len(SELF_ATTN) == self.edit_config.example_based.example_based_steps_end:
|
162 |
+
|
163 |
+
temp_seed = self.edit_config.example_based.temp_seed
|
164 |
+
for id_ in range(query.shape[0] - 1):
|
165 |
+
with torch.random.fork_rng():
|
166 |
+
torch.manual_seed(temp_seed)
|
167 |
+
tensor = query[0]
|
168 |
+
chunks = torch.split(
|
169 |
+
tensor, self.edit_config.example_based.chunk_size, dim=0
|
170 |
+
)
|
171 |
+
shuffled_indices = torch.randperm(len(chunks))
|
172 |
+
shuffled_chunks = [chunks[i] for i in shuffled_indices]
|
173 |
+
shuffled_tensor = torch.cat(shuffled_chunks, dim=0)
|
174 |
+
query[1 + id_] = shuffled_tensor
|
175 |
+
temp_seed += self.edit_config.example_based.temp_seed_bar
|
176 |
+
|
177 |
+
# time shift motion editing (q, k)
|
178 |
+
time_shift = self.edit_config.time_shift.use
|
179 |
+
if time_shift:
|
180 |
+
if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end:
|
181 |
+
part1 = int(
|
182 |
+
key.shape[1] * self.edit_config.time_shift.time_shift_ratio // 1
|
183 |
+
)
|
184 |
+
part2 = int(
|
185 |
+
key.shape[1]
|
186 |
+
* (1 - self.edit_config.time_shift.time_shift_ratio)
|
187 |
+
// 1
|
188 |
+
)
|
189 |
+
q_front_part = query[0, :part1, :, :]
|
190 |
+
q_back_part = query[0, -part2:, :, :]
|
191 |
+
|
192 |
+
new_q = torch.cat((q_back_part, q_front_part), dim=0)
|
193 |
+
query[1] = new_q
|
194 |
+
|
195 |
+
k_front_part = key[0, :part1, :, :]
|
196 |
+
k_back_part = key[0, -part2:, :, :]
|
197 |
+
new_k = torch.cat((k_back_part, k_front_part), dim=0)
|
198 |
+
key[1] = new_k
|
199 |
+
|
200 |
+
# B, T, N, H
|
201 |
+
attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H)
|
202 |
+
weight = self.dropout(F.softmax(attention, dim=2))
|
203 |
+
|
204 |
+
# for counting the step and logging attention maps
|
205 |
+
try:
|
206 |
+
attention_matrix = (
|
207 |
+
weight[0, :, :].mean(dim=-1).detach().cpu().numpy().astype(float)
|
208 |
+
)
|
209 |
+
SELF_ATTN[-1].append(attention_matrix)
|
210 |
+
except:
|
211 |
+
pass
|
212 |
+
|
213 |
+
# attention manipulation for replacement
|
214 |
+
attention_manipulation = self.edit_config.manipulation.use
|
215 |
+
if attention_manipulation:
|
216 |
+
if len(SELF_ATTN) <= self.edit_config.manipulation.manipulation_steps_end:
|
217 |
+
weight[1, :, :, :] = weight[0, :, :, :]
|
218 |
+
|
219 |
+
value = self.value(self.norm(x)).view(B, N, H, -1)
|
220 |
+
|
221 |
+
# time shift motion editing (v)
|
222 |
+
if time_shift:
|
223 |
+
if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end:
|
224 |
+
v_front_part = value[0, :part1, :, :]
|
225 |
+
v_back_part = value[0, -part2:, :, :]
|
226 |
+
new_v = torch.cat((v_back_part, v_front_part), dim=0)
|
227 |
+
value[1] = new_v
|
228 |
+
y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D)
|
229 |
+
return y
|
230 |
+
|
231 |
+
|
232 |
+
class TimestepEmbedder(nn.Module):
|
233 |
+
def __init__(self, d_model, max_len=5000):
|
234 |
+
super(TimestepEmbedder, self).__init__()
|
235 |
+
|
236 |
+
pe = torch.zeros(max_len, d_model)
|
237 |
+
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
|
238 |
+
div_term = torch.exp(
|
239 |
+
torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
|
240 |
+
)
|
241 |
+
pe[:, 0::2] = torch.sin(position * div_term)
|
242 |
+
pe[:, 1::2] = torch.cos(position * div_term)
|
243 |
+
|
244 |
+
self.register_buffer("pe", pe)
|
245 |
+
|
246 |
+
def forward(self, x):
|
247 |
+
return self.pe[x]
|
248 |
+
|
249 |
+
|
250 |
+
class Downsample1d(nn.Module):
|
251 |
+
def __init__(self, dim):
|
252 |
+
super().__init__()
|
253 |
+
self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
|
254 |
+
|
255 |
+
def forward(self, x):
|
256 |
+
return self.conv(x)
|
257 |
+
|
258 |
+
|
259 |
+
class Upsample1d(nn.Module):
|
260 |
+
def __init__(self, dim_in, dim_out=None):
|
261 |
+
super().__init__()
|
262 |
+
dim_out = dim_out or dim_in
|
263 |
+
self.conv = nn.ConvTranspose1d(dim_in, dim_out, 4, 2, 1)
|
264 |
+
|
265 |
+
def forward(self, x):
|
266 |
+
return self.conv(x)
|
267 |
+
|
268 |
+
|
269 |
+
class Conv1dBlock(nn.Module):
|
270 |
+
"""
|
271 |
+
Conv1d --> GroupNorm --> Mish
|
272 |
+
"""
|
273 |
+
|
274 |
+
def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4, zero=False):
|
275 |
+
super().__init__()
|
276 |
+
self.out_channels = out_channels
|
277 |
+
self.block = nn.Conv1d(
|
278 |
+
inp_channels, out_channels, kernel_size, padding=kernel_size // 2
|
279 |
+
)
|
280 |
+
self.norm = nn.GroupNorm(n_groups, out_channels)
|
281 |
+
self.activation = nn.Mish()
|
282 |
+
|
283 |
+
if zero:
|
284 |
+
# zero init the convolution
|
285 |
+
nn.init.zeros_(self.block.weight)
|
286 |
+
nn.init.zeros_(self.block.bias)
|
287 |
+
|
288 |
+
def forward(self, x):
|
289 |
+
"""
|
290 |
+
Args:
|
291 |
+
x: [bs, nfeat, nframes]
|
292 |
+
"""
|
293 |
+
x = self.block(x)
|
294 |
+
|
295 |
+
batch_size, channels, horizon = x.size()
|
296 |
+
x = rearrange(
|
297 |
+
x, "batch channels horizon -> (batch horizon) channels"
|
298 |
+
) # [bs*seq, nfeats]
|
299 |
+
x = self.norm(x)
|
300 |
+
x = rearrange(
|
301 |
+
x.reshape(batch_size, horizon, channels),
|
302 |
+
"batch horizon channels -> batch channels horizon",
|
303 |
+
)
|
304 |
+
|
305 |
+
return self.activation(x)
|
306 |
+
|
307 |
+
|
308 |
+
def ada_shift_scale(x, shift, scale):
|
309 |
+
return x * (1 + scale) + shift
|
310 |
+
|
311 |
+
|
312 |
+
class ResidualTemporalBlock(nn.Module):
|
313 |
+
def __init__(
|
314 |
+
self,
|
315 |
+
inp_channels,
|
316 |
+
out_channels,
|
317 |
+
embed_dim,
|
318 |
+
kernel_size=5,
|
319 |
+
zero=True,
|
320 |
+
n_groups=8,
|
321 |
+
dropout: float = 0.1,
|
322 |
+
adagn=True,
|
323 |
+
):
|
324 |
+
super().__init__()
|
325 |
+
self.adagn = adagn
|
326 |
+
|
327 |
+
self.blocks = nn.ModuleList(
|
328 |
+
[
|
329 |
+
# adagn only the first conv (following guided-diffusion)
|
330 |
+
(
|
331 |
+
Conv1dAdaGNBlock(inp_channels, out_channels, kernel_size, n_groups)
|
332 |
+
if adagn
|
333 |
+
else Conv1dBlock(inp_channels, out_channels, kernel_size)
|
334 |
+
),
|
335 |
+
Conv1dBlock(
|
336 |
+
out_channels, out_channels, kernel_size, n_groups, zero=zero
|
337 |
+
),
|
338 |
+
]
|
339 |
+
)
|
340 |
+
|
341 |
+
self.time_mlp = nn.Sequential(
|
342 |
+
nn.Mish(),
|
343 |
+
# adagn = scale and shift
|
344 |
+
nn.Linear(embed_dim, out_channels * 2 if adagn else out_channels),
|
345 |
+
Rearrange("batch t -> batch t 1"),
|
346 |
+
)
|
347 |
+
self.dropout = nn.Dropout(dropout)
|
348 |
+
if zero:
|
349 |
+
nn.init.zeros_(self.time_mlp[1].weight)
|
350 |
+
nn.init.zeros_(self.time_mlp[1].bias)
|
351 |
+
|
352 |
+
self.residual_conv = (
|
353 |
+
nn.Conv1d(inp_channels, out_channels, 1)
|
354 |
+
if inp_channels != out_channels
|
355 |
+
else nn.Identity()
|
356 |
+
)
|
357 |
+
|
358 |
+
def forward(self, x, time_embeds=None):
|
359 |
+
"""
|
360 |
+
x : [ batch_size x inp_channels x nframes ]
|
361 |
+
t : [ batch_size x embed_dim ]
|
362 |
+
returns: [ batch_size x out_channels x nframes ]
|
363 |
+
"""
|
364 |
+
if self.adagn:
|
365 |
+
scale, shift = self.time_mlp(time_embeds).chunk(2, dim=1)
|
366 |
+
out = self.blocks[0](x, scale, shift)
|
367 |
+
else:
|
368 |
+
out = self.blocks[0](x) + self.time_mlp(time_embeds)
|
369 |
+
out = self.blocks[1](out)
|
370 |
+
out = self.dropout(out)
|
371 |
+
return out + self.residual_conv(x)
|
372 |
+
|
373 |
+
|
374 |
+
class CrossAttention(nn.Module):
|
375 |
+
|
376 |
+
def __init__(
|
377 |
+
self,
|
378 |
+
latent_dim,
|
379 |
+
text_latent_dim,
|
380 |
+
num_heads: int = 8,
|
381 |
+
dropout: float = 0.0,
|
382 |
+
log_attn=False,
|
383 |
+
edit_config=None,
|
384 |
+
):
|
385 |
+
super().__init__()
|
386 |
+
self.num_head = num_heads
|
387 |
+
self.norm = nn.LayerNorm(latent_dim)
|
388 |
+
self.text_norm = nn.LayerNorm(text_latent_dim)
|
389 |
+
self.query = nn.Linear(latent_dim, latent_dim)
|
390 |
+
self.key = nn.Linear(text_latent_dim, latent_dim)
|
391 |
+
self.value = nn.Linear(text_latent_dim, latent_dim)
|
392 |
+
self.dropout = nn.Dropout(dropout)
|
393 |
+
|
394 |
+
self.edit_config = edit_config
|
395 |
+
self.log_attn = log_attn
|
396 |
+
|
397 |
+
def forward(self, x, xf):
|
398 |
+
"""
|
399 |
+
x: B, T, D
|
400 |
+
xf: B, N, L
|
401 |
+
"""
|
402 |
+
B, T, D = x.shape
|
403 |
+
N = xf.shape[1]
|
404 |
+
H = self.num_head
|
405 |
+
# B, T, 1, D
|
406 |
+
query = self.query(self.norm(x)).unsqueeze(2)
|
407 |
+
# B, 1, N, D
|
408 |
+
key = self.key(self.text_norm(xf)).unsqueeze(1)
|
409 |
+
query = query.view(B, T, H, -1)
|
410 |
+
key = key.view(B, N, H, -1)
|
411 |
+
# B, T, N, H
|
412 |
+
attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H)
|
413 |
+
weight = self.dropout(F.softmax(attention, dim=2))
|
414 |
+
|
415 |
+
# attention reweighting for (de)-emphasizing motion
|
416 |
+
if self.edit_config.reweighting_attn.use:
|
417 |
+
reweighting_attn = self.edit_config.reweighting_attn.reweighting_attn_weight
|
418 |
+
if self.edit_config.reweighting_attn.idx == -1:
|
419 |
+
# read idxs from txt file
|
420 |
+
with open("./assets/reweighting_idx.txt", "r") as f:
|
421 |
+
idxs = f.readlines()
|
422 |
+
else:
|
423 |
+
# gradio demo mode
|
424 |
+
idxs = [0, self.edit_config.reweighting_attn.idx]
|
425 |
+
idxs = [int(idx) for idx in idxs]
|
426 |
+
for i in range(len(idxs)):
|
427 |
+
weight[i, :, 1 + idxs[i]] = weight[i, :, 1 + idxs[i]] + reweighting_attn
|
428 |
+
weight[i, :, 1 + idxs[i] + 1] = (
|
429 |
+
weight[i, :, 1 + idxs[i] + 1] + reweighting_attn
|
430 |
+
)
|
431 |
+
|
432 |
+
# for counting the step and logging attention maps
|
433 |
+
try:
|
434 |
+
attention_matrix = (
|
435 |
+
weight[0, :, 1 : 1 + 3]
|
436 |
+
.mean(dim=-1)
|
437 |
+
.detach()
|
438 |
+
.cpu()
|
439 |
+
.numpy()
|
440 |
+
.astype(float)
|
441 |
+
)
|
442 |
+
MONITOR_ATTN[-1].append(attention_matrix)
|
443 |
+
except:
|
444 |
+
pass
|
445 |
+
|
446 |
+
# erasing motion (autually is the deemphasizing motion)
|
447 |
+
erasing_motion = self.edit_config.erasing_motion.use
|
448 |
+
if erasing_motion:
|
449 |
+
reweighting_attn = self.edit_config.erasing_motion.erasing_motion_weight
|
450 |
+
begin = self.edit_config.erasing_motion.time_start
|
451 |
+
end = self.edit_config.erasing_motion.time_end
|
452 |
+
idx = self.edit_config.erasing_motion.idx
|
453 |
+
if reweighting_attn > 0.01 or reweighting_attn < -0.01:
|
454 |
+
weight[1, int(T * begin) : int(T * end), idx] = (
|
455 |
+
weight[1, int(T * begin) : int(T * end) :, idx] * reweighting_attn
|
456 |
+
)
|
457 |
+
weight[1, int(T * begin) : int(T * end), idx + 1] = (
|
458 |
+
weight[1, int(T * begin) : int(T * end), idx + 1] * reweighting_attn
|
459 |
+
)
|
460 |
+
|
461 |
+
# attention manipulation for motion replacement
|
462 |
+
manipulation = self.edit_config.manipulation.use
|
463 |
+
if manipulation:
|
464 |
+
if (
|
465 |
+
len(MONITOR_ATTN)
|
466 |
+
<= self.edit_config.manipulation.manipulation_steps_end_crossattn
|
467 |
+
):
|
468 |
+
word_idx = self.edit_config.manipulation.word_idx
|
469 |
+
weight[1, :, : 1 + word_idx, :] = weight[0, :, : 1 + word_idx, :]
|
470 |
+
weight[1, :, 1 + word_idx + 1 :, :] = weight[
|
471 |
+
0, :, 1 + word_idx + 1 :, :
|
472 |
+
]
|
473 |
+
|
474 |
+
value = self.value(self.text_norm(xf)).view(B, N, H, -1)
|
475 |
+
y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D)
|
476 |
+
return y
|
477 |
+
|
478 |
+
|
479 |
+
class ResidualCLRAttentionLayer(nn.Module):
|
480 |
+
def __init__(
|
481 |
+
self,
|
482 |
+
dim1,
|
483 |
+
dim2,
|
484 |
+
num_heads: int = 8,
|
485 |
+
dropout: float = 0.1,
|
486 |
+
no_eff: bool = False,
|
487 |
+
self_attention: bool = False,
|
488 |
+
log_attn=False,
|
489 |
+
edit_config=None,
|
490 |
+
):
|
491 |
+
super(ResidualCLRAttentionLayer, self).__init__()
|
492 |
+
self.dim1 = dim1
|
493 |
+
self.dim2 = dim2
|
494 |
+
self.num_heads = num_heads
|
495 |
+
|
496 |
+
# Multi-Head Attention Layer
|
497 |
+
if no_eff:
|
498 |
+
self.cross_attention = CrossAttention(
|
499 |
+
latent_dim=dim1,
|
500 |
+
text_latent_dim=dim2,
|
501 |
+
num_heads=num_heads,
|
502 |
+
dropout=dropout,
|
503 |
+
log_attn=log_attn,
|
504 |
+
edit_config=edit_config,
|
505 |
+
)
|
506 |
+
else:
|
507 |
+
self.cross_attention = LinearCrossAttention(
|
508 |
+
latent_dim=dim1,
|
509 |
+
text_latent_dim=dim2,
|
510 |
+
num_heads=num_heads,
|
511 |
+
dropout=dropout,
|
512 |
+
log_attn=log_attn,
|
513 |
+
)
|
514 |
+
if self_attention:
|
515 |
+
self.self_attn_use = True
|
516 |
+
self.self_attention = SelfAttention(
|
517 |
+
latent_dim=dim1,
|
518 |
+
text_latent_dim=dim2,
|
519 |
+
num_heads=num_heads,
|
520 |
+
dropout=dropout,
|
521 |
+
log_attn=log_attn,
|
522 |
+
edit_config=edit_config,
|
523 |
+
)
|
524 |
+
else:
|
525 |
+
self.self_attn_use = False
|
526 |
+
|
527 |
+
def forward(self, input_tensor, condition_tensor, cond_indices):
|
528 |
+
"""
|
529 |
+
input_tensor :B, D, L
|
530 |
+
condition_tensor: B, L, D
|
531 |
+
"""
|
532 |
+
if cond_indices.numel() == 0:
|
533 |
+
return input_tensor
|
534 |
+
|
535 |
+
# self attention
|
536 |
+
if self.self_attn_use:
|
537 |
+
x = input_tensor
|
538 |
+
x = x.permute(0, 2, 1) # (batch_size, seq_length, feat_dim)
|
539 |
+
x = self.self_attention(x)
|
540 |
+
x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length)
|
541 |
+
input_tensor = input_tensor + x
|
542 |
+
x = input_tensor
|
543 |
+
|
544 |
+
# cross attention
|
545 |
+
x = x[cond_indices].permute(0, 2, 1) # (batch_size, seq_length, feat_dim)
|
546 |
+
x = self.cross_attention(x, condition_tensor[cond_indices])
|
547 |
+
x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length)
|
548 |
+
|
549 |
+
input_tensor[cond_indices] = input_tensor[cond_indices] + x
|
550 |
+
|
551 |
+
return input_tensor
|
552 |
+
|
553 |
+
|
554 |
+
class CLRBlock(nn.Module):
|
555 |
+
def __init__(
|
556 |
+
self,
|
557 |
+
dim_in,
|
558 |
+
dim_out,
|
559 |
+
cond_dim,
|
560 |
+
time_dim,
|
561 |
+
adagn=True,
|
562 |
+
zero=True,
|
563 |
+
no_eff=False,
|
564 |
+
self_attention=False,
|
565 |
+
dropout: float = 0.1,
|
566 |
+
log_attn=False,
|
567 |
+
edit_config=None,
|
568 |
+
) -> None:
|
569 |
+
super().__init__()
|
570 |
+
self.conv1d = ResidualTemporalBlock(
|
571 |
+
dim_in, dim_out, embed_dim=time_dim, adagn=adagn, zero=zero, dropout=dropout
|
572 |
+
)
|
573 |
+
self.clr_attn = ResidualCLRAttentionLayer(
|
574 |
+
dim1=dim_out,
|
575 |
+
dim2=cond_dim,
|
576 |
+
no_eff=no_eff,
|
577 |
+
dropout=dropout,
|
578 |
+
self_attention=self_attention,
|
579 |
+
log_attn=log_attn,
|
580 |
+
edit_config=edit_config,
|
581 |
+
)
|
582 |
+
# import pdb; pdb.set_trace()
|
583 |
+
self.ffn = FFN(dim_out, dim_out * 4, dropout=dropout)
|
584 |
+
|
585 |
+
def forward(self, x, t, cond, cond_indices=None):
|
586 |
+
x = self.conv1d(x, t)
|
587 |
+
x = self.clr_attn(x, cond, cond_indices)
|
588 |
+
x = self.ffn(x.permute(0, 2, 1)).permute(0, 2, 1)
|
589 |
+
return x
|
590 |
+
|
591 |
+
|
592 |
+
class CondUnet1D(nn.Module):
|
593 |
+
"""
|
594 |
+
Diffusion's style UNET with 1D convolution and adaptive group normalization for motion suquence denoising,
|
595 |
+
cross-attention to introduce conditional prompts (like text).
|
596 |
+
"""
|
597 |
+
|
598 |
+
def __init__(
|
599 |
+
self,
|
600 |
+
input_dim,
|
601 |
+
cond_dim,
|
602 |
+
dim=128,
|
603 |
+
dim_mults=(1, 2, 4, 8),
|
604 |
+
dims=None,
|
605 |
+
time_dim=512,
|
606 |
+
adagn=True,
|
607 |
+
zero=True,
|
608 |
+
dropout=0.1,
|
609 |
+
no_eff=False,
|
610 |
+
self_attention=False,
|
611 |
+
log_attn=False,
|
612 |
+
edit_config=None,
|
613 |
+
):
|
614 |
+
super().__init__()
|
615 |
+
if not dims:
|
616 |
+
dims = [input_dim, *map(lambda m: int(dim * m), dim_mults)] ##[d, d,2d,4d]
|
617 |
+
print("dims: ", dims, "mults: ", dim_mults)
|
618 |
+
in_out = list(zip(dims[:-1], dims[1:]))
|
619 |
+
|
620 |
+
self.time_mlp = nn.Sequential(
|
621 |
+
TimestepEmbedder(time_dim),
|
622 |
+
nn.Linear(time_dim, time_dim * 4),
|
623 |
+
nn.Mish(),
|
624 |
+
nn.Linear(time_dim * 4, time_dim),
|
625 |
+
)
|
626 |
+
|
627 |
+
self.downs = nn.ModuleList([])
|
628 |
+
self.ups = nn.ModuleList([])
|
629 |
+
|
630 |
+
for ind, (dim_in, dim_out) in enumerate(in_out):
|
631 |
+
self.downs.append(
|
632 |
+
nn.ModuleList(
|
633 |
+
[
|
634 |
+
CLRBlock(
|
635 |
+
dim_in,
|
636 |
+
dim_out,
|
637 |
+
cond_dim,
|
638 |
+
time_dim,
|
639 |
+
adagn=adagn,
|
640 |
+
zero=zero,
|
641 |
+
no_eff=no_eff,
|
642 |
+
dropout=dropout,
|
643 |
+
self_attention=self_attention,
|
644 |
+
log_attn=log_attn,
|
645 |
+
edit_config=edit_config,
|
646 |
+
),
|
647 |
+
CLRBlock(
|
648 |
+
dim_out,
|
649 |
+
dim_out,
|
650 |
+
cond_dim,
|
651 |
+
time_dim,
|
652 |
+
adagn=adagn,
|
653 |
+
zero=zero,
|
654 |
+
no_eff=no_eff,
|
655 |
+
dropout=dropout,
|
656 |
+
self_attention=self_attention,
|
657 |
+
log_attn=log_attn,
|
658 |
+
edit_config=edit_config,
|
659 |
+
),
|
660 |
+
Downsample1d(dim_out),
|
661 |
+
]
|
662 |
+
)
|
663 |
+
)
|
664 |
+
|
665 |
+
mid_dim = dims[-1]
|
666 |
+
self.mid_block1 = CLRBlock(
|
667 |
+
dim_in=mid_dim,
|
668 |
+
dim_out=mid_dim,
|
669 |
+
cond_dim=cond_dim,
|
670 |
+
time_dim=time_dim,
|
671 |
+
adagn=adagn,
|
672 |
+
zero=zero,
|
673 |
+
no_eff=no_eff,
|
674 |
+
dropout=dropout,
|
675 |
+
self_attention=self_attention,
|
676 |
+
log_attn=log_attn,
|
677 |
+
edit_config=edit_config,
|
678 |
+
)
|
679 |
+
self.mid_block2 = CLRBlock(
|
680 |
+
dim_in=mid_dim,
|
681 |
+
dim_out=mid_dim,
|
682 |
+
cond_dim=cond_dim,
|
683 |
+
time_dim=time_dim,
|
684 |
+
adagn=adagn,
|
685 |
+
zero=zero,
|
686 |
+
no_eff=no_eff,
|
687 |
+
dropout=dropout,
|
688 |
+
self_attention=self_attention,
|
689 |
+
log_attn=log_attn,
|
690 |
+
edit_config=edit_config,
|
691 |
+
)
|
692 |
+
|
693 |
+
last_dim = mid_dim
|
694 |
+
for ind, dim_out in enumerate(reversed(dims[1:])):
|
695 |
+
self.ups.append(
|
696 |
+
nn.ModuleList(
|
697 |
+
[
|
698 |
+
Upsample1d(last_dim, dim_out),
|
699 |
+
CLRBlock(
|
700 |
+
dim_out * 2,
|
701 |
+
dim_out,
|
702 |
+
cond_dim,
|
703 |
+
time_dim,
|
704 |
+
adagn=adagn,
|
705 |
+
zero=zero,
|
706 |
+
no_eff=no_eff,
|
707 |
+
dropout=dropout,
|
708 |
+
self_attention=self_attention,
|
709 |
+
log_attn=log_attn,
|
710 |
+
edit_config=edit_config,
|
711 |
+
),
|
712 |
+
CLRBlock(
|
713 |
+
dim_out,
|
714 |
+
dim_out,
|
715 |
+
cond_dim,
|
716 |
+
time_dim,
|
717 |
+
adagn=adagn,
|
718 |
+
zero=zero,
|
719 |
+
no_eff=no_eff,
|
720 |
+
dropout=dropout,
|
721 |
+
self_attention=self_attention,
|
722 |
+
log_attn=log_attn,
|
723 |
+
edit_config=edit_config,
|
724 |
+
),
|
725 |
+
]
|
726 |
+
)
|
727 |
+
)
|
728 |
+
last_dim = dim_out
|
729 |
+
self.final_conv = nn.Conv1d(dim_out, input_dim, 1)
|
730 |
+
|
731 |
+
if zero:
|
732 |
+
nn.init.zeros_(self.final_conv.weight)
|
733 |
+
nn.init.zeros_(self.final_conv.bias)
|
734 |
+
|
735 |
+
def forward(
|
736 |
+
self,
|
737 |
+
x,
|
738 |
+
t,
|
739 |
+
cond,
|
740 |
+
cond_indices,
|
741 |
+
):
|
742 |
+
temb = self.time_mlp(t)
|
743 |
+
|
744 |
+
h = []
|
745 |
+
for block1, block2, downsample in self.downs:
|
746 |
+
x = block1(x, temb, cond, cond_indices)
|
747 |
+
x = block2(x, temb, cond, cond_indices)
|
748 |
+
h.append(x)
|
749 |
+
x = downsample(x)
|
750 |
+
|
751 |
+
x = self.mid_block1(x, temb, cond, cond_indices)
|
752 |
+
x = self.mid_block2(x, temb, cond, cond_indices)
|
753 |
+
|
754 |
+
for upsample, block1, block2 in self.ups:
|
755 |
+
x = upsample(x)
|
756 |
+
x = torch.cat((x, h.pop()), dim=1)
|
757 |
+
x = block1(x, temb, cond, cond_indices)
|
758 |
+
x = block2(x, temb, cond, cond_indices)
|
759 |
+
|
760 |
+
x = self.final_conv(x)
|
761 |
+
return x
|
762 |
+
|
763 |
+
|
764 |
+
class MotionCLR(nn.Module):
|
765 |
+
"""
|
766 |
+
Diffuser's style UNET for text-to-motion task.
|
767 |
+
"""
|
768 |
+
|
769 |
+
def __init__(
|
770 |
+
self,
|
771 |
+
input_feats,
|
772 |
+
base_dim=128,
|
773 |
+
dim_mults=(1, 2, 2, 2),
|
774 |
+
dims=None,
|
775 |
+
adagn=True,
|
776 |
+
zero=True,
|
777 |
+
dropout=0.1,
|
778 |
+
no_eff=False,
|
779 |
+
time_dim=512,
|
780 |
+
latent_dim=256,
|
781 |
+
cond_mask_prob=0.1,
|
782 |
+
clip_dim=512,
|
783 |
+
clip_version="ViT-B/32",
|
784 |
+
text_latent_dim=256,
|
785 |
+
text_ff_size=2048,
|
786 |
+
text_num_heads=4,
|
787 |
+
activation="gelu",
|
788 |
+
num_text_layers=4,
|
789 |
+
self_attention=False,
|
790 |
+
vis_attn=False,
|
791 |
+
edit_config=None,
|
792 |
+
out_path=None,
|
793 |
+
):
|
794 |
+
super().__init__()
|
795 |
+
self.input_feats = input_feats
|
796 |
+
self.dim_mults = dim_mults
|
797 |
+
self.base_dim = base_dim
|
798 |
+
self.latent_dim = latent_dim
|
799 |
+
self.cond_mask_prob = cond_mask_prob
|
800 |
+
self.vis_attn = vis_attn
|
801 |
+
self.counting_map = []
|
802 |
+
self.out_path = out_path
|
803 |
+
|
804 |
+
print(
|
805 |
+
f"The T2M Unet mask the text prompt by {self.cond_mask_prob} prob. in training"
|
806 |
+
)
|
807 |
+
|
808 |
+
# text encoder
|
809 |
+
self.embed_text = nn.Linear(clip_dim, text_latent_dim)
|
810 |
+
self.clip_version = clip_version
|
811 |
+
self.clip_model = self.load_and_freeze_clip(clip_version)
|
812 |
+
textTransEncoderLayer = nn.TransformerEncoderLayer(
|
813 |
+
d_model=text_latent_dim,
|
814 |
+
nhead=text_num_heads,
|
815 |
+
dim_feedforward=text_ff_size,
|
816 |
+
dropout=dropout,
|
817 |
+
activation=activation,
|
818 |
+
)
|
819 |
+
self.textTransEncoder = nn.TransformerEncoder(
|
820 |
+
textTransEncoderLayer, num_layers=num_text_layers
|
821 |
+
)
|
822 |
+
self.text_ln = nn.LayerNorm(text_latent_dim)
|
823 |
+
|
824 |
+
self.unet = CondUnet1D(
|
825 |
+
input_dim=self.input_feats,
|
826 |
+
cond_dim=text_latent_dim,
|
827 |
+
dim=self.base_dim,
|
828 |
+
dim_mults=self.dim_mults,
|
829 |
+
adagn=adagn,
|
830 |
+
zero=zero,
|
831 |
+
dropout=dropout,
|
832 |
+
no_eff=no_eff,
|
833 |
+
dims=dims,
|
834 |
+
time_dim=time_dim,
|
835 |
+
self_attention=self_attention,
|
836 |
+
log_attn=self.vis_attn,
|
837 |
+
edit_config=edit_config,
|
838 |
+
)
|
839 |
+
|
840 |
+
def encode_text(self, raw_text, device):
|
841 |
+
with torch.no_grad():
|
842 |
+
texts = clip.tokenize(raw_text, truncate=True).to(
|
843 |
+
device
|
844 |
+
) # [bs, context_length] # if n_tokens > 77 -> will truncate
|
845 |
+
x = self.clip_model.token_embedding(texts).type(
|
846 |
+
self.clip_model.dtype
|
847 |
+
) # [batch_size, n_ctx, d_model]
|
848 |
+
x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype)
|
849 |
+
x = x.permute(1, 0, 2) # NLD -> LND
|
850 |
+
x = self.clip_model.transformer(x)
|
851 |
+
x = self.clip_model.ln_final(x).type(
|
852 |
+
self.clip_model.dtype
|
853 |
+
) # [len, batch_size, 512]
|
854 |
+
|
855 |
+
x = self.embed_text(x) # [len, batch_size, 256]
|
856 |
+
x = self.textTransEncoder(x)
|
857 |
+
x = self.text_ln(x)
|
858 |
+
|
859 |
+
# T, B, D -> B, T, D
|
860 |
+
xf_out = x.permute(1, 0, 2)
|
861 |
+
|
862 |
+
ablation_text = False
|
863 |
+
if ablation_text:
|
864 |
+
xf_out[:, 1:, :] = xf_out[:, 0, :].unsqueeze(1)
|
865 |
+
return xf_out
|
866 |
+
|
867 |
+
def load_and_freeze_clip(self, clip_version):
|
868 |
+
clip_model, _ = clip.load( # clip_model.dtype=float32
|
869 |
+
clip_version, device="cpu", jit=False
|
870 |
+
) # Must set jit=False for training
|
871 |
+
|
872 |
+
# Freeze CLIP weights
|
873 |
+
clip_model.eval()
|
874 |
+
for p in clip_model.parameters():
|
875 |
+
p.requires_grad = False
|
876 |
+
|
877 |
+
return clip_model
|
878 |
+
|
879 |
+
def mask_cond(self, bs, force_mask=False):
|
880 |
+
"""
|
881 |
+
mask motion condition , return contitional motion index in the batch
|
882 |
+
"""
|
883 |
+
if force_mask:
|
884 |
+
cond_indices = torch.empty(0)
|
885 |
+
elif self.training and self.cond_mask_prob > 0.0:
|
886 |
+
mask = torch.bernoulli(
|
887 |
+
torch.ones(
|
888 |
+
bs,
|
889 |
+
)
|
890 |
+
* self.cond_mask_prob
|
891 |
+
) # 1-> use null_cond, 0-> use real cond
|
892 |
+
mask = 1.0 - mask
|
893 |
+
cond_indices = torch.nonzero(mask).squeeze(-1)
|
894 |
+
else:
|
895 |
+
cond_indices = torch.arange(bs)
|
896 |
+
|
897 |
+
return cond_indices
|
898 |
+
|
899 |
+
def forward(
|
900 |
+
self,
|
901 |
+
x,
|
902 |
+
timesteps,
|
903 |
+
text=None,
|
904 |
+
uncond=False,
|
905 |
+
enc_text=None,
|
906 |
+
):
|
907 |
+
"""
|
908 |
+
Args:
|
909 |
+
x: [batch_size, nframes, nfeats],
|
910 |
+
timesteps: [batch_size] (int)
|
911 |
+
text: list (batch_size length) of strings with input text prompts
|
912 |
+
uncond: whethere using text condition
|
913 |
+
|
914 |
+
Returns: [batch_size, seq_length, nfeats]
|
915 |
+
"""
|
916 |
+
B, T, _ = x.shape
|
917 |
+
x = x.transpose(1, 2) # [bs, nfeats, nframes]
|
918 |
+
|
919 |
+
if enc_text is None:
|
920 |
+
enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim]
|
921 |
+
|
922 |
+
cond_indices = self.mask_cond(x.shape[0], force_mask=uncond)
|
923 |
+
|
924 |
+
# NOTE: need to pad to be the multiplier of 8 for the unet
|
925 |
+
PADDING_NEEEDED = (16 - (T % 16)) % 16
|
926 |
+
|
927 |
+
padding = (0, PADDING_NEEEDED)
|
928 |
+
x = F.pad(x, padding, value=0)
|
929 |
+
|
930 |
+
x = self.unet(
|
931 |
+
x,
|
932 |
+
t=timesteps,
|
933 |
+
cond=enc_text,
|
934 |
+
cond_indices=cond_indices,
|
935 |
+
) # [bs, nfeats,, nframes]
|
936 |
+
|
937 |
+
x = x[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,]
|
938 |
+
|
939 |
+
return x
|
940 |
+
|
941 |
+
def forward_with_cfg(self, x, timesteps, text=None, enc_text=None, cfg_scale=2.5):
|
942 |
+
"""
|
943 |
+
Args:
|
944 |
+
x: [batch_size, nframes, nfeats],
|
945 |
+
timesteps: [batch_size] (int)
|
946 |
+
text: list (batch_size length) of strings with input text prompts
|
947 |
+
|
948 |
+
Returns: [batch_size, max_frames, nfeats]
|
949 |
+
"""
|
950 |
+
global SELF_ATTN
|
951 |
+
global MONITOR_ATTN
|
952 |
+
MONITOR_ATTN.append([])
|
953 |
+
SELF_ATTN.append([])
|
954 |
+
|
955 |
+
B, T, _ = x.shape
|
956 |
+
x = x.transpose(1, 2) # [bs, nfeats, nframes]
|
957 |
+
if enc_text is None:
|
958 |
+
enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim]
|
959 |
+
|
960 |
+
cond_indices = self.mask_cond(B)
|
961 |
+
|
962 |
+
# NOTE: need to pad to be the multiplier of 8 for the unet
|
963 |
+
PADDING_NEEEDED = (16 - (T % 16)) % 16
|
964 |
+
|
965 |
+
padding = (0, PADDING_NEEEDED)
|
966 |
+
x = F.pad(x, padding, value=0)
|
967 |
+
|
968 |
+
combined_x = torch.cat([x, x], dim=0)
|
969 |
+
combined_t = torch.cat([timesteps, timesteps], dim=0)
|
970 |
+
out = self.unet(
|
971 |
+
x=combined_x,
|
972 |
+
t=combined_t,
|
973 |
+
cond=enc_text,
|
974 |
+
cond_indices=cond_indices,
|
975 |
+
) # [bs, nfeats, nframes]
|
976 |
+
|
977 |
+
out = out[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,]
|
978 |
+
|
979 |
+
out_cond, out_uncond = torch.split(out, len(out) // 2, dim=0)
|
980 |
+
|
981 |
+
if self.vis_attn == True:
|
982 |
+
i = len(MONITOR_ATTN)
|
983 |
+
attnlist = MONITOR_ATTN[-1]
|
984 |
+
print(i, "cross", len(attnlist))
|
985 |
+
for j, att in enumerate(attnlist):
|
986 |
+
vis_attn(
|
987 |
+
att,
|
988 |
+
out_path=self.out_path,
|
989 |
+
step=i,
|
990 |
+
layer=j,
|
991 |
+
shape="_".join(map(str, att.shape)),
|
992 |
+
type_="cross",
|
993 |
+
)
|
994 |
+
|
995 |
+
attnlist = SELF_ATTN[-1]
|
996 |
+
print(i, "self", len(attnlist))
|
997 |
+
for j, att in enumerate(attnlist):
|
998 |
+
vis_attn(
|
999 |
+
att,
|
1000 |
+
out_path=self.out_path,
|
1001 |
+
step=i,
|
1002 |
+
layer=j,
|
1003 |
+
shape="_".join(map(str, att.shape)),
|
1004 |
+
type_="self",
|
1005 |
+
lines=False,
|
1006 |
+
)
|
1007 |
+
|
1008 |
+
if len(SELF_ATTN) % 10 == 0:
|
1009 |
+
SELF_ATTN = []
|
1010 |
+
MONITOR_ATTN = []
|
1011 |
+
|
1012 |
+
return out_uncond + (cfg_scale * (out_cond - out_uncond))
|
1013 |
+
|
1014 |
+
|
1015 |
+
if __name__ == "__main__":
|
1016 |
+
|
1017 |
+
device = "cuda:0"
|
1018 |
+
n_feats = 263
|
1019 |
+
num_frames = 196
|
1020 |
+
text_latent_dim = 256
|
1021 |
+
dim_mults = [2, 2, 2, 2]
|
1022 |
+
base_dim = 512
|
1023 |
+
model = MotionCLR(
|
1024 |
+
input_feats=n_feats,
|
1025 |
+
text_latent_dim=text_latent_dim,
|
1026 |
+
base_dim=base_dim,
|
1027 |
+
dim_mults=dim_mults,
|
1028 |
+
adagn=True,
|
1029 |
+
zero=True,
|
1030 |
+
dropout=0.1,
|
1031 |
+
no_eff=True,
|
1032 |
+
cond_mask_prob=0.1,
|
1033 |
+
self_attention=True,
|
1034 |
+
)
|
1035 |
+
|
1036 |
+
model = model.to(device)
|
1037 |
+
from utils.model_load import load_model_weights
|
1038 |
+
|
1039 |
+
checkpoint_path = "/comp_robot/chenlinghao/StableMoFusion/checkpoints/t2m/self_attn—fulllayer-ffn-drop0_1-lr1e4/model/latest.tar"
|
1040 |
+
new_state_dict = {}
|
1041 |
+
checkpoint = torch.load(checkpoint_path)
|
1042 |
+
ckpt2 = checkpoint.copy()
|
1043 |
+
ckpt2["model_ema"] = {}
|
1044 |
+
ckpt2["encoder"] = {}
|
1045 |
+
|
1046 |
+
for key, value in list(checkpoint["model_ema"].items()):
|
1047 |
+
new_key = key.replace(
|
1048 |
+
"cross_attn", "clr_attn"
|
1049 |
+
) # Replace 'cross_attn' with 'clr_attn'
|
1050 |
+
ckpt2["model_ema"][new_key] = value
|
1051 |
+
for key, value in list(checkpoint["encoder"].items()):
|
1052 |
+
new_key = key.replace(
|
1053 |
+
"cross_attn", "clr_attn"
|
1054 |
+
) # Replace 'cross_attn' with 'clr_attn'
|
1055 |
+
ckpt2["encoder"][new_key] = value
|
1056 |
+
|
1057 |
+
torch.save(
|
1058 |
+
ckpt2,
|
1059 |
+
"/comp_robot/chenlinghao/CLRpreview/checkpoints/t2m/release/model/latest.tar",
|
1060 |
+
)
|
1061 |
+
|
1062 |
+
dtype = torch.float32
|
1063 |
+
bs = 1
|
1064 |
+
x = torch.rand((bs, 196, 263), dtype=dtype).to(device)
|
1065 |
+
timesteps = torch.randint(low=0, high=1000, size=(bs,)).to(device)
|
1066 |
+
y = ["A man jumps to his left." for i in range(bs)]
|
1067 |
+
length = torch.randint(low=20, high=196, size=(bs,)).to(device)
|
1068 |
+
|
1069 |
+
out = model(x, timesteps, text=y)
|
1070 |
+
print(out.shape)
|
1071 |
+
model.eval()
|
1072 |
+
out = model.forward_with_cfg(x, timesteps, text=y)
|
1073 |
+
print(out.shape)
|
motion_loader/__init__.py
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
from .model_motion_loaders import get_motion_loader
|
2 |
+
from .dataset_motion_loaders import get_dataset_loader
|
motion_loader/dataset_motion_loaders.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from datasets import get_dataset
|
2 |
+
from torch.utils.data import DataLoader
|
3 |
+
from torch.utils.data._utils.collate import default_collate
|
4 |
+
|
5 |
+
|
6 |
+
def collate_fn(batch):
|
7 |
+
batch.sort(key=lambda x: x[3], reverse=True)
|
8 |
+
return default_collate(batch)
|
9 |
+
|
10 |
+
|
11 |
+
def get_dataset_loader(opt, batch_size, mode="eval", split="test", accelerator=None):
|
12 |
+
dataset = get_dataset(opt, split, mode, accelerator)
|
13 |
+
if mode in ["eval", "gt_eval"]:
|
14 |
+
dataloader = DataLoader(
|
15 |
+
dataset,
|
16 |
+
batch_size=batch_size,
|
17 |
+
shuffle=True,
|
18 |
+
num_workers=4,
|
19 |
+
drop_last=True,
|
20 |
+
collate_fn=collate_fn,
|
21 |
+
)
|
22 |
+
else:
|
23 |
+
dataloader = DataLoader(
|
24 |
+
dataset,
|
25 |
+
batch_size=batch_size,
|
26 |
+
shuffle=True,
|
27 |
+
num_workers=4,
|
28 |
+
drop_last=True,
|
29 |
+
persistent_workers=True,
|
30 |
+
)
|
31 |
+
return dataloader
|
motion_loader/model_motion_loaders.py
ADDED
@@ -0,0 +1,243 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from utils.word_vectorizer import WordVectorizer
|
3 |
+
from torch.utils.data import Dataset, DataLoader
|
4 |
+
from os.path import join as pjoin
|
5 |
+
from tqdm import tqdm
|
6 |
+
import numpy as np
|
7 |
+
from eval.evaluator_modules import *
|
8 |
+
|
9 |
+
from torch.utils.data._utils.collate import default_collate
|
10 |
+
|
11 |
+
|
12 |
+
class GeneratedDataset(Dataset):
|
13 |
+
"""
|
14 |
+
opt.dataset_name
|
15 |
+
opt.max_motion_length
|
16 |
+
opt.unit_length
|
17 |
+
"""
|
18 |
+
|
19 |
+
def __init__(
|
20 |
+
self, opt, pipeline, dataset, w_vectorizer, mm_num_samples, mm_num_repeats
|
21 |
+
):
|
22 |
+
assert mm_num_samples < len(dataset)
|
23 |
+
self.dataset = dataset
|
24 |
+
dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True)
|
25 |
+
generated_motion = []
|
26 |
+
min_mov_length = 10 if opt.dataset_name == "t2m" else 6
|
27 |
+
|
28 |
+
# Pre-process all target captions
|
29 |
+
mm_generated_motions = []
|
30 |
+
if mm_num_samples > 0:
|
31 |
+
mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False)
|
32 |
+
mm_idxs = np.sort(mm_idxs)
|
33 |
+
|
34 |
+
all_caption = []
|
35 |
+
all_m_lens = []
|
36 |
+
all_data = []
|
37 |
+
with torch.no_grad():
|
38 |
+
for i, data in tqdm(enumerate(dataloader)):
|
39 |
+
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
|
40 |
+
all_data.append(data)
|
41 |
+
tokens = tokens[0].split("_")
|
42 |
+
mm_num_now = len(mm_generated_motions)
|
43 |
+
is_mm = (
|
44 |
+
True
|
45 |
+
if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now]))
|
46 |
+
else False
|
47 |
+
)
|
48 |
+
repeat_times = mm_num_repeats if is_mm else 1
|
49 |
+
m_lens = max(
|
50 |
+
torch.div(m_lens, opt.unit_length, rounding_mode="trunc")
|
51 |
+
* opt.unit_length,
|
52 |
+
min_mov_length * opt.unit_length,
|
53 |
+
)
|
54 |
+
m_lens = min(m_lens, opt.max_motion_length)
|
55 |
+
if isinstance(m_lens, int):
|
56 |
+
m_lens = torch.LongTensor([m_lens]).to(opt.device)
|
57 |
+
else:
|
58 |
+
m_lens = m_lens.to(opt.device)
|
59 |
+
for t in range(repeat_times):
|
60 |
+
all_m_lens.append(m_lens)
|
61 |
+
all_caption.extend(caption)
|
62 |
+
if is_mm:
|
63 |
+
mm_generated_motions.append(0)
|
64 |
+
all_m_lens = torch.stack(all_m_lens)
|
65 |
+
|
66 |
+
# Generate all sequences
|
67 |
+
with torch.no_grad():
|
68 |
+
all_pred_motions, t_eval = pipeline.generate(all_caption, all_m_lens)
|
69 |
+
self.eval_generate_time = t_eval
|
70 |
+
|
71 |
+
cur_idx = 0
|
72 |
+
mm_generated_motions = []
|
73 |
+
with torch.no_grad():
|
74 |
+
for i, data_dummy in tqdm(enumerate(dataloader)):
|
75 |
+
data = all_data[i]
|
76 |
+
word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
|
77 |
+
tokens = tokens[0].split("_")
|
78 |
+
mm_num_now = len(mm_generated_motions)
|
79 |
+
is_mm = (
|
80 |
+
True
|
81 |
+
if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now]))
|
82 |
+
else False
|
83 |
+
)
|
84 |
+
repeat_times = mm_num_repeats if is_mm else 1
|
85 |
+
mm_motions = []
|
86 |
+
for t in range(repeat_times):
|
87 |
+
pred_motions = all_pred_motions[cur_idx]
|
88 |
+
cur_idx += 1
|
89 |
+
if t == 0:
|
90 |
+
sub_dict = {
|
91 |
+
"motion": pred_motions.cpu().numpy(),
|
92 |
+
"length": pred_motions.shape[0], # m_lens[0].item(), #
|
93 |
+
"caption": caption[0],
|
94 |
+
"cap_len": cap_lens[0].item(),
|
95 |
+
"tokens": tokens,
|
96 |
+
}
|
97 |
+
generated_motion.append(sub_dict)
|
98 |
+
|
99 |
+
if is_mm:
|
100 |
+
mm_motions.append(
|
101 |
+
{
|
102 |
+
"motion": pred_motions.cpu().numpy(),
|
103 |
+
"length": pred_motions.shape[
|
104 |
+
0
|
105 |
+
], # m_lens[0].item(), #m_lens[0].item()
|
106 |
+
}
|
107 |
+
)
|
108 |
+
if is_mm:
|
109 |
+
mm_generated_motions.append(
|
110 |
+
{
|
111 |
+
"caption": caption[0],
|
112 |
+
"tokens": tokens,
|
113 |
+
"cap_len": cap_lens[0].item(),
|
114 |
+
"mm_motions": mm_motions,
|
115 |
+
}
|
116 |
+
)
|
117 |
+
self.generated_motion = generated_motion
|
118 |
+
self.mm_generated_motion = mm_generated_motions
|
119 |
+
self.opt = opt
|
120 |
+
self.w_vectorizer = w_vectorizer
|
121 |
+
|
122 |
+
def __len__(self):
|
123 |
+
return len(self.generated_motion)
|
124 |
+
|
125 |
+
def __getitem__(self, item):
|
126 |
+
data = self.generated_motion[item]
|
127 |
+
motion, m_length, caption, tokens = (
|
128 |
+
data["motion"],
|
129 |
+
data["length"],
|
130 |
+
data["caption"],
|
131 |
+
data["tokens"],
|
132 |
+
)
|
133 |
+
sent_len = data["cap_len"]
|
134 |
+
|
135 |
+
# This step is needed because T2M evaluators expect their norm convention
|
136 |
+
normed_motion = motion
|
137 |
+
denormed_motion = self.dataset.inv_transform(normed_motion)
|
138 |
+
renormed_motion = (
|
139 |
+
denormed_motion - self.dataset.mean_for_eval
|
140 |
+
) / self.dataset.std_for_eval # according to T2M norms
|
141 |
+
motion = renormed_motion
|
142 |
+
|
143 |
+
pos_one_hots = []
|
144 |
+
word_embeddings = []
|
145 |
+
for token in tokens:
|
146 |
+
word_emb, pos_oh = self.w_vectorizer[token]
|
147 |
+
pos_one_hots.append(pos_oh[None, :])
|
148 |
+
word_embeddings.append(word_emb[None, :])
|
149 |
+
pos_one_hots = np.concatenate(pos_one_hots, axis=0)
|
150 |
+
word_embeddings = np.concatenate(word_embeddings, axis=0)
|
151 |
+
length = len(motion)
|
152 |
+
if length < self.opt.max_motion_length:
|
153 |
+
motion = np.concatenate(
|
154 |
+
[
|
155 |
+
motion,
|
156 |
+
np.zeros((self.opt.max_motion_length - length, motion.shape[1])),
|
157 |
+
],
|
158 |
+
axis=0,
|
159 |
+
)
|
160 |
+
return (
|
161 |
+
word_embeddings,
|
162 |
+
pos_one_hots,
|
163 |
+
caption,
|
164 |
+
sent_len,
|
165 |
+
motion,
|
166 |
+
m_length,
|
167 |
+
"_".join(tokens),
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
def collate_fn(batch):
|
172 |
+
batch.sort(key=lambda x: x[3], reverse=True)
|
173 |
+
return default_collate(batch)
|
174 |
+
|
175 |
+
|
176 |
+
class MMGeneratedDataset(Dataset):
|
177 |
+
def __init__(self, opt, motion_dataset, w_vectorizer):
|
178 |
+
self.opt = opt
|
179 |
+
self.dataset = motion_dataset.mm_generated_motion
|
180 |
+
self.w_vectorizer = w_vectorizer
|
181 |
+
|
182 |
+
def __len__(self):
|
183 |
+
return len(self.dataset)
|
184 |
+
|
185 |
+
def __getitem__(self, item):
|
186 |
+
data = self.dataset[item]
|
187 |
+
mm_motions = data["mm_motions"]
|
188 |
+
m_lens = []
|
189 |
+
motions = []
|
190 |
+
for mm_motion in mm_motions:
|
191 |
+
m_lens.append(mm_motion["length"])
|
192 |
+
motion = mm_motion["motion"]
|
193 |
+
if len(motion) < self.opt.max_motion_length:
|
194 |
+
motion = np.concatenate(
|
195 |
+
[
|
196 |
+
motion,
|
197 |
+
np.zeros(
|
198 |
+
(self.opt.max_motion_length - len(motion), motion.shape[1])
|
199 |
+
),
|
200 |
+
],
|
201 |
+
axis=0,
|
202 |
+
)
|
203 |
+
motion = motion[None, :]
|
204 |
+
motions.append(motion)
|
205 |
+
m_lens = np.array(m_lens, dtype=np.int32)
|
206 |
+
motions = np.concatenate(motions, axis=0)
|
207 |
+
sort_indx = np.argsort(m_lens)[::-1].copy()
|
208 |
+
|
209 |
+
m_lens = m_lens[sort_indx]
|
210 |
+
motions = motions[sort_indx]
|
211 |
+
return motions, m_lens
|
212 |
+
|
213 |
+
|
214 |
+
def get_motion_loader(
|
215 |
+
opt, batch_size, pipeline, ground_truth_dataset, mm_num_samples, mm_num_repeats
|
216 |
+
):
|
217 |
+
|
218 |
+
# Currently the configurations of two datasets are almost the same
|
219 |
+
if opt.dataset_name == "t2m" or opt.dataset_name == "kit":
|
220 |
+
w_vectorizer = WordVectorizer(opt.glove_dir, "our_vab")
|
221 |
+
else:
|
222 |
+
raise KeyError("Dataset not recognized!!")
|
223 |
+
|
224 |
+
dataset = GeneratedDataset(
|
225 |
+
opt,
|
226 |
+
pipeline,
|
227 |
+
ground_truth_dataset,
|
228 |
+
w_vectorizer,
|
229 |
+
mm_num_samples,
|
230 |
+
mm_num_repeats,
|
231 |
+
)
|
232 |
+
mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer)
|
233 |
+
|
234 |
+
motion_loader = DataLoader(
|
235 |
+
dataset,
|
236 |
+
batch_size=batch_size,
|
237 |
+
collate_fn=collate_fn,
|
238 |
+
drop_last=True,
|
239 |
+
num_workers=4,
|
240 |
+
)
|
241 |
+
mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1)
|
242 |
+
|
243 |
+
return motion_loader, mm_motion_loader, dataset.eval_generate_time
|
options/__init__.py
ADDED
File without changes
|
options/edit.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# edit.yaml
|
2 |
+
reweighting_attn:
|
3 |
+
use: False
|
4 |
+
reweighting_attn_weight: 0.0 # the weight of reweighting attention for motion emphasizing and de-emphasizing
|
5 |
+
idx: -1 # the index of the word to be emphasized or de-emphasized (0 ~ 10)
|
6 |
+
|
7 |
+
erasing_motion:
|
8 |
+
use: False
|
9 |
+
erasing_motion_weight: 0.1 # the weight of motion erasing
|
10 |
+
time_start: 0.5 # the start time of motion erasing (0.0 ~ 1.0), ratio of the total time
|
11 |
+
time_end: 1.0 # the end time of motion erasing (0.0 ~ 1.0), ratio of the total time
|
12 |
+
idx: -1
|
13 |
+
|
14 |
+
manipulation: # motion manipulation means in-place motion replacement
|
15 |
+
use: False
|
16 |
+
manipulation_steps_start: 0 # the start step of motion manipulation, 0 ~ 10
|
17 |
+
manipulation_steps_end: 3 # the end step of motion manipulation, 0 ~ 10
|
18 |
+
manipulation_steps_end_crossattn: 3 # the end step of cross-attention for motion manipulation, 0 ~ 10
|
19 |
+
word_idx: 3 # the index of the word to be manipulated
|
20 |
+
|
21 |
+
time_shift:
|
22 |
+
use: False
|
23 |
+
time_shift_steps_start: 0 # the start step of time shifting, 0 ~ 10
|
24 |
+
time_shift_steps_end: 4 # the end step of time shifting, 0 ~ 10
|
25 |
+
time_shift_ratio: 0.5 # the ratio of time shifting, 0.0 ~ 1.0
|
26 |
+
|
27 |
+
example_based:
|
28 |
+
use: False
|
29 |
+
chunk_size: 20 # the size of the chunk for example-based generation
|
30 |
+
example_based_steps_end: 6 # the end step of example-based generation, 0 ~ 10
|
31 |
+
temp_seed: 200 # the inintial seed for example-based generation
|
32 |
+
temp_seed_bar: 15 # the the seed bar for example-based generation
|
33 |
+
|
34 |
+
style_tranfer:
|
35 |
+
use: False
|
36 |
+
style_transfer_steps_start: 0 # the start step of style transfer, 0 ~ 10
|
37 |
+
style_transfer_steps_end: 5 # the end step of style transfer, 0 ~ 10
|
38 |
+
|
39 |
+
grounded_generation:
|
40 |
+
use: False
|
options/evaluate_options.py
ADDED
@@ -0,0 +1,53 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from .get_opt import get_opt
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
class TestOptions():
|
6 |
+
def __init__(self):
|
7 |
+
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
8 |
+
self.initialize()
|
9 |
+
|
10 |
+
def initialize(self):
|
11 |
+
self.parser.add_argument("--opt_path", type=str, default='./checkpoints/t2m/t2m_condunet1d_batch64/opt.txt',help='option file path for loading model')
|
12 |
+
self.parser.add_argument("--gpu_id", type=int, default=0, help='GPU id')
|
13 |
+
|
14 |
+
# evaluator
|
15 |
+
self.parser.add_argument("--evaluator_dir", type=str, default='./data/checkpoints', help='Directory path where save T2M evaluator\'s checkpoints')
|
16 |
+
self.parser.add_argument("--eval_meta_dir", type=str, default='./data', help='Directory path where save T2M evaluator\'s normalization data.')
|
17 |
+
self.parser.add_argument("--glove_dir",type=str,default='./data/glove', help='Directory path where save glove')
|
18 |
+
|
19 |
+
# inference
|
20 |
+
self.parser.add_argument("--num_inference_steps", type=int, default=10, help='Number of iterative denoising steps during inference.')
|
21 |
+
self.parser.add_argument("--which_ckpt", type=str, default='latest', help='name of checkpoint to load')
|
22 |
+
self.parser.add_argument("--diffuser_name", type=str, default='dpmsolver', help='sampler\'s scheduler class name in the diffuser library')
|
23 |
+
self.parser.add_argument("--no_ema", action="store_true", help='Where use EMA model in inference')
|
24 |
+
self.parser.add_argument("--no_fp16", action="store_true", help='Whether use FP16 in inference')
|
25 |
+
self.parser.add_argument('--debug', action="store_true", help='debug mode')
|
26 |
+
self.parser.add_argument('--self_attention', action="store_true", help='self_attention use or not')
|
27 |
+
self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient linear attention')
|
28 |
+
self.parser.add_argument('--vis_attn', action='store_true', help='vis attention value or not')
|
29 |
+
self.parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
|
30 |
+
|
31 |
+
# evaluation
|
32 |
+
self.parser.add_argument("--replication_times", type=int, default=1, help='Number of generation rounds for each text description')
|
33 |
+
self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size for eval')
|
34 |
+
self.parser.add_argument('--diversity_times', type=int, default=300, help='')
|
35 |
+
self.parser.add_argument('--mm_num_samples', type=int, default=100, help='Number of samples for evaluating multimodality')
|
36 |
+
self.parser.add_argument('--mm_num_repeats', type=int, default=30, help='Number of generation rounds for each text description when evaluating multimodality')
|
37 |
+
self.parser.add_argument('--mm_num_times', type=int, default=10, help='')
|
38 |
+
self.parser.add_argument('--edit_mode', action='store_true', help='editing mode')
|
39 |
+
|
40 |
+
def parse(self):
|
41 |
+
# load evaluation options
|
42 |
+
self.opt = self.parser.parse_args()
|
43 |
+
opt_dict = vars(self.opt)
|
44 |
+
|
45 |
+
# load the model options of T2m evaluator
|
46 |
+
with open('./config/evaluator.yaml', 'r') as yaml_file:
|
47 |
+
yaml_config = yaml.safe_load(yaml_file)
|
48 |
+
opt_dict.update(yaml_config)
|
49 |
+
|
50 |
+
# load the training options of the selected checkpoint
|
51 |
+
get_opt(self.opt, self.opt.opt_path)
|
52 |
+
|
53 |
+
return self.opt
|
options/generate_options.py
ADDED
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from .get_opt import get_opt
|
3 |
+
|
4 |
+
class GenerateOptions():
|
5 |
+
def __init__(self, app=False):
|
6 |
+
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
7 |
+
self.initialize()
|
8 |
+
|
9 |
+
def initialize(self):
|
10 |
+
self.parser.add_argument("--opt_path", type=str, default='./checkpoints/t2m/t2m_condunet1d_batch64/opt.txt', help='option file path for loading model')
|
11 |
+
self.parser.add_argument("--gpu_id", type=int, default=0, help='GPU id')
|
12 |
+
self.parser.add_argument("--output_dir", type=str, default='', help='Directory path to save generation result')
|
13 |
+
self.parser.add_argument("--footskate_cleanup", action="store_true", help='Where use footskate cleanup in inference')
|
14 |
+
|
15 |
+
# inference
|
16 |
+
self.parser.add_argument("--num_inference_steps", type=int, default=10, help='Number of iterative denoising steps during inference.')
|
17 |
+
self.parser.add_argument("--which_ckpt", type=str, default='latest', help='name of checkpoint to load')
|
18 |
+
self.parser.add_argument("--diffuser_name", type=str, default='dpmsolver', help='sampler\'s scheduler class name in the diffuser library')
|
19 |
+
self.parser.add_argument("--no_ema", action="store_true", help='Where use EMA model in inference')
|
20 |
+
self.parser.add_argument("--no_fp16", action="store_true", help='Whether use FP16 in inference')
|
21 |
+
self.parser.add_argument('--batch_size', type=int, default=1, help='Batch size for generate')
|
22 |
+
self.parser.add_argument("--seed", default=0, type=int, help="For fixing random seed.")
|
23 |
+
self.parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
|
24 |
+
|
25 |
+
# generate prompts
|
26 |
+
self.parser.add_argument('--text_prompt', type=str, default="", help='One text description pompt for motion generation')
|
27 |
+
self.parser.add_argument("--motion_length", default=4.0, type=float, help="The length of the generated motion [in seconds] when using prompts. Maximum is 9.8 for HumanML3D (text-to-motion)")
|
28 |
+
self.parser.add_argument('--input_text', type=str, default='', help='File path of texts when using multiple texts.')
|
29 |
+
self.parser.add_argument('--input_lens', type=str, default='', help='File path of expected motion frame lengths when using multitext.')
|
30 |
+
self.parser.add_argument("--num_samples", type=int, default=10, help='Number of samples for generate when using dataset.')
|
31 |
+
self.parser.add_argument('--debug', action="store_true", help='debug mode')
|
32 |
+
self.parser.add_argument('--self_attention', action="store_true", help='self_attention use or not')
|
33 |
+
self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient linear attention')
|
34 |
+
self.parser.add_argument('--vis_attn', action='store_true', help='vis attention value or not')
|
35 |
+
self.parser.add_argument('--edit_mode', action='store_true', help='editing mode')
|
36 |
+
|
37 |
+
|
38 |
+
def parse(self):
|
39 |
+
self.opt = self.parser.parse_args()
|
40 |
+
opt_path = self.opt.opt_path
|
41 |
+
get_opt(self.opt, opt_path)
|
42 |
+
return self.opt
|
43 |
+
|
44 |
+
def parse_app(self):
|
45 |
+
self.opt = self.parser.parse_args(
|
46 |
+
args=['--motion_length', '8', '--self_attention', '--no_eff', '--opt_path', './checkpoints/t2m/release/opt.txt', '--edit_mode']
|
47 |
+
)
|
48 |
+
opt_path = self.opt.opt_path
|
49 |
+
get_opt(self.opt, opt_path)
|
50 |
+
return self.opt
|
options/get_opt.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from argparse import Namespace
|
3 |
+
import re
|
4 |
+
from os.path import join as pjoin
|
5 |
+
|
6 |
+
|
7 |
+
def is_float(numStr):
|
8 |
+
flag = False
|
9 |
+
numStr = str(numStr).strip().lstrip("-").lstrip("+")
|
10 |
+
try:
|
11 |
+
reg = re.compile(r"^[-+]?[0-9]+\.[0-9]+$")
|
12 |
+
res = reg.match(str(numStr))
|
13 |
+
if res:
|
14 |
+
flag = True
|
15 |
+
except Exception as ex:
|
16 |
+
print("is_float() - error: " + str(ex))
|
17 |
+
return flag
|
18 |
+
|
19 |
+
|
20 |
+
def is_number(numStr):
|
21 |
+
flag = False
|
22 |
+
numStr = str(numStr).strip().lstrip("-").lstrip("+")
|
23 |
+
if str(numStr).isdigit():
|
24 |
+
flag = True
|
25 |
+
return flag
|
26 |
+
|
27 |
+
|
28 |
+
def get_opt(opt, opt_path):
|
29 |
+
opt_dict = vars(opt)
|
30 |
+
|
31 |
+
skip = (
|
32 |
+
"-------------- End ----------------",
|
33 |
+
"------------ Options -------------",
|
34 |
+
"\n",
|
35 |
+
)
|
36 |
+
print("Reading", opt_path)
|
37 |
+
with open(opt_path) as f:
|
38 |
+
for line in f:
|
39 |
+
if line.strip() not in skip:
|
40 |
+
print(line.strip())
|
41 |
+
key, value = line.strip().split(": ")
|
42 |
+
if getattr(opt, key, None) is not None:
|
43 |
+
continue
|
44 |
+
if value in ("True", "False"):
|
45 |
+
opt_dict[key] = True if value == "True" else False
|
46 |
+
elif is_float(value):
|
47 |
+
opt_dict[key] = float(value)
|
48 |
+
elif is_number(value):
|
49 |
+
opt_dict[key] = int(value)
|
50 |
+
elif "," in value:
|
51 |
+
value = value[1:-1].split(",")
|
52 |
+
opt_dict[key] = [int(i) for i in value]
|
53 |
+
else:
|
54 |
+
opt_dict[key] = str(value)
|
55 |
+
|
56 |
+
# opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
|
57 |
+
opt.save_root = os.path.dirname(opt_path)
|
58 |
+
opt.model_dir = pjoin(opt.save_root, "model")
|
59 |
+
opt.meta_dir = pjoin(opt.save_root, "meta")
|
60 |
+
|
61 |
+
if opt.dataset_name == "t2m" or opt.dataset_name == "humanml":
|
62 |
+
opt.joints_num = 22
|
63 |
+
opt.dim_pose = 263
|
64 |
+
opt.max_motion_length = 196
|
65 |
+
opt.radius = 4
|
66 |
+
opt.fps = 20
|
67 |
+
elif opt.dataset_name == "kit":
|
68 |
+
opt.joints_num = 21
|
69 |
+
opt.dim_pose = 251
|
70 |
+
opt.max_motion_length = 196
|
71 |
+
opt.radius = 240 * 8
|
72 |
+
opt.fps = 12.5
|
73 |
+
else:
|
74 |
+
raise KeyError("Dataset not recognized")
|
options/noedit.yaml
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# noedit.yaml
|
2 |
+
reweighting_attn:
|
3 |
+
use: False
|
4 |
+
|
5 |
+
erasing_motion:
|
6 |
+
use: False
|
7 |
+
|
8 |
+
manipulation:
|
9 |
+
use: False
|
10 |
+
|
11 |
+
time_shift:
|
12 |
+
use: False
|
13 |
+
|
14 |
+
example_based:
|
15 |
+
use: False
|
16 |
+
|
17 |
+
style_tranfer:
|
18 |
+
use: False
|
19 |
+
|
20 |
+
grounded_generation:
|
21 |
+
use: False
|
options/train_options.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from .get_opt import get_opt
|
3 |
+
from os.path import join as pjoin
|
4 |
+
import os
|
5 |
+
|
6 |
+
class TrainOptions():
|
7 |
+
def __init__(self):
|
8 |
+
self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
9 |
+
self.initialized = False
|
10 |
+
|
11 |
+
def initialize(self):
|
12 |
+
# base set
|
13 |
+
self.parser.add_argument('--name', type=str, default="test", help='Name of this trial')
|
14 |
+
self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name')
|
15 |
+
self.parser.add_argument('--feat_bias', type=float, default=5, help='Scales for global motion features and foot contact')
|
16 |
+
self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
|
17 |
+
self.parser.add_argument('--log_every', type=int, default=5, help='Frequency of printing training progress (by iteration)')
|
18 |
+
self.parser.add_argument('--save_interval', type=int, default=10_000, help='Frequency of evaluateing and saving models (by iteration)')
|
19 |
+
|
20 |
+
|
21 |
+
# network hyperparams
|
22 |
+
self.parser.add_argument('--num_layers', type=int, default=8, help='num_layers of transformer')
|
23 |
+
self.parser.add_argument('--latent_dim', type=int, default=512, help='latent_dim of transformer')
|
24 |
+
self.parser.add_argument('--text_latent_dim', type=int, default=256, help='latent_dim of text embeding')
|
25 |
+
self.parser.add_argument('--time_dim', type=int, default=512, help='latent_dim of timesteps')
|
26 |
+
self.parser.add_argument('--base_dim', type=int, default=512, help='Dimension of Unet base channel')
|
27 |
+
self.parser.add_argument('--dim_mults', type=int, default=[2,2,2,2], nargs='+', help='Unet channel multipliers.')
|
28 |
+
self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient linear attention')
|
29 |
+
self.parser.add_argument('--no_adagn', action='store_true', help='whether use adagn block')
|
30 |
+
self.parser.add_argument('--diffusion_steps', type=int, default=1000, help='diffusion_steps of transformer')
|
31 |
+
self.parser.add_argument('--prediction_type', type=str, default='sample', help='diffusion_steps of transformer')
|
32 |
+
|
33 |
+
# train hyperparams
|
34 |
+
self.parser.add_argument('--seed', type=int, default=0, help='seed for train')
|
35 |
+
self.parser.add_argument('--num_train_steps', type=int, default=50_000, help='Number of training iterations')
|
36 |
+
self.parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate')
|
37 |
+
self.parser.add_argument("--decay_rate", default=0.9, type=float, help="the decay rate of lr (0-1 default 0.9)")
|
38 |
+
self.parser.add_argument("--update_lr_steps", default=5_000, type=int, help="")
|
39 |
+
self.parser.add_argument("--cond_mask_prob", default=0.1, type=float,
|
40 |
+
help="The probability of masking the condition during training."
|
41 |
+
" For classifier-free guidance learning.")
|
42 |
+
self.parser.add_argument('--clip_grad_norm', type=float, default=1, help='Gradient clip')
|
43 |
+
self.parser.add_argument('--weight_decay', type=float, default=1e-2, help='Learning rate weight_decay')
|
44 |
+
self.parser.add_argument('--batch_size', type=int, default=64, help='Batch size per GPU')
|
45 |
+
self.parser.add_argument("--beta_schedule", default='linear', type=str, help="Types of beta in diffusion (e.g. linear, cosine)")
|
46 |
+
self.parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
|
47 |
+
|
48 |
+
# continue training
|
49 |
+
self.parser.add_argument('--is_continue', action="store_true", help='Is this trail continued from previous trail?')
|
50 |
+
self.parser.add_argument('--continue_ckpt', type=str, default="latest.tar", help='previous trail to continue')
|
51 |
+
self.parser.add_argument("--opt_path", type=str, default='',help='option file path for loading model')
|
52 |
+
self.parser.add_argument('--debug', action="store_true", help='debug mode')
|
53 |
+
self.parser.add_argument('--self_attention', action="store_true", help='self_attention use or not')
|
54 |
+
self.parser.add_argument('--vis_attn', action='store_true', help='vis attention value or not')
|
55 |
+
|
56 |
+
self.parser.add_argument('--edit_mode', action='store_true', help='editing mode')
|
57 |
+
|
58 |
+
# EMA params
|
59 |
+
self.parser.add_argument(
|
60 |
+
"--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
|
61 |
+
)
|
62 |
+
self.parser.add_argument(
|
63 |
+
"--model-ema-steps",
|
64 |
+
type=int,
|
65 |
+
default=32,
|
66 |
+
help="the number of iterations that controls how often to update the EMA model (default: 32)",
|
67 |
+
)
|
68 |
+
self.parser.add_argument(
|
69 |
+
"--model-ema-decay",
|
70 |
+
type=float,
|
71 |
+
default=0.9999,
|
72 |
+
help="decay factor for Exponential Moving Average of model parameters (default: 0.99988)",
|
73 |
+
)
|
74 |
+
|
75 |
+
self.initialized = True
|
76 |
+
|
77 |
+
def parse(self,accelerator):
|
78 |
+
if not self.initialized:
|
79 |
+
self.initialize()
|
80 |
+
|
81 |
+
self.opt = self.parser.parse_args()
|
82 |
+
|
83 |
+
if self.opt.is_continue:
|
84 |
+
assert self.opt.opt_path.endswith('.txt')
|
85 |
+
get_opt(self.opt, self.opt.opt_path)
|
86 |
+
self.opt.is_train = True
|
87 |
+
self.opt.is_continue=True
|
88 |
+
elif accelerator.is_main_process:
|
89 |
+
args = vars(self.opt)
|
90 |
+
accelerator.print('------------ Options -------------')
|
91 |
+
for k, v in sorted(args.items()):
|
92 |
+
accelerator.print('%s: %s' % (str(k), str(v)))
|
93 |
+
accelerator.print('-------------- End ----------------')
|
94 |
+
# save to the disk
|
95 |
+
expr_dir = pjoin(self.opt.checkpoints_dir, self.opt.dataset_name, self.opt.name)
|
96 |
+
os.makedirs(expr_dir,exist_ok=True)
|
97 |
+
file_name = pjoin(expr_dir, 'opt.txt')
|
98 |
+
with open(file_name, 'wt') as opt_file:
|
99 |
+
opt_file.write('------------ Options -------------\n')
|
100 |
+
for k, v in sorted(args.items()):
|
101 |
+
if k =='opt_path':
|
102 |
+
continue
|
103 |
+
opt_file.write('%s: %s\n' % (str(k), str(v)))
|
104 |
+
opt_file.write('-------------- End ----------------\n')
|
105 |
+
|
106 |
+
|
107 |
+
if self.opt.dataset_name == 't2m' or self.opt.dataset_name == 'humanml':
|
108 |
+
self.opt.joints_num = 22
|
109 |
+
self.opt.dim_pose = 263
|
110 |
+
self.opt.max_motion_length = 196
|
111 |
+
self.opt.radius = 4
|
112 |
+
self.opt.fps = 20
|
113 |
+
elif self.opt.dataset_name == 'kit':
|
114 |
+
self.opt.joints_num = 21
|
115 |
+
self.opt.dim_pose = 251
|
116 |
+
self.opt.max_motion_length = 196
|
117 |
+
self.opt.radius = 240 * 8
|
118 |
+
self.opt.fps = 12.5
|
119 |
+
else:
|
120 |
+
raise KeyError('Dataset not recognized')
|
121 |
+
|
122 |
+
self.opt.device = accelerator.device
|
123 |
+
self.opt.is_train = True
|
124 |
+
return self.opt
|
125 |
+
|
126 |
+
|
prepare/download_glove.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cd ./data/
|
2 |
+
|
3 |
+
echo -e "Downloading glove (in use by the evaluators)"
|
4 |
+
gdown --fuzzy https://drive.google.com/file/d/1cmXKUT31pqd7_XpJAiWEo1K81TMYHA5n/view?usp=sharing
|
5 |
+
rm -rf glove
|
6 |
+
|
7 |
+
unzip glove.zip
|
8 |
+
echo -e "Cleaning\n"
|
9 |
+
rm glove.zip
|
10 |
+
cd ..
|
11 |
+
|
12 |
+
echo -e "Downloading done!"
|
prepare/download_t2m_evaluators.sh
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mkdir -p data/
|
2 |
+
cd data/
|
3 |
+
mkdir -p checkpoints/
|
4 |
+
cd checkpoints/
|
5 |
+
|
6 |
+
echo "The t2m evaluators will be stored in the './deps' folder"
|
7 |
+
|
8 |
+
echo "Downloading"
|
9 |
+
gdown --fuzzy https://drive.google.com/file/d/16hyR4XlEyksVyNVjhIWK684Lrm_7_pvX/view?usp=sharing
|
10 |
+
echo "Extracting"
|
11 |
+
unzip t2m.zip
|
12 |
+
echo "Cleaning"
|
13 |
+
rm t2m.zip
|
14 |
+
|
15 |
+
cd ../..
|
16 |
+
|
17 |
+
echo "Downloading done!"
|
prepare/prepare_clip.sh
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mkdir -p deps/
|
2 |
+
cd deps/
|
3 |
+
git lfs install
|
4 |
+
git clone https://huggingface.co/openai/clip-vit-large-patch14
|
5 |
+
cd ..
|
requirements.txt
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
tqdm
|
2 |
+
opencv-python
|
3 |
+
scipy
|
4 |
+
matplotlib==3.3.1
|
5 |
+
spacy
|
6 |
+
accelerate
|
7 |
+
transformers
|
8 |
+
einops
|
9 |
+
diffusers
|
10 |
+
panda3d
|
11 |
+
numpy==1.23.0
|
12 |
+
git+https://github.com/openai/CLIP.git
|
13 |
+
diffusers==0.30.3
|
14 |
+
transformers==4.45.2
|
15 |
+
|
16 |
+
# for train
|
17 |
+
tensorboard
|
18 |
+
accelerate==1.0.1
|
19 |
+
smplx
|
20 |
+
python-box
|
scripts/__init__.py
ADDED
File without changes
|
scripts/evaluation.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import torch
|
3 |
+
from motion_loader import get_dataset_loader, get_motion_loader
|
4 |
+
from datasets import get_dataset
|
5 |
+
from models import build_models
|
6 |
+
from eval import EvaluatorModelWrapper, evaluation
|
7 |
+
from utils.utils import *
|
8 |
+
from utils.model_load import load_model_weights
|
9 |
+
import os
|
10 |
+
from os.path import join as pjoin
|
11 |
+
|
12 |
+
from models.gaussian_diffusion import DiffusePipeline
|
13 |
+
from accelerate.utils import set_seed
|
14 |
+
|
15 |
+
from options.evaluate_options import TestOptions
|
16 |
+
|
17 |
+
import yaml
|
18 |
+
from box import Box
|
19 |
+
|
20 |
+
|
21 |
+
def yaml_to_box(yaml_file):
|
22 |
+
with open(yaml_file, "r") as file:
|
23 |
+
yaml_data = yaml.safe_load(file)
|
24 |
+
|
25 |
+
return Box(yaml_data)
|
26 |
+
|
27 |
+
|
28 |
+
if __name__ == "__main__":
|
29 |
+
parser = TestOptions()
|
30 |
+
opt = parser.parse()
|
31 |
+
set_seed(0)
|
32 |
+
|
33 |
+
if opt.edit_mode:
|
34 |
+
edit_config = yaml_to_box("options/edit.yaml")
|
35 |
+
else:
|
36 |
+
edit_config = yaml_to_box("options/noedit.yaml")
|
37 |
+
|
38 |
+
device_id = opt.gpu_id
|
39 |
+
device = torch.device("cuda:%d" % device_id if torch.cuda.is_available() else "cpu")
|
40 |
+
torch.cuda.set_device(device)
|
41 |
+
opt.device = device
|
42 |
+
|
43 |
+
# load evaluator
|
44 |
+
eval_wrapper = EvaluatorModelWrapper(opt)
|
45 |
+
|
46 |
+
# load dataset
|
47 |
+
gt_loader = get_dataset_loader(opt, opt.batch_size, mode="gt_eval", split="test")
|
48 |
+
gen_dataset = get_dataset(opt, mode="eval", split="test")
|
49 |
+
|
50 |
+
# load model
|
51 |
+
model = build_models(opt, edit_config=edit_config)
|
52 |
+
ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + ".tar")
|
53 |
+
load_model_weights(model, ckpt_path, use_ema=not opt.no_ema, device=device)
|
54 |
+
|
55 |
+
# Create a pipeline for generation in diffusion model framework
|
56 |
+
pipeline = DiffusePipeline(
|
57 |
+
opt=opt,
|
58 |
+
model=model,
|
59 |
+
diffuser_name=opt.diffuser_name,
|
60 |
+
device=device,
|
61 |
+
num_inference_steps=opt.num_inference_steps,
|
62 |
+
torch_dtype=torch.float32 if opt.no_fp16 else torch.float16,
|
63 |
+
)
|
64 |
+
|
65 |
+
eval_motion_loaders = {
|
66 |
+
"text2motion": lambda: get_motion_loader(
|
67 |
+
opt,
|
68 |
+
opt.batch_size,
|
69 |
+
pipeline,
|
70 |
+
gen_dataset,
|
71 |
+
opt.mm_num_samples,
|
72 |
+
opt.mm_num_repeats,
|
73 |
+
)
|
74 |
+
}
|
75 |
+
|
76 |
+
save_dir = pjoin(opt.save_root, "eval")
|
77 |
+
os.makedirs(save_dir, exist_ok=True)
|
78 |
+
if opt.no_ema:
|
79 |
+
log_file = (
|
80 |
+
pjoin(save_dir, opt.diffuser_name)
|
81 |
+
+ f"_{str(opt.num_inference_steps)}setps.log"
|
82 |
+
)
|
83 |
+
else:
|
84 |
+
log_file = (
|
85 |
+
pjoin(save_dir, opt.diffuser_name)
|
86 |
+
+ f"_{str(opt.num_inference_steps)}steps_ema.log"
|
87 |
+
)
|
88 |
+
|
89 |
+
if not os.path.exists(log_file):
|
90 |
+
config_dict = dict(pipeline.scheduler.config)
|
91 |
+
config_dict["no_ema"] = opt.no_ema
|
92 |
+
with open(log_file, "wt") as f:
|
93 |
+
f.write("------------ Options -------------\n")
|
94 |
+
for k, v in sorted(config_dict.items()):
|
95 |
+
f.write("%s: %s\n" % (str(k), str(v)))
|
96 |
+
f.write("-------------- End ----------------\n")
|
97 |
+
|
98 |
+
all_metrics = evaluation(
|
99 |
+
eval_wrapper,
|
100 |
+
gt_loader,
|
101 |
+
eval_motion_loaders,
|
102 |
+
log_file,
|
103 |
+
opt.replication_times,
|
104 |
+
opt.diversity_times,
|
105 |
+
opt.mm_num_times,
|
106 |
+
run_mm=True,
|
107 |
+
)
|
scripts/generate.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
from os.path import join as pjoin
|
6 |
+
import utils.paramUtil as paramUtil
|
7 |
+
from utils.plot_script import *
|
8 |
+
|
9 |
+
from utils.utils import *
|
10 |
+
from utils.motion_process import recover_from_ric
|
11 |
+
from accelerate.utils import set_seed
|
12 |
+
from models.gaussian_diffusion import DiffusePipeline
|
13 |
+
from options.generate_options import GenerateOptions
|
14 |
+
from utils.model_load import load_model_weights
|
15 |
+
from motion_loader import get_dataset_loader
|
16 |
+
from models import build_models
|
17 |
+
import yaml
|
18 |
+
from box import Box
|
19 |
+
|
20 |
+
|
21 |
+
def yaml_to_box(yaml_file):
|
22 |
+
with open(yaml_file, "r") as file:
|
23 |
+
yaml_data = yaml.safe_load(file)
|
24 |
+
return Box(yaml_data)
|
25 |
+
|
26 |
+
|
27 |
+
if __name__ == "__main__":
|
28 |
+
parser = GenerateOptions()
|
29 |
+
opt = parser.parse()
|
30 |
+
set_seed(opt.seed)
|
31 |
+
device_id = opt.gpu_id
|
32 |
+
device = torch.device("cuda:%d" % device_id if torch.cuda.is_available() else "cpu")
|
33 |
+
opt.device = device
|
34 |
+
|
35 |
+
assert opt.dataset_name == "t2m" or "kit"
|
36 |
+
|
37 |
+
# Using a text prompt for generation
|
38 |
+
if opt.text_prompt != "":
|
39 |
+
texts = [opt.text_prompt]
|
40 |
+
opt.num_samples = 1
|
41 |
+
motion_lens = [opt.motion_length * opt.fps]
|
42 |
+
|
43 |
+
# Or using texts (in .txt file) for generation
|
44 |
+
elif opt.input_text != "":
|
45 |
+
with open(opt.input_text, "r") as fr:
|
46 |
+
texts = [line.strip() for line in fr.readlines()]
|
47 |
+
opt.num_samples = len(texts)
|
48 |
+
if opt.input_lens != "":
|
49 |
+
with open(opt.input_lens, "r") as fr:
|
50 |
+
motion_lens = [int(line.strip()) for line in fr.readlines()]
|
51 |
+
assert len(texts) == len(
|
52 |
+
motion_lens
|
53 |
+
), f"Please ensure that the motion length in {opt.input_lens} corresponds to the text in {opt.input_text}."
|
54 |
+
else:
|
55 |
+
motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
|
56 |
+
|
57 |
+
# Or usining texts in dataset
|
58 |
+
else:
|
59 |
+
gen_datasetloader = get_dataset_loader(
|
60 |
+
opt, opt.num_samples, mode="hml_gt", split="test"
|
61 |
+
)
|
62 |
+
texts, _, motion_lens = next(iter(gen_datasetloader))
|
63 |
+
|
64 |
+
# edit mode
|
65 |
+
if opt.edit_mode:
|
66 |
+
edit_config = yaml_to_box("options/edit.yaml")
|
67 |
+
else:
|
68 |
+
edit_config = yaml_to_box("options/noedit.yaml")
|
69 |
+
print(edit_config)
|
70 |
+
|
71 |
+
ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + ".tar")
|
72 |
+
checkpoint = torch.load(ckpt_path,map_location={'cuda:0': str(device)})
|
73 |
+
niter = checkpoint.get('total_it', 0)
|
74 |
+
# make save dir
|
75 |
+
out_path = opt.output_dir
|
76 |
+
if out_path == "":
|
77 |
+
out_path = pjoin(opt.save_root, "samples_iter{}_seed{}".format(niter, opt.seed))
|
78 |
+
if opt.text_prompt != "":
|
79 |
+
out_path += "_" + opt.text_prompt.replace(" ", "_").replace(".", "")
|
80 |
+
elif opt.input_text != "":
|
81 |
+
out_path += "_" + os.path.basename(opt.input_text).replace(
|
82 |
+
".txt", ""
|
83 |
+
).replace(" ", "_").replace(".", "")
|
84 |
+
os.makedirs(out_path, exist_ok=True)
|
85 |
+
|
86 |
+
# load model
|
87 |
+
model = build_models(opt, edit_config=edit_config, out_path=out_path)
|
88 |
+
niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)
|
89 |
+
|
90 |
+
# Create a pipeline for generation in diffusion model framework
|
91 |
+
pipeline = DiffusePipeline(
|
92 |
+
opt=opt,
|
93 |
+
model=model,
|
94 |
+
diffuser_name=opt.diffuser_name,
|
95 |
+
device=device,
|
96 |
+
num_inference_steps=opt.num_inference_steps,
|
97 |
+
torch_dtype=torch.float16,
|
98 |
+
)
|
99 |
+
|
100 |
+
# generate
|
101 |
+
pred_motions, _ = pipeline.generate(
|
102 |
+
texts, torch.LongTensor([int(x) for x in motion_lens])
|
103 |
+
)
|
104 |
+
|
105 |
+
# Convert the generated motion representaion into 3D joint coordinates and save as npy file
|
106 |
+
npy_dir = pjoin(out_path, "joints_npy")
|
107 |
+
root_dir = pjoin(out_path, "root_npy")
|
108 |
+
os.makedirs(npy_dir, exist_ok=True)
|
109 |
+
os.makedirs(root_dir, exist_ok=True)
|
110 |
+
print(f"saving results npy file (3d joints) to [{npy_dir}]")
|
111 |
+
mean = np.load(pjoin(opt.meta_dir, "mean.npy"))
|
112 |
+
std = np.load(pjoin(opt.meta_dir, "std.npy"))
|
113 |
+
samples = []
|
114 |
+
|
115 |
+
root_list = []
|
116 |
+
for i, motion in enumerate(pred_motions):
|
117 |
+
motion = motion.cpu().numpy() * std + mean
|
118 |
+
np.save(pjoin(npy_dir, f"raw_{i:02}.npy"), motion)
|
119 |
+
npy_name = f"{i:02}.npy"
|
120 |
+
# 1. recover 3d joints representation by ik
|
121 |
+
motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
|
122 |
+
# 2. put on Floor (Y axis)
|
123 |
+
floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
|
124 |
+
motion[:, :, 1] -= floor_height
|
125 |
+
motion = motion.numpy()
|
126 |
+
# 3. remove jitter
|
127 |
+
motion = motion_temporal_filter(motion, sigma=1)
|
128 |
+
|
129 |
+
# save root trajectory (Y axis)
|
130 |
+
root_trajectory = motion[:, 0, :]
|
131 |
+
root_list.append(root_trajectory)
|
132 |
+
np.save(pjoin(root_dir, f"root_{i:02}.npy"), root_trajectory)
|
133 |
+
y = root_trajectory[:, 1]
|
134 |
+
|
135 |
+
plt.figure()
|
136 |
+
plt.plot(y)
|
137 |
+
|
138 |
+
plt.legend()
|
139 |
+
|
140 |
+
plt.title("Root Joint Trajectory")
|
141 |
+
plt.xlabel("Frame")
|
142 |
+
plt.ylabel("Position")
|
143 |
+
|
144 |
+
plt.savefig("./root_trajectory_xyz.png")
|
145 |
+
np.save(pjoin(npy_dir, npy_name), motion)
|
146 |
+
samples.append(motion)
|
147 |
+
|
148 |
+
root_list_res = np.concatenate(root_list, axis=0)
|
149 |
+
np.save("root_list.npy", root_list_res)
|
150 |
+
|
151 |
+
# save the text and length conditions used for this generation
|
152 |
+
with open(pjoin(out_path, "results.txt"), "w") as fw:
|
153 |
+
fw.write("\n".join(texts))
|
154 |
+
with open(pjoin(out_path, "results_lens.txt"), "w") as fw:
|
155 |
+
fw.write("\n".join([str(l) for l in motion_lens]))
|
156 |
+
|
157 |
+
# skeletal animation visualization
|
158 |
+
print(f"saving motion videos to [{out_path}]...")
|
159 |
+
for i, title in enumerate(texts):
|
160 |
+
motion = samples[i]
|
161 |
+
fname = f"{i:02}.mp4"
|
162 |
+
kinematic_tree = (
|
163 |
+
paramUtil.t2m_kinematic_chain
|
164 |
+
if (opt.dataset_name == "t2m")
|
165 |
+
else paramUtil.kit_kinematic_chain
|
166 |
+
)
|
167 |
+
plot_3d_motion(
|
168 |
+
pjoin(out_path, fname),
|
169 |
+
kinematic_tree,
|
170 |
+
motion,
|
171 |
+
title=title,
|
172 |
+
fps=opt.fps,
|
173 |
+
radius=opt.radius,
|
174 |
+
)
|
scripts/train.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
import os
|
3 |
+
from os.path import join as pjoin
|
4 |
+
from options.train_options import TrainOptions
|
5 |
+
from utils.plot_script import *
|
6 |
+
|
7 |
+
from models import build_models
|
8 |
+
from utils.ema import ExponentialMovingAverage
|
9 |
+
from trainers import DDPMTrainer
|
10 |
+
from motion_loader import get_dataset_loader
|
11 |
+
|
12 |
+
from accelerate.utils import set_seed
|
13 |
+
from accelerate import Accelerator
|
14 |
+
import torch
|
15 |
+
|
16 |
+
import yaml
|
17 |
+
from box import Box
|
18 |
+
|
19 |
+
def yaml_to_box(yaml_file):
|
20 |
+
with open(yaml_file, 'r') as file:
|
21 |
+
yaml_data = yaml.safe_load(file)
|
22 |
+
|
23 |
+
return Box(yaml_data)
|
24 |
+
|
25 |
+
if __name__ == '__main__':
|
26 |
+
accelerator = Accelerator()
|
27 |
+
|
28 |
+
parser = TrainOptions()
|
29 |
+
opt = parser.parse(accelerator)
|
30 |
+
set_seed(opt.seed)
|
31 |
+
torch.autograd.set_detect_anomaly(True)
|
32 |
+
|
33 |
+
opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
|
34 |
+
opt.model_dir = pjoin(opt.save_root, 'model')
|
35 |
+
opt.meta_dir = pjoin(opt.save_root, 'meta')
|
36 |
+
|
37 |
+
if opt.edit_mode:
|
38 |
+
edit_config = yaml_to_box('options/edit.yaml')
|
39 |
+
else:
|
40 |
+
edit_config = yaml_to_box('options/noedit.yaml')
|
41 |
+
|
42 |
+
if accelerator.is_main_process:
|
43 |
+
os.makedirs(opt.model_dir, exist_ok=True)
|
44 |
+
os.makedirs(opt.meta_dir, exist_ok=True)
|
45 |
+
|
46 |
+
train_datasetloader = get_dataset_loader(opt, batch_size = opt.batch_size, split='train', accelerator=accelerator, mode='train') # 7169
|
47 |
+
|
48 |
+
|
49 |
+
accelerator.print('\nInitializing model ...' )
|
50 |
+
encoder = build_models(opt, edit_config=edit_config)
|
51 |
+
model_ema = None
|
52 |
+
if opt.model_ema:
|
53 |
+
# Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at:
|
54 |
+
# https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
|
55 |
+
adjust = 106_667 * opt.model_ema_steps / opt.num_train_steps
|
56 |
+
alpha = 1.0 - opt.model_ema_decay
|
57 |
+
alpha = min(1.0, alpha * adjust)
|
58 |
+
print('EMA alpha:',alpha)
|
59 |
+
model_ema = ExponentialMovingAverage(encoder, decay=1.0 - alpha)
|
60 |
+
accelerator.print('Finish building Model.\n')
|
61 |
+
|
62 |
+
trainer = DDPMTrainer(opt, encoder,accelerator, model_ema)
|
63 |
+
|
64 |
+
trainer.train(train_datasetloader)
|
65 |
+
|
66 |
+
|
trainers/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .ddpm_trainer import DDPMTrainer
|
2 |
+
|
3 |
+
|
4 |
+
__all__ = ['DDPMTrainer']
|
trainers/ddpm_trainer.py
ADDED
@@ -0,0 +1,245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import time
|
3 |
+
import torch.optim as optim
|
4 |
+
from collections import OrderedDict
|
5 |
+
from utils.utils import print_current_loss
|
6 |
+
from os.path import join as pjoin
|
7 |
+
|
8 |
+
from diffusers import DDPMScheduler
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
import time
|
11 |
+
import pdb
|
12 |
+
import sys
|
13 |
+
import os
|
14 |
+
from torch.optim.lr_scheduler import ExponentialLR
|
15 |
+
|
16 |
+
|
17 |
+
class DDPMTrainer(object):
|
18 |
+
|
19 |
+
def __init__(self, args, model, accelerator, model_ema=None):
|
20 |
+
self.opt = args
|
21 |
+
self.accelerator = accelerator
|
22 |
+
self.device = self.accelerator.device
|
23 |
+
self.model = model
|
24 |
+
self.diffusion_steps = args.diffusion_steps
|
25 |
+
self.noise_scheduler = DDPMScheduler(
|
26 |
+
num_train_timesteps=self.diffusion_steps,
|
27 |
+
beta_schedule=args.beta_schedule,
|
28 |
+
variance_type="fixed_small",
|
29 |
+
prediction_type=args.prediction_type,
|
30 |
+
clip_sample=False,
|
31 |
+
)
|
32 |
+
self.model_ema = model_ema
|
33 |
+
if args.is_train:
|
34 |
+
self.mse_criterion = torch.nn.MSELoss(reduction="none")
|
35 |
+
|
36 |
+
accelerator.print("Diffusion_config:\n", self.noise_scheduler.config)
|
37 |
+
|
38 |
+
if self.accelerator.is_main_process:
|
39 |
+
starttime = time.strftime("%Y-%m-%d_%H:%M:%S")
|
40 |
+
print("Start experiment:", starttime)
|
41 |
+
self.writer = SummaryWriter(
|
42 |
+
log_dir=pjoin(args.save_root, "logs_") + starttime[:16],
|
43 |
+
comment=starttime[:16],
|
44 |
+
flush_secs=60,
|
45 |
+
)
|
46 |
+
self.accelerator.wait_for_everyone()
|
47 |
+
|
48 |
+
self.optimizer = optim.AdamW(
|
49 |
+
self.model.parameters(), lr=self.opt.lr, weight_decay=self.opt.weight_decay
|
50 |
+
)
|
51 |
+
self.scheduler = (
|
52 |
+
ExponentialLR(self.optimizer, gamma=args.decay_rate)
|
53 |
+
if args.decay_rate > 0
|
54 |
+
else None
|
55 |
+
)
|
56 |
+
|
57 |
+
@staticmethod
|
58 |
+
def zero_grad(opt_list):
|
59 |
+
for opt in opt_list:
|
60 |
+
opt.zero_grad()
|
61 |
+
|
62 |
+
def clip_norm(self, network_list):
|
63 |
+
for network in network_list:
|
64 |
+
self.accelerator.clip_grad_norm_(
|
65 |
+
network.parameters(), self.opt.clip_grad_norm
|
66 |
+
) # 0.5 -> 1
|
67 |
+
|
68 |
+
@staticmethod
|
69 |
+
def step(opt_list):
|
70 |
+
for opt in opt_list:
|
71 |
+
opt.step()
|
72 |
+
|
73 |
+
def forward(self, batch_data):
|
74 |
+
caption, motions, m_lens = batch_data
|
75 |
+
motions = motions.detach().float()
|
76 |
+
|
77 |
+
x_start = motions
|
78 |
+
B, T = x_start.shape[:2]
|
79 |
+
cur_len = torch.LongTensor([min(T, m_len) for m_len in m_lens]).to(self.device)
|
80 |
+
self.src_mask = self.generate_src_mask(T, cur_len).to(x_start.device)
|
81 |
+
|
82 |
+
# 1. Sample noise that we'll add to the motion
|
83 |
+
real_noise = torch.randn_like(x_start)
|
84 |
+
|
85 |
+
# 2. Sample a random timestep for each motion
|
86 |
+
t = torch.randint(0, self.diffusion_steps, (B,), device=self.device)
|
87 |
+
self.timesteps = t
|
88 |
+
|
89 |
+
# 3. Add noise to the motion according to the noise magnitude at each timestep
|
90 |
+
# (this is the forward diffusion process)
|
91 |
+
x_t = self.noise_scheduler.add_noise(x_start, real_noise, t)
|
92 |
+
|
93 |
+
# 4. network prediction
|
94 |
+
self.prediction = self.model(x_t, t, text=caption)
|
95 |
+
|
96 |
+
if self.opt.prediction_type == "sample":
|
97 |
+
self.target = x_start
|
98 |
+
elif self.opt.prediction_type == "epsilon":
|
99 |
+
self.target = real_noise
|
100 |
+
elif self.opt.prediction_type == "v_prediction":
|
101 |
+
self.target = self.noise_scheduler.get_velocity(x_start, real_noise, t)
|
102 |
+
|
103 |
+
def masked_l2(self, a, b, mask, weights):
|
104 |
+
|
105 |
+
loss = self.mse_criterion(a, b).mean(dim=-1) # (bath_size, motion_length)
|
106 |
+
|
107 |
+
loss = (loss * mask).sum(-1) / mask.sum(-1) # (batch_size, )
|
108 |
+
|
109 |
+
loss = (loss * weights).mean()
|
110 |
+
|
111 |
+
return loss
|
112 |
+
|
113 |
+
def backward_G(self):
|
114 |
+
loss_logs = OrderedDict({})
|
115 |
+
mse_loss_weights = torch.ones_like(self.timesteps)
|
116 |
+
loss_logs["loss_mot_rec"] = self.masked_l2(
|
117 |
+
self.prediction, self.target, self.src_mask, mse_loss_weights
|
118 |
+
)
|
119 |
+
|
120 |
+
self.loss = loss_logs["loss_mot_rec"]
|
121 |
+
|
122 |
+
return loss_logs
|
123 |
+
|
124 |
+
def update(self):
|
125 |
+
self.zero_grad([self.optimizer])
|
126 |
+
loss_logs = self.backward_G()
|
127 |
+
self.accelerator.backward(self.loss)
|
128 |
+
self.clip_norm([self.model])
|
129 |
+
self.step([self.optimizer])
|
130 |
+
|
131 |
+
return loss_logs
|
132 |
+
|
133 |
+
def generate_src_mask(self, T, length):
|
134 |
+
B = len(length)
|
135 |
+
src_mask = torch.ones(B, T)
|
136 |
+
for i in range(B):
|
137 |
+
for j in range(length[i], T):
|
138 |
+
src_mask[i, j] = 0
|
139 |
+
return src_mask
|
140 |
+
|
141 |
+
def train_mode(self):
|
142 |
+
self.model.train()
|
143 |
+
if self.model_ema:
|
144 |
+
self.model_ema.train()
|
145 |
+
|
146 |
+
def eval_mode(self):
|
147 |
+
self.model.eval()
|
148 |
+
if self.model_ema:
|
149 |
+
self.model_ema.eval()
|
150 |
+
|
151 |
+
def save(self, file_name, total_it):
|
152 |
+
state = {
|
153 |
+
"opt_encoder": self.optimizer.state_dict(),
|
154 |
+
"total_it": total_it,
|
155 |
+
"encoder": self.accelerator.unwrap_model(self.model).state_dict(),
|
156 |
+
}
|
157 |
+
if self.model_ema:
|
158 |
+
state["model_ema"] = self.accelerator.unwrap_model(
|
159 |
+
self.model_ema
|
160 |
+
).module.state_dict()
|
161 |
+
torch.save(state, file_name)
|
162 |
+
return
|
163 |
+
|
164 |
+
def load(self, model_dir):
|
165 |
+
checkpoint = torch.load(model_dir, map_location=self.device)
|
166 |
+
self.optimizer.load_state_dict(checkpoint["opt_encoder"])
|
167 |
+
if self.model_ema:
|
168 |
+
self.model_ema.load_state_dict(checkpoint["model_ema"], strict=True)
|
169 |
+
self.model.load_state_dict(checkpoint["encoder"], strict=True)
|
170 |
+
|
171 |
+
return checkpoint.get("total_it", 0)
|
172 |
+
|
173 |
+
def train(self, train_loader):
|
174 |
+
|
175 |
+
it = 0
|
176 |
+
if self.opt.is_continue:
|
177 |
+
model_path = pjoin(self.opt.model_dir, self.opt.continue_ckpt)
|
178 |
+
it = self.load(model_path)
|
179 |
+
self.accelerator.print(f"continue train from {it} iters in {model_path}")
|
180 |
+
start_time = time.time()
|
181 |
+
|
182 |
+
logs = OrderedDict()
|
183 |
+
self.dataset = train_loader.dataset
|
184 |
+
self.model, self.mse_criterion, self.optimizer, train_loader, self.model_ema = (
|
185 |
+
self.accelerator.prepare(
|
186 |
+
self.model,
|
187 |
+
self.mse_criterion,
|
188 |
+
self.optimizer,
|
189 |
+
train_loader,
|
190 |
+
self.model_ema,
|
191 |
+
)
|
192 |
+
)
|
193 |
+
|
194 |
+
num_epochs = (self.opt.num_train_steps - it) // len(train_loader) + 1
|
195 |
+
self.accelerator.print(f"need to train for {num_epochs} epochs....")
|
196 |
+
|
197 |
+
for epoch in range(0, num_epochs):
|
198 |
+
self.train_mode()
|
199 |
+
for i, batch_data in enumerate(train_loader):
|
200 |
+
self.forward(batch_data)
|
201 |
+
log_dict = self.update()
|
202 |
+
it += 1
|
203 |
+
|
204 |
+
if self.model_ema and it % self.opt.model_ema_steps == 0:
|
205 |
+
self.accelerator.unwrap_model(self.model_ema).update_parameters(
|
206 |
+
self.model
|
207 |
+
)
|
208 |
+
|
209 |
+
# update logger
|
210 |
+
for k, v in log_dict.items():
|
211 |
+
if k not in logs:
|
212 |
+
logs[k] = v
|
213 |
+
else:
|
214 |
+
logs[k] += v
|
215 |
+
|
216 |
+
if it % self.opt.log_every == 0:
|
217 |
+
mean_loss = OrderedDict({})
|
218 |
+
for tag, value in logs.items():
|
219 |
+
mean_loss[tag] = value / self.opt.log_every
|
220 |
+
logs = OrderedDict()
|
221 |
+
print_current_loss(
|
222 |
+
self.accelerator, start_time, it, mean_loss, epoch, inner_iter=i
|
223 |
+
)
|
224 |
+
if self.accelerator.is_main_process:
|
225 |
+
self.writer.add_scalar("loss", mean_loss["loss_mot_rec"], it)
|
226 |
+
self.accelerator.wait_for_everyone()
|
227 |
+
|
228 |
+
if (
|
229 |
+
it % self.opt.save_interval == 0
|
230 |
+
and self.accelerator.is_main_process
|
231 |
+
): # Save model
|
232 |
+
self.save(pjoin(self.opt.model_dir, "latest.tar").format(it), it)
|
233 |
+
self.accelerator.wait_for_everyone()
|
234 |
+
|
235 |
+
if (self.scheduler is not None) and (
|
236 |
+
it % self.opt.update_lr_steps == 0
|
237 |
+
):
|
238 |
+
self.scheduler.step()
|
239 |
+
|
240 |
+
# Save the last checkpoint if it wasn't already saved.
|
241 |
+
if it % self.opt.save_interval != 0 and self.accelerator.is_main_process:
|
242 |
+
self.save(pjoin(self.opt.model_dir, "latest.tar"), it)
|
243 |
+
|
244 |
+
self.accelerator.wait_for_everyone()
|
245 |
+
self.accelerator.print("FINISH")
|
utils/__init__.py
ADDED
File without changes
|
utils/constants.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
SMPL_FOOT_R = [8, 11]
|
2 |
+
SMPL_FOOT_L = [7, 10]
|
3 |
+
SMPL_FACE_FORWARD_JOINTS = [2, 1, 17, 16]
|
4 |
+
|
5 |
+
# Define a kinematic tree for the skeletal struture
|
6 |
+
SMPL_BODY_CHAIN = [
|
7 |
+
[0, 2, 5, 8, 11],
|
8 |
+
[0, 1, 4, 7, 10],
|
9 |
+
[0, 3, 6, 9, 12, 15],
|
10 |
+
[9, 14, 17, 19, 21],
|
11 |
+
[9, 13, 16, 18, 20],
|
12 |
+
]
|
13 |
+
SMPL_LEFT_HAND_CHAIN = [
|
14 |
+
[20, 22, 23, 24],
|
15 |
+
[20, 34, 35, 36],
|
16 |
+
[20, 25, 26, 27],
|
17 |
+
[20, 31, 32, 33],
|
18 |
+
[20, 28, 29, 30],
|
19 |
+
]
|
20 |
+
SMPL_RIGHT_HAND_CHAIN = [
|
21 |
+
[21, 43, 44, 45],
|
22 |
+
[21, 46, 47, 48],
|
23 |
+
[21, 40, 41, 42],
|
24 |
+
[21, 37, 38, 39],
|
25 |
+
[21, 49, 50, 51],
|
26 |
+
]
|
27 |
+
|
28 |
+
SMPL_BODY_BONES = [
|
29 |
+
-0.0018,
|
30 |
+
-0.2233,
|
31 |
+
0.0282,
|
32 |
+
0.0695,
|
33 |
+
-0.0914,
|
34 |
+
-0.0068,
|
35 |
+
-0.0677,
|
36 |
+
-0.0905,
|
37 |
+
-0.0043,
|
38 |
+
-0.0025,
|
39 |
+
0.1090,
|
40 |
+
-0.0267,
|
41 |
+
0.0343,
|
42 |
+
-0.3752,
|
43 |
+
-0.0045,
|
44 |
+
-0.0383,
|
45 |
+
-0.3826,
|
46 |
+
-0.0089,
|
47 |
+
0.0055,
|
48 |
+
0.1352,
|
49 |
+
0.0011,
|
50 |
+
-0.0136,
|
51 |
+
-0.3980,
|
52 |
+
-0.0437,
|
53 |
+
0.0158,
|
54 |
+
-0.3984,
|
55 |
+
-0.0423,
|
56 |
+
0.0015,
|
57 |
+
0.0529,
|
58 |
+
0.0254,
|
59 |
+
0.0264,
|
60 |
+
-0.0558,
|
61 |
+
0.1193,
|
62 |
+
-0.0254,
|
63 |
+
-0.0481,
|
64 |
+
0.1233,
|
65 |
+
-0.0028,
|
66 |
+
0.2139,
|
67 |
+
-0.0429,
|
68 |
+
0.0788,
|
69 |
+
0.1217,
|
70 |
+
-0.0341,
|
71 |
+
-0.0818,
|
72 |
+
0.1188,
|
73 |
+
-0.0386,
|
74 |
+
0.0052,
|
75 |
+
0.0650,
|
76 |
+
0.0513,
|
77 |
+
0.0910,
|
78 |
+
0.0305,
|
79 |
+
-0.0089,
|
80 |
+
-0.0960,
|
81 |
+
0.0326,
|
82 |
+
-0.0091,
|
83 |
+
0.2596,
|
84 |
+
-0.0128,
|
85 |
+
-0.0275,
|
86 |
+
-0.2537,
|
87 |
+
-0.0133,
|
88 |
+
-0.0214,
|
89 |
+
0.2492,
|
90 |
+
0.0090,
|
91 |
+
-0.0012,
|
92 |
+
-0.2553,
|
93 |
+
0.0078,
|
94 |
+
-0.0056,
|
95 |
+
0.0840,
|
96 |
+
-0.0082,
|
97 |
+
-0.0149,
|
98 |
+
-0.0846,
|
99 |
+
-0.0061,
|
100 |
+
-0.0103,
|
101 |
+
]
|
102 |
+
|
103 |
+
SMPL_HYBRIK = [
|
104 |
+
0,
|
105 |
+
0,
|
106 |
+
0,
|
107 |
+
0,
|
108 |
+
0,
|
109 |
+
0,
|
110 |
+
0,
|
111 |
+
0,
|
112 |
+
0,
|
113 |
+
0,
|
114 |
+
0,
|
115 |
+
0,
|
116 |
+
1,
|
117 |
+
1,
|
118 |
+
1,
|
119 |
+
1,
|
120 |
+
0,
|
121 |
+
0,
|
122 |
+
0,
|
123 |
+
0,
|
124 |
+
0,
|
125 |
+
0,
|
126 |
+
]
|
127 |
+
|
128 |
+
SMPL_BODY_PARENTS = [
|
129 |
+
0,
|
130 |
+
0,
|
131 |
+
0,
|
132 |
+
0,
|
133 |
+
1,
|
134 |
+
2,
|
135 |
+
3,
|
136 |
+
4,
|
137 |
+
5,
|
138 |
+
6,
|
139 |
+
7,
|
140 |
+
8,
|
141 |
+
9,
|
142 |
+
9,
|
143 |
+
9,
|
144 |
+
12,
|
145 |
+
13,
|
146 |
+
14,
|
147 |
+
16,
|
148 |
+
17,
|
149 |
+
18,
|
150 |
+
19,
|
151 |
+
]
|
152 |
+
|
153 |
+
SMPL_BODY_CHILDS = [
|
154 |
+
-1,
|
155 |
+
4,
|
156 |
+
5,
|
157 |
+
6,
|
158 |
+
7,
|
159 |
+
8,
|
160 |
+
9,
|
161 |
+
10,
|
162 |
+
11,
|
163 |
+
-1,
|
164 |
+
-2,
|
165 |
+
-2,
|
166 |
+
15,
|
167 |
+
16,
|
168 |
+
17,
|
169 |
+
-2,
|
170 |
+
18,
|
171 |
+
19,
|
172 |
+
20,
|
173 |
+
21,
|
174 |
+
-2,
|
175 |
+
-2,
|
176 |
+
]
|
utils/ema.py
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
|
3 |
+
class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
|
4 |
+
"""Maintains moving averages of model parameters using an exponential decay.
|
5 |
+
``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
|
6 |
+
`torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
|
7 |
+
is used to compute the EMA.
|
8 |
+
"""
|
9 |
+
|
10 |
+
def __init__(self, model, decay, device="cpu"):
|
11 |
+
def ema_avg(avg_model_param, model_param, num_averaged):
|
12 |
+
return decay * avg_model_param + (1 - decay) * model_param
|
13 |
+
super().__init__(model, device, ema_avg)
|
utils/eval_humanml.py
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
from datetime import datetime
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
from utils.metrics import *
|
6 |
+
from collections import OrderedDict
|
7 |
+
import os
|
8 |
+
|
9 |
+
|
10 |
+
|
11 |
+
def evaluate_matching_score(eval_wrapper,motion_loaders, file):
|
12 |
+
match_score_dict = OrderedDict({})
|
13 |
+
R_precision_dict = OrderedDict({})
|
14 |
+
activation_dict = OrderedDict({})
|
15 |
+
# print(motion_loaders.keys())
|
16 |
+
print('========== Evaluating Matching Score ==========')
|
17 |
+
for motion_loader_name, motion_loader in motion_loaders.items():
|
18 |
+
all_motion_embeddings = []
|
19 |
+
score_list = []
|
20 |
+
all_size = 0
|
21 |
+
matching_score_sum = 0
|
22 |
+
top_k_count = 0
|
23 |
+
# print(motion_loader_name)
|
24 |
+
with torch.no_grad():
|
25 |
+
for idx, batch in enumerate(motion_loader):
|
26 |
+
word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch
|
27 |
+
text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(
|
28 |
+
word_embs=word_embeddings,
|
29 |
+
pos_ohot=pos_one_hots,
|
30 |
+
cap_lens=sent_lens,
|
31 |
+
motions=motions,
|
32 |
+
m_lens=m_lens
|
33 |
+
)
|
34 |
+
dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(),
|
35 |
+
motion_embeddings.cpu().numpy())
|
36 |
+
matching_score_sum += dist_mat.trace()
|
37 |
+
# import pdb;pdb.set_trace()
|
38 |
+
|
39 |
+
argsmax = np.argsort(dist_mat, axis=1)
|
40 |
+
top_k_mat = calculate_top_k(argsmax, top_k=3)
|
41 |
+
top_k_count += top_k_mat.sum(axis=0)
|
42 |
+
|
43 |
+
all_size += text_embeddings.shape[0]
|
44 |
+
|
45 |
+
all_motion_embeddings.append(motion_embeddings.cpu().numpy())
|
46 |
+
|
47 |
+
all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
|
48 |
+
# import pdb;pdb.set_trace()
|
49 |
+
matching_score = matching_score_sum / all_size
|
50 |
+
R_precision = top_k_count / all_size
|
51 |
+
match_score_dict[motion_loader_name] = matching_score
|
52 |
+
R_precision_dict[motion_loader_name] = R_precision
|
53 |
+
activation_dict[motion_loader_name] = all_motion_embeddings
|
54 |
+
|
55 |
+
print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}')
|
56 |
+
print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True)
|
57 |
+
|
58 |
+
line = f'---> [{motion_loader_name}] R_precision: '
|
59 |
+
for i in range(len(R_precision)):
|
60 |
+
line += '(top %d): %.4f ' % (i+1, R_precision[i])
|
61 |
+
print(line)
|
62 |
+
print(line, file=file, flush=True)
|
63 |
+
|
64 |
+
return match_score_dict, R_precision_dict, activation_dict
|
65 |
+
|
66 |
+
|
67 |
+
def evaluate_fid(eval_wrapper,groundtruth_loader, activation_dict, file):
|
68 |
+
eval_dict = OrderedDict({})
|
69 |
+
gt_motion_embeddings = []
|
70 |
+
print('========== Evaluating FID ==========')
|
71 |
+
with torch.no_grad():
|
72 |
+
for idx, batch in enumerate(groundtruth_loader):
|
73 |
+
_, _, _, sent_lens, motions, m_lens, _ = batch
|
74 |
+
motion_embeddings = eval_wrapper.get_motion_embeddings(
|
75 |
+
motions=motions,
|
76 |
+
m_lens=m_lens
|
77 |
+
)
|
78 |
+
gt_motion_embeddings.append(motion_embeddings.cpu().numpy())
|
79 |
+
gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0)
|
80 |
+
gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings)
|
81 |
+
|
82 |
+
for model_name, motion_embeddings in activation_dict.items():
|
83 |
+
mu, cov = calculate_activation_statistics(motion_embeddings)
|
84 |
+
# print(mu)
|
85 |
+
fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
|
86 |
+
print(f'---> [{model_name}] FID: {fid:.4f}')
|
87 |
+
print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True)
|
88 |
+
eval_dict[model_name] = fid
|
89 |
+
return eval_dict
|
90 |
+
|
91 |
+
|
92 |
+
def evaluate_diversity(activation_dict, file, diversity_times):
|
93 |
+
eval_dict = OrderedDict({})
|
94 |
+
print('========== Evaluating Diversity ==========')
|
95 |
+
for model_name, motion_embeddings in activation_dict.items():
|
96 |
+
diversity = calculate_diversity(motion_embeddings, diversity_times)
|
97 |
+
eval_dict[model_name] = diversity
|
98 |
+
print(f'---> [{model_name}] Diversity: {diversity:.4f}')
|
99 |
+
print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True)
|
100 |
+
return eval_dict
|
101 |
+
|
102 |
+
|
103 |
+
def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times):
|
104 |
+
eval_dict = OrderedDict({})
|
105 |
+
print('========== Evaluating MultiModality ==========')
|
106 |
+
for model_name, mm_motion_loader in mm_motion_loaders.items():
|
107 |
+
mm_motion_embeddings = []
|
108 |
+
with torch.no_grad():
|
109 |
+
for idx, batch in enumerate(mm_motion_loader):
|
110 |
+
# (1, mm_replications, dim_pos)
|
111 |
+
motions, m_lens = batch
|
112 |
+
motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0])
|
113 |
+
mm_motion_embeddings.append(motion_embedings.unsqueeze(0))
|
114 |
+
if len(mm_motion_embeddings) == 0:
|
115 |
+
multimodality = 0
|
116 |
+
else:
|
117 |
+
mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy()
|
118 |
+
multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times)
|
119 |
+
print(f'---> [{model_name}] Multimodality: {multimodality:.4f}')
|
120 |
+
print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True)
|
121 |
+
eval_dict[model_name] = multimodality
|
122 |
+
return eval_dict
|
123 |
+
|
124 |
+
|
125 |
+
def get_metric_statistics(values, replication_times):
|
126 |
+
mean = np.mean(values, axis=0)
|
127 |
+
std = np.std(values, axis=0)
|
128 |
+
conf_interval = 1.96 * std / np.sqrt(replication_times)
|
129 |
+
return mean, conf_interval
|
130 |
+
|
131 |
+
|
132 |
+
def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=False):
|
133 |
+
with open(log_file, 'a') as f:
|
134 |
+
all_metrics = OrderedDict({'Matching Score': OrderedDict({}),
|
135 |
+
'R_precision': OrderedDict({}),
|
136 |
+
'FID': OrderedDict({}),
|
137 |
+
'Diversity': OrderedDict({}),
|
138 |
+
'MultiModality': OrderedDict({})})
|
139 |
+
|
140 |
+
for replication in range(replication_times):
|
141 |
+
print(f'Time: {datetime.now()}')
|
142 |
+
print(f'Time: {datetime.now()}', file=f, flush=True)
|
143 |
+
motion_loaders = {}
|
144 |
+
motion_loaders['ground truth'] = gt_loader
|
145 |
+
mm_motion_loaders = {}
|
146 |
+
# motion_loaders['ground truth'] = gt_loader
|
147 |
+
for motion_loader_name, motion_loader_getter in eval_motion_loaders.items():
|
148 |
+
motion_loader, mm_motion_loader,eval_generate_time = motion_loader_getter()
|
149 |
+
print(f'---> [{motion_loader_name}] batch_generate_time: {eval_generate_time}s', file=f, flush=True)
|
150 |
+
motion_loaders[motion_loader_name] = motion_loader
|
151 |
+
mm_motion_loaders[motion_loader_name] = mm_motion_loader
|
152 |
+
|
153 |
+
if replication_times>1:
|
154 |
+
print(f'==================== Replication {replication} ====================')
|
155 |
+
print(f'==================== Replication {replication} ====================', file=f, flush=True)
|
156 |
+
|
157 |
+
|
158 |
+
mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f)
|
159 |
+
|
160 |
+
fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f)
|
161 |
+
|
162 |
+
div_score_dict = evaluate_diversity(acti_dict, f, diversity_times)
|
163 |
+
|
164 |
+
if run_mm:
|
165 |
+
mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times)
|
166 |
+
|
167 |
+
print(f'!!! DONE !!!')
|
168 |
+
print(f'!!! DONE !!!', file=f, flush=True)
|
169 |
+
|
170 |
+
for key, item in mat_score_dict.items():
|
171 |
+
if key not in all_metrics['Matching Score']:
|
172 |
+
all_metrics['Matching Score'][key] = [item]
|
173 |
+
else:
|
174 |
+
all_metrics['Matching Score'][key] += [item]
|
175 |
+
|
176 |
+
for key, item in R_precision_dict.items():
|
177 |
+
if key not in all_metrics['R_precision']:
|
178 |
+
all_metrics['R_precision'][key] = [item]
|
179 |
+
else:
|
180 |
+
all_metrics['R_precision'][key] += [item]
|
181 |
+
|
182 |
+
for key, item in fid_score_dict.items():
|
183 |
+
if key not in all_metrics['FID']:
|
184 |
+
all_metrics['FID'][key] = [item]
|
185 |
+
else:
|
186 |
+
all_metrics['FID'][key] += [item]
|
187 |
+
|
188 |
+
for key, item in div_score_dict.items():
|
189 |
+
if key not in all_metrics['Diversity']:
|
190 |
+
all_metrics['Diversity'][key] = [item]
|
191 |
+
else:
|
192 |
+
all_metrics['Diversity'][key] += [item]
|
193 |
+
|
194 |
+
for key, item in mm_score_dict.items():
|
195 |
+
if key not in all_metrics['MultiModality']:
|
196 |
+
all_metrics['MultiModality'][key] = [item]
|
197 |
+
else:
|
198 |
+
all_metrics['MultiModality'][key] += [item]
|
199 |
+
|
200 |
+
|
201 |
+
mean_dict = {}
|
202 |
+
if replication_times>1:
|
203 |
+
for metric_name, metric_dict in all_metrics.items():
|
204 |
+
print('========== %s Summary ==========' % metric_name)
|
205 |
+
print('========== %s Summary ==========' % metric_name, file=f, flush=True)
|
206 |
+
|
207 |
+
for model_name, values in metric_dict.items():
|
208 |
+
# print(metric_name, model_name)
|
209 |
+
mean, conf_interval = get_metric_statistics(np.array(values),replication_times)
|
210 |
+
mean_dict[metric_name + '_' + model_name] = mean
|
211 |
+
# print(mean, mean.dtype)
|
212 |
+
if isinstance(mean, np.float64) or isinstance(mean, np.float32):
|
213 |
+
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
|
214 |
+
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
|
215 |
+
elif isinstance(mean, np.ndarray):
|
216 |
+
line = f'---> [{model_name}]'
|
217 |
+
for i in range(len(mean)):
|
218 |
+
line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
|
219 |
+
print(line)
|
220 |
+
print(line, file=f, flush=True)
|
221 |
+
return mean_dict
|
222 |
+
else:
|
223 |
+
return all_metrics
|
224 |
+
|
225 |
+
|
226 |
+
def distributed_evaluation(eval_wrapper, gt_loader, eval_motion_loader, log_file, replication_times, diversity_times):
|
227 |
+
with open(log_file, 'a') as f:
|
228 |
+
all_metrics = OrderedDict({'Matching Score': OrderedDict({}),
|
229 |
+
'R_precision': OrderedDict({}),
|
230 |
+
'FID': OrderedDict({}),
|
231 |
+
'Diversity': OrderedDict({}),
|
232 |
+
'MultiModality': OrderedDict({})})
|
233 |
+
|
234 |
+
for replication in range(replication_times):
|
235 |
+
print(f'Time: {datetime.now()}')
|
236 |
+
print(f'Time: {datetime.now()}', file=f, flush=True)
|
237 |
+
motion_loaders = {'test':eval_motion_loader}
|
238 |
+
|
239 |
+
if replication_times>1:
|
240 |
+
print(f'==================== Replication {replication} ====================')
|
241 |
+
print(f'==================== Replication {replication} ====================', file=f, flush=True)
|
242 |
+
|
243 |
+
|
244 |
+
mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f)
|
245 |
+
|
246 |
+
fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f)
|
247 |
+
|
248 |
+
div_score_dict = evaluate_diversity(acti_dict, f, diversity_times)
|
249 |
+
|
250 |
+
|
251 |
+
print(f'!!! DONE !!!')
|
252 |
+
print(f'!!! DONE !!!', file=f, flush=True)
|
253 |
+
|
254 |
+
for key, item in mat_score_dict.items():
|
255 |
+
if key not in all_metrics['Matching Score']:
|
256 |
+
all_metrics['Matching Score'][key] = [item]
|
257 |
+
else:
|
258 |
+
all_metrics['Matching Score'][key] += [item]
|
259 |
+
|
260 |
+
for key, item in R_precision_dict.items():
|
261 |
+
if key not in all_metrics['R_precision']:
|
262 |
+
all_metrics['R_precision'][key] = [item]
|
263 |
+
else:
|
264 |
+
all_metrics['R_precision'][key] += [item]
|
265 |
+
|
266 |
+
for key, item in fid_score_dict.items():
|
267 |
+
if key not in all_metrics['FID']:
|
268 |
+
all_metrics['FID'][key] = [item]
|
269 |
+
else:
|
270 |
+
all_metrics['FID'][key] += [item]
|
271 |
+
|
272 |
+
for key, item in div_score_dict.items():
|
273 |
+
if key not in all_metrics['Diversity']:
|
274 |
+
all_metrics['Diversity'][key] = [item]
|
275 |
+
else:
|
276 |
+
all_metrics['Diversity'][key] += [item]
|
277 |
+
|
278 |
+
mean_dict = {}
|
279 |
+
for metric_name, metric_dict in all_metrics.items():
|
280 |
+
print('========== %s Summary ==========' % metric_name)
|
281 |
+
print('========== %s Summary ==========' % metric_name, file=f, flush=True)
|
282 |
+
|
283 |
+
for model_name, values in metric_dict.items():
|
284 |
+
# print(metric_name, model_name)
|
285 |
+
mean, conf_interval = get_metric_statistics(np.array(values),replication_times)
|
286 |
+
mean_dict[metric_name + '_' + model_name] = mean
|
287 |
+
# print(mean, mean.dtype)
|
288 |
+
if isinstance(mean, np.float64) or isinstance(mean, np.float32):
|
289 |
+
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
|
290 |
+
print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
|
291 |
+
elif isinstance(mean, np.ndarray):
|
292 |
+
line = f'---> [{model_name}]'
|
293 |
+
for i in range(len(mean)):
|
294 |
+
line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
|
295 |
+
print(line)
|
296 |
+
print(line, file=f, flush=True)
|
297 |
+
return mean_dict
|
298 |
+
|
utils/kinematics.py
ADDED
@@ -0,0 +1,350 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Tuple
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
|
5 |
+
import utils.constants as constants
|
6 |
+
import torch
|
7 |
+
|
8 |
+
class HybrIKJointsToRotmat:
|
9 |
+
def __init__(self):
|
10 |
+
self.naive_hybrik = constants.SMPL_HYBRIK
|
11 |
+
self.num_nodes = 22
|
12 |
+
self.parents = constants.SMPL_BODY_PARENTS
|
13 |
+
self.child = constants.SMPL_BODY_CHILDS
|
14 |
+
self.bones = np.array(constants.SMPL_BODY_BONES).reshape(24, 3)[
|
15 |
+
: self.num_nodes
|
16 |
+
]
|
17 |
+
|
18 |
+
def multi_child_rot(
|
19 |
+
self, t: np.ndarray, p: np.ndarray, pose_global_parent: np.ndarray
|
20 |
+
) -> Tuple[np.ndarray]:
|
21 |
+
"""
|
22 |
+
t: B x 3 x child_num
|
23 |
+
p: B x 3 x child_num
|
24 |
+
pose_global_parent: B x 3 x 3
|
25 |
+
"""
|
26 |
+
m = np.matmul(
|
27 |
+
t, np.transpose(np.matmul(np.linalg.inv(pose_global_parent), p), [0, 2, 1])
|
28 |
+
)
|
29 |
+
u, s, vt = np.linalg.svd(m)
|
30 |
+
r = np.matmul(np.transpose(vt, [0, 2, 1]), np.transpose(u, [0, 2, 1]))
|
31 |
+
err_det_mask = (np.linalg.det(r) < 0.0).reshape(-1, 1, 1)
|
32 |
+
id_fix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]).reshape(
|
33 |
+
1, 3, 3
|
34 |
+
)
|
35 |
+
r_fix = np.matmul(
|
36 |
+
np.transpose(vt, [0, 2, 1]), np.matmul(id_fix, np.transpose(u, [0, 2, 1]))
|
37 |
+
)
|
38 |
+
r = r * (1.0 - err_det_mask) + r_fix * err_det_mask
|
39 |
+
return r, np.matmul(pose_global_parent, r)
|
40 |
+
|
41 |
+
def single_child_rot(
|
42 |
+
self,
|
43 |
+
t: np.ndarray,
|
44 |
+
p: np.ndarray,
|
45 |
+
pose_global_parent: np.ndarray,
|
46 |
+
twist: np.ndarray = None,
|
47 |
+
) -> Tuple[np.ndarray]:
|
48 |
+
"""
|
49 |
+
t: B x 3 x 1
|
50 |
+
p: B x 3 x 1
|
51 |
+
pose_global_parent: B x 3 x 3
|
52 |
+
twist: B x 2 if given, default to None
|
53 |
+
"""
|
54 |
+
p_rot = np.matmul(np.linalg.inv(pose_global_parent), p)
|
55 |
+
cross = np.cross(t, p_rot, axisa=1, axisb=1, axisc=1)
|
56 |
+
sina = np.linalg.norm(cross, axis=1, keepdims=True) / (
|
57 |
+
np.linalg.norm(t, axis=1, keepdims=True)
|
58 |
+
* np.linalg.norm(p_rot, axis=1, keepdims=True)
|
59 |
+
)
|
60 |
+
cross = cross / np.linalg.norm(cross, axis=1, keepdims=True)
|
61 |
+
cosa = np.sum(t * p_rot, axis=1, keepdims=True) / (
|
62 |
+
np.linalg.norm(t, axis=1, keepdims=True)
|
63 |
+
* np.linalg.norm(p_rot, axis=1, keepdims=True)
|
64 |
+
)
|
65 |
+
sina = sina.reshape(-1, 1, 1)
|
66 |
+
cosa = cosa.reshape(-1, 1, 1)
|
67 |
+
skew_sym_t = np.stack(
|
68 |
+
[
|
69 |
+
0.0 * cross[:, 0],
|
70 |
+
-cross[:, 2],
|
71 |
+
cross[:, 1],
|
72 |
+
cross[:, 2],
|
73 |
+
0.0 * cross[:, 0],
|
74 |
+
-cross[:, 0],
|
75 |
+
-cross[:, 1],
|
76 |
+
cross[:, 0],
|
77 |
+
0.0 * cross[:, 0],
|
78 |
+
],
|
79 |
+
1,
|
80 |
+
)
|
81 |
+
skew_sym_t = skew_sym_t.reshape(-1, 3, 3)
|
82 |
+
dsw_rotmat = (
|
83 |
+
np.eye(3).reshape(1, 3, 3)
|
84 |
+
+ sina * skew_sym_t
|
85 |
+
+ (1.0 - cosa) * np.matmul(skew_sym_t, skew_sym_t)
|
86 |
+
)
|
87 |
+
if twist is not None:
|
88 |
+
skew_sym_t = np.stack(
|
89 |
+
[
|
90 |
+
0.0 * t[:, 0],
|
91 |
+
-t[:, 2],
|
92 |
+
t[:, 1],
|
93 |
+
t[:, 2],
|
94 |
+
0.0 * t[:, 0],
|
95 |
+
-t[:, 0],
|
96 |
+
-t[:, 1],
|
97 |
+
t[:, 0],
|
98 |
+
0.0 * t[:, 0],
|
99 |
+
],
|
100 |
+
1,
|
101 |
+
)
|
102 |
+
skew_sym_t = skew_sym_t.reshape(-1, 3, 3)
|
103 |
+
sina = twist[:, 1].reshape(-1, 1, 1)
|
104 |
+
cosa = twist[:, 0].reshape(-1, 1, 1)
|
105 |
+
dtw_rotmat = (
|
106 |
+
np.eye(3).reshape([1, 3, 3])
|
107 |
+
+ sina * skew_sym_t
|
108 |
+
+ (1.0 - cosa) * np.matmul(skew_sym_t, skew_sym_t)
|
109 |
+
)
|
110 |
+
dsw_rotmat = np.matmul(dsw_rotmat, dtw_rotmat)
|
111 |
+
return dsw_rotmat, np.matmul(pose_global_parent, dsw_rotmat)
|
112 |
+
|
113 |
+
def __call__(self, joints: np.ndarray, twist: np.ndarray = None) -> np.ndarray:
|
114 |
+
"""
|
115 |
+
joints: B x N x 3
|
116 |
+
twist: B x N x 2 if given, default to None
|
117 |
+
"""
|
118 |
+
expand_dim = False
|
119 |
+
if len(joints.shape) == 2:
|
120 |
+
expand_dim = True
|
121 |
+
joints = np.expand_dims(joints, 0)
|
122 |
+
if twist is not None:
|
123 |
+
twist = np.expand_dims(twist, 0)
|
124 |
+
assert len(joints.shape) == 3
|
125 |
+
batch_size = np.shape(joints)[0]
|
126 |
+
joints_rel = joints - joints[:, self.parents]
|
127 |
+
joints_hybrik = 0.0 * joints_rel
|
128 |
+
pose_global = np.zeros([batch_size, self.num_nodes, 3, 3])
|
129 |
+
pose = np.zeros([batch_size, self.num_nodes, 3, 3])
|
130 |
+
for i in range(self.num_nodes):
|
131 |
+
if i == 0:
|
132 |
+
joints_hybrik[:, 0] = joints[:, 0]
|
133 |
+
else:
|
134 |
+
joints_hybrik[:, i] = (
|
135 |
+
np.matmul(
|
136 |
+
pose_global[:, self.parents[i]],
|
137 |
+
self.bones[i].reshape(1, 3, 1),
|
138 |
+
).reshape(-1, 3)
|
139 |
+
+ joints_hybrik[:, self.parents[i]]
|
140 |
+
)
|
141 |
+
if self.child[i] == -2:
|
142 |
+
pose[:, i] = pose[:, i] + np.eye(3).reshape(1, 3, 3)
|
143 |
+
pose_global[:, i] = pose_global[:, self.parents[i]]
|
144 |
+
continue
|
145 |
+
if i == 0:
|
146 |
+
r, rg = self.multi_child_rot(
|
147 |
+
np.transpose(self.bones[[1, 2, 3]].reshape(1, 3, 3), [0, 2, 1]),
|
148 |
+
np.transpose(joints_rel[:, [1, 2, 3]], [0, 2, 1]),
|
149 |
+
np.eye(3).reshape(1, 3, 3),
|
150 |
+
)
|
151 |
+
|
152 |
+
elif i == 9:
|
153 |
+
r, rg = self.multi_child_rot(
|
154 |
+
np.transpose(self.bones[[12, 13, 14]].reshape(1, 3, 3), [0, 2, 1]),
|
155 |
+
np.transpose(joints_rel[:, [12, 13, 14]], [0, 2, 1]),
|
156 |
+
pose_global[:, self.parents[9]],
|
157 |
+
)
|
158 |
+
else:
|
159 |
+
p = joints_rel[:, self.child[i]]
|
160 |
+
if self.naive_hybrik[i] == 0:
|
161 |
+
p = joints[:, self.child[i]] - joints_hybrik[:, i]
|
162 |
+
twi = None
|
163 |
+
if twist is not None:
|
164 |
+
twi = twist[:, i]
|
165 |
+
r, rg = self.single_child_rot(
|
166 |
+
self.bones[self.child[i]].reshape(1, 3, 1),
|
167 |
+
p.reshape(-1, 3, 1),
|
168 |
+
pose_global[:, self.parents[i]],
|
169 |
+
twi,
|
170 |
+
)
|
171 |
+
pose[:, i] = r
|
172 |
+
pose_global[:, i] = rg
|
173 |
+
if expand_dim:
|
174 |
+
pose = pose[0]
|
175 |
+
return pose
|
176 |
+
|
177 |
+
class HybrIKJointsToRotmat_Tensor:
|
178 |
+
def __init__(self):
|
179 |
+
self.naive_hybrik = constants.SMPL_HYBRIK
|
180 |
+
self.num_nodes = 22
|
181 |
+
self.parents = constants.SMPL_BODY_PARENTS
|
182 |
+
self.child = constants.SMPL_BODY_CHILDS
|
183 |
+
self.bones = torch.tensor(constants.SMPL_BODY_BONES).reshape(24, 3)[:self.num_nodes]
|
184 |
+
|
185 |
+
def multi_child_rot(self, t, p, pose_global_parent):
|
186 |
+
"""
|
187 |
+
t: B x 3 x child_num
|
188 |
+
p: B x 3 x child_num
|
189 |
+
pose_global_parent: B x 3 x 3
|
190 |
+
"""
|
191 |
+
m = torch.matmul(
|
192 |
+
t, torch.transpose(torch.matmul(torch.inverse(pose_global_parent), p), 1, 2)
|
193 |
+
)
|
194 |
+
u, s, vt = torch.linalg.svd(m)
|
195 |
+
r = torch.matmul(torch.transpose(vt, 1, 2), torch.transpose(u, 1, 2))
|
196 |
+
err_det_mask = (torch.det(r) < 0.0).reshape(-1, 1, 1)
|
197 |
+
id_fix = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]).reshape(1, 3, 3)
|
198 |
+
r_fix = torch.matmul(
|
199 |
+
torch.transpose(vt, 1, 2), torch.matmul(id_fix, torch.transpose(u, 1, 2))
|
200 |
+
)
|
201 |
+
r = r * (~err_det_mask) + r_fix * err_det_mask
|
202 |
+
return r, torch.matmul(pose_global_parent, r)
|
203 |
+
|
204 |
+
def single_child_rot(
|
205 |
+
self,
|
206 |
+
t,
|
207 |
+
p,
|
208 |
+
pose_global_parent,
|
209 |
+
twist = None,
|
210 |
+
) -> Tuple[torch.Tensor]:
|
211 |
+
"""
|
212 |
+
t: B x 3 x 1
|
213 |
+
p: B x 3 x 1
|
214 |
+
pose_global_parent: B x 3 x 3
|
215 |
+
twist: B x 2 if given, default to None
|
216 |
+
"""
|
217 |
+
t_tensor = t.clone().detach()#torch.tensor(t)
|
218 |
+
p_tensor = p.clone().detach()#torch.tensor(p)
|
219 |
+
pose_global_parent_tensor = pose_global_parent.clone().detach()#torch.tensor(pose_global_parent)
|
220 |
+
|
221 |
+
p_rot = torch.matmul(torch.linalg.inv(pose_global_parent_tensor), p_tensor)
|
222 |
+
cross = torch.cross(t_tensor, p_rot, dim=1)
|
223 |
+
sina = torch.linalg.norm(cross, dim=1, keepdim=True) / (
|
224 |
+
torch.linalg.norm(t_tensor, dim=1, keepdim=True)
|
225 |
+
* torch.linalg.norm(p_rot, dim=1, keepdim=True)
|
226 |
+
)
|
227 |
+
cross = cross / torch.linalg.norm(cross, dim=1, keepdim=True)
|
228 |
+
cosa = torch.sum(t_tensor * p_rot, dim=1, keepdim=True) / (
|
229 |
+
torch.linalg.norm(t_tensor, dim=1, keepdim=True)
|
230 |
+
* torch.linalg.norm(p_rot, dim=1, keepdim=True)
|
231 |
+
)
|
232 |
+
sina = sina.reshape(-1, 1, 1)
|
233 |
+
cosa = cosa.reshape(-1, 1, 1)
|
234 |
+
skew_sym_t = torch.stack(
|
235 |
+
[
|
236 |
+
0.0 * cross[:, 0],
|
237 |
+
-cross[:, 2],
|
238 |
+
cross[:, 1],
|
239 |
+
cross[:, 2],
|
240 |
+
0.0 * cross[:, 0],
|
241 |
+
-cross[:, 0],
|
242 |
+
-cross[:, 1],
|
243 |
+
cross[:, 0],
|
244 |
+
0.0 * cross[:, 0],
|
245 |
+
],
|
246 |
+
1,
|
247 |
+
)
|
248 |
+
skew_sym_t = skew_sym_t.reshape(-1, 3, 3)
|
249 |
+
dsw_rotmat = (
|
250 |
+
torch.eye(3).reshape(1, 3, 3)
|
251 |
+
+ sina * skew_sym_t
|
252 |
+
+ (1.0 - cosa) * torch.matmul(skew_sym_t, skew_sym_t)
|
253 |
+
)
|
254 |
+
if twist is not None:
|
255 |
+
twist_tensor = torch.tensor(twist)
|
256 |
+
skew_sym_t = torch.stack(
|
257 |
+
[
|
258 |
+
0.0 * t_tensor[:, 0],
|
259 |
+
-t_tensor[:, 2],
|
260 |
+
t_tensor[:, 1],
|
261 |
+
t_tensor[:, 2],
|
262 |
+
0.0 * t_tensor[:, 0],
|
263 |
+
-t_tensor[:, 0],
|
264 |
+
-t_tensor[:, 1],
|
265 |
+
t_tensor[:, 0],
|
266 |
+
0.0 * t_tensor[:, 0],
|
267 |
+
],
|
268 |
+
1,
|
269 |
+
)
|
270 |
+
skew_sym_t = skew_sym_t.reshape(-1, 3, 3)
|
271 |
+
sina = twist_tensor[:, 1].reshape(-1, 1, 1)
|
272 |
+
cosa = twist_tensor[:, 0].reshape(-1, 1, 1)
|
273 |
+
dtw_rotmat = (
|
274 |
+
torch.eye(3).reshape([1, 3, 3])
|
275 |
+
+ sina * skew_sym_t
|
276 |
+
+ (1.0 - cosa) * torch.matmul(skew_sym_t, skew_sym_t)
|
277 |
+
)
|
278 |
+
dsw_rotmat = torch.matmul(dsw_rotmat, dtw_rotmat)
|
279 |
+
|
280 |
+
return dsw_rotmat, torch.matmul(pose_global_parent_tensor, dsw_rotmat)
|
281 |
+
|
282 |
+
def __call__(self, joints, twist = None) -> torch.Tensor:
|
283 |
+
"""
|
284 |
+
joints: B x N x 3
|
285 |
+
twist: B x N x 2 if given, default to None
|
286 |
+
"""
|
287 |
+
expand_dim = False
|
288 |
+
if len(joints.shape) == 2:
|
289 |
+
expand_dim = True
|
290 |
+
joints = joints.unsqueeze(0)
|
291 |
+
if twist is not None:
|
292 |
+
twist = twist.unsqueeze(0)
|
293 |
+
assert len(joints.shape) == 3
|
294 |
+
batch_size = joints.shape[0]
|
295 |
+
joints_rel = joints - joints[:, self.parents]
|
296 |
+
joints_hybrik = torch.zeros_like(joints_rel)
|
297 |
+
pose_global = torch.zeros([batch_size, self.num_nodes, 3, 3])
|
298 |
+
pose = torch.zeros([batch_size, self.num_nodes, 3, 3])
|
299 |
+
for i in range(self.num_nodes):
|
300 |
+
if i == 0:
|
301 |
+
joints_hybrik[:, 0] = joints[:, 0]
|
302 |
+
else:
|
303 |
+
joints_hybrik[:, i] = (
|
304 |
+
torch.matmul(
|
305 |
+
pose_global[:, self.parents[i]],
|
306 |
+
self.bones[i].reshape(1, 3, 1),
|
307 |
+
).reshape(-1, 3)
|
308 |
+
+ joints_hybrik[:, self.parents[i]]
|
309 |
+
)
|
310 |
+
if self.child[i] == -2:
|
311 |
+
pose[:, i] = pose[:, i] + torch.eye(3).reshape(1, 3, 3)
|
312 |
+
pose_global[:, i] = pose_global[:, self.parents[i]]
|
313 |
+
continue
|
314 |
+
if i == 0:
|
315 |
+
t = self.bones[[1, 2, 3]].reshape(1, 3, 3).permute(0, 2, 1)
|
316 |
+
p = joints_rel[:, [1, 2, 3]].permute(0, 2, 1)
|
317 |
+
pose_global_parent = torch.eye(3).reshape(1, 3, 3)
|
318 |
+
r, rg = self.multi_child_rot(t, p, pose_global_parent)
|
319 |
+
elif i == 9:
|
320 |
+
t = self.bones[[12, 13, 14]].reshape(1, 3, 3).permute(0, 2, 1)
|
321 |
+
p = joints_rel[:, [12, 13, 14]].permute(0, 2, 1)
|
322 |
+
r, rg = self.multi_child_rot(t, p, pose_global[:, self.parents[9]],)
|
323 |
+
else:
|
324 |
+
p = joints_rel[:, self.child[i]]
|
325 |
+
if self.naive_hybrik[i] == 0:
|
326 |
+
p = joints[:, self.child[i]] - joints_hybrik[:, i]
|
327 |
+
twi = None
|
328 |
+
if twist is not None:
|
329 |
+
twi = twist[:, i]
|
330 |
+
t = self.bones[self.child[i]].reshape(-1, 3, 1)
|
331 |
+
p = p.reshape(-1, 3, 1)
|
332 |
+
nframes, _, _ = p.shape
|
333 |
+
t = t.repeat(nframes, 1, 1)
|
334 |
+
r, rg = self.single_child_rot(t, p, pose_global[:, self.parents[i]], twi)
|
335 |
+
pose[:, i] = r
|
336 |
+
pose_global[:, i] = rg
|
337 |
+
if expand_dim:
|
338 |
+
pose = pose[0]
|
339 |
+
return pose
|
340 |
+
|
341 |
+
|
342 |
+
if __name__ == "__main__":
|
343 |
+
jts2rot_hybrik = HybrIKJointsToRotmat_Tensor()
|
344 |
+
joints = torch.tensor(constants.SMPL_BODY_BONES).reshape(1, 24, 3)[:, :22]
|
345 |
+
parents = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]
|
346 |
+
for i in range(1, 22):
|
347 |
+
joints[:, i] = joints[:, i] + joints[:, parents[i]]
|
348 |
+
print(joints.shape)
|
349 |
+
pose = jts2rot_hybrik(joints)
|
350 |
+
print(pose.shape)
|
utils/metrics.py
ADDED
@@ -0,0 +1,142 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from scipy import linalg
|
3 |
+
|
4 |
+
def euclidean_distance_matrix(matrix1, matrix2):
|
5 |
+
"""
|
6 |
+
Params:
|
7 |
+
-- matrix1: N1 x D
|
8 |
+
-- matrix2: N2 x D
|
9 |
+
Returns:
|
10 |
+
-- dist: N1 x N2
|
11 |
+
dist[i, j] == distance(matrix1[i], matrix2[j])
|
12 |
+
"""
|
13 |
+
assert matrix1.shape[1] == matrix2.shape[1]
|
14 |
+
d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
|
15 |
+
d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
|
16 |
+
d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
|
17 |
+
dists = np.sqrt(d1 + d2 + d3) # broadcasting
|
18 |
+
return dists
|
19 |
+
|
20 |
+
def calculate_top_k(mat, top_k):
|
21 |
+
size = mat.shape[0]
|
22 |
+
gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
|
23 |
+
bool_mat = (mat == gt_mat)
|
24 |
+
correct_vec = False
|
25 |
+
top_k_list = []
|
26 |
+
for i in range(top_k):
|
27 |
+
correct_vec = (correct_vec | bool_mat[:, i])
|
28 |
+
top_k_list.append(correct_vec[:, None])
|
29 |
+
top_k_mat = np.concatenate(top_k_list, axis=1)
|
30 |
+
return top_k_mat
|
31 |
+
|
32 |
+
|
33 |
+
def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
|
34 |
+
dist_mat = euclidean_distance_matrix(embedding1, embedding2)
|
35 |
+
argmax = np.argsort(dist_mat, axis=1)
|
36 |
+
top_k_mat = calculate_top_k(argmax, top_k)
|
37 |
+
if sum_all:
|
38 |
+
return top_k_mat.sum(axis=0)
|
39 |
+
else:
|
40 |
+
return top_k_mat
|
41 |
+
|
42 |
+
|
43 |
+
def calculate_matching_score(embedding1, embedding2, sum_all=False):
|
44 |
+
assert len(embedding1.shape) == 2
|
45 |
+
assert embedding1.shape[0] == embedding2.shape[0]
|
46 |
+
assert embedding1.shape[1] == embedding2.shape[1]
|
47 |
+
|
48 |
+
dist = linalg.norm(embedding1 - embedding2, axis=1)
|
49 |
+
if sum_all:
|
50 |
+
return dist.sum(axis=0)
|
51 |
+
else:
|
52 |
+
return dist
|
53 |
+
|
54 |
+
|
55 |
+
|
56 |
+
def calculate_activation_statistics(activations):
|
57 |
+
"""
|
58 |
+
Params:
|
59 |
+
-- activation: num_samples x dim_feat
|
60 |
+
Returns:
|
61 |
+
-- mu: dim_feat
|
62 |
+
-- sigma: dim_feat x dim_feat
|
63 |
+
"""
|
64 |
+
mu = np.mean(activations, axis=0)
|
65 |
+
cov = np.cov(activations, rowvar=False)
|
66 |
+
return mu, cov
|
67 |
+
|
68 |
+
|
69 |
+
def calculate_diversity(activation, diversity_times):
|
70 |
+
assert len(activation.shape) == 2
|
71 |
+
assert activation.shape[0] > diversity_times
|
72 |
+
num_samples = activation.shape[0]
|
73 |
+
|
74 |
+
first_indices = np.random.choice(num_samples, diversity_times, replace=False)
|
75 |
+
second_indices = np.random.choice(num_samples, diversity_times, replace=False)
|
76 |
+
dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
|
77 |
+
return dist.mean()
|
78 |
+
|
79 |
+
|
80 |
+
def calculate_multimodality(activation, multimodality_times):
|
81 |
+
assert len(activation.shape) == 3
|
82 |
+
assert activation.shape[1] > multimodality_times
|
83 |
+
num_per_sent = activation.shape[1]
|
84 |
+
|
85 |
+
first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
|
86 |
+
second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
|
87 |
+
dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
|
88 |
+
return dist.mean()
|
89 |
+
|
90 |
+
|
91 |
+
def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
|
92 |
+
"""Numpy implementation of the Frechet Distance.
|
93 |
+
The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
|
94 |
+
and X_2 ~ N(mu_2, C_2) is
|
95 |
+
d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
|
96 |
+
Stable version by Dougal J. Sutherland.
|
97 |
+
Params:
|
98 |
+
-- mu1 : Numpy array containing the activations of a layer of the
|
99 |
+
inception net (like returned by the function 'get_predictions')
|
100 |
+
for generated samples.
|
101 |
+
-- mu2 : The sample mean over activations, precalculated on an
|
102 |
+
representative data set.
|
103 |
+
-- sigma1: The covariance matrix over activations for generated samples.
|
104 |
+
-- sigma2: The covariance matrix over activations, precalculated on an
|
105 |
+
representative data set.
|
106 |
+
Returns:
|
107 |
+
-- : The Frechet Distance.
|
108 |
+
"""
|
109 |
+
|
110 |
+
mu1 = np.atleast_1d(mu1)
|
111 |
+
mu2 = np.atleast_1d(mu2)
|
112 |
+
|
113 |
+
sigma1 = np.atleast_2d(sigma1)
|
114 |
+
sigma2 = np.atleast_2d(sigma2)
|
115 |
+
|
116 |
+
assert mu1.shape == mu2.shape, \
|
117 |
+
'Training and test mean vectors have different lengths'
|
118 |
+
assert sigma1.shape == sigma2.shape, \
|
119 |
+
'Training and test covariances have different dimensions'
|
120 |
+
|
121 |
+
diff = mu1 - mu2
|
122 |
+
|
123 |
+
# Product might be almost singular
|
124 |
+
covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
|
125 |
+
if not np.isfinite(covmean).all():
|
126 |
+
msg = ('fid calculation produces singular product; '
|
127 |
+
'adding %s to diagonal of cov estimates') % eps
|
128 |
+
print(msg)
|
129 |
+
offset = np.eye(sigma1.shape[0]) * eps
|
130 |
+
covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
|
131 |
+
|
132 |
+
# Numerical error might give slight imaginary component
|
133 |
+
if np.iscomplexobj(covmean):
|
134 |
+
if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
|
135 |
+
m = np.max(np.abs(covmean.imag))
|
136 |
+
raise ValueError('Imaginary component {}'.format(m))
|
137 |
+
covmean = covmean.real
|
138 |
+
|
139 |
+
tr_covmean = np.trace(covmean)
|
140 |
+
|
141 |
+
return (diff.dot(diff) + np.trace(sigma1) +
|
142 |
+
np.trace(sigma2) - 2 * tr_covmean)
|
utils/model_load.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from .ema import ExponentialMovingAverage
|
3 |
+
|
4 |
+
def load_model_weights(model, ckpt_path, use_ema=True, device='cuda:0'):
|
5 |
+
"""
|
6 |
+
Load weights of a model from a checkpoint file.
|
7 |
+
|
8 |
+
Args:
|
9 |
+
model (torch.nn.Module): The model to load weights into.
|
10 |
+
ckpt_path (str): Path to the checkpoint file.
|
11 |
+
use_ema (bool): Whether to use Exponential Moving Average (EMA) weights if available.
|
12 |
+
"""
|
13 |
+
checkpoint = torch.load(ckpt_path,map_location={'cuda:0': str(device)})
|
14 |
+
total_iter = checkpoint.get('total_it', 0)
|
15 |
+
|
16 |
+
if "model_ema" in checkpoint and use_ema:
|
17 |
+
ema_key = next(iter(checkpoint["model_ema"]))
|
18 |
+
if ('module' in ema_key) or ('n_averaged' in ema_key):
|
19 |
+
model = ExponentialMovingAverage(model, decay=1.0)
|
20 |
+
|
21 |
+
model.load_state_dict(checkpoint["model_ema"], strict=True)
|
22 |
+
if ('module' in ema_key) or ('n_averaged' in ema_key):
|
23 |
+
model = model.module
|
24 |
+
print(f'\nLoading EMA module model from {ckpt_path} with {total_iter} iterations')
|
25 |
+
else:
|
26 |
+
print(f'\nLoading EMA model from {ckpt_path} with {total_iter} iterations')
|
27 |
+
else:
|
28 |
+
model.load_state_dict(checkpoint['encoder'], strict=True)
|
29 |
+
print(f'\nLoading model from {ckpt_path} with {total_iter} iterations')
|
30 |
+
|
31 |
+
return total_iter
|
utils/motion_process.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This file code from T2M(https://github.com/EricGuo5513/text-to-motion), licensed under the https://github.com/EricGuo5513/text-to-motion/blob/main/LICENSE.
|
2 |
+
# Copyright (c) 2022 Chuan Guo
|
3 |
+
from os.path import join as pjoin
|
4 |
+
from typing import Union
|
5 |
+
import numpy as np
|
6 |
+
import os
|
7 |
+
from utils.quaternion import *
|
8 |
+
from utils.skeleton import Skeleton
|
9 |
+
from utils.paramUtil import *
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from tqdm import tqdm
|
13 |
+
|
14 |
+
# positions (batch, joint_num, 3)
|
15 |
+
def uniform_skeleton(positions, target_offset):
|
16 |
+
src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
|
17 |
+
src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0]))
|
18 |
+
src_offset = src_offset.numpy()
|
19 |
+
tgt_offset = target_offset.numpy()
|
20 |
+
|
21 |
+
'''Calculate Scale Ratio as the ratio of legs'''
|
22 |
+
src_leg_len = np.abs(src_offset[l_idx1]).max() + np.abs(src_offset[l_idx2]).max()
|
23 |
+
tgt_leg_len = np.abs(tgt_offset[l_idx1]).max() + np.abs(tgt_offset[l_idx2]).max()
|
24 |
+
|
25 |
+
scale_rt = tgt_leg_len / src_leg_len
|
26 |
+
src_root_pos = positions[:, 0]
|
27 |
+
tgt_root_pos = src_root_pos * scale_rt
|
28 |
+
|
29 |
+
'''Inverse Kinematics'''
|
30 |
+
quat_params = src_skel.inverse_kinematics_np(positions, face_joint_indx)
|
31 |
+
|
32 |
+
'''Forward Kinematics'''
|
33 |
+
src_skel.set_offset(target_offset)
|
34 |
+
new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos)
|
35 |
+
return new_joints
|
36 |
+
|
37 |
+
|
38 |
+
def extract_features(positions, feet_thre, n_raw_offsets, kinematic_chain, face_joint_indx, fid_r, fid_l):
|
39 |
+
global_positions = positions.copy()
|
40 |
+
""" Get Foot Contacts """
|
41 |
+
|
42 |
+
def foot_detect(positions, thres):
|
43 |
+
velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])
|
44 |
+
|
45 |
+
feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
|
46 |
+
feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2
|
47 |
+
feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
|
48 |
+
|
49 |
+
feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float)
|
50 |
+
|
51 |
+
feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
|
52 |
+
feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
|
53 |
+
feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
|
54 |
+
|
55 |
+
feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float)
|
56 |
+
return feet_l, feet_r
|
57 |
+
|
58 |
+
feet_l, feet_r = foot_detect(positions, feet_thre)
|
59 |
+
|
60 |
+
'''Quaternion and Cartesian representation'''
|
61 |
+
r_rot = None
|
62 |
+
|
63 |
+
def get_rifke(positions):
|
64 |
+
'''Local pose'''
|
65 |
+
positions[..., 0] -= positions[:, 0:1, 0]
|
66 |
+
positions[..., 2] -= positions[:, 0:1, 2]
|
67 |
+
'''All pose face Z+'''
|
68 |
+
positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)
|
69 |
+
return positions
|
70 |
+
|
71 |
+
def get_quaternion(positions):
|
72 |
+
skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
|
73 |
+
# (seq_len, joints_num, 4)
|
74 |
+
quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False)
|
75 |
+
|
76 |
+
'''Fix Quaternion Discontinuity'''
|
77 |
+
quat_params = qfix(quat_params)
|
78 |
+
# (seq_len, 4)
|
79 |
+
r_rot = quat_params[:, 0].copy()
|
80 |
+
# print(r_rot[0])
|
81 |
+
'''Root Linear Velocity'''
|
82 |
+
# (seq_len - 1, 3)
|
83 |
+
velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
|
84 |
+
# print(r_rot.shape, velocity.shape)
|
85 |
+
velocity = qrot_np(r_rot[1:], velocity)
|
86 |
+
'''Root Angular Velocity'''
|
87 |
+
# (seq_len - 1, 4)
|
88 |
+
r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
|
89 |
+
quat_params[1:, 0] = r_velocity
|
90 |
+
# (seq_len, joints_num, 4)
|
91 |
+
return quat_params, r_velocity, velocity, r_rot
|
92 |
+
|
93 |
+
def get_cont6d_params(positions):
|
94 |
+
skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
|
95 |
+
# (seq_len, joints_num, 4)
|
96 |
+
quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)
|
97 |
+
|
98 |
+
'''Quaternion to continuous 6D'''
|
99 |
+
cont_6d_params = quaternion_to_cont6d_np(quat_params)
|
100 |
+
|
101 |
+
# (seq_len, 4)
|
102 |
+
r_rot = quat_params[:, 0].copy()
|
103 |
+
|
104 |
+
'''Root Linear Velocity'''
|
105 |
+
# (seq_len - 1, 3)
|
106 |
+
velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
|
107 |
+
|
108 |
+
velocity = qrot_np(r_rot[1:], velocity)
|
109 |
+
|
110 |
+
'''Root Angular Velocity'''
|
111 |
+
# (seq_len - 1, 4)
|
112 |
+
r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
|
113 |
+
# (seq_len, joints_num, 4)
|
114 |
+
return cont_6d_params, r_velocity, velocity, r_rot
|
115 |
+
|
116 |
+
cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)
|
117 |
+
positions = get_rifke(positions)
|
118 |
+
|
119 |
+
'''Root height'''
|
120 |
+
root_y = positions[:, 0, 1:2]
|
121 |
+
|
122 |
+
'''Root rotation and linear velocity'''
|
123 |
+
# (seq_len-1, 1) rotation velocity along y-axis
|
124 |
+
# (seq_len-1, 2) linear velovity on xz plane
|
125 |
+
r_velocity = np.arcsin(r_velocity[:, 2:3])
|
126 |
+
l_velocity = velocity[:, [0, 2]]
|
127 |
+
# print(r_velocity.shape, l_velocity.shape, root_y.shape)
|
128 |
+
root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)
|
129 |
+
|
130 |
+
'''Get Joint Rotation Representation'''
|
131 |
+
# (seq_len, (joints_num-1) *6) quaternion for skeleton joints
|
132 |
+
rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)
|
133 |
+
|
134 |
+
'''Get Joint Rotation Invariant Position Represention'''
|
135 |
+
# (seq_len, (joints_num-1)*3) local joint position
|
136 |
+
ric_data = positions[:, 1:].reshape(len(positions), -1)
|
137 |
+
|
138 |
+
'''Get Joint Velocity Representation'''
|
139 |
+
# (seq_len-1, joints_num*3)
|
140 |
+
local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1),
|
141 |
+
global_positions[1:] - global_positions[:-1])
|
142 |
+
local_vel = local_vel.reshape(len(local_vel), -1)
|
143 |
+
|
144 |
+
data = root_data
|
145 |
+
data = np.concatenate([data, ric_data[:-1]], axis=-1)
|
146 |
+
data = np.concatenate([data, rot_data[:-1]], axis=-1)
|
147 |
+
# print(data.shape, local_vel.shape)
|
148 |
+
data = np.concatenate([data, local_vel], axis=-1)
|
149 |
+
data = np.concatenate([data, feet_l, feet_r], axis=-1)
|
150 |
+
|
151 |
+
return data
|
152 |
+
|
153 |
+
|
154 |
+
def process_file(positions, feet_thre):
|
155 |
+
'''Uniform Skeleton'''
|
156 |
+
positions = uniform_skeleton(positions, tgt_offsets)
|
157 |
+
|
158 |
+
'''Put on Floor'''
|
159 |
+
floor_height = positions.min(axis=0).min(axis=0)[1]
|
160 |
+
positions[:, :, 1] -= floor_height
|
161 |
+
|
162 |
+
'''XZ at origin'''
|
163 |
+
root_pos_init = positions[0]
|
164 |
+
root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1])
|
165 |
+
positions = positions - root_pose_init_xz
|
166 |
+
|
167 |
+
# '''Move the first pose to origin '''
|
168 |
+
# root_pos_init = positions[0]
|
169 |
+
# positions = positions - root_pos_init[0]
|
170 |
+
|
171 |
+
'''All initially face Z+'''
|
172 |
+
r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
|
173 |
+
across1 = root_pos_init[r_hip] - root_pos_init[l_hip]
|
174 |
+
across2 = root_pos_init[sdr_r] - root_pos_init[sdr_l]
|
175 |
+
across = across1 + across2
|
176 |
+
across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis]
|
177 |
+
|
178 |
+
# forward (3,), rotate around y-axis
|
179 |
+
forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
|
180 |
+
# forward (3,)
|
181 |
+
forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis]
|
182 |
+
|
183 |
+
# print(forward_init)
|
184 |
+
|
185 |
+
target = np.array([[0, 0, 1]])
|
186 |
+
root_quat_init = qbetween_np(forward_init, target)
|
187 |
+
root_quat_init = np.ones(positions.shape[:-1] + (4,)) * root_quat_init
|
188 |
+
|
189 |
+
positions_b = positions.copy()
|
190 |
+
|
191 |
+
positions = qrot_np(root_quat_init, positions)
|
192 |
+
|
193 |
+
'''New ground truth positions'''
|
194 |
+
global_positions = positions.copy()
|
195 |
+
|
196 |
+
""" Get Foot Contacts """
|
197 |
+
|
198 |
+
def foot_detect(positions, thres):
|
199 |
+
velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])
|
200 |
+
|
201 |
+
feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
|
202 |
+
feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2
|
203 |
+
feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
|
204 |
+
# feet_l_h = positions[:-1,fid_l,1]
|
205 |
+
# feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float)
|
206 |
+
feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float)
|
207 |
+
|
208 |
+
feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
|
209 |
+
feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
|
210 |
+
feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
|
211 |
+
# feet_r_h = positions[:-1,fid_r,1]
|
212 |
+
# feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float)
|
213 |
+
feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float)
|
214 |
+
return feet_l, feet_r
|
215 |
+
|
216 |
+
feet_l, feet_r = foot_detect(positions, feet_thre)
|
217 |
+
|
218 |
+
'''Quaternion and Cartesian representation'''
|
219 |
+
r_rot = None
|
220 |
+
|
221 |
+
def get_rifke(positions):
|
222 |
+
'''Local pose'''
|
223 |
+
positions[..., 0] -= positions[:, 0:1, 0]
|
224 |
+
positions[..., 2] -= positions[:, 0:1, 2]
|
225 |
+
'''All pose face Z+'''
|
226 |
+
positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)
|
227 |
+
return positions
|
228 |
+
|
229 |
+
def get_quaternion(positions):
|
230 |
+
skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
|
231 |
+
# (seq_len, joints_num, 4)
|
232 |
+
quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False)
|
233 |
+
|
234 |
+
'''Fix Quaternion Discontinuity'''
|
235 |
+
quat_params = qfix(quat_params)
|
236 |
+
# (seq_len, 4)
|
237 |
+
r_rot = quat_params[:, 0].copy()
|
238 |
+
# print(r_rot[0])
|
239 |
+
'''Root Linear Velocity'''
|
240 |
+
# (seq_len - 1, 3)
|
241 |
+
velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
|
242 |
+
# print(r_rot.shape, velocity.shape)
|
243 |
+
velocity = qrot_np(r_rot[1:], velocity)
|
244 |
+
'''Root Angular Velocity'''
|
245 |
+
# (seq_len - 1, 4)
|
246 |
+
r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
|
247 |
+
quat_params[1:, 0] = r_velocity
|
248 |
+
# (seq_len, joints_num, 4)
|
249 |
+
return quat_params, r_velocity, velocity, r_rot
|
250 |
+
|
251 |
+
def get_cont6d_params(positions):
|
252 |
+
skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
|
253 |
+
# (seq_len, joints_num, 4)
|
254 |
+
quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)
|
255 |
+
|
256 |
+
'''Quaternion to continuous 6D'''
|
257 |
+
cont_6d_params = quaternion_to_cont6d_np(quat_params)
|
258 |
+
# (seq_len, 4)
|
259 |
+
r_rot = quat_params[:, 0].copy()
|
260 |
+
# print(r_rot[0])
|
261 |
+
'''Root Linear Velocity'''
|
262 |
+
# (seq_len - 1, 3)
|
263 |
+
velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
|
264 |
+
# print(r_rot.shape, velocity.shape)
|
265 |
+
velocity = qrot_np(r_rot[1:], velocity)
|
266 |
+
'''Root Angular Velocity'''
|
267 |
+
# (seq_len - 1, 4)
|
268 |
+
r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
|
269 |
+
# (seq_len, joints_num, 4)
|
270 |
+
return cont_6d_params, r_velocity, velocity, r_rot
|
271 |
+
|
272 |
+
cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)
|
273 |
+
positions = get_rifke(positions)
|
274 |
+
|
275 |
+
'''Root height'''
|
276 |
+
root_y = positions[:, 0, 1:2]
|
277 |
+
|
278 |
+
'''Root rotation and linear velocity'''
|
279 |
+
# (seq_len-1, 1) rotation velocity along y-axis
|
280 |
+
# (seq_len-1, 2) linear velovity on xz plane
|
281 |
+
r_velocity = np.arcsin(r_velocity[:, 2:3])
|
282 |
+
l_velocity = velocity[:, [0, 2]]
|
283 |
+
# print(r_velocity.shape, l_velocity.shape, root_y.shape)
|
284 |
+
root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)
|
285 |
+
|
286 |
+
'''Get Joint Rotation Representation'''
|
287 |
+
# (seq_len, (joints_num-1) *6) quaternion for skeleton joints
|
288 |
+
rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)
|
289 |
+
|
290 |
+
'''Get Joint Rotation Invariant Position Represention'''
|
291 |
+
# (seq_len, (joints_num-1)*3) local joint position
|
292 |
+
ric_data = positions[:, 1:].reshape(len(positions), -1)
|
293 |
+
|
294 |
+
'''Get Joint Velocity Representation'''
|
295 |
+
# (seq_len-1, joints_num*3)
|
296 |
+
local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1),
|
297 |
+
global_positions[1:] - global_positions[:-1])
|
298 |
+
local_vel = local_vel.reshape(len(local_vel), -1)
|
299 |
+
|
300 |
+
data = root_data
|
301 |
+
data = np.concatenate([data, ric_data[:-1]], axis=-1)
|
302 |
+
data = np.concatenate([data, rot_data[:-1]], axis=-1)
|
303 |
+
|
304 |
+
data = np.concatenate([data, local_vel], axis=-1)
|
305 |
+
data = np.concatenate([data, feet_l, feet_r], axis=-1)
|
306 |
+
|
307 |
+
return data, global_positions, positions, l_velocity
|
308 |
+
|
309 |
+
|
310 |
+
# Recover global angle and positions for rotation data
|
311 |
+
# root_rot_velocity (B, seq_len, 1)
|
312 |
+
# root_linear_velocity (B, seq_len, 2)
|
313 |
+
# root_y (B, seq_len, 1)
|
314 |
+
# ric_data (B, seq_len, (joint_num - 1)*3)
|
315 |
+
# rot_data (B, seq_len, (joint_num - 1)*6)
|
316 |
+
# local_velocity (B, seq_len, joint_num*3)
|
317 |
+
# foot contact (B, seq_len, 4)
|
318 |
+
def recover_root_rot_pos(data):
|
319 |
+
rot_vel = data[..., 0]
|
320 |
+
r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
|
321 |
+
'''Get Y-axis rotation from rotation velocity'''
|
322 |
+
r_rot_ang[..., 1:] = rot_vel[..., :-1]
|
323 |
+
r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
|
324 |
+
|
325 |
+
r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
|
326 |
+
r_rot_quat[..., 0] = torch.cos(r_rot_ang)
|
327 |
+
r_rot_quat[..., 2] = torch.sin(r_rot_ang)
|
328 |
+
|
329 |
+
r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
|
330 |
+
r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
|
331 |
+
'''Add Y-axis rotation to root position'''
|
332 |
+
r_pos = qrot(qinv(r_rot_quat), r_pos)
|
333 |
+
|
334 |
+
r_pos = torch.cumsum(r_pos, dim=-2)
|
335 |
+
|
336 |
+
r_pos[..., 1] = data[..., 3]
|
337 |
+
return r_rot_quat, r_pos
|
338 |
+
|
339 |
+
|
340 |
+
def recover_from_rot(data, joints_num, skeleton):
|
341 |
+
r_rot_quat, r_pos = recover_root_rot_pos(data)
|
342 |
+
|
343 |
+
r_rot_cont6d = quaternion_to_cont6d(r_rot_quat)
|
344 |
+
|
345 |
+
start_indx = 1 + 2 + 1 + (joints_num - 1) * 3
|
346 |
+
end_indx = start_indx + (joints_num - 1) * 6
|
347 |
+
cont6d_params = data[..., start_indx:end_indx]
|
348 |
+
# print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape)
|
349 |
+
cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1)
|
350 |
+
cont6d_params = cont6d_params.view(-1, joints_num, 6)
|
351 |
+
|
352 |
+
positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos)
|
353 |
+
|
354 |
+
return positions
|
355 |
+
|
356 |
+
|
357 |
+
# NOTE: Expand input data types (torch.Tensor -> Union[torch.Tensor, np.array])
|
358 |
+
def recover_from_ric(
|
359 |
+
data: Union[torch.Tensor, np.array], joints_num: int
|
360 |
+
) -> Union[torch.Tensor, np.array]:
|
361 |
+
if isinstance(data, np.ndarray):
|
362 |
+
data = torch.from_numpy(data).float()
|
363 |
+
dtype = "numpy"
|
364 |
+
else:
|
365 |
+
data = data.float()
|
366 |
+
dtype = "tensor"
|
367 |
+
r_rot_quat, r_pos = recover_root_rot_pos(data)
|
368 |
+
positions = data[..., 4:(joints_num - 1) * 3 + 4]
|
369 |
+
positions = positions.view(positions.shape[:-1] + (-1, 3))
|
370 |
+
|
371 |
+
'''Add Y-axis rotation to local joints'''
|
372 |
+
positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
|
373 |
+
|
374 |
+
'''Add root XZ to joints'''
|
375 |
+
positions[..., 0] += r_pos[..., 0:1]
|
376 |
+
positions[..., 2] += r_pos[..., 2:3]
|
377 |
+
|
378 |
+
'''Concate root and joints'''
|
379 |
+
positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
|
380 |
+
|
381 |
+
if dtype == "numpy":
|
382 |
+
positions = positions.numpy()
|
383 |
+
|
384 |
+
return positions
|
385 |
+
'''
|
386 |
+
For Text2Motion Dataset
|
387 |
+
'''
|
388 |
+
'''
|
389 |
+
if __name__ == "__main__":
|
390 |
+
example_id = "000021"
|
391 |
+
# Lower legs
|
392 |
+
l_idx1, l_idx2 = 5, 8
|
393 |
+
# Right/Left foot
|
394 |
+
fid_r, fid_l = [8, 11], [7, 10]
|
395 |
+
# Face direction, r_hip, l_hip, sdr_r, sdr_l
|
396 |
+
face_joint_indx = [2, 1, 17, 16]
|
397 |
+
# l_hip, r_hip
|
398 |
+
r_hip, l_hip = 2, 1
|
399 |
+
joints_num = 22
|
400 |
+
# ds_num = 8
|
401 |
+
data_dir = '../dataset/pose_data_raw/joints/'
|
402 |
+
save_dir1 = '../dataset/pose_data_raw/new_joints/'
|
403 |
+
save_dir2 = '../dataset/pose_data_raw/new_joint_vecs/'
|
404 |
+
|
405 |
+
n_raw_offsets = torch.from_numpy(t2m_raw_offsets)
|
406 |
+
kinematic_chain = t2m_kinematic_chain
|
407 |
+
|
408 |
+
# Get offsets of target skeleton
|
409 |
+
example_data = np.load(os.path.join(data_dir, example_id + '.npy'))
|
410 |
+
example_data = example_data.reshape(len(example_data), -1, 3)
|
411 |
+
example_data = torch.from_numpy(example_data)
|
412 |
+
tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
|
413 |
+
# (joints_num, 3)
|
414 |
+
tgt_offsets = tgt_skel.get_offsets_joints(example_data[0])
|
415 |
+
# print(tgt_offsets)
|
416 |
+
|
417 |
+
source_list = os.listdir(data_dir)
|
418 |
+
frame_num = 0
|
419 |
+
for source_file in tqdm(source_list):
|
420 |
+
source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num]
|
421 |
+
try:
|
422 |
+
data, ground_positions, positions, l_velocity = process_file(source_data, 0.002)
|
423 |
+
rec_ric_data = recover_from_ric(torch.from_numpy(data).unsqueeze(0).float(), joints_num)
|
424 |
+
np.save(pjoin(save_dir1, source_file), rec_ric_data.squeeze().numpy())
|
425 |
+
np.save(pjoin(save_dir2, source_file), data)
|
426 |
+
frame_num += data.shape[0]
|
427 |
+
except Exception as e:
|
428 |
+
print(source_file)
|
429 |
+
print(e)
|
430 |
+
|
431 |
+
print('Total clips: %d, Frames: %d, Duration: %fm' %
|
432 |
+
(len(source_list), frame_num, frame_num / 20 / 60))
|
433 |
+
'''
|
434 |
+
|
435 |
+
if __name__ == "__main__":
|
436 |
+
example_id = "03950_gt"
|
437 |
+
# Lower legs
|
438 |
+
l_idx1, l_idx2 = 17, 18
|
439 |
+
# Right/Left foot
|
440 |
+
fid_r, fid_l = [14, 15], [19, 20]
|
441 |
+
# Face direction, r_hip, l_hip, sdr_r, sdr_l
|
442 |
+
face_joint_indx = [11, 16, 5, 8]
|
443 |
+
# l_hip, r_hip
|
444 |
+
r_hip, l_hip = 11, 16
|
445 |
+
joints_num = 21
|
446 |
+
# ds_num = 8
|
447 |
+
data_dir = '../dataset/kit_mocap_dataset/joints/'
|
448 |
+
save_dir1 = '../dataset/kit_mocap_dataset/new_joints/'
|
449 |
+
save_dir2 = '../dataset/kit_mocap_dataset/new_joint_vecs/'
|
450 |
+
|
451 |
+
n_raw_offsets = torch.from_numpy(kit_raw_offsets)
|
452 |
+
kinematic_chain = kit_kinematic_chain
|
453 |
+
|
454 |
+
'''Get offsets of target skeleton'''
|
455 |
+
example_data = np.load(os.path.join(data_dir, example_id + '.npy'))
|
456 |
+
example_data = example_data.reshape(len(example_data), -1, 3)
|
457 |
+
example_data = torch.from_numpy(example_data)
|
458 |
+
tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
|
459 |
+
# (joints_num, 3)
|
460 |
+
tgt_offsets = tgt_skel.get_offsets_joints(example_data[0])
|
461 |
+
# print(tgt_offsets)
|
462 |
+
|
463 |
+
source_list = os.listdir(data_dir)
|
464 |
+
frame_num = 0
|
465 |
+
'''Read source data'''
|
466 |
+
for source_file in tqdm(source_list):
|
467 |
+
source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num]
|
468 |
+
try:
|
469 |
+
name = ''.join(source_file[:-7].split('_')) + '.npy'
|
470 |
+
data, ground_positions, positions, l_velocity = process_file(source_data, 0.05)
|
471 |
+
rec_ric_data = recover_from_ric(torch.from_numpy(data).unsqueeze(0).float(), joints_num)
|
472 |
+
if np.isnan(rec_ric_data.numpy()).any():
|
473 |
+
print(source_file)
|
474 |
+
continue
|
475 |
+
np.save(pjoin(save_dir1, name), rec_ric_data.squeeze().numpy())
|
476 |
+
np.save(pjoin(save_dir2, name), data)
|
477 |
+
frame_num += data.shape[0]
|
478 |
+
except Exception as e:
|
479 |
+
print(source_file)
|
480 |
+
print(e)
|
481 |
+
|
482 |
+
print('Total clips: %d, Frames: %d, Duration: %fm' %
|
483 |
+
(len(source_list), frame_num, frame_num / 12.5 / 60))
|
utils/paramUtil.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
# Define a kinematic tree for the skeletal struture
|
4 |
+
kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
|
5 |
+
|
6 |
+
kit_raw_offsets = np.array(
|
7 |
+
[
|
8 |
+
[0, 0, 0],
|
9 |
+
[0, 1, 0],
|
10 |
+
[0, 1, 0],
|
11 |
+
[0, 1, 0],
|
12 |
+
[0, 1, 0],
|
13 |
+
[1, 0, 0],
|
14 |
+
[0, -1, 0],
|
15 |
+
[0, -1, 0],
|
16 |
+
[-1, 0, 0],
|
17 |
+
[0, -1, 0],
|
18 |
+
[0, -1, 0],
|
19 |
+
[1, 0, 0],
|
20 |
+
[0, -1, 0],
|
21 |
+
[0, -1, 0],
|
22 |
+
[0, 0, 1],
|
23 |
+
[0, 0, 1],
|
24 |
+
[-1, 0, 0],
|
25 |
+
[0, -1, 0],
|
26 |
+
[0, -1, 0],
|
27 |
+
[0, 0, 1],
|
28 |
+
[0, 0, 1]
|
29 |
+
]
|
30 |
+
)
|
31 |
+
|
32 |
+
t2m_raw_offsets = np.array([[0,0,0],
|
33 |
+
[1,0,0],
|
34 |
+
[-1,0,0],
|
35 |
+
[0,1,0],
|
36 |
+
[0,-1,0],
|
37 |
+
[0,-1,0],
|
38 |
+
[0,1,0],
|
39 |
+
[0,-1,0],
|
40 |
+
[0,-1,0],
|
41 |
+
[0,1,0],
|
42 |
+
[0,0,1],
|
43 |
+
[0,0,1],
|
44 |
+
[0,1,0],
|
45 |
+
[1,0,0],
|
46 |
+
[-1,0,0],
|
47 |
+
[0,0,1],
|
48 |
+
[0,-1,0],
|
49 |
+
[0,-1,0],
|
50 |
+
[0,-1,0],
|
51 |
+
[0,-1,0],
|
52 |
+
[0,-1,0],
|
53 |
+
[0,-1,0]])
|
54 |
+
|
55 |
+
t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
|
56 |
+
t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
|
57 |
+
t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
|
58 |
+
|
59 |
+
kit_kinematic_tree = [
|
60 |
+
[0, 1, 2, 3, 4], # body
|
61 |
+
[3, 5, 6, 7], # right arm
|
62 |
+
[3, 8, 9, 10], # left arm
|
63 |
+
[0, 11, 12, 13, 14, 15], # right leg
|
64 |
+
[0, 16, 17, 18, 19, 20], # left leg
|
65 |
+
]
|
66 |
+
|
67 |
+
humanml3d_kinematic_tree = [
|
68 |
+
[0, 3, 6, 9, 12, 15], # body
|
69 |
+
[9, 14, 17, 19, 21], # right arm
|
70 |
+
[9, 13, 16, 18, 20], # left arm
|
71 |
+
[0, 2, 5, 8, 11], # right leg
|
72 |
+
[0, 1, 4, 7, 10], # left leg
|
73 |
+
]
|
74 |
+
|
75 |
+
|
76 |
+
kit_tgt_skel_id = '03950'
|
77 |
+
|
78 |
+
t2m_tgt_skel_id = '000021'
|
79 |
+
|
80 |
+
KIT_JOINT_NAMES = [
|
81 |
+
"pelvis", # 0
|
82 |
+
"spine_1", # 1
|
83 |
+
"spine_3", # 2
|
84 |
+
"neck", # 3
|
85 |
+
"head", # 4
|
86 |
+
"left_shoulder", # 5
|
87 |
+
"left_elbow", # 6
|
88 |
+
"left_wrist", # 7
|
89 |
+
"right_shoulder", # 8
|
90 |
+
"right_elbow", # 9
|
91 |
+
"right_wrist", # 10
|
92 |
+
"left_hip", # 11
|
93 |
+
"left_knee", # 12
|
94 |
+
"left_ankle", # 13
|
95 |
+
"left_heel", # 14
|
96 |
+
"left_foot", # 15
|
97 |
+
"right_hip", # 16
|
98 |
+
"right_knee", # 17
|
99 |
+
"right_ankle", # 18
|
100 |
+
"right_heel", # 19
|
101 |
+
"right_foot", # 20
|
102 |
+
]
|
103 |
+
|
104 |
+
HumanML3D_JOINT_NAMES = [
|
105 |
+
"pelvis", # 0: root
|
106 |
+
"left_hip", # 1: lhip
|
107 |
+
"right_hip", # 2: rhip
|
108 |
+
"spine_1", # 3: belly
|
109 |
+
"left_knee", # 4: lknee
|
110 |
+
"right_knee", # 5: rknee
|
111 |
+
"spine_2", # 6: spine
|
112 |
+
"left_ankle", # 7: lankle
|
113 |
+
"right_ankle", # 8: rankle
|
114 |
+
"spine_3", # 9: chest
|
115 |
+
"left_foot", # 10: ltoes
|
116 |
+
"right_foot", # 11: rtoes
|
117 |
+
"neck", # 12: neck
|
118 |
+
"left_clavicle", # 13: linshoulder
|
119 |
+
"right_clavicle", # 14: rinshoulder
|
120 |
+
"head", # 15: head
|
121 |
+
"left_shoulder", # 16: lshoulder
|
122 |
+
"right_shoulder", # 17: rshoulder
|
123 |
+
"left_elbow", # 18: lelbow
|
124 |
+
"right_elbow", # 19: relbow
|
125 |
+
"left_wrist", # 20: lwrist
|
126 |
+
"right_wrist", # 21: rwrist
|
127 |
+
]
|
128 |
+
|
129 |
+
HumanML3D2KIT = {
|
130 |
+
0:0, # pelvis--pelvis
|
131 |
+
1:16, # left_hip--right_hip
|
132 |
+
2:11, # right_hip--left_hip
|
133 |
+
3:1, # spine_1--spine_1
|
134 |
+
4:17, # left_knee--right_knee
|
135 |
+
5:12, # right_knee--left_knee
|
136 |
+
6:1, # spine_2--spine_1
|
137 |
+
7:18, # left_ankle--right_ankle
|
138 |
+
8:13, # right_ankle--left_ankle
|
139 |
+
9:2, # spine_3--spine_3
|
140 |
+
10:20, # left_foot--right_foot
|
141 |
+
11:15, # right_foot--left_foot
|
142 |
+
12:3, # neck--neck
|
143 |
+
13:8, # left_clavicle--right_shoulder
|
144 |
+
14:5, # right_clavicle--left_shoulder
|
145 |
+
15:4, # head--head
|
146 |
+
16:8, # left_shoulder--right_shoulder
|
147 |
+
17:5, # right_shoulder--left_shoulder
|
148 |
+
18:9, # left_elbow--right_elbow
|
149 |
+
19:6, # right_elbow--left_elbow
|
150 |
+
20:10, # left_wrist--right_wrist
|
151 |
+
21:7, # right_wrist--left_wrist
|
152 |
+
}
|
utils/plot_script.py
ADDED
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib
|
4 |
+
import matplotlib.pyplot as plt
|
5 |
+
from mpl_toolkits.mplot3d import Axes3D
|
6 |
+
from matplotlib.animation import FuncAnimation, FFMpegFileWriter
|
7 |
+
from mpl_toolkits.mplot3d.art3d import Poly3DCollection
|
8 |
+
import mpl_toolkits.mplot3d.axes3d as p3
|
9 |
+
# import cv2
|
10 |
+
|
11 |
+
|
12 |
+
def list_cut_average(ll, intervals):
|
13 |
+
if intervals == 1:
|
14 |
+
return ll
|
15 |
+
|
16 |
+
bins = math.ceil(len(ll) * 1.0 / intervals)
|
17 |
+
ll_new = []
|
18 |
+
for i in range(bins):
|
19 |
+
l_low = intervals * i
|
20 |
+
l_high = l_low + intervals
|
21 |
+
l_high = l_high if l_high < len(ll) else len(ll)
|
22 |
+
ll_new.append(np.mean(ll[l_low:l_high]))
|
23 |
+
return ll_new
|
24 |
+
|
25 |
+
|
26 |
+
def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4):
|
27 |
+
matplotlib.use('Agg')
|
28 |
+
|
29 |
+
title_sp = title.split(' ')
|
30 |
+
if len(title_sp) > 20:
|
31 |
+
title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])])
|
32 |
+
elif len(title_sp) > 10:
|
33 |
+
title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
|
34 |
+
|
35 |
+
def init():
|
36 |
+
ax.set_xlim3d([-radius / 4, radius / 4])
|
37 |
+
ax.set_ylim3d([0, radius / 2])
|
38 |
+
ax.set_zlim3d([0, radius / 2])
|
39 |
+
# print(title)
|
40 |
+
fig.suptitle(title, fontsize=20)
|
41 |
+
ax.grid(b=False)
|
42 |
+
|
43 |
+
def plot_xzPlane(minx, maxx, miny, minz, maxz):
|
44 |
+
## Plot a plane XZ
|
45 |
+
verts = [
|
46 |
+
[minx, miny, minz],
|
47 |
+
[minx, miny, maxz],
|
48 |
+
[maxx, miny, maxz],
|
49 |
+
[maxx, miny, minz]
|
50 |
+
]
|
51 |
+
xz_plane = Poly3DCollection([verts])
|
52 |
+
xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
|
53 |
+
ax.add_collection3d(xz_plane)
|
54 |
+
|
55 |
+
# return ax
|
56 |
+
|
57 |
+
# (seq_len, joints_num, 3)
|
58 |
+
data = joints.copy().reshape(len(joints), -1, 3)
|
59 |
+
fig = plt.figure(figsize=figsize)
|
60 |
+
ax = p3.Axes3D(fig)
|
61 |
+
init()
|
62 |
+
MINS = data.min(axis=0).min(axis=0)
|
63 |
+
MAXS = data.max(axis=0).max(axis=0)
|
64 |
+
colors = ['red', 'blue', 'black', 'red', 'blue',
|
65 |
+
'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',
|
66 |
+
'darkred', 'darkred', 'darkred', 'darkred', 'darkred']
|
67 |
+
frame_number = data.shape[0]
|
68 |
+
# print(data.shape)
|
69 |
+
|
70 |
+
height_offset = MINS[1]
|
71 |
+
data[:, :, 1] -= height_offset
|
72 |
+
trajec = data[:, 0, [0, 2]]
|
73 |
+
|
74 |
+
data[..., 0] -= data[:, 0:1, 0]
|
75 |
+
data[..., 2] -= data[:, 0:1, 2]
|
76 |
+
|
77 |
+
# print(trajec.shape)
|
78 |
+
|
79 |
+
def update(index):
|
80 |
+
# print(index)
|
81 |
+
ax.lines = []
|
82 |
+
ax.collections = []
|
83 |
+
ax.view_init(elev=120, azim=-90)
|
84 |
+
ax.dist = 7.5
|
85 |
+
# ax =
|
86 |
+
plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
|
87 |
+
MAXS[2] - trajec[index, 1])
|
88 |
+
# ax.scatter(data[index, :22, 0], data[index, :22, 1], data[index, :22, 2], color='black', s=3)
|
89 |
+
|
90 |
+
# if index > 1:
|
91 |
+
# ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]),
|
92 |
+
# trajec[:index, 1] - trajec[index, 1], linewidth=1.0,
|
93 |
+
# color='blue')
|
94 |
+
# ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2])
|
95 |
+
|
96 |
+
for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
|
97 |
+
# print(color)
|
98 |
+
if i < 5:
|
99 |
+
linewidth = 4.0
|
100 |
+
else:
|
101 |
+
linewidth = 2.0
|
102 |
+
ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
|
103 |
+
color=color)
|
104 |
+
# print(trajec[:index, 0].shape)
|
105 |
+
|
106 |
+
plt.axis('off')
|
107 |
+
ax.set_xticklabels([])
|
108 |
+
ax.set_yticklabels([])
|
109 |
+
ax.set_zticklabels([])
|
110 |
+
|
111 |
+
ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
|
112 |
+
|
113 |
+
from matplotlib import rcParams
|
114 |
+
rcParams['animation.ffmpeg_path'] = '/home/chenlinghao/software/miniconda3/envs/py10/bin/ffmpeg'
|
115 |
+
writer = FFMpegFileWriter(fps=fps)
|
116 |
+
ani.save(save_path, writer=writer)
|
117 |
+
plt.close()
|
utils/quaternion.py
ADDED
@@ -0,0 +1,423 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) 2018-present, Facebook, Inc.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
#
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
_EPS4 = np.finfo(float).eps * 4.0
|
12 |
+
|
13 |
+
_FLOAT_EPS = np.finfo(np.float32).eps
|
14 |
+
|
15 |
+
# PyTorch-backed implementations
|
16 |
+
def qinv(q):
|
17 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
18 |
+
mask = torch.ones_like(q)
|
19 |
+
mask[..., 1:] = -mask[..., 1:]
|
20 |
+
return q * mask
|
21 |
+
|
22 |
+
|
23 |
+
def qinv_np(q):
|
24 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
25 |
+
return qinv(torch.from_numpy(q).float()).numpy()
|
26 |
+
|
27 |
+
|
28 |
+
def qnormalize(q):
|
29 |
+
assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
|
30 |
+
return q / torch.norm(q, dim=-1, keepdim=True)
|
31 |
+
|
32 |
+
|
33 |
+
def qmul(q, r):
|
34 |
+
"""
|
35 |
+
Multiply quaternion(s) q with quaternion(s) r.
|
36 |
+
Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
|
37 |
+
Returns q*r as a tensor of shape (*, 4).
|
38 |
+
"""
|
39 |
+
assert q.shape[-1] == 4
|
40 |
+
assert r.shape[-1] == 4
|
41 |
+
|
42 |
+
original_shape = q.shape
|
43 |
+
|
44 |
+
# Compute outer product
|
45 |
+
terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
|
46 |
+
|
47 |
+
w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
|
48 |
+
x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
|
49 |
+
y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
|
50 |
+
z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
|
51 |
+
return torch.stack((w, x, y, z), dim=1).view(original_shape)
|
52 |
+
|
53 |
+
|
54 |
+
def qrot(q, v):
|
55 |
+
"""
|
56 |
+
Rotate vector(s) v about the rotation described by quaternion(s) q.
|
57 |
+
Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
|
58 |
+
where * denotes any number of dimensions.
|
59 |
+
Returns a tensor of shape (*, 3).
|
60 |
+
"""
|
61 |
+
assert q.shape[-1] == 4
|
62 |
+
assert v.shape[-1] == 3
|
63 |
+
assert q.shape[:-1] == v.shape[:-1]
|
64 |
+
|
65 |
+
original_shape = list(v.shape)
|
66 |
+
# print(q.shape)
|
67 |
+
q = q.contiguous().view(-1, 4)
|
68 |
+
v = v.contiguous().view(-1, 3)
|
69 |
+
|
70 |
+
qvec = q[:, 1:]
|
71 |
+
uv = torch.cross(qvec, v, dim=1)
|
72 |
+
uuv = torch.cross(qvec, uv, dim=1)
|
73 |
+
return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
|
74 |
+
|
75 |
+
|
76 |
+
def qeuler(q, order, epsilon=0, deg=True):
|
77 |
+
"""
|
78 |
+
Convert quaternion(s) q to Euler angles.
|
79 |
+
Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
|
80 |
+
Returns a tensor of shape (*, 3).
|
81 |
+
"""
|
82 |
+
assert q.shape[-1] == 4
|
83 |
+
|
84 |
+
original_shape = list(q.shape)
|
85 |
+
original_shape[-1] = 3
|
86 |
+
q = q.view(-1, 4)
|
87 |
+
|
88 |
+
q0 = q[:, 0]
|
89 |
+
q1 = q[:, 1]
|
90 |
+
q2 = q[:, 2]
|
91 |
+
q3 = q[:, 3]
|
92 |
+
|
93 |
+
if order == 'xyz':
|
94 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
95 |
+
y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
|
96 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
97 |
+
elif order == 'yzx':
|
98 |
+
x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
99 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
100 |
+
z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
|
101 |
+
elif order == 'zxy':
|
102 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
|
103 |
+
y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
104 |
+
z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
|
105 |
+
elif order == 'xzy':
|
106 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
107 |
+
y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
|
108 |
+
z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
|
109 |
+
elif order == 'yxz':
|
110 |
+
x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
|
111 |
+
y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
|
112 |
+
z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
|
113 |
+
elif order == 'zyx':
|
114 |
+
x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
|
115 |
+
y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
|
116 |
+
z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
|
117 |
+
else:
|
118 |
+
raise
|
119 |
+
|
120 |
+
if deg:
|
121 |
+
return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
|
122 |
+
else:
|
123 |
+
return torch.stack((x, y, z), dim=1).view(original_shape)
|
124 |
+
|
125 |
+
|
126 |
+
# Numpy-backed implementations
|
127 |
+
|
128 |
+
def qmul_np(q, r):
|
129 |
+
q = torch.from_numpy(q).contiguous().float()
|
130 |
+
r = torch.from_numpy(r).contiguous().float()
|
131 |
+
return qmul(q, r).numpy()
|
132 |
+
|
133 |
+
|
134 |
+
def qrot_np(q, v):
|
135 |
+
q = torch.from_numpy(q).contiguous().float()
|
136 |
+
v = torch.from_numpy(v).contiguous().float()
|
137 |
+
return qrot(q, v).numpy()
|
138 |
+
|
139 |
+
|
140 |
+
def qeuler_np(q, order, epsilon=0, use_gpu=False):
|
141 |
+
if use_gpu:
|
142 |
+
q = torch.from_numpy(q).cuda().float()
|
143 |
+
return qeuler(q, order, epsilon).cpu().numpy()
|
144 |
+
else:
|
145 |
+
q = torch.from_numpy(q).contiguous().float()
|
146 |
+
return qeuler(q, order, epsilon).numpy()
|
147 |
+
|
148 |
+
|
149 |
+
def qfix(q):
|
150 |
+
"""
|
151 |
+
Enforce quaternion continuity across the time dimension by selecting
|
152 |
+
the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
|
153 |
+
between two consecutive frames.
|
154 |
+
|
155 |
+
Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
|
156 |
+
Returns a tensor of the same shape.
|
157 |
+
"""
|
158 |
+
assert len(q.shape) == 3
|
159 |
+
assert q.shape[-1] == 4
|
160 |
+
|
161 |
+
result = q.copy()
|
162 |
+
dot_products = np.sum(q[1:] * q[:-1], axis=2)
|
163 |
+
mask = dot_products < 0
|
164 |
+
mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
|
165 |
+
result[1:][mask] *= -1
|
166 |
+
return result
|
167 |
+
|
168 |
+
|
169 |
+
def euler2quat(e, order, deg=True):
|
170 |
+
"""
|
171 |
+
Convert Euler angles to quaternions.
|
172 |
+
"""
|
173 |
+
assert e.shape[-1] == 3
|
174 |
+
|
175 |
+
original_shape = list(e.shape)
|
176 |
+
original_shape[-1] = 4
|
177 |
+
|
178 |
+
e = e.view(-1, 3)
|
179 |
+
|
180 |
+
## if euler angles in degrees
|
181 |
+
if deg:
|
182 |
+
e = e * np.pi / 180.
|
183 |
+
|
184 |
+
x = e[:, 0]
|
185 |
+
y = e[:, 1]
|
186 |
+
z = e[:, 2]
|
187 |
+
|
188 |
+
rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
|
189 |
+
ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
|
190 |
+
rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
|
191 |
+
|
192 |
+
result = None
|
193 |
+
for coord in order:
|
194 |
+
if coord == 'x':
|
195 |
+
r = rx
|
196 |
+
elif coord == 'y':
|
197 |
+
r = ry
|
198 |
+
elif coord == 'z':
|
199 |
+
r = rz
|
200 |
+
else:
|
201 |
+
raise
|
202 |
+
if result is None:
|
203 |
+
result = r
|
204 |
+
else:
|
205 |
+
result = qmul(result, r)
|
206 |
+
|
207 |
+
# Reverse antipodal representation to have a non-negative "w"
|
208 |
+
if order in ['xyz', 'yzx', 'zxy']:
|
209 |
+
result *= -1
|
210 |
+
|
211 |
+
return result.view(original_shape)
|
212 |
+
|
213 |
+
|
214 |
+
def expmap_to_quaternion(e):
|
215 |
+
"""
|
216 |
+
Convert axis-angle rotations (aka exponential maps) to quaternions.
|
217 |
+
Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
|
218 |
+
Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
|
219 |
+
Returns a tensor of shape (*, 4).
|
220 |
+
"""
|
221 |
+
assert e.shape[-1] == 3
|
222 |
+
|
223 |
+
original_shape = list(e.shape)
|
224 |
+
original_shape[-1] = 4
|
225 |
+
e = e.reshape(-1, 3)
|
226 |
+
|
227 |
+
theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
|
228 |
+
w = np.cos(0.5 * theta).reshape(-1, 1)
|
229 |
+
xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
|
230 |
+
return np.concatenate((w, xyz), axis=1).reshape(original_shape)
|
231 |
+
|
232 |
+
|
233 |
+
def euler_to_quaternion(e, order):
|
234 |
+
"""
|
235 |
+
Convert Euler angles to quaternions.
|
236 |
+
"""
|
237 |
+
assert e.shape[-1] == 3
|
238 |
+
|
239 |
+
original_shape = list(e.shape)
|
240 |
+
original_shape[-1] = 4
|
241 |
+
|
242 |
+
e = e.reshape(-1, 3)
|
243 |
+
|
244 |
+
x = e[:, 0]
|
245 |
+
y = e[:, 1]
|
246 |
+
z = e[:, 2]
|
247 |
+
|
248 |
+
rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
|
249 |
+
ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
|
250 |
+
rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
|
251 |
+
|
252 |
+
result = None
|
253 |
+
for coord in order:
|
254 |
+
if coord == 'x':
|
255 |
+
r = rx
|
256 |
+
elif coord == 'y':
|
257 |
+
r = ry
|
258 |
+
elif coord == 'z':
|
259 |
+
r = rz
|
260 |
+
else:
|
261 |
+
raise
|
262 |
+
if result is None:
|
263 |
+
result = r
|
264 |
+
else:
|
265 |
+
result = qmul_np(result, r)
|
266 |
+
|
267 |
+
# Reverse antipodal representation to have a non-negative "w"
|
268 |
+
if order in ['xyz', 'yzx', 'zxy']:
|
269 |
+
result *= -1
|
270 |
+
|
271 |
+
return result.reshape(original_shape)
|
272 |
+
|
273 |
+
|
274 |
+
def quaternion_to_matrix(quaternions):
|
275 |
+
"""
|
276 |
+
Convert rotations given as quaternions to rotation matrices.
|
277 |
+
Args:
|
278 |
+
quaternions: quaternions with real part first,
|
279 |
+
as tensor of shape (..., 4).
|
280 |
+
Returns:
|
281 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
282 |
+
"""
|
283 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
284 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
285 |
+
|
286 |
+
o = torch.stack(
|
287 |
+
(
|
288 |
+
1 - two_s * (j * j + k * k),
|
289 |
+
two_s * (i * j - k * r),
|
290 |
+
two_s * (i * k + j * r),
|
291 |
+
two_s * (i * j + k * r),
|
292 |
+
1 - two_s * (i * i + k * k),
|
293 |
+
two_s * (j * k - i * r),
|
294 |
+
two_s * (i * k - j * r),
|
295 |
+
two_s * (j * k + i * r),
|
296 |
+
1 - two_s * (i * i + j * j),
|
297 |
+
),
|
298 |
+
-1,
|
299 |
+
)
|
300 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
301 |
+
|
302 |
+
|
303 |
+
def quaternion_to_matrix_np(quaternions):
|
304 |
+
q = torch.from_numpy(quaternions).contiguous().float()
|
305 |
+
return quaternion_to_matrix(q).numpy()
|
306 |
+
|
307 |
+
|
308 |
+
def quaternion_to_cont6d_np(quaternions):
|
309 |
+
rotation_mat = quaternion_to_matrix_np(quaternions)
|
310 |
+
cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
|
311 |
+
return cont_6d
|
312 |
+
|
313 |
+
|
314 |
+
def quaternion_to_cont6d(quaternions):
|
315 |
+
rotation_mat = quaternion_to_matrix(quaternions)
|
316 |
+
cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
|
317 |
+
return cont_6d
|
318 |
+
|
319 |
+
|
320 |
+
def cont6d_to_matrix(cont6d):
|
321 |
+
assert cont6d.shape[-1] == 6, "The last dimension must be 6"
|
322 |
+
x_raw = cont6d[..., 0:3]
|
323 |
+
y_raw = cont6d[..., 3:6]
|
324 |
+
|
325 |
+
x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
|
326 |
+
z = torch.cross(x, y_raw, dim=-1)
|
327 |
+
z = z / torch.norm(z, dim=-1, keepdim=True)
|
328 |
+
|
329 |
+
y = torch.cross(z, x, dim=-1)
|
330 |
+
|
331 |
+
x = x[..., None]
|
332 |
+
y = y[..., None]
|
333 |
+
z = z[..., None]
|
334 |
+
|
335 |
+
mat = torch.cat([x, y, z], dim=-1)
|
336 |
+
return mat
|
337 |
+
|
338 |
+
|
339 |
+
def cont6d_to_matrix_np(cont6d):
|
340 |
+
q = torch.from_numpy(cont6d).contiguous().float()
|
341 |
+
return cont6d_to_matrix(q).numpy()
|
342 |
+
|
343 |
+
|
344 |
+
def qpow(q0, t, dtype=torch.float):
|
345 |
+
''' q0 : tensor of quaternions
|
346 |
+
t: tensor of powers
|
347 |
+
'''
|
348 |
+
q0 = qnormalize(q0)
|
349 |
+
theta0 = torch.acos(q0[..., 0])
|
350 |
+
|
351 |
+
## if theta0 is close to zero, add epsilon to avoid NaNs
|
352 |
+
mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
|
353 |
+
theta0 = (1 - mask) * theta0 + mask * 10e-10
|
354 |
+
v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
|
355 |
+
|
356 |
+
if isinstance(t, torch.Tensor):
|
357 |
+
q = torch.zeros(t.shape + q0.shape)
|
358 |
+
theta = t.view(-1, 1) * theta0.view(1, -1)
|
359 |
+
else: ## if t is a number
|
360 |
+
q = torch.zeros(q0.shape)
|
361 |
+
theta = t * theta0
|
362 |
+
|
363 |
+
q[..., 0] = torch.cos(theta)
|
364 |
+
q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
|
365 |
+
|
366 |
+
return q.to(dtype)
|
367 |
+
|
368 |
+
|
369 |
+
def qslerp(q0, q1, t):
|
370 |
+
'''
|
371 |
+
q0: starting quaternion
|
372 |
+
q1: ending quaternion
|
373 |
+
t: array of points along the way
|
374 |
+
|
375 |
+
Returns:
|
376 |
+
Tensor of Slerps: t.shape + q0.shape
|
377 |
+
'''
|
378 |
+
|
379 |
+
q0 = qnormalize(q0)
|
380 |
+
q1 = qnormalize(q1)
|
381 |
+
q_ = qpow(qmul(q1, qinv(q0)), t)
|
382 |
+
|
383 |
+
return qmul(q_,
|
384 |
+
q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
|
385 |
+
|
386 |
+
|
387 |
+
def qbetween(v0, v1):
|
388 |
+
'''
|
389 |
+
find the quaternion used to rotate v0 to v1
|
390 |
+
'''
|
391 |
+
assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
|
392 |
+
assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
|
393 |
+
|
394 |
+
v = torch.cross(v0, v1)
|
395 |
+
w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
|
396 |
+
keepdim=True)
|
397 |
+
return qnormalize(torch.cat([w, v], dim=-1))
|
398 |
+
|
399 |
+
|
400 |
+
def qbetween_np(v0, v1):
|
401 |
+
'''
|
402 |
+
find the quaternion used to rotate v0 to v1
|
403 |
+
'''
|
404 |
+
assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
|
405 |
+
assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
|
406 |
+
|
407 |
+
v0 = torch.from_numpy(v0).float()
|
408 |
+
v1 = torch.from_numpy(v1).float()
|
409 |
+
return qbetween(v0, v1).numpy()
|
410 |
+
|
411 |
+
|
412 |
+
def lerp(p0, p1, t):
|
413 |
+
if not isinstance(t, torch.Tensor):
|
414 |
+
t = torch.Tensor([t])
|
415 |
+
|
416 |
+
new_shape = t.shape + p0.shape
|
417 |
+
new_view_t = t.shape + torch.Size([1] * len(p0.shape))
|
418 |
+
new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
|
419 |
+
p0 = p0.view(new_view_p).expand(new_shape)
|
420 |
+
p1 = p1.view(new_view_p).expand(new_shape)
|
421 |
+
t = t.view(new_view_t).expand(new_shape)
|
422 |
+
|
423 |
+
return p0 + t * (p1 - p0)
|
utils/skeleton.py
ADDED
@@ -0,0 +1,199 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils.quaternion import *
|
2 |
+
import scipy.ndimage.filters as filters
|
3 |
+
|
4 |
+
class Skeleton(object):
|
5 |
+
def __init__(self, offset, kinematic_tree, device):
|
6 |
+
self.device = device
|
7 |
+
self._raw_offset_np = offset.numpy()
|
8 |
+
self._raw_offset = offset.clone().detach().to(device).float()
|
9 |
+
self._kinematic_tree = kinematic_tree
|
10 |
+
self._offset = None
|
11 |
+
self._parents = [0] * len(self._raw_offset)
|
12 |
+
self._parents[0] = -1
|
13 |
+
for chain in self._kinematic_tree:
|
14 |
+
for j in range(1, len(chain)):
|
15 |
+
self._parents[chain[j]] = chain[j-1]
|
16 |
+
|
17 |
+
def njoints(self):
|
18 |
+
return len(self._raw_offset)
|
19 |
+
|
20 |
+
def offset(self):
|
21 |
+
return self._offset
|
22 |
+
|
23 |
+
def set_offset(self, offsets):
|
24 |
+
self._offset = offsets.clone().detach().to(self.device).float()
|
25 |
+
|
26 |
+
def kinematic_tree(self):
|
27 |
+
return self._kinematic_tree
|
28 |
+
|
29 |
+
def parents(self):
|
30 |
+
return self._parents
|
31 |
+
|
32 |
+
# joints (batch_size, joints_num, 3)
|
33 |
+
def get_offsets_joints_batch(self, joints):
|
34 |
+
assert len(joints.shape) == 3
|
35 |
+
_offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
|
36 |
+
for i in range(1, self._raw_offset.shape[0]):
|
37 |
+
_offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
|
38 |
+
|
39 |
+
self._offset = _offsets.detach()
|
40 |
+
return _offsets
|
41 |
+
|
42 |
+
# joints (joints_num, 3)
|
43 |
+
def get_offsets_joints(self, joints):
|
44 |
+
assert len(joints.shape) == 2
|
45 |
+
_offsets = self._raw_offset.clone()
|
46 |
+
for i in range(1, self._raw_offset.shape[0]):
|
47 |
+
# print(joints.shape)
|
48 |
+
_offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
|
49 |
+
|
50 |
+
self._offset = _offsets.detach()
|
51 |
+
return _offsets
|
52 |
+
|
53 |
+
# face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
|
54 |
+
# joints (batch_size, joints_num, 3)
|
55 |
+
def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
|
56 |
+
assert len(face_joint_idx) == 4
|
57 |
+
'''Get Forward Direction'''
|
58 |
+
l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
|
59 |
+
across1 = joints[:, r_hip] - joints[:, l_hip]
|
60 |
+
across2 = joints[:, sdr_r] - joints[:, sdr_l]
|
61 |
+
across = across1 + across2
|
62 |
+
across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
|
63 |
+
# print(across1.shape, across2.shape)
|
64 |
+
|
65 |
+
# forward (batch_size, 3)
|
66 |
+
forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
|
67 |
+
if smooth_forward:
|
68 |
+
forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
|
69 |
+
# forward (batch_size, 3)
|
70 |
+
forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
|
71 |
+
|
72 |
+
'''Get Root Rotation'''
|
73 |
+
target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
|
74 |
+
root_quat = qbetween_np(forward, target)
|
75 |
+
|
76 |
+
'''Inverse Kinematics'''
|
77 |
+
# quat_params (batch_size, joints_num, 4)
|
78 |
+
# print(joints.shape[:-1])
|
79 |
+
quat_params = np.zeros(joints.shape[:-1] + (4,))
|
80 |
+
# print(quat_params.shape)
|
81 |
+
root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
|
82 |
+
quat_params[:, 0] = root_quat
|
83 |
+
# quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
|
84 |
+
for chain in self._kinematic_tree:
|
85 |
+
R = root_quat
|
86 |
+
for j in range(len(chain) - 1):
|
87 |
+
# (batch, 3)
|
88 |
+
u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
|
89 |
+
# print(u.shape)
|
90 |
+
# (batch, 3)
|
91 |
+
v = joints[:, chain[j+1]] - joints[:, chain[j]]
|
92 |
+
v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
|
93 |
+
# print(u.shape, v.shape)
|
94 |
+
rot_u_v = qbetween_np(u, v)
|
95 |
+
|
96 |
+
R_loc = qmul_np(qinv_np(R), rot_u_v)
|
97 |
+
|
98 |
+
quat_params[:,chain[j + 1], :] = R_loc
|
99 |
+
R = qmul_np(R, R_loc)
|
100 |
+
|
101 |
+
return quat_params
|
102 |
+
|
103 |
+
# Be sure root joint is at the beginning of kinematic chains
|
104 |
+
def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
|
105 |
+
# quat_params (batch_size, joints_num, 4)
|
106 |
+
# joints (batch_size, joints_num, 3)
|
107 |
+
# root_pos (batch_size, 3)
|
108 |
+
if skel_joints is not None:
|
109 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
110 |
+
if len(self._offset.shape) == 2:
|
111 |
+
offsets = self._offset.expand(quat_params.shape[0], -1, -1)
|
112 |
+
joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
|
113 |
+
joints[:, 0] = root_pos
|
114 |
+
for chain in self._kinematic_tree:
|
115 |
+
if do_root_R:
|
116 |
+
R = quat_params[:, 0]
|
117 |
+
else:
|
118 |
+
R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
|
119 |
+
for i in range(1, len(chain)):
|
120 |
+
R = qmul(R, quat_params[:, chain[i]])
|
121 |
+
offset_vec = offsets[:, chain[i]]
|
122 |
+
joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
|
123 |
+
return joints
|
124 |
+
|
125 |
+
# Be sure root joint is at the beginning of kinematic chains
|
126 |
+
def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
|
127 |
+
# quat_params (batch_size, joints_num, 4)
|
128 |
+
# joints (batch_size, joints_num, 3)
|
129 |
+
# root_pos (batch_size, 3)
|
130 |
+
if skel_joints is not None:
|
131 |
+
skel_joints = torch.from_numpy(skel_joints)
|
132 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
133 |
+
if len(self._offset.shape) == 2:
|
134 |
+
offsets = self._offset.expand(quat_params.shape[0], -1, -1)
|
135 |
+
offsets = offsets.numpy()
|
136 |
+
joints = np.zeros(quat_params.shape[:-1] + (3,))
|
137 |
+
joints[:, 0] = root_pos
|
138 |
+
for chain in self._kinematic_tree:
|
139 |
+
if do_root_R:
|
140 |
+
R = quat_params[:, 0]
|
141 |
+
else:
|
142 |
+
R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
|
143 |
+
for i in range(1, len(chain)):
|
144 |
+
R = qmul_np(R, quat_params[:, chain[i]])
|
145 |
+
offset_vec = offsets[:, chain[i]]
|
146 |
+
joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
|
147 |
+
return joints
|
148 |
+
|
149 |
+
def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
|
150 |
+
# cont6d_params (batch_size, joints_num, 6)
|
151 |
+
# joints (batch_size, joints_num, 3)
|
152 |
+
# root_pos (batch_size, 3)
|
153 |
+
if skel_joints is not None:
|
154 |
+
skel_joints = torch.from_numpy(skel_joints)
|
155 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
156 |
+
if len(self._offset.shape) == 2:
|
157 |
+
offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
|
158 |
+
offsets = offsets.numpy()
|
159 |
+
joints = np.zeros(cont6d_params.shape[:-1] + (3,))
|
160 |
+
joints[:, 0] = root_pos
|
161 |
+
for chain in self._kinematic_tree:
|
162 |
+
if do_root_R:
|
163 |
+
matR = cont6d_to_matrix_np(cont6d_params[:, 0])
|
164 |
+
else:
|
165 |
+
matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
|
166 |
+
for i in range(1, len(chain)):
|
167 |
+
matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
|
168 |
+
offset_vec = offsets[:, chain[i]][..., np.newaxis]
|
169 |
+
# print(matR.shape, offset_vec.shape)
|
170 |
+
joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
|
171 |
+
return joints
|
172 |
+
|
173 |
+
def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
|
174 |
+
# cont6d_params (batch_size, joints_num, 6)
|
175 |
+
# joints (batch_size, joints_num, 3)
|
176 |
+
# root_pos (batch_size, 3)
|
177 |
+
if skel_joints is not None:
|
178 |
+
# skel_joints = torch.from_numpy(skel_joints)
|
179 |
+
offsets = self.get_offsets_joints_batch(skel_joints)
|
180 |
+
if len(self._offset.shape) == 2:
|
181 |
+
offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
|
182 |
+
joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
|
183 |
+
joints[..., 0, :] = root_pos
|
184 |
+
for chain in self._kinematic_tree:
|
185 |
+
if do_root_R:
|
186 |
+
matR = cont6d_to_matrix(cont6d_params[:, 0])
|
187 |
+
else:
|
188 |
+
matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
|
189 |
+
for i in range(1, len(chain)):
|
190 |
+
matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
|
191 |
+
offset_vec = offsets[:, chain[i]].unsqueeze(-1)
|
192 |
+
# print(matR.shape, offset_vec.shape)
|
193 |
+
joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
|
194 |
+
return joints
|
195 |
+
|
196 |
+
|
197 |
+
|
198 |
+
|
199 |
+
|
utils/smpl.py
ADDED
@@ -0,0 +1,343 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
from utils.transforms import *
|
7 |
+
|
8 |
+
import pickle
|
9 |
+
from typing import Optional
|
10 |
+
# import smplx
|
11 |
+
# from smplx.lbs import vertices2joints
|
12 |
+
import os
|
13 |
+
# from smplx import SMPL as _SMPL
|
14 |
+
# from smplx.body_models import ModelOutput
|
15 |
+
|
16 |
+
smpl_joints = [
|
17 |
+
"root", # 0
|
18 |
+
"lhip", # 1
|
19 |
+
"rhip", # 2
|
20 |
+
"belly", # 3
|
21 |
+
"lknee", # 4
|
22 |
+
"rknee", # 5
|
23 |
+
"spine", # 6
|
24 |
+
"lankle",# 7
|
25 |
+
"rankle",# 8
|
26 |
+
"chest", # 9
|
27 |
+
"ltoes", # 10
|
28 |
+
"rtoes", # 11
|
29 |
+
"neck", # 12
|
30 |
+
"linshoulder", # 13
|
31 |
+
"rinshoulder", # 14
|
32 |
+
"head", # 15
|
33 |
+
"lshoulder", # 16
|
34 |
+
"rshoulder", # 17
|
35 |
+
"lelbow", # 18
|
36 |
+
"relbow", # 19
|
37 |
+
"lwrist", # 20
|
38 |
+
"rwrist", # 21
|
39 |
+
# "lhand", # 22
|
40 |
+
# "rhand", # 23
|
41 |
+
]
|
42 |
+
|
43 |
+
smpl_parents = [
|
44 |
+
-1,
|
45 |
+
0,
|
46 |
+
0,
|
47 |
+
0,
|
48 |
+
1,
|
49 |
+
2,
|
50 |
+
3,
|
51 |
+
4,
|
52 |
+
5,
|
53 |
+
6,
|
54 |
+
7,
|
55 |
+
8,
|
56 |
+
9,
|
57 |
+
9,
|
58 |
+
9,
|
59 |
+
12,
|
60 |
+
13,
|
61 |
+
14,
|
62 |
+
16,
|
63 |
+
17,
|
64 |
+
18,
|
65 |
+
19,
|
66 |
+
# 20,
|
67 |
+
# 21,
|
68 |
+
]
|
69 |
+
|
70 |
+
smpl_offsets = [
|
71 |
+
[0.0, 0.0, 0.0],
|
72 |
+
[0.05858135, -0.08228004, -0.01766408],
|
73 |
+
[-0.06030973, -0.09051332, -0.01354254],
|
74 |
+
[0.00443945, 0.12440352, -0.03838522],
|
75 |
+
[0.04345142, -0.38646945, 0.008037],
|
76 |
+
[-0.04325663, -0.38368791, -0.00484304],
|
77 |
+
[0.00448844, 0.1379564, 0.02682033],
|
78 |
+
[-0.01479032, -0.42687458, -0.037428],
|
79 |
+
[0.01905555, -0.4200455, -0.03456167],
|
80 |
+
[-0.00226458, 0.05603239, 0.00285505],
|
81 |
+
[0.04105436, -0.06028581, 0.12204243],
|
82 |
+
[-0.03483987, -0.06210566, 0.13032329],
|
83 |
+
[-0.0133902, 0.21163553, -0.03346758],
|
84 |
+
[0.07170245, 0.11399969, -0.01889817],
|
85 |
+
[-0.08295366, 0.11247234, -0.02370739],
|
86 |
+
[0.01011321, 0.08893734, 0.05040987],
|
87 |
+
[0.12292141, 0.04520509, -0.019046],
|
88 |
+
[-0.11322832, 0.04685326, -0.00847207],
|
89 |
+
[0.2553319, -0.01564902, -0.02294649],
|
90 |
+
[-0.26012748, -0.01436928, -0.03126873],
|
91 |
+
[0.26570925, 0.01269811, -0.00737473],
|
92 |
+
[-0.26910836, 0.00679372, -0.00602676],
|
93 |
+
# [0.08669055, -0.01063603, -0.01559429],
|
94 |
+
# [-0.0887537, -0.00865157, -0.01010708],
|
95 |
+
]
|
96 |
+
|
97 |
+
|
98 |
+
def set_line_data_3d(line, x):
|
99 |
+
line.set_data(x[:, :2].T)
|
100 |
+
line.set_3d_properties(x[:, 2])
|
101 |
+
|
102 |
+
|
103 |
+
def set_scatter_data_3d(scat, x, c):
|
104 |
+
scat.set_offsets(x[:, :2])
|
105 |
+
scat.set_3d_properties(x[:, 2], "z")
|
106 |
+
scat.set_facecolors([c])
|
107 |
+
|
108 |
+
|
109 |
+
def get_axrange(poses):
|
110 |
+
pose = poses[0]
|
111 |
+
x_min = pose[:, 0].min()
|
112 |
+
x_max = pose[:, 0].max()
|
113 |
+
|
114 |
+
y_min = pose[:, 1].min()
|
115 |
+
y_max = pose[:, 1].max()
|
116 |
+
|
117 |
+
z_min = pose[:, 2].min()
|
118 |
+
z_max = pose[:, 2].max()
|
119 |
+
|
120 |
+
xdiff = x_max - x_min
|
121 |
+
ydiff = y_max - y_min
|
122 |
+
zdiff = z_max - z_min
|
123 |
+
|
124 |
+
biggestdiff = max([xdiff, ydiff, zdiff])
|
125 |
+
return biggestdiff
|
126 |
+
|
127 |
+
|
128 |
+
def plot_single_pose(num, poses, lines, ax, axrange, scat, contact):
|
129 |
+
pose = poses[num]
|
130 |
+
static = contact[num]
|
131 |
+
indices = [7, 8, 10, 11]
|
132 |
+
|
133 |
+
for i, (point, idx) in enumerate(zip(scat, indices)):
|
134 |
+
position = pose[idx : idx + 1]
|
135 |
+
color = "r" if static[i] else "g"
|
136 |
+
set_scatter_data_3d(point, position, color)
|
137 |
+
|
138 |
+
for i, (p, line) in enumerate(zip(smpl_parents, lines)):
|
139 |
+
# don't plot root
|
140 |
+
if i == 0:
|
141 |
+
continue
|
142 |
+
# stack to create a line
|
143 |
+
data = np.stack((pose[i], pose[p]), axis=0)
|
144 |
+
set_line_data_3d(line, data)
|
145 |
+
|
146 |
+
if num == 0:
|
147 |
+
if isinstance(axrange, int):
|
148 |
+
axrange = (axrange, axrange, axrange)
|
149 |
+
xcenter, ycenter, zcenter = 0, 0, 2.5
|
150 |
+
stepx, stepy, stepz = axrange[0] / 2, axrange[1] / 2, axrange[2] / 2
|
151 |
+
|
152 |
+
x_min, x_max = xcenter - stepx, xcenter + stepx
|
153 |
+
y_min, y_max = ycenter - stepy, ycenter + stepy
|
154 |
+
z_min, z_max = zcenter - stepz, zcenter + stepz
|
155 |
+
|
156 |
+
ax.set_xlim(x_min, x_max)
|
157 |
+
ax.set_ylim(y_min, y_max)
|
158 |
+
ax.set_zlim(z_min, z_max)
|
159 |
+
|
160 |
+
|
161 |
+
class SMPLSkeleton:
|
162 |
+
def __init__(
|
163 |
+
self, device=None,
|
164 |
+
):
|
165 |
+
offsets = smpl_offsets
|
166 |
+
parents = smpl_parents
|
167 |
+
assert len(offsets) == len(parents)
|
168 |
+
|
169 |
+
self._offsets = torch.Tensor(offsets).to(device)
|
170 |
+
self._parents = np.array(parents)
|
171 |
+
self._compute_metadata()
|
172 |
+
|
173 |
+
def _compute_metadata(self):
|
174 |
+
self._has_children = np.zeros(len(self._parents)).astype(bool)
|
175 |
+
for i, parent in enumerate(self._parents):
|
176 |
+
if parent != -1:
|
177 |
+
self._has_children[parent] = True
|
178 |
+
|
179 |
+
self._children = []
|
180 |
+
for i, parent in enumerate(self._parents):
|
181 |
+
self._children.append([])
|
182 |
+
for i, parent in enumerate(self._parents):
|
183 |
+
if parent != -1:
|
184 |
+
self._children[parent].append(i)
|
185 |
+
|
186 |
+
def forward(self, rotations, root_positions):
|
187 |
+
"""
|
188 |
+
Perform forward kinematics using the given trajectory and local rotations.
|
189 |
+
Arguments (where N = batch size, L = sequence length, J = number of joints):
|
190 |
+
-- rotations: (N, L, J, 3) tensor of axis-angle rotations describing the local rotations of each joint.
|
191 |
+
-- root_positions: (N, L, 3) tensor describing the root joint positions.
|
192 |
+
"""
|
193 |
+
assert len(rotations.shape) == 4
|
194 |
+
assert len(root_positions.shape) == 3
|
195 |
+
# transform from axis angle to quaternion
|
196 |
+
rotations = axis_angle_to_quaternion(rotations)
|
197 |
+
|
198 |
+
positions_world = []
|
199 |
+
rotations_world = []
|
200 |
+
|
201 |
+
expanded_offsets = self._offsets.expand(
|
202 |
+
rotations.shape[0],
|
203 |
+
rotations.shape[1],
|
204 |
+
self._offsets.shape[0],
|
205 |
+
self._offsets.shape[1],
|
206 |
+
)
|
207 |
+
|
208 |
+
# Parallelize along the batch and time dimensions
|
209 |
+
for i in range(self._offsets.shape[0]):
|
210 |
+
if self._parents[i] == -1:
|
211 |
+
positions_world.append(root_positions)
|
212 |
+
rotations_world.append(rotations[:, :, 0])
|
213 |
+
else:
|
214 |
+
positions_world.append(
|
215 |
+
quaternion_apply(
|
216 |
+
rotations_world[self._parents[i]], expanded_offsets[:, :, i]
|
217 |
+
)
|
218 |
+
+ positions_world[self._parents[i]]
|
219 |
+
)
|
220 |
+
if self._has_children[i]:
|
221 |
+
rotations_world.append(
|
222 |
+
quaternion_multiply(
|
223 |
+
rotations_world[self._parents[i]], rotations[:, :, i]
|
224 |
+
)
|
225 |
+
)
|
226 |
+
else:
|
227 |
+
# This joint is a terminal node -> it would be useless to compute the transformation
|
228 |
+
rotations_world.append(None)
|
229 |
+
|
230 |
+
return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2)
|
231 |
+
|
232 |
+
|
233 |
+
# class SMPL_old(smplx.SMPLLayer):
|
234 |
+
# def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs):
|
235 |
+
# """
|
236 |
+
# Extension of the official SMPL implementation to support more joints.
|
237 |
+
# Args:
|
238 |
+
# Same as SMPLLayer.
|
239 |
+
# joint_regressor_extra (str): Path to extra joint regressor.
|
240 |
+
# """
|
241 |
+
# super(SMPL, self).__init__(*args, **kwargs)
|
242 |
+
# smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4,
|
243 |
+
# 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
|
244 |
+
|
245 |
+
# if joint_regressor_extra is not None:
|
246 |
+
# self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
|
247 |
+
# self.register_buffer('joint_map', torch.tensor(smpl_to_openpose, dtype=torch.long))
|
248 |
+
# self.update_hips = update_hips
|
249 |
+
|
250 |
+
# def forward(self, *args, **kwargs):
|
251 |
+
# """
|
252 |
+
# Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified.
|
253 |
+
# """
|
254 |
+
# smpl_output = super(SMPL, self).forward(*args, **kwargs)
|
255 |
+
# joints = smpl_output.joints[:, self.joint_map, :]
|
256 |
+
# if self.update_hips:
|
257 |
+
# joints[:,[9,12]] = joints[:,[9,12]] + \
|
258 |
+
# 0.25*(joints[:,[9,12]]-joints[:,[12,9]]) + \
|
259 |
+
# 0.5*(joints[:,[8]] - 0.5*(joints[:,[9,12]] + joints[:,[12,9]]))
|
260 |
+
# if hasattr(self, 'joint_regressor_extra'):
|
261 |
+
# extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices)
|
262 |
+
# joints = torch.cat([joints, extra_joints], dim=1)
|
263 |
+
# smpl_output.joints = joints
|
264 |
+
# return smpl_output
|
265 |
+
|
266 |
+
# Map joints to SMPL joints
|
267 |
+
JOINT_MAP = {
|
268 |
+
'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
|
269 |
+
'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
|
270 |
+
'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
|
271 |
+
'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
|
272 |
+
'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
|
273 |
+
'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
|
274 |
+
'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
|
275 |
+
'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
|
276 |
+
'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
|
277 |
+
'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
|
278 |
+
'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
|
279 |
+
'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
|
280 |
+
'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
|
281 |
+
'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
|
282 |
+
'Spine (H36M)': 51, 'Jaw (H36M)': 52,
|
283 |
+
'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
|
284 |
+
'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
|
285 |
+
}
|
286 |
+
JOINT_NAMES = [
|
287 |
+
'OP Nose', 'OP Neck', 'OP RShoulder',
|
288 |
+
'OP RElbow', 'OP RWrist', 'OP LShoulder',
|
289 |
+
'OP LElbow', 'OP LWrist', 'OP MidHip',
|
290 |
+
'OP RHip', 'OP RKnee', 'OP RAnkle',
|
291 |
+
'OP LHip', 'OP LKnee', 'OP LAnkle',
|
292 |
+
'OP REye', 'OP LEye', 'OP REar',
|
293 |
+
'OP LEar', 'OP LBigToe', 'OP LSmallToe',
|
294 |
+
'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel',
|
295 |
+
'Right Ankle', 'Right Knee', 'Right Hip',
|
296 |
+
'Left Hip', 'Left Knee', 'Left Ankle',
|
297 |
+
'Right Wrist', 'Right Elbow', 'Right Shoulder',
|
298 |
+
'Left Shoulder', 'Left Elbow', 'Left Wrist',
|
299 |
+
'Neck (LSP)', 'Top of Head (LSP)',
|
300 |
+
'Pelvis (MPII)', 'Thorax (MPII)',
|
301 |
+
'Spine (H36M)', 'Jaw (H36M)',
|
302 |
+
'Head (H36M)', 'Nose', 'Left Eye',
|
303 |
+
'Right Eye', 'Left Ear', 'Right Ear'
|
304 |
+
]
|
305 |
+
BASE_DATA_DIR = "/data2/TSMC_data/base_data"
|
306 |
+
JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}
|
307 |
+
JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(BASE_DATA_DIR, 'J_regressor_extra.npy')
|
308 |
+
SMPL_MEAN_PARAMS = os.path.join(BASE_DATA_DIR, 'smpl_mean_params.npz')
|
309 |
+
SMPL_MODEL_DIR = BASE_DATA_DIR
|
310 |
+
H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
|
311 |
+
H36M_TO_J14 = H36M_TO_J17[:14]
|
312 |
+
|
313 |
+
|
314 |
+
# class SMPL(_SMPL):
|
315 |
+
# """ Extension of the official SMPL implementation to support more joints """
|
316 |
+
|
317 |
+
# def __init__(self, *args, **kwargs):
|
318 |
+
# super(SMPL, self).__init__(*args, **kwargs)
|
319 |
+
# joints = [JOINT_MAP[i] for i in JOINT_NAMES]
|
320 |
+
# J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA)
|
321 |
+
# self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
|
322 |
+
# self.joint_map = torch.tensor(joints, dtype=torch.long)
|
323 |
+
|
324 |
+
|
325 |
+
# def forward(self, *args, **kwargs):
|
326 |
+
# kwargs['get_skin'] = True
|
327 |
+
# smpl_output = super(SMPL, self).forward(*args, **kwargs)
|
328 |
+
# extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
|
329 |
+
# joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
|
330 |
+
# joints = joints[:, self.joint_map, :]
|
331 |
+
# output = ModelOutput(vertices=smpl_output.vertices,
|
332 |
+
# global_orient=smpl_output.global_orient,
|
333 |
+
# body_pose=smpl_output.body_pose,
|
334 |
+
# joints=joints,
|
335 |
+
# betas=smpl_output.betas,
|
336 |
+
# full_pose=smpl_output.full_pose)
|
337 |
+
# return output
|
338 |
+
|
339 |
+
|
340 |
+
# def get_smpl_faces():
|
341 |
+
# print("Get SMPL faces")
|
342 |
+
# smpl = SMPL(SMPL_MODEL_DIR, batch_size=1, create_transl=False)
|
343 |
+
# return smpl.faces
|
utils/transforms.py
ADDED
@@ -0,0 +1,600 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# This code is based on https://github.com/Mathux/ACTOR.git
|
2 |
+
# Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
|
3 |
+
# Check PYTORCH3D_LICENCE before use
|
4 |
+
|
5 |
+
import functools
|
6 |
+
from typing import Optional
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn.functional as F
|
10 |
+
import numpy as np
|
11 |
+
|
12 |
+
"""
|
13 |
+
The transformation matrices returned from the functions in this file assume
|
14 |
+
the points on which the transformation will be applied are column vectors.
|
15 |
+
i.e. the R matrix is structured as
|
16 |
+
|
17 |
+
R = [
|
18 |
+
[Rxx, Rxy, Rxz],
|
19 |
+
[Ryx, Ryy, Ryz],
|
20 |
+
[Rzx, Rzy, Rzz],
|
21 |
+
] # (3, 3)
|
22 |
+
|
23 |
+
This matrix can be applied to column vectors by post multiplication
|
24 |
+
by the points e.g.
|
25 |
+
|
26 |
+
points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
|
27 |
+
transformed_points = R * points
|
28 |
+
|
29 |
+
To apply the same matrix to points which are row vectors, the R matrix
|
30 |
+
can be transposed and pre multiplied by the points:
|
31 |
+
|
32 |
+
e.g.
|
33 |
+
points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
|
34 |
+
transformed_points = points * R.transpose(1, 0)
|
35 |
+
"""
|
36 |
+
|
37 |
+
|
38 |
+
def quaternion_to_matrix(quaternions):
|
39 |
+
"""
|
40 |
+
Convert rotations given as quaternions to rotation matrices.
|
41 |
+
|
42 |
+
Args:
|
43 |
+
quaternions: quaternions with real part first,
|
44 |
+
as tensor of shape (..., 4).
|
45 |
+
|
46 |
+
Returns:
|
47 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
48 |
+
"""
|
49 |
+
r, i, j, k = torch.unbind(quaternions, -1)
|
50 |
+
two_s = 2.0 / (quaternions * quaternions).sum(-1)
|
51 |
+
|
52 |
+
o = torch.stack(
|
53 |
+
(
|
54 |
+
1 - two_s * (j * j + k * k),
|
55 |
+
two_s * (i * j - k * r),
|
56 |
+
two_s * (i * k + j * r),
|
57 |
+
two_s * (i * j + k * r),
|
58 |
+
1 - two_s * (i * i + k * k),
|
59 |
+
two_s * (j * k - i * r),
|
60 |
+
two_s * (i * k - j * r),
|
61 |
+
two_s * (j * k + i * r),
|
62 |
+
1 - two_s * (i * i + j * j),
|
63 |
+
),
|
64 |
+
-1,
|
65 |
+
)
|
66 |
+
return o.reshape(quaternions.shape[:-1] + (3, 3))
|
67 |
+
|
68 |
+
|
69 |
+
def _copysign(a, b):
|
70 |
+
"""
|
71 |
+
Return a tensor where each element has the absolute value taken from the,
|
72 |
+
corresponding element of a, with sign taken from the corresponding
|
73 |
+
element of b. This is like the standard copysign floating-point operation,
|
74 |
+
but is not careful about negative 0 and NaN.
|
75 |
+
|
76 |
+
Args:
|
77 |
+
a: source tensor.
|
78 |
+
b: tensor whose signs will be used, of the same shape as a.
|
79 |
+
|
80 |
+
Returns:
|
81 |
+
Tensor of the same shape as a with the signs of b.
|
82 |
+
"""
|
83 |
+
signs_differ = (a < 0) != (b < 0)
|
84 |
+
return torch.where(signs_differ, -a, a)
|
85 |
+
|
86 |
+
|
87 |
+
def _sqrt_positive_part(x):
|
88 |
+
"""
|
89 |
+
Returns torch.sqrt(torch.max(0, x))
|
90 |
+
but with a zero subgradient where x is 0.
|
91 |
+
"""
|
92 |
+
ret = torch.zeros_like(x)
|
93 |
+
positive_mask = x > 0
|
94 |
+
ret[positive_mask] = torch.sqrt(x[positive_mask])
|
95 |
+
return ret
|
96 |
+
|
97 |
+
|
98 |
+
def matrix_to_quaternion(matrix):
|
99 |
+
"""
|
100 |
+
Convert rotations given as rotation matrices to quaternions.
|
101 |
+
|
102 |
+
Args:
|
103 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
104 |
+
|
105 |
+
Returns:
|
106 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
107 |
+
"""
|
108 |
+
# print("matrix.size:", matrix.size())
|
109 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
110 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
111 |
+
m00 = matrix[..., 0, 0]
|
112 |
+
m11 = matrix[..., 1, 1]
|
113 |
+
m22 = matrix[..., 2, 2]
|
114 |
+
o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
|
115 |
+
x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
|
116 |
+
y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
|
117 |
+
z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
|
118 |
+
o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
|
119 |
+
o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
|
120 |
+
o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
|
121 |
+
return torch.stack((o0, o1, o2, o3), -1)
|
122 |
+
|
123 |
+
|
124 |
+
def _axis_angle_rotation(axis: str, angle):
|
125 |
+
"""
|
126 |
+
Return the rotation matrices for one of the rotations about an axis
|
127 |
+
of which Euler angles describe, for each value of the angle given.
|
128 |
+
|
129 |
+
Args:
|
130 |
+
axis: Axis label "X" or "Y or "Z".
|
131 |
+
angle: any shape tensor of Euler angles in radians
|
132 |
+
|
133 |
+
Returns:
|
134 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
135 |
+
"""
|
136 |
+
|
137 |
+
cos = torch.cos(angle)
|
138 |
+
sin = torch.sin(angle)
|
139 |
+
one = torch.ones_like(angle)
|
140 |
+
zero = torch.zeros_like(angle)
|
141 |
+
|
142 |
+
if axis == "X":
|
143 |
+
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
|
144 |
+
if axis == "Y":
|
145 |
+
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
|
146 |
+
if axis == "Z":
|
147 |
+
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
|
148 |
+
|
149 |
+
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
|
150 |
+
|
151 |
+
|
152 |
+
def euler_angles_to_matrix(euler_angles, convention: str):
|
153 |
+
"""
|
154 |
+
Convert rotations given as Euler angles in radians to rotation matrices.
|
155 |
+
|
156 |
+
Args:
|
157 |
+
euler_angles: Euler angles in radians as tensor of shape (..., 3).
|
158 |
+
convention: Convention string of three uppercase letters from
|
159 |
+
{"X", "Y", and "Z"}.
|
160 |
+
|
161 |
+
Returns:
|
162 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
163 |
+
"""
|
164 |
+
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
|
165 |
+
raise ValueError("Invalid input euler angles.")
|
166 |
+
if len(convention) != 3:
|
167 |
+
raise ValueError("Convention must have 3 letters.")
|
168 |
+
if convention[1] in (convention[0], convention[2]):
|
169 |
+
raise ValueError(f"Invalid convention {convention}.")
|
170 |
+
for letter in convention:
|
171 |
+
if letter not in ("X", "Y", "Z"):
|
172 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
173 |
+
matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
|
174 |
+
return functools.reduce(torch.matmul, matrices)
|
175 |
+
|
176 |
+
|
177 |
+
def _angle_from_tan(
|
178 |
+
axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
|
179 |
+
):
|
180 |
+
"""
|
181 |
+
Extract the first or third Euler angle from the two members of
|
182 |
+
the matrix which are positive constant times its sine and cosine.
|
183 |
+
|
184 |
+
Args:
|
185 |
+
axis: Axis label "X" or "Y or "Z" for the angle we are finding.
|
186 |
+
other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
|
187 |
+
convention.
|
188 |
+
data: Rotation matrices as tensor of shape (..., 3, 3).
|
189 |
+
horizontal: Whether we are looking for the angle for the third axis,
|
190 |
+
which means the relevant entries are in the same row of the
|
191 |
+
rotation matrix. If not, they are in the same column.
|
192 |
+
tait_bryan: Whether the first and third axes in the convention differ.
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
Euler Angles in radians for each matrix in dataset as a tensor
|
196 |
+
of shape (...).
|
197 |
+
"""
|
198 |
+
|
199 |
+
i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
|
200 |
+
if horizontal:
|
201 |
+
i2, i1 = i1, i2
|
202 |
+
even = (axis + other_axis) in ["XY", "YZ", "ZX"]
|
203 |
+
if horizontal == even:
|
204 |
+
return torch.atan2(data[..., i1], data[..., i2])
|
205 |
+
if tait_bryan:
|
206 |
+
return torch.atan2(-data[..., i2], data[..., i1])
|
207 |
+
return torch.atan2(data[..., i2], -data[..., i1])
|
208 |
+
|
209 |
+
|
210 |
+
def _index_from_letter(letter: str):
|
211 |
+
if letter == "X":
|
212 |
+
return 0
|
213 |
+
if letter == "Y":
|
214 |
+
return 1
|
215 |
+
if letter == "Z":
|
216 |
+
return 2
|
217 |
+
|
218 |
+
|
219 |
+
def matrix_to_euler_angles(matrix, convention: str):
|
220 |
+
"""
|
221 |
+
Convert rotations given as rotation matrices to Euler angles in radians.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
225 |
+
convention: Convention string of three uppercase letters.
|
226 |
+
|
227 |
+
Returns:
|
228 |
+
Euler angles in radians as tensor of shape (..., 3).
|
229 |
+
"""
|
230 |
+
if len(convention) != 3:
|
231 |
+
raise ValueError("Convention must have 3 letters.")
|
232 |
+
if convention[1] in (convention[0], convention[2]):
|
233 |
+
raise ValueError(f"Invalid convention {convention}.")
|
234 |
+
for letter in convention:
|
235 |
+
if letter not in ("X", "Y", "Z"):
|
236 |
+
raise ValueError(f"Invalid letter {letter} in convention string.")
|
237 |
+
if matrix.size(-1) != 3 or matrix.size(-2) != 3:
|
238 |
+
raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
|
239 |
+
i0 = _index_from_letter(convention[0])
|
240 |
+
i2 = _index_from_letter(convention[2])
|
241 |
+
tait_bryan = i0 != i2
|
242 |
+
if tait_bryan:
|
243 |
+
central_angle = torch.asin(
|
244 |
+
matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
|
245 |
+
)
|
246 |
+
else:
|
247 |
+
central_angle = torch.acos(matrix[..., i0, i0])
|
248 |
+
|
249 |
+
o = (
|
250 |
+
_angle_from_tan(
|
251 |
+
convention[0], convention[1], matrix[..., i2], False, tait_bryan
|
252 |
+
),
|
253 |
+
central_angle,
|
254 |
+
_angle_from_tan(
|
255 |
+
convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
|
256 |
+
),
|
257 |
+
)
|
258 |
+
return torch.stack(o, -1)
|
259 |
+
|
260 |
+
|
261 |
+
def random_quaternions(
|
262 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
263 |
+
):
|
264 |
+
"""
|
265 |
+
Generate random quaternions representing rotations,
|
266 |
+
i.e. versors with nonnegative real part.
|
267 |
+
|
268 |
+
Args:
|
269 |
+
n: Number of quaternions in a batch to return.
|
270 |
+
dtype: Type to return.
|
271 |
+
device: Desired device of returned tensor. Default:
|
272 |
+
uses the current device for the default tensor type.
|
273 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
274 |
+
flag set.
|
275 |
+
|
276 |
+
Returns:
|
277 |
+
Quaternions as tensor of shape (N, 4).
|
278 |
+
"""
|
279 |
+
o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
|
280 |
+
s = (o * o).sum(1)
|
281 |
+
o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
|
282 |
+
return o
|
283 |
+
|
284 |
+
|
285 |
+
def random_rotations(
|
286 |
+
n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
287 |
+
):
|
288 |
+
"""
|
289 |
+
Generate random rotations as 3x3 rotation matrices.
|
290 |
+
|
291 |
+
Args:
|
292 |
+
n: Number of rotation matrices in a batch to return.
|
293 |
+
dtype: Type to return.
|
294 |
+
device: Device of returned tensor. Default: if None,
|
295 |
+
uses the current device for the default tensor type.
|
296 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
297 |
+
flag set.
|
298 |
+
|
299 |
+
Returns:
|
300 |
+
Rotation matrices as tensor of shape (n, 3, 3).
|
301 |
+
"""
|
302 |
+
quaternions = random_quaternions(
|
303 |
+
n, dtype=dtype, device=device, requires_grad=requires_grad
|
304 |
+
)
|
305 |
+
return quaternion_to_matrix(quaternions)
|
306 |
+
|
307 |
+
|
308 |
+
def random_rotation(
|
309 |
+
dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
|
310 |
+
):
|
311 |
+
"""
|
312 |
+
Generate a single random 3x3 rotation matrix.
|
313 |
+
|
314 |
+
Args:
|
315 |
+
dtype: Type to return
|
316 |
+
device: Device of returned tensor. Default: if None,
|
317 |
+
uses the current device for the default tensor type
|
318 |
+
requires_grad: Whether the resulting tensor should have the gradient
|
319 |
+
flag set
|
320 |
+
|
321 |
+
Returns:
|
322 |
+
Rotation matrix as tensor of shape (3, 3).
|
323 |
+
"""
|
324 |
+
return random_rotations(1, dtype, device, requires_grad)[0]
|
325 |
+
|
326 |
+
|
327 |
+
def standardize_quaternion(quaternions):
|
328 |
+
"""
|
329 |
+
Convert a unit quaternion to a standard form: one in which the real
|
330 |
+
part is non negative.
|
331 |
+
|
332 |
+
Args:
|
333 |
+
quaternions: Quaternions with real part first,
|
334 |
+
as tensor of shape (..., 4).
|
335 |
+
|
336 |
+
Returns:
|
337 |
+
Standardized quaternions as tensor of shape (..., 4).
|
338 |
+
"""
|
339 |
+
return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
|
340 |
+
|
341 |
+
|
342 |
+
def quaternion_raw_multiply(a, b):
|
343 |
+
"""
|
344 |
+
Multiply two quaternions.
|
345 |
+
Usual torch rules for broadcasting apply.
|
346 |
+
|
347 |
+
Args:
|
348 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
349 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
350 |
+
|
351 |
+
Returns:
|
352 |
+
The product of a and b, a tensor of quaternions shape (..., 4).
|
353 |
+
"""
|
354 |
+
aw, ax, ay, az = torch.unbind(a, -1)
|
355 |
+
bw, bx, by, bz = torch.unbind(b, -1)
|
356 |
+
ow = aw * bw - ax * bx - ay * by - az * bz
|
357 |
+
ox = aw * bx + ax * bw + ay * bz - az * by
|
358 |
+
oy = aw * by - ax * bz + ay * bw + az * bx
|
359 |
+
oz = aw * bz + ax * by - ay * bx + az * bw
|
360 |
+
return torch.stack((ow, ox, oy, oz), -1)
|
361 |
+
|
362 |
+
|
363 |
+
def quaternion_multiply(a, b):
|
364 |
+
"""
|
365 |
+
Multiply two quaternions representing rotations, returning the quaternion
|
366 |
+
representing their composition, i.e. the versor with nonnegative real part.
|
367 |
+
Usual torch rules for broadcasting apply.
|
368 |
+
|
369 |
+
Args:
|
370 |
+
a: Quaternions as tensor of shape (..., 4), real part first.
|
371 |
+
b: Quaternions as tensor of shape (..., 4), real part first.
|
372 |
+
|
373 |
+
Returns:
|
374 |
+
The product of a and b, a tensor of quaternions of shape (..., 4).
|
375 |
+
"""
|
376 |
+
ab = quaternion_raw_multiply(a, b)
|
377 |
+
return standardize_quaternion(ab)
|
378 |
+
|
379 |
+
|
380 |
+
def quaternion_invert(quaternion):
|
381 |
+
"""
|
382 |
+
Given a quaternion representing rotation, get the quaternion representing
|
383 |
+
its inverse.
|
384 |
+
|
385 |
+
Args:
|
386 |
+
quaternion: Quaternions as tensor of shape (..., 4), with real part
|
387 |
+
first, which must be versors (unit quaternions).
|
388 |
+
|
389 |
+
Returns:
|
390 |
+
The inverse, a tensor of quaternions of shape (..., 4).
|
391 |
+
"""
|
392 |
+
|
393 |
+
return quaternion * quaternion.new_tensor([1, -1, -1, -1])
|
394 |
+
|
395 |
+
|
396 |
+
def quaternion_apply(quaternion, point):
|
397 |
+
"""
|
398 |
+
Apply the rotation given by a quaternion to a 3D point.
|
399 |
+
Usual torch rules for broadcasting apply.
|
400 |
+
|
401 |
+
Args:
|
402 |
+
quaternion: Tensor of quaternions, real part first, of shape (..., 4).
|
403 |
+
point: Tensor of 3D points of shape (..., 3).
|
404 |
+
|
405 |
+
Returns:
|
406 |
+
Tensor of rotated points of shape (..., 3).
|
407 |
+
"""
|
408 |
+
if point.size(-1) != 3:
|
409 |
+
raise ValueError(f"Points are not in 3D, f{point.shape}.")
|
410 |
+
real_parts = point.new_zeros(point.shape[:-1] + (1,))
|
411 |
+
point_as_quaternion = torch.cat((real_parts, point), -1)
|
412 |
+
out = quaternion_raw_multiply(
|
413 |
+
quaternion_raw_multiply(quaternion, point_as_quaternion),
|
414 |
+
quaternion_invert(quaternion),
|
415 |
+
)
|
416 |
+
return out[..., 1:]
|
417 |
+
|
418 |
+
|
419 |
+
def axis_angle_to_matrix(axis_angle):
|
420 |
+
"""
|
421 |
+
Convert rotations given as axis/angle to rotation matrices.
|
422 |
+
|
423 |
+
Args:
|
424 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
425 |
+
as a tensor of shape (..., 3), where the magnitude is
|
426 |
+
the angle turned anticlockwise in radians around the
|
427 |
+
vector's direction.
|
428 |
+
|
429 |
+
Returns:
|
430 |
+
Rotation matrices as tensor of shape (..., 3, 3).
|
431 |
+
"""
|
432 |
+
return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
|
433 |
+
|
434 |
+
|
435 |
+
def matrix_to_axis_angle(matrix):
|
436 |
+
"""
|
437 |
+
Convert rotations given as rotation matrices to axis/angle.
|
438 |
+
|
439 |
+
Args:
|
440 |
+
matrix: Rotation matrices as tensor of shape (..., 3, 3).
|
441 |
+
|
442 |
+
Returns:
|
443 |
+
Rotations given as a vector in axis angle form, as a tensor
|
444 |
+
of shape (..., 3), where the magnitude is the angle
|
445 |
+
turned anticlockwise in radians around the vector's
|
446 |
+
direction.
|
447 |
+
"""
|
448 |
+
return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
|
449 |
+
|
450 |
+
|
451 |
+
def axis_angle_to_quaternion(axis_angle):
|
452 |
+
"""
|
453 |
+
Convert rotations given as axis/angle to quaternions.
|
454 |
+
|
455 |
+
Args:
|
456 |
+
axis_angle: Rotations given as a vector in axis angle form,
|
457 |
+
as a tensor of shape (..., 3), where the magnitude is
|
458 |
+
the angle turned anticlockwise in radians around the
|
459 |
+
vector's direction.
|
460 |
+
|
461 |
+
Returns:
|
462 |
+
quaternions with real part first, as tensor of shape (..., 4).
|
463 |
+
"""
|
464 |
+
angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
|
465 |
+
half_angles = 0.5 * angles
|
466 |
+
eps = 1e-6
|
467 |
+
small_angles = angles.abs() < eps
|
468 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
469 |
+
sin_half_angles_over_angles[~small_angles] = (
|
470 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
471 |
+
)
|
472 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
473 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
474 |
+
sin_half_angles_over_angles[small_angles] = (
|
475 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
476 |
+
)
|
477 |
+
quaternions = torch.cat(
|
478 |
+
[torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
|
479 |
+
)
|
480 |
+
return quaternions
|
481 |
+
|
482 |
+
|
483 |
+
def quaternion_to_axis_angle(quaternions):
|
484 |
+
"""
|
485 |
+
Convert rotations given as quaternions to axis/angle.
|
486 |
+
|
487 |
+
Args:
|
488 |
+
quaternions: quaternions with real part first,
|
489 |
+
as tensor of shape (..., 4).
|
490 |
+
|
491 |
+
Returns:
|
492 |
+
Rotations given as a vector in axis angle form, as a tensor
|
493 |
+
of shape (..., 3), where the magnitude is the angle
|
494 |
+
turned anticlockwise in radians around the vector's
|
495 |
+
direction.
|
496 |
+
"""
|
497 |
+
norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
|
498 |
+
half_angles = torch.atan2(norms, quaternions[..., :1])
|
499 |
+
angles = 2 * half_angles
|
500 |
+
eps = 1e-6
|
501 |
+
small_angles = angles.abs() < eps
|
502 |
+
sin_half_angles_over_angles = torch.empty_like(angles)
|
503 |
+
sin_half_angles_over_angles[~small_angles] = (
|
504 |
+
torch.sin(half_angles[~small_angles]) / angles[~small_angles]
|
505 |
+
)
|
506 |
+
# for x small, sin(x/2) is about x/2 - (x/2)^3/6
|
507 |
+
# so sin(x/2)/x is about 1/2 - (x*x)/48
|
508 |
+
sin_half_angles_over_angles[small_angles] = (
|
509 |
+
0.5 - (angles[small_angles] * angles[small_angles]) / 48
|
510 |
+
)
|
511 |
+
return quaternions[..., 1:] / sin_half_angles_over_angles
|
512 |
+
|
513 |
+
|
514 |
+
def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
|
515 |
+
"""
|
516 |
+
Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
|
517 |
+
using Gram--Schmidt orthogonalisation per Section B of [1].
|
518 |
+
Args:
|
519 |
+
d6: 6D rotation representation, of size (*, 6)
|
520 |
+
|
521 |
+
Returns:
|
522 |
+
batch of rotation matrices of size (*, 3, 3)
|
523 |
+
|
524 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
525 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
526 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
527 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
528 |
+
"""
|
529 |
+
|
530 |
+
a1, a2 = d6[..., :3], d6[..., 3:]
|
531 |
+
b1 = F.normalize(a1, dim=-1)
|
532 |
+
b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
|
533 |
+
b2 = F.normalize(b2, dim=-1)
|
534 |
+
b3 = torch.cross(b1, b2, dim=-1)
|
535 |
+
return torch.stack((b1, b2, b3), dim=-2)
|
536 |
+
|
537 |
+
|
538 |
+
def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
|
539 |
+
"""
|
540 |
+
Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
|
541 |
+
by dropping the last row. Note that 6D representation is not unique.
|
542 |
+
Args:
|
543 |
+
matrix: batch of rotation matrices of size (*, 3, 3)
|
544 |
+
|
545 |
+
Returns:
|
546 |
+
6D rotation representation, of size (*, 6)
|
547 |
+
|
548 |
+
[1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
|
549 |
+
On the Continuity of Rotation Representations in Neural Networks.
|
550 |
+
IEEE Conference on Computer Vision and Pattern Recognition, 2019.
|
551 |
+
Retrieved from http://arxiv.org/abs/1812.07035
|
552 |
+
"""
|
553 |
+
return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
|
554 |
+
|
555 |
+
def axis_angle_to_rotation_6d(axis_angle):
|
556 |
+
matrix = axis_angle_to_matrix(axis_angle)
|
557 |
+
return matrix_to_rotation_6d(matrix)
|
558 |
+
|
559 |
+
def rotateXYZ(mesh_v, Rxyz):
|
560 |
+
for rot in Rxyz:
|
561 |
+
angle = np.radians(rot[0])
|
562 |
+
rx = np.array([
|
563 |
+
[1., 0., 0. ],
|
564 |
+
[0., np.cos(angle), -np.sin(angle)],
|
565 |
+
[0., np.sin(angle), np.cos(angle) ]
|
566 |
+
])
|
567 |
+
|
568 |
+
angle = np.radians(rot[1])
|
569 |
+
ry = np.array([
|
570 |
+
[np.cos(angle), 0., np.sin(angle)],
|
571 |
+
[0., 1., 0. ],
|
572 |
+
[-np.sin(angle), 0., np.cos(angle)]
|
573 |
+
])
|
574 |
+
|
575 |
+
angle = np.radians(rot[2])
|
576 |
+
rz = np.array([
|
577 |
+
[np.cos(angle), -np.sin(angle), 0. ],
|
578 |
+
[np.sin(angle), np.cos(angle), 0. ],
|
579 |
+
[0., 0., 1. ]
|
580 |
+
])
|
581 |
+
# return rotateZ(rotateY(rotateX(mesh_v, Rxyz[0]), Rxyz[1]), Rxyz[2])
|
582 |
+
mesh_v = rz.dot(ry.dot(rx.dot(mesh_v.T))).T
|
583 |
+
# return rx.dot(mesh_v.T).T
|
584 |
+
return mesh_v
|
585 |
+
|
586 |
+
def rotate_trans(trans_3d, rot_body=None):
|
587 |
+
if rot_body is not None:
|
588 |
+
trans_3d = rotateXYZ(trans_3d, rot_body)
|
589 |
+
trans_3d = torch.from_numpy(trans_3d)
|
590 |
+
return trans_3d
|
591 |
+
|
592 |
+
def rotate_root(pose_6d, rot_body=None):
|
593 |
+
root_6d = pose_6d[:, :6]
|
594 |
+
root_6d = rotation_6d_to_matrix(root_6d)
|
595 |
+
if rot_body is not None:
|
596 |
+
root_6d = rotateXYZ(root_6d, rot_body)
|
597 |
+
root_6d = torch.from_numpy(root_6d)
|
598 |
+
root_6d = matrix_to_rotation_6d(root_6d)
|
599 |
+
pose_6d[:, :6] = root_6d
|
600 |
+
return pose_6d
|