EvanTHU commited on
Commit
b887ad8
1 Parent(s): 53ca020
This view is limited to 50 files because it contains too many changes.   See raw diff
1gpu.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: 'NO'
4
+ downcast_bf16: 'no'
5
+ machine_rank: 0
6
+ main_training_function: main
7
+ mixed_precision: no
8
+ num_machines: 1
9
+ num_processes: 1
10
+ rdzv_backend: static
11
+ same_network: true
12
+ tpu_env: []
13
+ tpu_use_cluster: false
14
+ tpu_use_sudo: false
15
+ use_cpu: false
16
+ main_process_port: 21000
LICENSE ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #--------------------------------------------
2
+ # MotionCLR
3
+ # Copyright (c) 2024 IDEA. All Rights Reserved.
4
+ # Licensed under the IDEA License, Version 1.0 [see LICENSE for details]
5
+ #--------------------------------------------
6
+
7
+ IDEA License 1.0
8
+
9
+ This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and the International Digital Economy Academy (“IDEA” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by IDEA under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by IDEA related to the Software (“Documentation”).
10
+
11
+ By downloading the Software or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to IDEA that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity.
12
+
13
+ 1. LICENSE GRANT
14
+
15
+ a. You are granted a non-exclusive, worldwide, transferable, sublicensable, irrevocable, royalty free and limited license under IDEA’s copyright interests to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Software solely for your non-commercial research purposes.
16
+
17
+ b. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. IDEA and its licensors reserve all rights not expressly granted by this License.
18
+
19
+ c. If you intend to use the Software Products for any commercial purposes, you must request a license from IDEA, which IDEA may grant to you in its sole discretion.
20
+
21
+ 2. REDISTRIBUTION AND USE
22
+
23
+ a. If you distribute or make the Software Products, or any derivative works thereof, available to a third party, you shall provide a copy of this Agreement to such third party.
24
+
25
+ b. You must retain in all copies of the Software Products that you distribute the following attribution notice: "MotionCLR is licensed under the IDEA License 1.0, Copyright (c) IDEA. All Rights Reserved."
26
+
27
+ d. Your use of the Software Products must comply with applicable laws and regulations (including trade compliance laws and regulations).
28
+
29
+ e. You will not, and will not permit, assist or cause any third party to use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for in any manner that infringes, misappropriates, or otherwise violates any third-party rights.
30
+
31
+ 3. DISCLAIMER OF WARRANTY
32
+
33
+ UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS.
34
+
35
+ 4. LIMITATION OF LIABILITY
36
+
37
+ IN NO EVENT WILL IDEA OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF IDEA OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING.
38
+
39
+ 5. INDEMNIFICATION
40
+
41
+ You will indemnify, defend and hold harmless IDEA and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “IDEA Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any IDEA Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the IDEA Parties of any such Claims, and cooperate with IDEA Parties in defending such Claims. You will also grant the IDEA Parties sole control of the defense or settlement, at IDEA’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and IDEA or the other IDEA Parties.
42
+
43
+ 6. TERMINATION; SURVIVAL
44
+
45
+ a. This License will automatically terminate upon any breach by you of the terms of this License.
46
+
47
+ b. If you institute litigation or other proceedings against IDEA or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted.
48
+
49
+ c. The following sections survive termination of this License: 2 (Redistribution and use), 3 (Disclaimers of Warranty), 4 (Limitation of Liability), 5 (Indemnification), 6 (Termination; Survival), 7 (Trademarks) and 8 (Applicable Law; Dispute Resolution).
50
+
51
+ 7. TRADEMARKS
52
+
53
+ Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with IDEA without the prior written permission of IDEA, except to the extent necessary to make the reference required by the attribution notice of this Agreement.
54
+
55
+ 8. APPLICABLE LAW; DISPUTE RESOLUTION
56
+
57
+ This License will be governed and construed under the laws of the People’s Republic of China without regard to conflicts of law provisions. The parties expressly agree that the United Nations Convention on Contracts for the International Sale of Goods will not apply. Any suit or proceeding arising out of or relating to this License will be brought in the courts, as applicable, in Shenzhen, Guangdong, and each party irrevocably submits to the jurisdiction and venue of such courts.
app.py ADDED
@@ -0,0 +1,540 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import spaces
2
+ import gradio as gr
3
+ import sys
4
+ import os
5
+ import torch
6
+ import numpy as np
7
+ from os.path import join as pjoin
8
+ import utils.paramUtil as paramUtil
9
+ from utils.plot_script import *
10
+ from utils.utils import *
11
+ from utils.motion_process import recover_from_ric
12
+ from accelerate.utils import set_seed
13
+ from models.gaussian_diffusion import DiffusePipeline
14
+ from options.generate_options import GenerateOptions
15
+ from utils.model_load import load_model_weights
16
+ from motion_loader import get_dataset_loader
17
+ from models import build_models
18
+ import yaml
19
+ import time
20
+ from box import Box
21
+ import hashlib
22
+ from huggingface_hub import hf_hub_download
23
+
24
+ ckptdir = './checkpoints/t2m/release'
25
+ os.makedirs(ckptdir, exist_ok=True)
26
+
27
+
28
+ mean_path = hf_hub_download(
29
+ repo_id="EvanTHU/MotionCLR",
30
+ filename="meta/mean.npy",
31
+ local_dir=ckptdir,
32
+ local_dir_use_symlinks=False
33
+ )
34
+
35
+ std_path = hf_hub_download(
36
+ repo_id="EvanTHU/MotionCLR",
37
+ filename="meta/std.npy",
38
+ local_dir=ckptdir,
39
+ local_dir_use_symlinks=False
40
+ )
41
+
42
+ model_path = hf_hub_download(
43
+ repo_id="EvanTHU/MotionCLR",
44
+ filename="model/latest.tar",
45
+ local_dir=ckptdir,
46
+ local_dir_use_symlinks=False
47
+ )
48
+
49
+ opt_path = hf_hub_download(
50
+ repo_id="EvanTHU/MotionCLR",
51
+ filename="opt.txt",
52
+ local_dir=ckptdir,
53
+ local_dir_use_symlinks=False
54
+ )
55
+
56
+
57
+
58
+ os.makedirs("tmp", exist_ok=True)
59
+ os.environ['GRADIO_TEMP_DIR'] = './tmp'
60
+
61
+ def generate_md5(input_string):
62
+ # Encode the string and compute the MD5 hash
63
+ md5_hash = hashlib.md5(input_string.encode())
64
+ # Return the hexadecimal representation of the hash
65
+ return md5_hash.hexdigest()
66
+
67
+ def set_all_use_to_false(data):
68
+ for key, value in data.items():
69
+ if isinstance(value, Box):
70
+ set_all_use_to_false(value)
71
+ elif key == 'use':
72
+ data[key] = False
73
+ return data
74
+
75
+ def yaml_to_box(yaml_file):
76
+ with open(yaml_file, 'r') as file:
77
+ yaml_data = yaml.safe_load(file)
78
+
79
+ return Box(yaml_data)
80
+
81
+ HEAD = """<div class="embed_hidden">
82
+ <h1 style='text-align: center'> MotionCLR User Interaction Demo </h1>
83
+ """
84
+
85
+ edit_config = yaml_to_box('options/edit.yaml')
86
+ os.environ['GRADIO_TEMP_DIR'] = './tmp'
87
+ CSS = """
88
+ .retrieved_video {
89
+ position: relative;
90
+ margin: 0;
91
+ box-shadow: var(--block-shadow);
92
+ border-width: var(--block-border-width);
93
+ border-color: #000000;
94
+ border-radius: var(--block-radius);
95
+ background: var(--block-background-fill);
96
+ width: 100%;
97
+ line-height: var(--line-sm);
98
+ }
99
+ .contour_video {
100
+ display: flex;
101
+ flex-direction: column;
102
+ justify-content: center;
103
+ align-items: center;
104
+ z-index: var(--layer-5);
105
+ border-radius: var(--block-radius);
106
+ background: var(--background-fill-primary);
107
+ padding: 0 var(--size-6);
108
+ max-height: var(--size-screen-h);
109
+ overflow: hidden;
110
+ }
111
+ """
112
+
113
+ def generate_video_from_text(text, opt, pipeline):
114
+ width = 500
115
+ height = 500
116
+ texts = [text]
117
+ motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
118
+
119
+ save_dir = './tmp/gen/'
120
+ filename = generate_md5(str(time.time())) + ".mp4"
121
+ save_path = pjoin(save_dir, str(filename))
122
+ os.makedirs(save_dir, exist_ok=True)
123
+
124
+ start_time = time.perf_counter()
125
+ gr.Info("Generating motion...", duration = 3)
126
+ pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
127
+ end_time = time.perf_counter()
128
+ exc = end_time - start_time
129
+ gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
130
+ start_time = time.perf_counter()
131
+ mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
132
+ std = np.load(pjoin(opt.meta_dir, 'std.npy'))
133
+
134
+
135
+ samples = []
136
+
137
+ root_list = []
138
+ for i, motion in enumerate(pred_motions):
139
+ motion = motion.cpu().numpy() * std + mean
140
+ # 1. recover 3d joints representation by ik
141
+ motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
142
+ # 2. put on Floor (Y axis)
143
+ floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
144
+ motion[:, :, 1] -= floor_height
145
+ motion = motion.numpy()
146
+ # 3. remove jitter
147
+ motion = motion_temporal_filter(motion, sigma=1)
148
+
149
+ samples.append(motion)
150
+
151
+ i = 0
152
+ title = texts[i]
153
+ motion = samples[i]
154
+ kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
155
+ plot_3d_motion(save_path, kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)
156
+
157
+ gr.Info("Rendered motion...", duration = 3)
158
+ end_time = time.perf_counter()
159
+ exc = end_time - start_time
160
+ gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
161
+
162
+ video_dis = f'<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_path}"></video>'
163
+ style_dis = video_dis + """<br> <p align="center"> Content Reference </p>"""
164
+ global edit_config
165
+ edit_config = set_all_use_to_false(edit_config)
166
+ return video_dis, video_dis, video_dis, video_dis, style_dis, video_dis, gr.update(visible=True)
167
+
168
+ def reweighting(text, idx, weight, opt, pipeline):
169
+ global edit_config
170
+ edit_config.reweighting_attn.use = True
171
+ edit_config.reweighting_attn.idx = idx
172
+ edit_config.reweighting_attn.reweighting_attn_weight = weight
173
+
174
+
175
+ gr.Info("Loading Configurations...", duration = 3)
176
+ model = build_models(opt, edit_config=edit_config)
177
+ ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')
178
+ niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)
179
+
180
+ pipeline = DiffusePipeline(
181
+ opt = opt,
182
+ model = model,
183
+ diffuser_name = opt.diffuser_name,
184
+ device=opt.device,
185
+ num_inference_steps=opt.num_inference_steps,
186
+ torch_dtype=torch.float16,
187
+ )
188
+
189
+ print(edit_config)
190
+
191
+ width = 500
192
+ height = 500
193
+ texts = [text, text]
194
+ motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
195
+
196
+ save_dir = './tmp/gen/'
197
+ filenames = [generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4"]
198
+ save_paths = [pjoin(save_dir, str(filenames[0])), pjoin(save_dir, str(filenames[1]))]
199
+ os.makedirs(save_dir, exist_ok=True)
200
+
201
+ start_time = time.perf_counter()
202
+ gr.Info("Generating motion...", duration = 3)
203
+ pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
204
+ end_time = time.perf_counter()
205
+ exc = end_time - start_time
206
+ gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
207
+ start_time = time.perf_counter()
208
+ mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
209
+ std = np.load(pjoin(opt.meta_dir, 'std.npy'))
210
+
211
+
212
+ samples = []
213
+
214
+ root_list = []
215
+ for i, motion in enumerate(pred_motions):
216
+ motion = motion.cpu().numpy() * std + mean
217
+ # 1. recover 3d joints representation by ik
218
+ motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
219
+ # 2. put on Floor (Y axis)
220
+ floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
221
+ motion[:, :, 1] -= floor_height
222
+ motion = motion.numpy()
223
+ # 3. remove jitter
224
+ motion = motion_temporal_filter(motion, sigma=1)
225
+
226
+ samples.append(motion)
227
+
228
+ i = 1
229
+ title = texts[i]
230
+ motion = samples[i]
231
+ kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
232
+ plot_3d_motion(save_paths[1], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)
233
+
234
+
235
+ gr.Info("Rendered motion...", duration = 3)
236
+ end_time = time.perf_counter()
237
+ exc = end_time - start_time
238
+ gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
239
+
240
+ video_dis = f'<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[1]}"></video>'
241
+
242
+
243
+ edit_config = set_all_use_to_false(edit_config)
244
+ return video_dis
245
+
246
+ def generate_example_based_motion(text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion, opt, pipeline):
247
+ global edit_config
248
+ edit_config.example_based.use = True
249
+ edit_config.example_based.chunk_size = chunk_size
250
+ edit_config.example_based.example_based_steps_end = example_based_steps_end
251
+ edit_config.example_based.temp_seed = temp_seed
252
+ edit_config.example_based.temp_seed_bar = temp_seed_bar
253
+
254
+
255
+ gr.Info("Loading Configurations...", duration = 3)
256
+ model = build_models(opt, edit_config=edit_config)
257
+ ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')
258
+ niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)
259
+
260
+ pipeline = DiffusePipeline(
261
+ opt = opt,
262
+ model = model,
263
+ diffuser_name = opt.diffuser_name,
264
+ device=opt.device,
265
+ num_inference_steps=opt.num_inference_steps,
266
+ torch_dtype=torch.float16,
267
+ )
268
+
269
+ width = 500
270
+ height = 500
271
+ texts = [text for _ in range(num_motion)]
272
+ motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
273
+
274
+ save_dir = './tmp/gen/'
275
+ filenames = [generate_md5(str(time.time())) + ".mp4" for _ in range(num_motion)]
276
+ save_paths = [pjoin(save_dir, str(filenames[i])) for i in range(num_motion)]
277
+ os.makedirs(save_dir, exist_ok=True)
278
+
279
+ start_time = time.perf_counter()
280
+ gr.Info("Generating motion...", duration = 3)
281
+ pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
282
+ end_time = time.perf_counter()
283
+ exc = end_time - start_time
284
+ gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
285
+ start_time = time.perf_counter()
286
+ mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
287
+ std = np.load(pjoin(opt.meta_dir, 'std.npy'))
288
+
289
+
290
+ samples = []
291
+
292
+ root_list = []
293
+ progress=gr.Progress()
294
+ progress(0, desc="Starting...")
295
+ for i, motion in enumerate(pred_motions):
296
+ motion = motion.cpu().numpy() * std + mean
297
+ # 1. recover 3d joints representation by ik
298
+ motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
299
+ # 2. put on Floor (Y axis)
300
+ floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
301
+ motion[:, :, 1] -= floor_height
302
+ motion = motion.numpy()
303
+ # 3. remove jitter
304
+ motion = motion_temporal_filter(motion, sigma=1)
305
+
306
+ samples.append(motion)
307
+
308
+ video_dis = []
309
+ i = 0
310
+ for title in progress.tqdm(texts):
311
+ print(save_paths[i])
312
+ title = texts[i]
313
+ motion = samples[i]
314
+ kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
315
+ plot_3d_motion(save_paths[i], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)
316
+ video_html = f'''
317
+ <video class="retrieved_video" width="{width}" height="{height}" preload="auto" muted playsinline onpause="this.load()" autoplay loop disablepictureinpicture src="./file={save_paths[i]}"> </video>
318
+ '''
319
+ video_dis.append(video_html)
320
+ i += 1
321
+
322
+ for _ in range(24 - num_motion):
323
+ video_dis.append(None)
324
+ gr.Info("Rendered motion...", duration = 3)
325
+ end_time = time.perf_counter()
326
+ exc = end_time - start_time
327
+ gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
328
+
329
+ edit_config = set_all_use_to_false(edit_config)
330
+ return video_dis
331
+
332
+ def transfer_style(text, style_text, style_transfer_steps_end, opt, pipeline):
333
+ global edit_config
334
+ edit_config.style_tranfer.use = True
335
+ edit_config.style_tranfer.style_transfer_steps_end = style_transfer_steps_end
336
+
337
+ gr.Info("Loading Configurations...", duration = 3)
338
+ model = build_models(opt, edit_config=edit_config)
339
+ ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')
340
+ niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)
341
+
342
+ pipeline = DiffusePipeline(
343
+ opt = opt,
344
+ model = model,
345
+ diffuser_name = opt.diffuser_name,
346
+ device=opt.device,
347
+ num_inference_steps=opt.num_inference_steps,
348
+ torch_dtype=torch.float16,
349
+ )
350
+
351
+ print(edit_config)
352
+
353
+ width = 500
354
+ height = 500
355
+ texts = [style_text, text, text]
356
+ motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
357
+
358
+ save_dir = './tmp/gen/'
359
+ filenames = [generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4"]
360
+ save_paths = [pjoin(save_dir, str(filenames[0])), pjoin(save_dir, str(filenames[1])), pjoin(save_dir, str(filenames[2]))]
361
+ os.makedirs(save_dir, exist_ok=True)
362
+
363
+ start_time = time.perf_counter()
364
+ gr.Info("Generating motion...", duration = 3)
365
+ pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens]))
366
+ end_time = time.perf_counter()
367
+ exc = end_time - start_time
368
+ gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3)
369
+ start_time = time.perf_counter()
370
+ mean = np.load(pjoin(opt.meta_dir, 'mean.npy'))
371
+ std = np.load(pjoin(opt.meta_dir, 'std.npy'))
372
+
373
+ samples = []
374
+
375
+ root_list = []
376
+ for i, motion in enumerate(pred_motions):
377
+ motion = motion.cpu().numpy() * std + mean
378
+ # 1. recover 3d joints representation by ik
379
+ motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
380
+ # 2. put on Floor (Y axis)
381
+ floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
382
+ motion[:, :, 1] -= floor_height
383
+ motion = motion.numpy()
384
+ # 3. remove jitter
385
+ motion = motion_temporal_filter(motion, sigma=1)
386
+
387
+ samples.append(motion)
388
+
389
+ for i,title in enumerate(texts):
390
+ title = texts[i]
391
+ motion = samples[i]
392
+ kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain
393
+ plot_3d_motion(save_paths[i], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius)
394
+
395
+ gr.Info("Rendered motion...", duration = 3)
396
+ end_time = time.perf_counter()
397
+ exc = end_time - start_time
398
+ gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3)
399
+
400
+ video_dis0 = f"""<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[0]}"></video> <br> <p align="center"> Style Reference </p>"""
401
+ video_dis1 = f"""<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[2]}"></video> <br> <p align="center"> Content Reference </p>"""
402
+ video_dis2 = f"""<video controls playsinline width="{width}" style="display: block; margin: 0 auto;" src="./file={save_paths[1]}"></video> <br> <p align="center"> Transfered Result </p>"""
403
+
404
+ edit_config = set_all_use_to_false(edit_config)
405
+ return video_dis0, video_dis2
406
+
407
+
408
+ @spaces.GPU
409
+ def main():
410
+ parser = GenerateOptions()
411
+ opt = parser.parse_app()
412
+ set_seed(opt.seed)
413
+ device_id = opt.gpu_id
414
+ device = torch.device('cuda:%d' % device_id if torch.cuda.is_available() else 'cpu')
415
+ opt.device = device
416
+
417
+
418
+ # load model
419
+ model = build_models(opt, edit_config=edit_config)
420
+ ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar')
421
+ niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)
422
+
423
+ pipeline = DiffusePipeline(
424
+ opt = opt,
425
+ model = model,
426
+ diffuser_name = opt.diffuser_name,
427
+ device=device,
428
+ num_inference_steps=opt.num_inference_steps,
429
+ torch_dtype=torch.float16,
430
+ )
431
+
432
+ with gr.Blocks() as demo:
433
+ gr.Markdown(HEAD)
434
+ with gr.Row():
435
+ with gr.Column(scale=7):
436
+ text_input = gr.Textbox(label="Input the text prompt to generate motion...")
437
+ with gr.Column(scale=3):
438
+ sequence_length = gr.Slider(minimum=1, maximum=9.6, step=0.1, label="Motion length", value=8)
439
+ with gr.Row():
440
+ generate_button = gr.Button("Generate motion")
441
+
442
+ with gr.Row():
443
+ video_display = gr.HTML(label="生成的视频", visible=True)
444
+
445
+
446
+ tabs = gr.Tabs(visible=True)
447
+ with tabs:
448
+ with gr.Tab("Motion (de-)emphasizing"):
449
+ with gr.Row():
450
+ int_input = gr.Number(label="Editing word index", minimum=0, maximum=70)
451
+ weight_input = gr.Slider(minimum=-1, maximum=1, step=0.01, label="Input weight for (de-)emphasizing [-1, 1]", value=0)
452
+
453
+ trim_button = gr.Button("Edit reweighting")
454
+
455
+ with gr.Row():
456
+ original_video1 = gr.HTML(label="before editing", visible=False)
457
+ edited_video = gr.HTML(label="after editing")
458
+
459
+ trim_button.click(
460
+ fn=lambda x, int_input, weight_input : reweighting(x, int_input, weight_input, opt, pipeline),
461
+ inputs=[text_input, int_input, weight_input],
462
+ outputs=edited_video,
463
+ )
464
+
465
+ with gr.Tab("Example-based motion genration"):
466
+ with gr.Row():
467
+ with gr.Column(scale=4):
468
+ chunk_size = gr.Number(minimum=10, maximum=20, step=10,label="Chunk size (#frames)", value=20)
469
+ example_based_steps_end = gr.Number(minimum=0, maximum=9,label="Ending step of manipulation", value=6)
470
+ with gr.Column(scale=3):
471
+ temp_seed = gr.Number(label="Seed for random", value=200, minimum=0)
472
+ temp_seed_bar = gr.Slider(minimum=0, maximum=100, step=1, label="Seed for random bar", value=15)
473
+ with gr.Column(scale=3):
474
+ num_motion = gr.Radio(choices=[4, 8, 12, 16, 24], value=8, label="Select number of motions")
475
+
476
+ gen_button = gr.Button("Generate example-based motion")
477
+
478
+
479
+ example_video_display = []
480
+ for _ in range(6):
481
+ with gr.Row():
482
+ for _ in range(4):
483
+ video = gr.HTML(label="Example-based motion", visible=True)
484
+ example_video_display.append(video)
485
+
486
+ gen_button.click(
487
+ fn=lambda text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion: generate_example_based_motion(text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion, opt, pipeline),
488
+ inputs=[text_input, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion],
489
+ outputs=example_video_display
490
+ )
491
+
492
+ with gr.Tab("Style transfer"):
493
+ with gr.Row():
494
+ style_text = gr.Textbox(label="Reference prompt (e.g. 'a man walks.')", value="a man walks.")
495
+ style_transfer_steps_end = gr.Number(label="The end step of diffusion (0~9)", minimum=0, maximum=9, value=5)
496
+
497
+ style_transfer_button = gr.Button("Transfer style")
498
+
499
+ with gr.Row():
500
+ style_reference = gr.HTML(label="style reference")
501
+ original_video4 = gr.HTML(label="before style transfer", visible=False)
502
+ styled_video = gr.HTML(label="after style transfer")
503
+
504
+ style_transfer_button.click(
505
+ fn=lambda text, style_text, style_transfer_steps_end: transfer_style(text, style_text, style_transfer_steps_end, opt, pipeline),
506
+ inputs=[text_input, style_text, style_transfer_steps_end],
507
+ outputs=[style_reference, styled_video],
508
+ )
509
+
510
+ def update_motion_length(sequence_length):
511
+ opt.motion_length = sequence_length
512
+
513
+ def on_generate(text, length, pipeline):
514
+ update_motion_length(length)
515
+ return generate_video_from_text(text, opt, pipeline)
516
+
517
+
518
+ generate_button.click(
519
+ fn=lambda text, length: on_generate(text, length, pipeline),
520
+ inputs=[text_input, sequence_length],
521
+ outputs=[
522
+ video_display,
523
+ original_video1,
524
+ original_video4,
525
+ tabs,
526
+ ],
527
+ show_progress=True
528
+ )
529
+
530
+ generate_button.click(
531
+ fn=lambda: [gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)],
532
+ inputs=None,
533
+ outputs=[video_display, original_video1, original_video4]
534
+ )
535
+
536
+ demo.launch()
537
+
538
+
539
+ if __name__ == '__main__':
540
+ main()
assets/motion_lens.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ 160
assets/prompts.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ a man jumps.
config/diffuser_params.yaml ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dpmsolver:
2
+ scheduler_class: DPMSolverMultistepScheduler
3
+ additional_params:
4
+ algorithm_type: sde-dpmsolver++
5
+ use_karras_sigmas: true
6
+
7
+ ddpm:
8
+ scheduler_class: DDPMScheduler
9
+ additional_params:
10
+ variance_type: fixed_small
11
+ clip_sample: false
12
+
13
+ ddim:
14
+ scheduler_class: DDIMScheduler
15
+ additional_params:
16
+ clip_sample: false
17
+
18
+ deis:
19
+ scheduler_class: DEISMultistepScheduler
20
+ additional_params:
21
+ num_train_timesteps: 1000
22
+
23
+ pndm:
24
+ scheduler_class: PNDMScheduler
25
+ additional_params:
26
+ num_train_timesteps: 1000
config/evaluator.yaml ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ unit_length: 4
2
+ max_text_len: 20
3
+ text_enc_mod: bigru
4
+ estimator_mod: bigru
5
+ dim_text_hidden: 512
6
+ dim_att_vec: 512
7
+ dim_z: 128
8
+ dim_movement_enc_hidden: 512
9
+ dim_movement_dec_hidden: 512
10
+ dim_movement_latent: 512
11
+ dim_word: 300
12
+ dim_pos_ohot: 15
13
+ dim_motion_hidden: 1024
14
+ dim_coemb_hidden: 512
datasets/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ from .t2m_dataset import HumanML3D,KIT
3
+
4
+ from os.path import join as pjoin
5
+ __all__ = [
6
+ 'HumanML3D', 'KIT', 'get_dataset',]
7
+
8
+ def get_dataset(opt, split='train', mode='train', accelerator=None):
9
+ if opt.dataset_name == 't2m' :
10
+ dataset = HumanML3D(opt, split, mode, accelerator)
11
+ elif opt.dataset_name == 'kit' :
12
+ dataset = KIT(opt,split, mode, accelerator)
13
+ else:
14
+ raise KeyError('Dataset Does Not Exist')
15
+
16
+ if accelerator:
17
+ accelerator.print('Completing loading %s dataset' % opt.dataset_name)
18
+ else:
19
+ print('Completing loading %s dataset' % opt.dataset_name)
20
+
21
+ return dataset
22
+
datasets/t2m_dataset.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils import data
3
+ import numpy as np
4
+ from os.path import join as pjoin
5
+ import random
6
+ import codecs as cs
7
+ from tqdm.auto import tqdm
8
+ from utils.word_vectorizer import WordVectorizer, POS_enumerator
9
+ from utils.motion_process import recover_from_ric
10
+
11
+
12
+ class Text2MotionDataset(data.Dataset):
13
+ """
14
+ Dataset for Text2Motion generation task.
15
+ """
16
+
17
+ data_root = ""
18
+ min_motion_len = 40
19
+ joints_num = None
20
+ dim_pose = None
21
+ max_motion_length = 196
22
+
23
+ def __init__(self, opt, split, mode="train", accelerator=None):
24
+ self.max_text_len = getattr(opt, "max_text_len", 20)
25
+ self.unit_length = getattr(opt, "unit_length", 4)
26
+ self.mode = mode
27
+ motion_dir = pjoin(self.data_root, "new_joint_vecs")
28
+ text_dir = pjoin(self.data_root, "texts")
29
+
30
+ if mode not in ["train", "eval", "gt_eval", "xyz_gt", "hml_gt"]:
31
+ raise ValueError(
32
+ f"Mode '{mode}' is not supported. Please use one of: 'train', 'eval', 'gt_eval', 'xyz_gt','hml_gt'."
33
+ )
34
+
35
+ mean, std = None, None
36
+ if mode == "gt_eval":
37
+ print(pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_std.npy"))
38
+ # used by T2M models (including evaluators)
39
+ mean = np.load(pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_mean.npy"))
40
+ std = np.load(pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_std.npy"))
41
+ elif mode in ["eval"]:
42
+ print(pjoin(opt.meta_dir, "std.npy"))
43
+ # used by our models during inference
44
+ mean = np.load(pjoin(opt.meta_dir, "mean.npy"))
45
+ std = np.load(pjoin(opt.meta_dir, "std.npy"))
46
+ else:
47
+ # used by our models during train
48
+ mean = np.load(pjoin(self.data_root, "Mean.npy"))
49
+ std = np.load(pjoin(self.data_root, "Std.npy"))
50
+
51
+ if mode == "eval":
52
+ # used by T2M models (including evaluators)
53
+ # this is to translate ours norms to theirs
54
+ self.mean_for_eval = np.load(
55
+ pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_mean.npy")
56
+ )
57
+ self.std_for_eval = np.load(
58
+ pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_std.npy")
59
+ )
60
+ if mode in ["gt_eval", "eval"]:
61
+ self.w_vectorizer = WordVectorizer(opt.glove_dir, "our_vab")
62
+
63
+ data_dict = {}
64
+ id_list = []
65
+ split_file = pjoin(self.data_root, f"{split}.txt")
66
+ with cs.open(split_file, "r") as f:
67
+ for line in f.readlines():
68
+ id_list.append(line.strip())
69
+
70
+ if opt.debug == True:
71
+ id_list = id_list[:1000]
72
+
73
+ new_name_list = []
74
+ length_list = []
75
+ for name in tqdm(
76
+ id_list,
77
+ disable=(
78
+ not accelerator.is_local_main_process
79
+ if accelerator is not None
80
+ else False
81
+ ),
82
+ ):
83
+ motion = np.load(pjoin(motion_dir, name + ".npy"))
84
+ if (len(motion)) < self.min_motion_len or (len(motion) >= 200):
85
+ continue
86
+ text_data = []
87
+ flag = False
88
+ with cs.open(pjoin(text_dir, name + ".txt")) as f:
89
+ for line in f.readlines():
90
+ text_dict = {}
91
+ line_split = line.strip().split("#")
92
+ caption = line_split[0]
93
+ try:
94
+ tokens = line_split[1].split(" ")
95
+ f_tag = float(line_split[2])
96
+ to_tag = float(line_split[3])
97
+ f_tag = 0.0 if np.isnan(f_tag) else f_tag
98
+ to_tag = 0.0 if np.isnan(to_tag) else to_tag
99
+ except:
100
+ tokens = ["a/NUM", "a/NUM"]
101
+ f_tag = 0.0
102
+ to_tag = 8.0
103
+ text_dict["caption"] = caption
104
+ text_dict["tokens"] = tokens
105
+ if f_tag == 0.0 and to_tag == 0.0:
106
+ flag = True
107
+ text_data.append(text_dict)
108
+ else:
109
+ n_motion = motion[int(f_tag * 20) : int(to_tag * 20)]
110
+ if (len(n_motion)) < self.min_motion_len or (
111
+ len(n_motion) >= 200
112
+ ):
113
+ continue
114
+ new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name
115
+ while new_name in data_dict:
116
+ new_name = (
117
+ random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name
118
+ )
119
+ data_dict[new_name] = {
120
+ "motion": n_motion,
121
+ "length": len(n_motion),
122
+ "text": [text_dict],
123
+ }
124
+ new_name_list.append(new_name)
125
+ length_list.append(len(n_motion))
126
+ if flag:
127
+ data_dict[name] = {
128
+ "motion": motion,
129
+ "length": len(motion),
130
+ "text": text_data,
131
+ }
132
+ new_name_list.append(name)
133
+ length_list.append(len(motion))
134
+
135
+ name_list, length_list = zip(
136
+ *sorted(zip(new_name_list, length_list), key=lambda x: x[1])
137
+ )
138
+
139
+ if mode == "train":
140
+ if opt.dataset_name != "amass":
141
+ joints_num = self.joints_num
142
+ # root_rot_velocity (B, seq_len, 1)
143
+ std[0:1] = std[0:1] / opt.feat_bias
144
+ # root_linear_velocity (B, seq_len, 2)
145
+ std[1:3] = std[1:3] / opt.feat_bias
146
+ # root_y (B, seq_len, 1)
147
+ std[3:4] = std[3:4] / opt.feat_bias
148
+ # ric_data (B, seq_len, (joint_num - 1)*3)
149
+ std[4 : 4 + (joints_num - 1) * 3] = (
150
+ std[4 : 4 + (joints_num - 1) * 3] / 1.0
151
+ )
152
+ # rot_data (B, seq_len, (joint_num - 1)*6)
153
+ std[4 + (joints_num - 1) * 3 : 4 + (joints_num - 1) * 9] = (
154
+ std[4 + (joints_num - 1) * 3 : 4 + (joints_num - 1) * 9] / 1.0
155
+ )
156
+ # local_velocity (B, seq_len, joint_num*3)
157
+ std[
158
+ 4 + (joints_num - 1) * 9 : 4 + (joints_num - 1) * 9 + joints_num * 3
159
+ ] = (
160
+ std[
161
+ 4
162
+ + (joints_num - 1) * 9 : 4
163
+ + (joints_num - 1) * 9
164
+ + joints_num * 3
165
+ ]
166
+ / 1.0
167
+ )
168
+ # foot contact (B, seq_len, 4)
169
+ std[4 + (joints_num - 1) * 9 + joints_num * 3 :] = (
170
+ std[4 + (joints_num - 1) * 9 + joints_num * 3 :] / opt.feat_bias
171
+ )
172
+
173
+ assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1]
174
+
175
+ if accelerator is not None and accelerator.is_main_process:
176
+ np.save(pjoin(opt.meta_dir, "mean.npy"), mean)
177
+ np.save(pjoin(opt.meta_dir, "std.npy"), std)
178
+
179
+ self.mean = mean
180
+ self.std = std
181
+ self.data_dict = data_dict
182
+ self.name_list = name_list
183
+
184
+ def inv_transform(self, data):
185
+ return data * self.std + self.mean
186
+
187
+ def __len__(self):
188
+ return len(self.data_dict)
189
+
190
+ def __getitem__(self, idx):
191
+ data = self.data_dict[self.name_list[idx]]
192
+ motion, m_length, text_list = data["motion"], data["length"], data["text"]
193
+
194
+ # Randomly select a caption
195
+ text_data = random.choice(text_list)
196
+ caption = text_data["caption"]
197
+
198
+ "Z Normalization"
199
+ if self.mode not in ["xyz_gt", "hml_gt"]:
200
+ motion = (motion - self.mean) / self.std
201
+
202
+ "crop motion"
203
+ if self.mode in ["eval", "gt_eval"]:
204
+ # Crop the motions in to times of 4, and introduce small variations
205
+ if self.unit_length < 10:
206
+ coin2 = np.random.choice(["single", "single", "double"])
207
+ else:
208
+ coin2 = "single"
209
+ if coin2 == "double":
210
+ m_length = (m_length // self.unit_length - 1) * self.unit_length
211
+ elif coin2 == "single":
212
+ m_length = (m_length // self.unit_length) * self.unit_length
213
+ idx = random.randint(0, len(motion) - m_length)
214
+ motion = motion[idx : idx + m_length]
215
+ elif m_length >= self.max_motion_length:
216
+ idx = random.randint(0, len(motion) - self.max_motion_length)
217
+ motion = motion[idx : idx + self.max_motion_length]
218
+ m_length = self.max_motion_length
219
+
220
+ "pad motion"
221
+ if m_length < self.max_motion_length:
222
+ motion = np.concatenate(
223
+ [
224
+ motion,
225
+ np.zeros((self.max_motion_length - m_length, motion.shape[1])),
226
+ ],
227
+ axis=0,
228
+ )
229
+ assert len(motion) == self.max_motion_length
230
+
231
+ if self.mode in ["gt_eval", "eval"]:
232
+ "word embedding for text-to-motion evaluation"
233
+ tokens = text_data["tokens"]
234
+ if len(tokens) < self.max_text_len:
235
+ # pad with "unk"
236
+ tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
237
+ sent_len = len(tokens)
238
+ tokens = tokens + ["unk/OTHER"] * (self.max_text_len + 2 - sent_len)
239
+ else:
240
+ # crop
241
+ tokens = tokens[: self.max_text_len]
242
+ tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"]
243
+ sent_len = len(tokens)
244
+ pos_one_hots = []
245
+ word_embeddings = []
246
+ for token in tokens:
247
+ word_emb, pos_oh = self.w_vectorizer[token]
248
+ pos_one_hots.append(pos_oh[None, :])
249
+ word_embeddings.append(word_emb[None, :])
250
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
251
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
252
+ return (
253
+ word_embeddings,
254
+ pos_one_hots,
255
+ caption,
256
+ sent_len,
257
+ motion,
258
+ m_length,
259
+ "_".join(tokens),
260
+ )
261
+ elif self.mode in ["xyz_gt"]:
262
+ "Convert motion hml representation to skeleton points xyz"
263
+ # 1. Use kn to get the keypoints position (the padding position after kn is all zero)
264
+ motion = torch.from_numpy(motion).float()
265
+ pred_joints = recover_from_ric(
266
+ motion, self.joints_num
267
+ ) # (nframe, njoints, 3)
268
+
269
+ # 2. Put on Floor (Y axis)
270
+ floor_height = pred_joints.min(dim=0)[0].min(dim=0)[0][1]
271
+ pred_joints[:, :, 1] -= floor_height
272
+ return pred_joints
273
+
274
+ return caption, motion, m_length
275
+
276
+
277
+ class HumanML3D(Text2MotionDataset):
278
+ def __init__(self, opt, split="train", mode="train", accelerator=None):
279
+ self.data_root = "./data/HumanML3D"
280
+ self.min_motion_len = 40
281
+ self.joints_num = 22
282
+ self.dim_pose = 263
283
+ self.max_motion_length = 196
284
+ if accelerator:
285
+ accelerator.print(
286
+ "\n Loading %s mode HumanML3D %s dataset ..." % (mode, split)
287
+ )
288
+ else:
289
+ print("\n Loading %s mode HumanML3D dataset ..." % mode)
290
+ super(HumanML3D, self).__init__(opt, split, mode, accelerator)
291
+
292
+
293
+ class KIT(Text2MotionDataset):
294
+ def __init__(self, opt, split="train", mode="train", accelerator=None):
295
+ self.data_root = "./data/KIT-ML"
296
+ self.min_motion_len = 24
297
+ self.joints_num = 21
298
+ self.dim_pose = 251
299
+ self.max_motion_length = 196
300
+ if accelerator:
301
+ accelerator.print("\n Loading %s mode KIT %s dataset ..." % (mode, split))
302
+ else:
303
+ print("\n Loading %s mode KIT dataset ..." % mode)
304
+ super(KIT, self).__init__(opt, split, mode, accelerator)
eval/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .evaluator_wrapper import EvaluatorModelWrapper
2
+ from .eval_t2m import evaluation
eval/eval_t2m.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file code from T2M(https://github.com/EricGuo5513/text-to-motion), licensed under the https://github.com/EricGuo5513/text-to-motion/blob/main/LICENSE.
2
+ # Copyright (c) 2022 Chuan Guo
3
+ from datetime import datetime
4
+ import numpy as np
5
+ import torch
6
+ from utils.metrics import *
7
+ from collections import OrderedDict
8
+
9
+
10
+ def evaluate_matching_score(eval_wrapper,motion_loaders, file):
11
+ match_score_dict = OrderedDict({})
12
+ R_precision_dict = OrderedDict({})
13
+ activation_dict = OrderedDict({})
14
+ # print(motion_loaders.keys())
15
+ print('========== Evaluating Matching Score ==========')
16
+ for motion_loader_name, motion_loader in motion_loaders.items():
17
+ all_motion_embeddings = []
18
+ score_list = []
19
+ all_size = 0
20
+ matching_score_sum = 0
21
+ top_k_count = 0
22
+ # print(motion_loader_name)
23
+ with torch.no_grad():
24
+ for idx, batch in enumerate(motion_loader):
25
+ word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch
26
+ text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(
27
+ word_embs=word_embeddings,
28
+ pos_ohot=pos_one_hots,
29
+ cap_lens=sent_lens,
30
+ motions=motions,
31
+ m_lens=m_lens
32
+ )
33
+ dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(),
34
+ motion_embeddings.cpu().numpy())
35
+ matching_score_sum += dist_mat.trace()
36
+ # import pdb;pdb.set_trace()
37
+
38
+ argsmax = np.argsort(dist_mat, axis=1)
39
+ top_k_mat = calculate_top_k(argsmax, top_k=3)
40
+ top_k_count += top_k_mat.sum(axis=0)
41
+
42
+ all_size += text_embeddings.shape[0]
43
+
44
+ all_motion_embeddings.append(motion_embeddings.cpu().numpy())
45
+
46
+ all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
47
+ # import pdb;pdb.set_trace()
48
+ matching_score = matching_score_sum / all_size
49
+ R_precision = top_k_count / all_size
50
+ match_score_dict[motion_loader_name] = matching_score
51
+ R_precision_dict[motion_loader_name] = R_precision
52
+ activation_dict[motion_loader_name] = all_motion_embeddings
53
+
54
+ print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}')
55
+ print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True)
56
+
57
+ line = f'---> [{motion_loader_name}] R_precision: '
58
+ for i in range(len(R_precision)):
59
+ line += '(top %d): %.4f ' % (i+1, R_precision[i])
60
+ print(line)
61
+ print(line, file=file, flush=True)
62
+
63
+ return match_score_dict, R_precision_dict, activation_dict
64
+
65
+
66
+ def evaluate_fid(eval_wrapper,groundtruth_loader, activation_dict, file):
67
+ eval_dict = OrderedDict({})
68
+ gt_motion_embeddings = []
69
+ print('========== Evaluating FID ==========')
70
+ with torch.no_grad():
71
+ for idx, batch in enumerate(groundtruth_loader):
72
+ _, _, _, sent_lens, motions, m_lens, _ = batch
73
+ motion_embeddings = eval_wrapper.get_motion_embeddings(
74
+ motions=motions,
75
+ m_lens=m_lens
76
+ )
77
+ gt_motion_embeddings.append(motion_embeddings.cpu().numpy())
78
+ gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0)
79
+ gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings)
80
+
81
+ for model_name, motion_embeddings in activation_dict.items():
82
+ mu, cov = calculate_activation_statistics(motion_embeddings)
83
+ # print(mu)
84
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
85
+ print(f'---> [{model_name}] FID: {fid:.4f}')
86
+ print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True)
87
+ eval_dict[model_name] = fid
88
+ return eval_dict
89
+
90
+
91
+ def evaluate_diversity(activation_dict, file, diversity_times):
92
+ eval_dict = OrderedDict({})
93
+ print('========== Evaluating Diversity ==========')
94
+ for model_name, motion_embeddings in activation_dict.items():
95
+ diversity = calculate_diversity(motion_embeddings, diversity_times)
96
+ eval_dict[model_name] = diversity
97
+ print(f'---> [{model_name}] Diversity: {diversity:.4f}')
98
+ print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True)
99
+ return eval_dict
100
+
101
+
102
+ def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times):
103
+ eval_dict = OrderedDict({})
104
+ print('========== Evaluating MultiModality ==========')
105
+ for model_name, mm_motion_loader in mm_motion_loaders.items():
106
+ mm_motion_embeddings = []
107
+ with torch.no_grad():
108
+ for idx, batch in enumerate(mm_motion_loader):
109
+ # (1, mm_replications, dim_pos)
110
+ motions, m_lens = batch
111
+ motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0])
112
+ mm_motion_embeddings.append(motion_embedings.unsqueeze(0))
113
+ if len(mm_motion_embeddings) == 0:
114
+ multimodality = 0
115
+ else:
116
+ mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy()
117
+ multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times)
118
+ print(f'---> [{model_name}] Multimodality: {multimodality:.4f}')
119
+ print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True)
120
+ eval_dict[model_name] = multimodality
121
+ return eval_dict
122
+
123
+
124
+ def get_metric_statistics(values, replication_times):
125
+ mean = np.mean(values, axis=0)
126
+ std = np.std(values, axis=0)
127
+ conf_interval = 1.96 * std / np.sqrt(replication_times)
128
+ return mean, conf_interval
129
+
130
+
131
+ def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=False):
132
+ with open(log_file, 'a') as f:
133
+ all_metrics = OrderedDict({'Matching Score': OrderedDict({}),
134
+ 'R_precision': OrderedDict({}),
135
+ 'FID': OrderedDict({}),
136
+ 'Diversity': OrderedDict({}),
137
+ 'MultiModality': OrderedDict({})})
138
+
139
+ for replication in range(replication_times):
140
+ print(f'Time: {datetime.now()}')
141
+ print(f'Time: {datetime.now()}', file=f, flush=True)
142
+ motion_loaders = {}
143
+ motion_loaders['ground truth'] = gt_loader
144
+ mm_motion_loaders = {}
145
+ # motion_loaders['ground truth'] = gt_loader
146
+ for motion_loader_name, motion_loader_getter in eval_motion_loaders.items():
147
+ motion_loader, mm_motion_loader,eval_generate_time = motion_loader_getter()
148
+ print(f'---> [{motion_loader_name}] batch_generate_time: {eval_generate_time}s', file=f, flush=True)
149
+ motion_loaders[motion_loader_name] = motion_loader
150
+ mm_motion_loaders[motion_loader_name] = mm_motion_loader
151
+
152
+ if replication_times>1:
153
+ print(f'==================== Replication {replication} ====================')
154
+ print(f'==================== Replication {replication} ====================', file=f, flush=True)
155
+
156
+
157
+ mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f)
158
+
159
+ fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f)
160
+
161
+ div_score_dict = evaluate_diversity(acti_dict, f, diversity_times)
162
+
163
+ if run_mm:
164
+ mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times)
165
+
166
+ print(f'!!! DONE !!!')
167
+ print(f'!!! DONE !!!', file=f, flush=True)
168
+
169
+ for key, item in mat_score_dict.items():
170
+ if key not in all_metrics['Matching Score']:
171
+ all_metrics['Matching Score'][key] = [item]
172
+ else:
173
+ all_metrics['Matching Score'][key] += [item]
174
+
175
+ for key, item in R_precision_dict.items():
176
+ if key not in all_metrics['R_precision']:
177
+ all_metrics['R_precision'][key] = [item]
178
+ else:
179
+ all_metrics['R_precision'][key] += [item]
180
+
181
+ for key, item in fid_score_dict.items():
182
+ if key not in all_metrics['FID']:
183
+ all_metrics['FID'][key] = [item]
184
+ else:
185
+ all_metrics['FID'][key] += [item]
186
+
187
+ for key, item in div_score_dict.items():
188
+ if key not in all_metrics['Diversity']:
189
+ all_metrics['Diversity'][key] = [item]
190
+ else:
191
+ all_metrics['Diversity'][key] += [item]
192
+
193
+ for key, item in mm_score_dict.items():
194
+ if key not in all_metrics['MultiModality']:
195
+ all_metrics['MultiModality'][key] = [item]
196
+ else:
197
+ all_metrics['MultiModality'][key] += [item]
198
+
199
+
200
+ mean_dict = {}
201
+ if replication_times>1:
202
+ for metric_name, metric_dict in all_metrics.items():
203
+ print('========== %s Summary ==========' % metric_name)
204
+ print('========== %s Summary ==========' % metric_name, file=f, flush=True)
205
+
206
+ for model_name, values in metric_dict.items():
207
+ # print(metric_name, model_name)
208
+ mean, conf_interval = get_metric_statistics(np.array(values),replication_times)
209
+ mean_dict[metric_name + '_' + model_name] = mean
210
+ # print(mean, mean.dtype)
211
+ if isinstance(mean, np.float64) or isinstance(mean, np.float32):
212
+ print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
213
+ print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
214
+ elif isinstance(mean, np.ndarray):
215
+ line = f'---> [{model_name}]'
216
+ for i in range(len(mean)):
217
+ line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
218
+ print(line)
219
+ print(line, file=f, flush=True)
220
+ return mean_dict
221
+ else:
222
+ return all_metrics
eval/evaluator_modules.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file code from T2M(https://github.com/EricGuo5513/text-to-motion), licensed under the https://github.com/EricGuo5513/text-to-motion/blob/main/LICENSE.
2
+ # Copyright (c) 2022 Chuan Guo
3
+ import torch
4
+ import torch.nn as nn
5
+ import numpy as np
6
+ import time
7
+ import math
8
+ from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
9
+ import torch.nn.functional as F
10
+
11
+
12
+ class ContrastiveLoss(torch.nn.Module):
13
+ """
14
+ Contrastive loss function.
15
+ Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
16
+ """
17
+ def __init__(self, margin=3.0):
18
+ super(ContrastiveLoss, self).__init__()
19
+ self.margin = margin
20
+
21
+ def forward(self, output1, output2, label):
22
+ euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True)
23
+ loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) +
24
+ (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2))
25
+ return loss_contrastive
26
+
27
+
28
+ def init_weight(m):
29
+ if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d):
30
+ nn.init.xavier_normal_(m.weight)
31
+ # m.bias.data.fill_(0.01)
32
+ if m.bias is not None:
33
+ nn.init.constant_(m.bias, 0)
34
+
35
+
36
+ def reparameterize(mu, logvar):
37
+ s_var = logvar.mul(0.5).exp_()
38
+ eps = s_var.data.new(s_var.size()).normal_()
39
+ return eps.mul(s_var).add_(mu)
40
+
41
+
42
+ # batch_size, dimension and position
43
+ # output: (batch_size, dim)
44
+ def positional_encoding(batch_size, dim, pos):
45
+ assert batch_size == pos.shape[0]
46
+ positions_enc = np.array([
47
+ [pos[j] / np.power(10000, (i-i%2)/dim) for i in range(dim)]
48
+ for j in range(batch_size)
49
+ ], dtype=np.float32)
50
+ positions_enc[:, 0::2] = np.sin(positions_enc[:, 0::2])
51
+ positions_enc[:, 1::2] = np.cos(positions_enc[:, 1::2])
52
+ return torch.from_numpy(positions_enc).float()
53
+
54
+
55
+ def get_padding_mask(batch_size, seq_len, cap_lens):
56
+ cap_lens = cap_lens.data.tolist()
57
+ mask_2d = torch.ones((batch_size, seq_len, seq_len), dtype=torch.float32)
58
+ for i, cap_len in enumerate(cap_lens):
59
+ mask_2d[i, :, :cap_len] = 0
60
+ return mask_2d.bool(), 1 - mask_2d[:, :, 0].clone()
61
+
62
+
63
+ class PositionalEncoding(nn.Module):
64
+
65
+ def __init__(self, d_model, max_len=300):
66
+ super(PositionalEncoding, self).__init__()
67
+
68
+ pe = torch.zeros(max_len, d_model)
69
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
70
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
71
+ pe[:, 0::2] = torch.sin(position * div_term)
72
+ pe[:, 1::2] = torch.cos(position * div_term)
73
+ # pe = pe.unsqueeze(0).transpose(0, 1)
74
+ self.register_buffer('pe', pe)
75
+
76
+ def forward(self, pos):
77
+ return self.pe[pos]
78
+
79
+
80
+ class MovementConvEncoder(nn.Module):
81
+ def __init__(self, input_size, hidden_size, output_size):
82
+ super(MovementConvEncoder, self).__init__()
83
+ self.main = nn.Sequential(
84
+ nn.Conv1d(input_size, hidden_size, 4, 2, 1),
85
+ nn.Dropout(0.2, inplace=True),
86
+ nn.LeakyReLU(0.2, inplace=True),
87
+ nn.Conv1d(hidden_size, output_size, 4, 2, 1),
88
+ nn.Dropout(0.2, inplace=True),
89
+ nn.LeakyReLU(0.2, inplace=True),
90
+ )
91
+ self.out_net = nn.Linear(output_size, output_size)
92
+ self.main.apply(init_weight)
93
+ self.out_net.apply(init_weight)
94
+
95
+ def forward(self, inputs):
96
+ inputs = inputs.permute(0, 2, 1)
97
+ outputs = self.main(inputs).permute(0, 2, 1)
98
+ # print(outputs.shape)
99
+ return self.out_net(outputs)
100
+
101
+
102
+ class MovementConvDecoder(nn.Module):
103
+ def __init__(self, input_size, hidden_size, output_size):
104
+ super(MovementConvDecoder, self).__init__()
105
+ self.main = nn.Sequential(
106
+ nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1),
107
+ # nn.Dropout(0.2, inplace=True),
108
+ nn.LeakyReLU(0.2, inplace=True),
109
+ nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1),
110
+ # nn.Dropout(0.2, inplace=True),
111
+ nn.LeakyReLU(0.2, inplace=True),
112
+ )
113
+ self.out_net = nn.Linear(output_size, output_size)
114
+
115
+ self.main.apply(init_weight)
116
+ self.out_net.apply(init_weight)
117
+
118
+ def forward(self, inputs):
119
+ inputs = inputs.permute(0, 2, 1)
120
+ outputs = self.main(inputs).permute(0, 2, 1)
121
+ return self.out_net(outputs)
122
+
123
+
124
+ class TextVAEDecoder(nn.Module):
125
+ def __init__(self, text_size, input_size, output_size, hidden_size, n_layers):
126
+ super(TextVAEDecoder, self).__init__()
127
+ self.input_size = input_size
128
+ self.output_size = output_size
129
+ self.hidden_size = hidden_size
130
+ self.n_layers = n_layers
131
+ self.emb = nn.Sequential(
132
+ nn.Linear(input_size, hidden_size),
133
+ nn.LayerNorm(hidden_size),
134
+ nn.LeakyReLU(0.2, inplace=True))
135
+
136
+ self.z2init = nn.Linear(text_size, hidden_size * n_layers)
137
+ self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)])
138
+ self.positional_encoder = PositionalEncoding(hidden_size)
139
+
140
+
141
+ self.output = nn.Sequential(
142
+ nn.Linear(hidden_size, hidden_size),
143
+ nn.LayerNorm(hidden_size),
144
+ nn.LeakyReLU(0.2, inplace=True),
145
+ nn.Linear(hidden_size, output_size)
146
+ )
147
+
148
+ #
149
+ # self.output = nn.Sequential(
150
+ # nn.Linear(hidden_size, hidden_size),
151
+ # nn.LayerNorm(hidden_size),
152
+ # nn.LeakyReLU(0.2, inplace=True),
153
+ # nn.Linear(hidden_size, output_size-4)
154
+ # )
155
+
156
+ # self.contact_net = nn.Sequential(
157
+ # nn.Linear(output_size-4, 64),
158
+ # nn.LayerNorm(64),
159
+ # nn.LeakyReLU(0.2, inplace=True),
160
+ # nn.Linear(64, 4)
161
+ # )
162
+
163
+ self.output.apply(init_weight)
164
+ self.emb.apply(init_weight)
165
+ self.z2init.apply(init_weight)
166
+ # self.contact_net.apply(init_weight)
167
+
168
+ def get_init_hidden(self, latent):
169
+ hidden = self.z2init(latent)
170
+ hidden = torch.split(hidden, self.hidden_size, dim=-1)
171
+ return list(hidden)
172
+
173
+ def forward(self, inputs, last_pred, hidden, p):
174
+ h_in = self.emb(inputs)
175
+ pos_enc = self.positional_encoder(p).to(inputs.device).detach()
176
+ h_in = h_in + pos_enc
177
+ for i in range(self.n_layers):
178
+ # print(h_in.shape)
179
+ hidden[i] = self.gru[i](h_in, hidden[i])
180
+ h_in = hidden[i]
181
+ pose_pred = self.output(h_in)
182
+ # pose_pred = self.output(h_in) + last_pred.detach()
183
+ # contact = self.contact_net(pose_pred)
184
+ # return torch.cat([pose_pred, contact], dim=-1), hidden
185
+ return pose_pred, hidden
186
+
187
+
188
+ class TextDecoder(nn.Module):
189
+ def __init__(self, text_size, input_size, output_size, hidden_size, n_layers):
190
+ super(TextDecoder, self).__init__()
191
+ self.input_size = input_size
192
+ self.output_size = output_size
193
+ self.hidden_size = hidden_size
194
+ self.n_layers = n_layers
195
+ self.emb = nn.Sequential(
196
+ nn.Linear(input_size, hidden_size),
197
+ nn.LayerNorm(hidden_size),
198
+ nn.LeakyReLU(0.2, inplace=True))
199
+
200
+ self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)])
201
+ self.z2init = nn.Linear(text_size, hidden_size * n_layers)
202
+ self.positional_encoder = PositionalEncoding(hidden_size)
203
+
204
+ self.mu_net = nn.Linear(hidden_size, output_size)
205
+ self.logvar_net = nn.Linear(hidden_size, output_size)
206
+
207
+ self.emb.apply(init_weight)
208
+ self.z2init.apply(init_weight)
209
+ self.mu_net.apply(init_weight)
210
+ self.logvar_net.apply(init_weight)
211
+
212
+ def get_init_hidden(self, latent):
213
+
214
+ hidden = self.z2init(latent)
215
+ hidden = torch.split(hidden, self.hidden_size, dim=-1)
216
+
217
+ return list(hidden)
218
+
219
+ def forward(self, inputs, hidden, p):
220
+ # print(inputs.shape)
221
+ x_in = self.emb(inputs)
222
+ pos_enc = self.positional_encoder(p).to(inputs.device).detach()
223
+ x_in = x_in + pos_enc
224
+
225
+ for i in range(self.n_layers):
226
+ hidden[i] = self.gru[i](x_in, hidden[i])
227
+ h_in = hidden[i]
228
+ mu = self.mu_net(h_in)
229
+ logvar = self.logvar_net(h_in)
230
+ z = reparameterize(mu, logvar)
231
+ return z, mu, logvar, hidden
232
+
233
+ class AttLayer(nn.Module):
234
+ def __init__(self, query_dim, key_dim, value_dim):
235
+ super(AttLayer, self).__init__()
236
+ self.W_q = nn.Linear(query_dim, value_dim)
237
+ self.W_k = nn.Linear(key_dim, value_dim, bias=False)
238
+ self.W_v = nn.Linear(key_dim, value_dim)
239
+
240
+ self.softmax = nn.Softmax(dim=1)
241
+ self.dim = value_dim
242
+
243
+ self.W_q.apply(init_weight)
244
+ self.W_k.apply(init_weight)
245
+ self.W_v.apply(init_weight)
246
+
247
+ def forward(self, query, key_mat):
248
+ '''
249
+ query (batch, query_dim)
250
+ key (batch, seq_len, key_dim)
251
+ '''
252
+ # print(query.shape)
253
+ query_vec = self.W_q(query).unsqueeze(-1) # (batch, value_dim, 1)
254
+ val_set = self.W_v(key_mat) # (batch, seq_len, value_dim)
255
+ key_set = self.W_k(key_mat) # (batch, seq_len, value_dim)
256
+
257
+ weights = torch.matmul(key_set, query_vec) / np.sqrt(self.dim)
258
+
259
+ co_weights = self.softmax(weights) # (batch, seq_len, 1)
260
+ values = val_set * co_weights # (batch, seq_len, value_dim)
261
+ pred = values.sum(dim=1) # (batch, value_dim)
262
+ return pred, co_weights
263
+
264
+ def short_cut(self, querys, keys):
265
+ return self.W_q(querys), self.W_k(keys)
266
+
267
+
268
+ class TextEncoderBiGRU(nn.Module):
269
+ def __init__(self, word_size, pos_size, hidden_size, device):
270
+ super(TextEncoderBiGRU, self).__init__()
271
+ self.device = device
272
+
273
+ self.pos_emb = nn.Linear(pos_size, word_size)
274
+ self.input_emb = nn.Linear(word_size, hidden_size)
275
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
276
+ # self.linear2 = nn.Linear(hidden_size, output_size)
277
+
278
+ self.input_emb.apply(init_weight)
279
+ self.pos_emb.apply(init_weight)
280
+ # self.linear2.apply(init_weight)
281
+ # self.batch_size = batch_size
282
+ self.hidden_size = hidden_size
283
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
284
+
285
+ # input(batch_size, seq_len, dim)
286
+ def forward(self, word_embs, pos_onehot, cap_lens):
287
+ num_samples = word_embs.shape[0]
288
+
289
+ pos_embs = self.pos_emb(pos_onehot)
290
+ inputs = word_embs + pos_embs
291
+ input_embs = self.input_emb(inputs)
292
+ hidden = self.hidden.repeat(1, num_samples, 1)
293
+
294
+ cap_lens = cap_lens.data.tolist()
295
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
296
+
297
+ gru_seq, gru_last = self.gru(emb, hidden)
298
+
299
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
300
+ gru_seq = pad_packed_sequence(gru_seq, batch_first=True)[0]
301
+ forward_seq = gru_seq[..., :self.hidden_size]
302
+ backward_seq = gru_seq[..., self.hidden_size:].clone()
303
+
304
+ # Concate the forward and backward word embeddings
305
+ for i, length in enumerate(cap_lens):
306
+ backward_seq[i:i+1, :length] = torch.flip(backward_seq[i:i+1, :length].clone(), dims=[1])
307
+ gru_seq = torch.cat([forward_seq, backward_seq], dim=-1)
308
+
309
+ return gru_seq, gru_last
310
+
311
+
312
+ class TextEncoderBiGRUCo(nn.Module):
313
+ def __init__(self, word_size, pos_size, hidden_size, output_size, device):
314
+ super(TextEncoderBiGRUCo, self).__init__()
315
+ self.device = device
316
+
317
+ self.pos_emb = nn.Linear(pos_size, word_size)
318
+ self.input_emb = nn.Linear(word_size, hidden_size)
319
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
320
+ self.output_net = nn.Sequential(
321
+ nn.Linear(hidden_size * 2, hidden_size),
322
+ nn.LayerNorm(hidden_size),
323
+ nn.LeakyReLU(0.2, inplace=True),
324
+ nn.Linear(hidden_size, output_size)
325
+ )
326
+
327
+ self.input_emb.apply(init_weight)
328
+ self.pos_emb.apply(init_weight)
329
+ self.output_net.apply(init_weight)
330
+ # self.linear2.apply(init_weight)
331
+ # self.batch_size = batch_size
332
+ self.hidden_size = hidden_size
333
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
334
+
335
+ # input(batch_size, seq_len, dim)
336
+ def forward(self, word_embs, pos_onehot, cap_lens):
337
+ num_samples = word_embs.shape[0]
338
+
339
+ pos_embs = self.pos_emb(pos_onehot)
340
+ inputs = word_embs + pos_embs
341
+ input_embs = self.input_emb(inputs)
342
+ hidden = self.hidden.repeat(1, num_samples, 1)
343
+
344
+ cap_lens = cap_lens.data.tolist()
345
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
346
+
347
+ gru_seq, gru_last = self.gru(emb, hidden)
348
+
349
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
350
+
351
+ return self.output_net(gru_last)
352
+
353
+
354
+ class MotionEncoderBiGRUCo(nn.Module):
355
+ def __init__(self, input_size, hidden_size, output_size, device):
356
+ super(MotionEncoderBiGRUCo, self).__init__()
357
+ self.device = device
358
+
359
+ self.input_emb = nn.Linear(input_size, hidden_size)
360
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
361
+ self.output_net = nn.Sequential(
362
+ nn.Linear(hidden_size*2, hidden_size),
363
+ nn.LayerNorm(hidden_size),
364
+ nn.LeakyReLU(0.2, inplace=True),
365
+ nn.Linear(hidden_size, output_size)
366
+ )
367
+
368
+ self.input_emb.apply(init_weight)
369
+ self.output_net.apply(init_weight)
370
+ self.hidden_size = hidden_size
371
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
372
+
373
+ # input(batch_size, seq_len, dim)
374
+ def forward(self, inputs, m_lens):
375
+ num_samples = inputs.shape[0]
376
+
377
+ input_embs = self.input_emb(inputs)
378
+ hidden = self.hidden.repeat(1, num_samples, 1)
379
+
380
+ cap_lens = m_lens.data.tolist()
381
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
382
+
383
+ gru_seq, gru_last = self.gru(emb, hidden)
384
+
385
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
386
+
387
+ return self.output_net(gru_last)
388
+
389
+
390
+ class MotionLenEstimatorBiGRU(nn.Module):
391
+ def __init__(self, word_size, pos_size, hidden_size, output_size):
392
+ super(MotionLenEstimatorBiGRU, self).__init__()
393
+
394
+ self.pos_emb = nn.Linear(pos_size, word_size)
395
+ self.input_emb = nn.Linear(word_size, hidden_size)
396
+ self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True)
397
+ nd = 512
398
+ self.output = nn.Sequential(
399
+ nn.Linear(hidden_size*2, nd),
400
+ nn.LayerNorm(nd),
401
+ nn.LeakyReLU(0.2, inplace=True),
402
+
403
+ nn.Linear(nd, nd // 2),
404
+ nn.LayerNorm(nd // 2),
405
+ nn.LeakyReLU(0.2, inplace=True),
406
+
407
+ nn.Linear(nd // 2, nd // 4),
408
+ nn.LayerNorm(nd // 4),
409
+ nn.LeakyReLU(0.2, inplace=True),
410
+
411
+ nn.Linear(nd // 4, output_size)
412
+ )
413
+
414
+ self.input_emb.apply(init_weight)
415
+ self.pos_emb.apply(init_weight)
416
+ self.output.apply(init_weight)
417
+ self.hidden_size = hidden_size
418
+ self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True))
419
+
420
+ # input(batch_size, seq_len, dim)
421
+ def forward(self, word_embs, pos_onehot, cap_lens):
422
+ num_samples = word_embs.shape[0]
423
+
424
+ pos_embs = self.pos_emb(pos_onehot)
425
+ inputs = word_embs + pos_embs
426
+ input_embs = self.input_emb(inputs)
427
+ hidden = self.hidden.repeat(1, num_samples, 1)
428
+
429
+ cap_lens = cap_lens.data.tolist()
430
+ emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True)
431
+
432
+ gru_seq, gru_last = self.gru(emb, hidden)
433
+
434
+ gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1)
435
+
436
+ return self.output(gru_last)
eval/evaluator_wrapper.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file code from T2M(https://github.com/EricGuo5513/text-to-motion), licensed under the https://github.com/EricGuo5513/text-to-motion/blob/main/LICENSE.
2
+ # Copyright (c) 2022 Chuan Guo
3
+ import torch
4
+ from os.path import join as pjoin
5
+ import numpy as np
6
+ from .evaluator_modules import *
7
+
8
+
9
+ def build_models(opt):
10
+ movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent)
11
+ text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word,
12
+ pos_size=opt.dim_pos_ohot,
13
+ hidden_size=opt.dim_text_hidden,
14
+ output_size=opt.dim_coemb_hidden,
15
+ device=opt.device)
16
+
17
+ motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent,
18
+ hidden_size=opt.dim_motion_hidden,
19
+ output_size=opt.dim_coemb_hidden,
20
+ device=opt.device)
21
+ checkpoint = torch.load(pjoin(opt.evaluator_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'),
22
+ map_location=opt.device)
23
+ movement_enc.load_state_dict(checkpoint['movement_encoder'])
24
+ text_enc.load_state_dict(checkpoint['text_encoder'])
25
+ motion_enc.load_state_dict(checkpoint['motion_encoder'])
26
+ print('\nLoading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch']))
27
+ return text_enc, motion_enc, movement_enc
28
+
29
+ class EvaluatorModelWrapper(object):
30
+
31
+ def __init__(self, opt):
32
+ self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt)
33
+ self.opt = opt
34
+ self.device = opt.device
35
+
36
+ self.text_encoder.to(opt.device)
37
+ self.motion_encoder.to(opt.device)
38
+ self.movement_encoder.to(opt.device)
39
+
40
+ self.text_encoder.eval()
41
+ self.motion_encoder.eval()
42
+ self.movement_encoder.eval()
43
+
44
+ # Please note that the results does not following the order of inputs
45
+ def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens):
46
+ with torch.no_grad():
47
+ word_embs = word_embs.detach().to(self.device).float()
48
+ pos_ohot = pos_ohot.detach().to(self.device).float()
49
+ motions = motions.detach().to(self.device).float()
50
+
51
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
52
+ motions = motions[align_idx]
53
+ m_lens = m_lens[align_idx]
54
+
55
+ '''Movement Encoding'''
56
+ movements = self.movement_encoder(motions[..., :-4]).detach()
57
+ m_lens = torch.div(m_lens, self.opt.unit_length, rounding_mode='trunc')
58
+ motion_embedding = self.motion_encoder(movements, m_lens)
59
+
60
+ '''Text Encoding'''
61
+ text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens)
62
+ text_embedding = text_embedding[align_idx]
63
+ return text_embedding, motion_embedding
64
+
65
+ # Please note that the results does not following the order of inputs
66
+ def get_motion_embeddings(self, motions, m_lens):
67
+ with torch.no_grad():
68
+ motions = motions.detach().to(self.device).float()
69
+
70
+ align_idx = np.argsort(m_lens.data.tolist())[::-1].copy()
71
+ motions = motions[align_idx]
72
+ m_lens = m_lens[align_idx]
73
+
74
+ '''Movement Encoding'''
75
+ movements = self.movement_encoder(motions[..., :-4]).detach()
76
+ m_lens = torch.div(m_lens, self.opt.unit_length, rounding_mode='trunc')
77
+ motion_embedding = self.motion_encoder(movements, m_lens)
78
+ return motion_embedding
models/__init__.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .unet import MotionCLR
2
+
3
+
4
+ __all__ = ["MotionCLR"]
5
+
6
+
7
+ def build_models(opt, edit_config=None, out_path=None):
8
+ print("\nInitializing model ...")
9
+ model = MotionCLR(
10
+ input_feats=opt.dim_pose,
11
+ text_latent_dim=opt.text_latent_dim,
12
+ base_dim=opt.base_dim,
13
+ dim_mults=opt.dim_mults,
14
+ time_dim=opt.time_dim,
15
+ adagn=not opt.no_adagn,
16
+ zero=True,
17
+ dropout=opt.dropout,
18
+ no_eff=opt.no_eff,
19
+ cond_mask_prob=getattr(opt, "cond_mask_prob", 0.0),
20
+ self_attention=opt.self_attention,
21
+ vis_attn=opt.vis_attn,
22
+ edit_config=edit_config,
23
+ out_path=out_path,
24
+ )
25
+
26
+ return model
models/gaussian_diffusion.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from diffusers import (
2
+ DPMSolverMultistepScheduler,
3
+ DDPMScheduler,
4
+ DDIMScheduler,
5
+ PNDMScheduler,
6
+ DEISMultistepScheduler,
7
+ )
8
+ import torch
9
+ import yaml
10
+ import math
11
+ import tqdm
12
+ import time
13
+
14
+
15
+ class DiffusePipeline(object):
16
+
17
+ def __init__(
18
+ self,
19
+ opt,
20
+ model,
21
+ diffuser_name,
22
+ num_inference_steps,
23
+ device,
24
+ torch_dtype=torch.float16,
25
+ ):
26
+ self.device = device
27
+ self.torch_dtype = torch_dtype
28
+ self.diffuser_name = diffuser_name
29
+ self.num_inference_steps = num_inference_steps
30
+ if self.torch_dtype == torch.float16:
31
+ model = model.half()
32
+ self.model = model.to(device)
33
+ self.opt = opt
34
+
35
+ # Load parameters from YAML file
36
+ with open("config/diffuser_params.yaml", "r") as yaml_file:
37
+ diffuser_params = yaml.safe_load(yaml_file)
38
+
39
+ # Select diffusion'parameters based on diffuser_name
40
+ if diffuser_name in diffuser_params:
41
+ params = diffuser_params[diffuser_name]
42
+ scheduler_class_name = params["scheduler_class"]
43
+ additional_params = params["additional_params"]
44
+
45
+ # align training parameters
46
+ additional_params["num_train_timesteps"] = opt.diffusion_steps
47
+ additional_params["beta_schedule"] = opt.beta_schedule
48
+ additional_params["prediction_type"] = opt.prediction_type
49
+
50
+ try:
51
+ scheduler_class = globals()[scheduler_class_name]
52
+ except KeyError:
53
+ raise ValueError(f"Class '{scheduler_class_name}' not found.")
54
+
55
+ self.scheduler = scheduler_class(**additional_params)
56
+ else:
57
+ raise ValueError(f"Unsupported diffuser_name: {diffuser_name}")
58
+
59
+ def generate_batch(self, caption, m_lens):
60
+ B = len(caption)
61
+ T = m_lens.max()
62
+ shape = (B, T, self.model.input_feats)
63
+
64
+ # random sampling noise x_T
65
+ sample = torch.randn(shape, device=self.device, dtype=self.torch_dtype)
66
+
67
+ # set timesteps
68
+ self.scheduler.set_timesteps(self.num_inference_steps, self.device)
69
+ timesteps = [
70
+ torch.tensor([t] * B, device=self.device).long()
71
+ for t in self.scheduler.timesteps
72
+ ]
73
+
74
+ # cache text_embedded
75
+ enc_text = self.model.encode_text(caption, self.device)
76
+
77
+ for i, t in enumerate(timesteps):
78
+ # 1. model predict
79
+ with torch.no_grad():
80
+ if getattr(self.model, "cond_mask_prob", 0) > 0:
81
+ predict = self.model.forward_with_cfg(sample, t, enc_text=enc_text)
82
+ else:
83
+
84
+ predict = self.model(sample, t, enc_text=enc_text)
85
+
86
+ # 2. compute less noisy motion and set x_t -> x_t-1
87
+ sample = self.scheduler.step(predict, t[0], sample).prev_sample
88
+
89
+ return sample
90
+
91
+ def generate(self, caption, m_lens, batch_size=32):
92
+ N = len(caption)
93
+ infer_mode = ""
94
+ if getattr(self.model, "cond_mask_prob", 0) > 0:
95
+ infer_mode = "classifier-free-guidance"
96
+ print(
97
+ f"\nUsing {self.diffuser_name} diffusion scheduler to {infer_mode} generate {N} motions, sampling {self.num_inference_steps} steps."
98
+ )
99
+ self.model.eval()
100
+
101
+ all_output = []
102
+ t_sum = 0
103
+ cur_idx = 0
104
+ for bacth_idx in tqdm.tqdm(range(math.ceil(N / batch_size))):
105
+ if cur_idx + batch_size >= N:
106
+ batch_caption = caption[cur_idx:]
107
+ batch_m_lens = m_lens[cur_idx:]
108
+ else:
109
+ batch_caption = caption[cur_idx : cur_idx + batch_size]
110
+ batch_m_lens = m_lens[cur_idx : cur_idx + batch_size]
111
+ torch.cuda.synchronize()
112
+ start_time = time.time()
113
+ output = self.generate_batch(batch_caption, batch_m_lens)
114
+ torch.cuda.synchronize()
115
+ now_time = time.time()
116
+
117
+ # The average inference time is calculated after GPU warm-up in the first 50 steps.
118
+ if (bacth_idx + 1) * self.num_inference_steps >= 50:
119
+ t_sum += now_time - start_time
120
+
121
+ # Crop motion with gt/predicted motion length
122
+ B = output.shape[0]
123
+ for i in range(B):
124
+ all_output.append(output[i, : batch_m_lens[i]])
125
+
126
+ cur_idx += batch_size
127
+
128
+ # calcalate average inference time
129
+ t_eval = t_sum / (bacth_idx - 1)
130
+ print(
131
+ "The average generation time of a batch motion (bs=%d) is %f seconds"
132
+ % (batch_size, t_eval)
133
+ )
134
+ return all_output, t_eval
models/unet.py ADDED
@@ -0,0 +1,1073 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import clip
2
+ import math
3
+
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from torch import nn
7
+ import numpy as np
8
+ from einops.layers.torch import Rearrange
9
+ from einops import rearrange
10
+ import matplotlib.pyplot as plt
11
+ import os
12
+
13
+
14
+ MONITOR_ATTN = []
15
+ SELF_ATTN = []
16
+
17
+
18
+ def vis_attn(att, out_path, step, layer, shape, type_="self", lines=True):
19
+ if lines:
20
+ plt.figure(figsize=(10, 3))
21
+ for token_index in range(att.shape[1]):
22
+ plt.plot(att[:, token_index], label=f"Token {token_index}")
23
+
24
+ plt.title("Attention Values for Each Token")
25
+ plt.xlabel("time")
26
+ plt.ylabel("Attention Value")
27
+ plt.legend(loc="upper right", bbox_to_anchor=(1.15, 1))
28
+
29
+ # save image
30
+ savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_lines_{shape}.png")
31
+ os.makedirs(os.path.dirname(savepath), exist_ok=True)
32
+ plt.savefig(savepath, bbox_inches="tight")
33
+ np.save(savepath.replace(".png", ".npy"), att)
34
+ else:
35
+ plt.figure(figsize=(10, 10))
36
+ plt.imshow(att.transpose(), cmap="viridis", aspect="auto")
37
+ plt.colorbar()
38
+ plt.title("Attention Matrix Heatmap")
39
+ plt.ylabel("time")
40
+ plt.xlabel("time")
41
+
42
+ # save image
43
+ savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_heatmap_{shape}.png")
44
+ os.makedirs(os.path.dirname(savepath), exist_ok=True)
45
+ plt.savefig(savepath, bbox_inches="tight")
46
+ np.save(savepath.replace(".png", ".npy"), att)
47
+
48
+
49
+ def zero_module(module):
50
+ """
51
+ Zero out the parameters of a module and return it.
52
+ """
53
+ for p in module.parameters():
54
+ p.detach().zero_()
55
+ return module
56
+
57
+
58
+ class FFN(nn.Module):
59
+
60
+ def __init__(self, latent_dim, ffn_dim, dropout):
61
+ super().__init__()
62
+ self.linear1 = nn.Linear(latent_dim, ffn_dim)
63
+ self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim))
64
+ self.activation = nn.GELU()
65
+ self.dropout = nn.Dropout(dropout)
66
+
67
+ def forward(self, x):
68
+ y = self.linear2(self.dropout(self.activation(self.linear1(x))))
69
+ y = x + y
70
+ return y
71
+
72
+
73
+ class Conv1dAdaGNBlock(nn.Module):
74
+ """
75
+ Conv1d --> GroupNorm --> scale,shift --> Mish
76
+ """
77
+
78
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4):
79
+ super().__init__()
80
+ self.out_channels = out_channels
81
+ self.block = nn.Conv1d(
82
+ inp_channels, out_channels, kernel_size, padding=kernel_size // 2
83
+ )
84
+ self.group_norm = nn.GroupNorm(n_groups, out_channels)
85
+ self.avtication = nn.Mish()
86
+
87
+ def forward(self, x, scale, shift):
88
+ """
89
+ Args:
90
+ x: [bs, nfeat, nframes]
91
+ scale: [bs, out_feat, 1]
92
+ shift: [bs, out_feat, 1]
93
+ """
94
+ x = self.block(x)
95
+
96
+ batch_size, channels, horizon = x.size()
97
+ x = rearrange(
98
+ x, "batch channels horizon -> (batch horizon) channels"
99
+ ) # [bs*seq, nfeats]
100
+ x = self.group_norm(x)
101
+ x = rearrange(
102
+ x.reshape(batch_size, horizon, channels),
103
+ "batch horizon channels -> batch channels horizon",
104
+ )
105
+ x = ada_shift_scale(x, shift, scale)
106
+
107
+ return self.avtication(x)
108
+
109
+
110
+ class SelfAttention(nn.Module):
111
+
112
+ def __init__(
113
+ self,
114
+ latent_dim,
115
+ text_latent_dim,
116
+ num_heads: int = 8,
117
+ dropout: float = 0.0,
118
+ log_attn=False,
119
+ edit_config=None,
120
+ ):
121
+ super().__init__()
122
+ self.num_head = num_heads
123
+ self.norm = nn.LayerNorm(latent_dim)
124
+ self.query = nn.Linear(latent_dim, latent_dim)
125
+ self.key = nn.Linear(latent_dim, latent_dim)
126
+ self.value = nn.Linear(latent_dim, latent_dim)
127
+ self.dropout = nn.Dropout(dropout)
128
+
129
+ self.edit_config = edit_config
130
+ self.log_attn = log_attn
131
+
132
+ def forward(self, x):
133
+ """
134
+ x: B, T, D
135
+ xf: B, N, L
136
+ """
137
+ B, T, D = x.shape
138
+ N = x.shape[1]
139
+ assert N == T
140
+ H = self.num_head
141
+
142
+ # B, T, 1, D
143
+ query = self.query(self.norm(x)).unsqueeze(2)
144
+ # B, 1, N, D
145
+ key = self.key(self.norm(x)).unsqueeze(1)
146
+ query = query.view(B, T, H, -1)
147
+ key = key.view(B, N, H, -1)
148
+
149
+ # style transfer motion editing
150
+ style_tranfer = self.edit_config.style_tranfer.use
151
+ if style_tranfer:
152
+ if (
153
+ len(SELF_ATTN)
154
+ <= self.edit_config.style_tranfer.style_transfer_steps_end
155
+ ):
156
+ query[1] = query[0]
157
+
158
+ # example based motion generation
159
+ example_based = self.edit_config.example_based.use
160
+ if example_based:
161
+ if len(SELF_ATTN) == self.edit_config.example_based.example_based_steps_end:
162
+
163
+ temp_seed = self.edit_config.example_based.temp_seed
164
+ for id_ in range(query.shape[0] - 1):
165
+ with torch.random.fork_rng():
166
+ torch.manual_seed(temp_seed)
167
+ tensor = query[0]
168
+ chunks = torch.split(
169
+ tensor, self.edit_config.example_based.chunk_size, dim=0
170
+ )
171
+ shuffled_indices = torch.randperm(len(chunks))
172
+ shuffled_chunks = [chunks[i] for i in shuffled_indices]
173
+ shuffled_tensor = torch.cat(shuffled_chunks, dim=0)
174
+ query[1 + id_] = shuffled_tensor
175
+ temp_seed += self.edit_config.example_based.temp_seed_bar
176
+
177
+ # time shift motion editing (q, k)
178
+ time_shift = self.edit_config.time_shift.use
179
+ if time_shift:
180
+ if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end:
181
+ part1 = int(
182
+ key.shape[1] * self.edit_config.time_shift.time_shift_ratio // 1
183
+ )
184
+ part2 = int(
185
+ key.shape[1]
186
+ * (1 - self.edit_config.time_shift.time_shift_ratio)
187
+ // 1
188
+ )
189
+ q_front_part = query[0, :part1, :, :]
190
+ q_back_part = query[0, -part2:, :, :]
191
+
192
+ new_q = torch.cat((q_back_part, q_front_part), dim=0)
193
+ query[1] = new_q
194
+
195
+ k_front_part = key[0, :part1, :, :]
196
+ k_back_part = key[0, -part2:, :, :]
197
+ new_k = torch.cat((k_back_part, k_front_part), dim=0)
198
+ key[1] = new_k
199
+
200
+ # B, T, N, H
201
+ attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H)
202
+ weight = self.dropout(F.softmax(attention, dim=2))
203
+
204
+ # for counting the step and logging attention maps
205
+ try:
206
+ attention_matrix = (
207
+ weight[0, :, :].mean(dim=-1).detach().cpu().numpy().astype(float)
208
+ )
209
+ SELF_ATTN[-1].append(attention_matrix)
210
+ except:
211
+ pass
212
+
213
+ # attention manipulation for replacement
214
+ attention_manipulation = self.edit_config.manipulation.use
215
+ if attention_manipulation:
216
+ if len(SELF_ATTN) <= self.edit_config.manipulation.manipulation_steps_end:
217
+ weight[1, :, :, :] = weight[0, :, :, :]
218
+
219
+ value = self.value(self.norm(x)).view(B, N, H, -1)
220
+
221
+ # time shift motion editing (v)
222
+ if time_shift:
223
+ if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end:
224
+ v_front_part = value[0, :part1, :, :]
225
+ v_back_part = value[0, -part2:, :, :]
226
+ new_v = torch.cat((v_back_part, v_front_part), dim=0)
227
+ value[1] = new_v
228
+ y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D)
229
+ return y
230
+
231
+
232
+ class TimestepEmbedder(nn.Module):
233
+ def __init__(self, d_model, max_len=5000):
234
+ super(TimestepEmbedder, self).__init__()
235
+
236
+ pe = torch.zeros(max_len, d_model)
237
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
238
+ div_term = torch.exp(
239
+ torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model)
240
+ )
241
+ pe[:, 0::2] = torch.sin(position * div_term)
242
+ pe[:, 1::2] = torch.cos(position * div_term)
243
+
244
+ self.register_buffer("pe", pe)
245
+
246
+ def forward(self, x):
247
+ return self.pe[x]
248
+
249
+
250
+ class Downsample1d(nn.Module):
251
+ def __init__(self, dim):
252
+ super().__init__()
253
+ self.conv = nn.Conv1d(dim, dim, 3, 2, 1)
254
+
255
+ def forward(self, x):
256
+ return self.conv(x)
257
+
258
+
259
+ class Upsample1d(nn.Module):
260
+ def __init__(self, dim_in, dim_out=None):
261
+ super().__init__()
262
+ dim_out = dim_out or dim_in
263
+ self.conv = nn.ConvTranspose1d(dim_in, dim_out, 4, 2, 1)
264
+
265
+ def forward(self, x):
266
+ return self.conv(x)
267
+
268
+
269
+ class Conv1dBlock(nn.Module):
270
+ """
271
+ Conv1d --> GroupNorm --> Mish
272
+ """
273
+
274
+ def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4, zero=False):
275
+ super().__init__()
276
+ self.out_channels = out_channels
277
+ self.block = nn.Conv1d(
278
+ inp_channels, out_channels, kernel_size, padding=kernel_size // 2
279
+ )
280
+ self.norm = nn.GroupNorm(n_groups, out_channels)
281
+ self.activation = nn.Mish()
282
+
283
+ if zero:
284
+ # zero init the convolution
285
+ nn.init.zeros_(self.block.weight)
286
+ nn.init.zeros_(self.block.bias)
287
+
288
+ def forward(self, x):
289
+ """
290
+ Args:
291
+ x: [bs, nfeat, nframes]
292
+ """
293
+ x = self.block(x)
294
+
295
+ batch_size, channels, horizon = x.size()
296
+ x = rearrange(
297
+ x, "batch channels horizon -> (batch horizon) channels"
298
+ ) # [bs*seq, nfeats]
299
+ x = self.norm(x)
300
+ x = rearrange(
301
+ x.reshape(batch_size, horizon, channels),
302
+ "batch horizon channels -> batch channels horizon",
303
+ )
304
+
305
+ return self.activation(x)
306
+
307
+
308
+ def ada_shift_scale(x, shift, scale):
309
+ return x * (1 + scale) + shift
310
+
311
+
312
+ class ResidualTemporalBlock(nn.Module):
313
+ def __init__(
314
+ self,
315
+ inp_channels,
316
+ out_channels,
317
+ embed_dim,
318
+ kernel_size=5,
319
+ zero=True,
320
+ n_groups=8,
321
+ dropout: float = 0.1,
322
+ adagn=True,
323
+ ):
324
+ super().__init__()
325
+ self.adagn = adagn
326
+
327
+ self.blocks = nn.ModuleList(
328
+ [
329
+ # adagn only the first conv (following guided-diffusion)
330
+ (
331
+ Conv1dAdaGNBlock(inp_channels, out_channels, kernel_size, n_groups)
332
+ if adagn
333
+ else Conv1dBlock(inp_channels, out_channels, kernel_size)
334
+ ),
335
+ Conv1dBlock(
336
+ out_channels, out_channels, kernel_size, n_groups, zero=zero
337
+ ),
338
+ ]
339
+ )
340
+
341
+ self.time_mlp = nn.Sequential(
342
+ nn.Mish(),
343
+ # adagn = scale and shift
344
+ nn.Linear(embed_dim, out_channels * 2 if adagn else out_channels),
345
+ Rearrange("batch t -> batch t 1"),
346
+ )
347
+ self.dropout = nn.Dropout(dropout)
348
+ if zero:
349
+ nn.init.zeros_(self.time_mlp[1].weight)
350
+ nn.init.zeros_(self.time_mlp[1].bias)
351
+
352
+ self.residual_conv = (
353
+ nn.Conv1d(inp_channels, out_channels, 1)
354
+ if inp_channels != out_channels
355
+ else nn.Identity()
356
+ )
357
+
358
+ def forward(self, x, time_embeds=None):
359
+ """
360
+ x : [ batch_size x inp_channels x nframes ]
361
+ t : [ batch_size x embed_dim ]
362
+ returns: [ batch_size x out_channels x nframes ]
363
+ """
364
+ if self.adagn:
365
+ scale, shift = self.time_mlp(time_embeds).chunk(2, dim=1)
366
+ out = self.blocks[0](x, scale, shift)
367
+ else:
368
+ out = self.blocks[0](x) + self.time_mlp(time_embeds)
369
+ out = self.blocks[1](out)
370
+ out = self.dropout(out)
371
+ return out + self.residual_conv(x)
372
+
373
+
374
+ class CrossAttention(nn.Module):
375
+
376
+ def __init__(
377
+ self,
378
+ latent_dim,
379
+ text_latent_dim,
380
+ num_heads: int = 8,
381
+ dropout: float = 0.0,
382
+ log_attn=False,
383
+ edit_config=None,
384
+ ):
385
+ super().__init__()
386
+ self.num_head = num_heads
387
+ self.norm = nn.LayerNorm(latent_dim)
388
+ self.text_norm = nn.LayerNorm(text_latent_dim)
389
+ self.query = nn.Linear(latent_dim, latent_dim)
390
+ self.key = nn.Linear(text_latent_dim, latent_dim)
391
+ self.value = nn.Linear(text_latent_dim, latent_dim)
392
+ self.dropout = nn.Dropout(dropout)
393
+
394
+ self.edit_config = edit_config
395
+ self.log_attn = log_attn
396
+
397
+ def forward(self, x, xf):
398
+ """
399
+ x: B, T, D
400
+ xf: B, N, L
401
+ """
402
+ B, T, D = x.shape
403
+ N = xf.shape[1]
404
+ H = self.num_head
405
+ # B, T, 1, D
406
+ query = self.query(self.norm(x)).unsqueeze(2)
407
+ # B, 1, N, D
408
+ key = self.key(self.text_norm(xf)).unsqueeze(1)
409
+ query = query.view(B, T, H, -1)
410
+ key = key.view(B, N, H, -1)
411
+ # B, T, N, H
412
+ attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H)
413
+ weight = self.dropout(F.softmax(attention, dim=2))
414
+
415
+ # attention reweighting for (de)-emphasizing motion
416
+ if self.edit_config.reweighting_attn.use:
417
+ reweighting_attn = self.edit_config.reweighting_attn.reweighting_attn_weight
418
+ if self.edit_config.reweighting_attn.idx == -1:
419
+ # read idxs from txt file
420
+ with open("./assets/reweighting_idx.txt", "r") as f:
421
+ idxs = f.readlines()
422
+ else:
423
+ # gradio demo mode
424
+ idxs = [0, self.edit_config.reweighting_attn.idx]
425
+ idxs = [int(idx) for idx in idxs]
426
+ for i in range(len(idxs)):
427
+ weight[i, :, 1 + idxs[i]] = weight[i, :, 1 + idxs[i]] + reweighting_attn
428
+ weight[i, :, 1 + idxs[i] + 1] = (
429
+ weight[i, :, 1 + idxs[i] + 1] + reweighting_attn
430
+ )
431
+
432
+ # for counting the step and logging attention maps
433
+ try:
434
+ attention_matrix = (
435
+ weight[0, :, 1 : 1 + 3]
436
+ .mean(dim=-1)
437
+ .detach()
438
+ .cpu()
439
+ .numpy()
440
+ .astype(float)
441
+ )
442
+ MONITOR_ATTN[-1].append(attention_matrix)
443
+ except:
444
+ pass
445
+
446
+ # erasing motion (autually is the deemphasizing motion)
447
+ erasing_motion = self.edit_config.erasing_motion.use
448
+ if erasing_motion:
449
+ reweighting_attn = self.edit_config.erasing_motion.erasing_motion_weight
450
+ begin = self.edit_config.erasing_motion.time_start
451
+ end = self.edit_config.erasing_motion.time_end
452
+ idx = self.edit_config.erasing_motion.idx
453
+ if reweighting_attn > 0.01 or reweighting_attn < -0.01:
454
+ weight[1, int(T * begin) : int(T * end), idx] = (
455
+ weight[1, int(T * begin) : int(T * end) :, idx] * reweighting_attn
456
+ )
457
+ weight[1, int(T * begin) : int(T * end), idx + 1] = (
458
+ weight[1, int(T * begin) : int(T * end), idx + 1] * reweighting_attn
459
+ )
460
+
461
+ # attention manipulation for motion replacement
462
+ manipulation = self.edit_config.manipulation.use
463
+ if manipulation:
464
+ if (
465
+ len(MONITOR_ATTN)
466
+ <= self.edit_config.manipulation.manipulation_steps_end_crossattn
467
+ ):
468
+ word_idx = self.edit_config.manipulation.word_idx
469
+ weight[1, :, : 1 + word_idx, :] = weight[0, :, : 1 + word_idx, :]
470
+ weight[1, :, 1 + word_idx + 1 :, :] = weight[
471
+ 0, :, 1 + word_idx + 1 :, :
472
+ ]
473
+
474
+ value = self.value(self.text_norm(xf)).view(B, N, H, -1)
475
+ y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D)
476
+ return y
477
+
478
+
479
+ class ResidualCLRAttentionLayer(nn.Module):
480
+ def __init__(
481
+ self,
482
+ dim1,
483
+ dim2,
484
+ num_heads: int = 8,
485
+ dropout: float = 0.1,
486
+ no_eff: bool = False,
487
+ self_attention: bool = False,
488
+ log_attn=False,
489
+ edit_config=None,
490
+ ):
491
+ super(ResidualCLRAttentionLayer, self).__init__()
492
+ self.dim1 = dim1
493
+ self.dim2 = dim2
494
+ self.num_heads = num_heads
495
+
496
+ # Multi-Head Attention Layer
497
+ if no_eff:
498
+ self.cross_attention = CrossAttention(
499
+ latent_dim=dim1,
500
+ text_latent_dim=dim2,
501
+ num_heads=num_heads,
502
+ dropout=dropout,
503
+ log_attn=log_attn,
504
+ edit_config=edit_config,
505
+ )
506
+ else:
507
+ self.cross_attention = LinearCrossAttention(
508
+ latent_dim=dim1,
509
+ text_latent_dim=dim2,
510
+ num_heads=num_heads,
511
+ dropout=dropout,
512
+ log_attn=log_attn,
513
+ )
514
+ if self_attention:
515
+ self.self_attn_use = True
516
+ self.self_attention = SelfAttention(
517
+ latent_dim=dim1,
518
+ text_latent_dim=dim2,
519
+ num_heads=num_heads,
520
+ dropout=dropout,
521
+ log_attn=log_attn,
522
+ edit_config=edit_config,
523
+ )
524
+ else:
525
+ self.self_attn_use = False
526
+
527
+ def forward(self, input_tensor, condition_tensor, cond_indices):
528
+ """
529
+ input_tensor :B, D, L
530
+ condition_tensor: B, L, D
531
+ """
532
+ if cond_indices.numel() == 0:
533
+ return input_tensor
534
+
535
+ # self attention
536
+ if self.self_attn_use:
537
+ x = input_tensor
538
+ x = x.permute(0, 2, 1) # (batch_size, seq_length, feat_dim)
539
+ x = self.self_attention(x)
540
+ x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length)
541
+ input_tensor = input_tensor + x
542
+ x = input_tensor
543
+
544
+ # cross attention
545
+ x = x[cond_indices].permute(0, 2, 1) # (batch_size, seq_length, feat_dim)
546
+ x = self.cross_attention(x, condition_tensor[cond_indices])
547
+ x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length)
548
+
549
+ input_tensor[cond_indices] = input_tensor[cond_indices] + x
550
+
551
+ return input_tensor
552
+
553
+
554
+ class CLRBlock(nn.Module):
555
+ def __init__(
556
+ self,
557
+ dim_in,
558
+ dim_out,
559
+ cond_dim,
560
+ time_dim,
561
+ adagn=True,
562
+ zero=True,
563
+ no_eff=False,
564
+ self_attention=False,
565
+ dropout: float = 0.1,
566
+ log_attn=False,
567
+ edit_config=None,
568
+ ) -> None:
569
+ super().__init__()
570
+ self.conv1d = ResidualTemporalBlock(
571
+ dim_in, dim_out, embed_dim=time_dim, adagn=adagn, zero=zero, dropout=dropout
572
+ )
573
+ self.clr_attn = ResidualCLRAttentionLayer(
574
+ dim1=dim_out,
575
+ dim2=cond_dim,
576
+ no_eff=no_eff,
577
+ dropout=dropout,
578
+ self_attention=self_attention,
579
+ log_attn=log_attn,
580
+ edit_config=edit_config,
581
+ )
582
+ # import pdb; pdb.set_trace()
583
+ self.ffn = FFN(dim_out, dim_out * 4, dropout=dropout)
584
+
585
+ def forward(self, x, t, cond, cond_indices=None):
586
+ x = self.conv1d(x, t)
587
+ x = self.clr_attn(x, cond, cond_indices)
588
+ x = self.ffn(x.permute(0, 2, 1)).permute(0, 2, 1)
589
+ return x
590
+
591
+
592
+ class CondUnet1D(nn.Module):
593
+ """
594
+ Diffusion's style UNET with 1D convolution and adaptive group normalization for motion suquence denoising,
595
+ cross-attention to introduce conditional prompts (like text).
596
+ """
597
+
598
+ def __init__(
599
+ self,
600
+ input_dim,
601
+ cond_dim,
602
+ dim=128,
603
+ dim_mults=(1, 2, 4, 8),
604
+ dims=None,
605
+ time_dim=512,
606
+ adagn=True,
607
+ zero=True,
608
+ dropout=0.1,
609
+ no_eff=False,
610
+ self_attention=False,
611
+ log_attn=False,
612
+ edit_config=None,
613
+ ):
614
+ super().__init__()
615
+ if not dims:
616
+ dims = [input_dim, *map(lambda m: int(dim * m), dim_mults)] ##[d, d,2d,4d]
617
+ print("dims: ", dims, "mults: ", dim_mults)
618
+ in_out = list(zip(dims[:-1], dims[1:]))
619
+
620
+ self.time_mlp = nn.Sequential(
621
+ TimestepEmbedder(time_dim),
622
+ nn.Linear(time_dim, time_dim * 4),
623
+ nn.Mish(),
624
+ nn.Linear(time_dim * 4, time_dim),
625
+ )
626
+
627
+ self.downs = nn.ModuleList([])
628
+ self.ups = nn.ModuleList([])
629
+
630
+ for ind, (dim_in, dim_out) in enumerate(in_out):
631
+ self.downs.append(
632
+ nn.ModuleList(
633
+ [
634
+ CLRBlock(
635
+ dim_in,
636
+ dim_out,
637
+ cond_dim,
638
+ time_dim,
639
+ adagn=adagn,
640
+ zero=zero,
641
+ no_eff=no_eff,
642
+ dropout=dropout,
643
+ self_attention=self_attention,
644
+ log_attn=log_attn,
645
+ edit_config=edit_config,
646
+ ),
647
+ CLRBlock(
648
+ dim_out,
649
+ dim_out,
650
+ cond_dim,
651
+ time_dim,
652
+ adagn=adagn,
653
+ zero=zero,
654
+ no_eff=no_eff,
655
+ dropout=dropout,
656
+ self_attention=self_attention,
657
+ log_attn=log_attn,
658
+ edit_config=edit_config,
659
+ ),
660
+ Downsample1d(dim_out),
661
+ ]
662
+ )
663
+ )
664
+
665
+ mid_dim = dims[-1]
666
+ self.mid_block1 = CLRBlock(
667
+ dim_in=mid_dim,
668
+ dim_out=mid_dim,
669
+ cond_dim=cond_dim,
670
+ time_dim=time_dim,
671
+ adagn=adagn,
672
+ zero=zero,
673
+ no_eff=no_eff,
674
+ dropout=dropout,
675
+ self_attention=self_attention,
676
+ log_attn=log_attn,
677
+ edit_config=edit_config,
678
+ )
679
+ self.mid_block2 = CLRBlock(
680
+ dim_in=mid_dim,
681
+ dim_out=mid_dim,
682
+ cond_dim=cond_dim,
683
+ time_dim=time_dim,
684
+ adagn=adagn,
685
+ zero=zero,
686
+ no_eff=no_eff,
687
+ dropout=dropout,
688
+ self_attention=self_attention,
689
+ log_attn=log_attn,
690
+ edit_config=edit_config,
691
+ )
692
+
693
+ last_dim = mid_dim
694
+ for ind, dim_out in enumerate(reversed(dims[1:])):
695
+ self.ups.append(
696
+ nn.ModuleList(
697
+ [
698
+ Upsample1d(last_dim, dim_out),
699
+ CLRBlock(
700
+ dim_out * 2,
701
+ dim_out,
702
+ cond_dim,
703
+ time_dim,
704
+ adagn=adagn,
705
+ zero=zero,
706
+ no_eff=no_eff,
707
+ dropout=dropout,
708
+ self_attention=self_attention,
709
+ log_attn=log_attn,
710
+ edit_config=edit_config,
711
+ ),
712
+ CLRBlock(
713
+ dim_out,
714
+ dim_out,
715
+ cond_dim,
716
+ time_dim,
717
+ adagn=adagn,
718
+ zero=zero,
719
+ no_eff=no_eff,
720
+ dropout=dropout,
721
+ self_attention=self_attention,
722
+ log_attn=log_attn,
723
+ edit_config=edit_config,
724
+ ),
725
+ ]
726
+ )
727
+ )
728
+ last_dim = dim_out
729
+ self.final_conv = nn.Conv1d(dim_out, input_dim, 1)
730
+
731
+ if zero:
732
+ nn.init.zeros_(self.final_conv.weight)
733
+ nn.init.zeros_(self.final_conv.bias)
734
+
735
+ def forward(
736
+ self,
737
+ x,
738
+ t,
739
+ cond,
740
+ cond_indices,
741
+ ):
742
+ temb = self.time_mlp(t)
743
+
744
+ h = []
745
+ for block1, block2, downsample in self.downs:
746
+ x = block1(x, temb, cond, cond_indices)
747
+ x = block2(x, temb, cond, cond_indices)
748
+ h.append(x)
749
+ x = downsample(x)
750
+
751
+ x = self.mid_block1(x, temb, cond, cond_indices)
752
+ x = self.mid_block2(x, temb, cond, cond_indices)
753
+
754
+ for upsample, block1, block2 in self.ups:
755
+ x = upsample(x)
756
+ x = torch.cat((x, h.pop()), dim=1)
757
+ x = block1(x, temb, cond, cond_indices)
758
+ x = block2(x, temb, cond, cond_indices)
759
+
760
+ x = self.final_conv(x)
761
+ return x
762
+
763
+
764
+ class MotionCLR(nn.Module):
765
+ """
766
+ Diffuser's style UNET for text-to-motion task.
767
+ """
768
+
769
+ def __init__(
770
+ self,
771
+ input_feats,
772
+ base_dim=128,
773
+ dim_mults=(1, 2, 2, 2),
774
+ dims=None,
775
+ adagn=True,
776
+ zero=True,
777
+ dropout=0.1,
778
+ no_eff=False,
779
+ time_dim=512,
780
+ latent_dim=256,
781
+ cond_mask_prob=0.1,
782
+ clip_dim=512,
783
+ clip_version="ViT-B/32",
784
+ text_latent_dim=256,
785
+ text_ff_size=2048,
786
+ text_num_heads=4,
787
+ activation="gelu",
788
+ num_text_layers=4,
789
+ self_attention=False,
790
+ vis_attn=False,
791
+ edit_config=None,
792
+ out_path=None,
793
+ ):
794
+ super().__init__()
795
+ self.input_feats = input_feats
796
+ self.dim_mults = dim_mults
797
+ self.base_dim = base_dim
798
+ self.latent_dim = latent_dim
799
+ self.cond_mask_prob = cond_mask_prob
800
+ self.vis_attn = vis_attn
801
+ self.counting_map = []
802
+ self.out_path = out_path
803
+
804
+ print(
805
+ f"The T2M Unet mask the text prompt by {self.cond_mask_prob} prob. in training"
806
+ )
807
+
808
+ # text encoder
809
+ self.embed_text = nn.Linear(clip_dim, text_latent_dim)
810
+ self.clip_version = clip_version
811
+ self.clip_model = self.load_and_freeze_clip(clip_version)
812
+ textTransEncoderLayer = nn.TransformerEncoderLayer(
813
+ d_model=text_latent_dim,
814
+ nhead=text_num_heads,
815
+ dim_feedforward=text_ff_size,
816
+ dropout=dropout,
817
+ activation=activation,
818
+ )
819
+ self.textTransEncoder = nn.TransformerEncoder(
820
+ textTransEncoderLayer, num_layers=num_text_layers
821
+ )
822
+ self.text_ln = nn.LayerNorm(text_latent_dim)
823
+
824
+ self.unet = CondUnet1D(
825
+ input_dim=self.input_feats,
826
+ cond_dim=text_latent_dim,
827
+ dim=self.base_dim,
828
+ dim_mults=self.dim_mults,
829
+ adagn=adagn,
830
+ zero=zero,
831
+ dropout=dropout,
832
+ no_eff=no_eff,
833
+ dims=dims,
834
+ time_dim=time_dim,
835
+ self_attention=self_attention,
836
+ log_attn=self.vis_attn,
837
+ edit_config=edit_config,
838
+ )
839
+
840
+ def encode_text(self, raw_text, device):
841
+ with torch.no_grad():
842
+ texts = clip.tokenize(raw_text, truncate=True).to(
843
+ device
844
+ ) # [bs, context_length] # if n_tokens > 77 -> will truncate
845
+ x = self.clip_model.token_embedding(texts).type(
846
+ self.clip_model.dtype
847
+ ) # [batch_size, n_ctx, d_model]
848
+ x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype)
849
+ x = x.permute(1, 0, 2) # NLD -> LND
850
+ x = self.clip_model.transformer(x)
851
+ x = self.clip_model.ln_final(x).type(
852
+ self.clip_model.dtype
853
+ ) # [len, batch_size, 512]
854
+
855
+ x = self.embed_text(x) # [len, batch_size, 256]
856
+ x = self.textTransEncoder(x)
857
+ x = self.text_ln(x)
858
+
859
+ # T, B, D -> B, T, D
860
+ xf_out = x.permute(1, 0, 2)
861
+
862
+ ablation_text = False
863
+ if ablation_text:
864
+ xf_out[:, 1:, :] = xf_out[:, 0, :].unsqueeze(1)
865
+ return xf_out
866
+
867
+ def load_and_freeze_clip(self, clip_version):
868
+ clip_model, _ = clip.load( # clip_model.dtype=float32
869
+ clip_version, device="cpu", jit=False
870
+ ) # Must set jit=False for training
871
+
872
+ # Freeze CLIP weights
873
+ clip_model.eval()
874
+ for p in clip_model.parameters():
875
+ p.requires_grad = False
876
+
877
+ return clip_model
878
+
879
+ def mask_cond(self, bs, force_mask=False):
880
+ """
881
+ mask motion condition , return contitional motion index in the batch
882
+ """
883
+ if force_mask:
884
+ cond_indices = torch.empty(0)
885
+ elif self.training and self.cond_mask_prob > 0.0:
886
+ mask = torch.bernoulli(
887
+ torch.ones(
888
+ bs,
889
+ )
890
+ * self.cond_mask_prob
891
+ ) # 1-> use null_cond, 0-> use real cond
892
+ mask = 1.0 - mask
893
+ cond_indices = torch.nonzero(mask).squeeze(-1)
894
+ else:
895
+ cond_indices = torch.arange(bs)
896
+
897
+ return cond_indices
898
+
899
+ def forward(
900
+ self,
901
+ x,
902
+ timesteps,
903
+ text=None,
904
+ uncond=False,
905
+ enc_text=None,
906
+ ):
907
+ """
908
+ Args:
909
+ x: [batch_size, nframes, nfeats],
910
+ timesteps: [batch_size] (int)
911
+ text: list (batch_size length) of strings with input text prompts
912
+ uncond: whethere using text condition
913
+
914
+ Returns: [batch_size, seq_length, nfeats]
915
+ """
916
+ B, T, _ = x.shape
917
+ x = x.transpose(1, 2) # [bs, nfeats, nframes]
918
+
919
+ if enc_text is None:
920
+ enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim]
921
+
922
+ cond_indices = self.mask_cond(x.shape[0], force_mask=uncond)
923
+
924
+ # NOTE: need to pad to be the multiplier of 8 for the unet
925
+ PADDING_NEEEDED = (16 - (T % 16)) % 16
926
+
927
+ padding = (0, PADDING_NEEEDED)
928
+ x = F.pad(x, padding, value=0)
929
+
930
+ x = self.unet(
931
+ x,
932
+ t=timesteps,
933
+ cond=enc_text,
934
+ cond_indices=cond_indices,
935
+ ) # [bs, nfeats,, nframes]
936
+
937
+ x = x[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,]
938
+
939
+ return x
940
+
941
+ def forward_with_cfg(self, x, timesteps, text=None, enc_text=None, cfg_scale=2.5):
942
+ """
943
+ Args:
944
+ x: [batch_size, nframes, nfeats],
945
+ timesteps: [batch_size] (int)
946
+ text: list (batch_size length) of strings with input text prompts
947
+
948
+ Returns: [batch_size, max_frames, nfeats]
949
+ """
950
+ global SELF_ATTN
951
+ global MONITOR_ATTN
952
+ MONITOR_ATTN.append([])
953
+ SELF_ATTN.append([])
954
+
955
+ B, T, _ = x.shape
956
+ x = x.transpose(1, 2) # [bs, nfeats, nframes]
957
+ if enc_text is None:
958
+ enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim]
959
+
960
+ cond_indices = self.mask_cond(B)
961
+
962
+ # NOTE: need to pad to be the multiplier of 8 for the unet
963
+ PADDING_NEEEDED = (16 - (T % 16)) % 16
964
+
965
+ padding = (0, PADDING_NEEEDED)
966
+ x = F.pad(x, padding, value=0)
967
+
968
+ combined_x = torch.cat([x, x], dim=0)
969
+ combined_t = torch.cat([timesteps, timesteps], dim=0)
970
+ out = self.unet(
971
+ x=combined_x,
972
+ t=combined_t,
973
+ cond=enc_text,
974
+ cond_indices=cond_indices,
975
+ ) # [bs, nfeats, nframes]
976
+
977
+ out = out[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,]
978
+
979
+ out_cond, out_uncond = torch.split(out, len(out) // 2, dim=0)
980
+
981
+ if self.vis_attn == True:
982
+ i = len(MONITOR_ATTN)
983
+ attnlist = MONITOR_ATTN[-1]
984
+ print(i, "cross", len(attnlist))
985
+ for j, att in enumerate(attnlist):
986
+ vis_attn(
987
+ att,
988
+ out_path=self.out_path,
989
+ step=i,
990
+ layer=j,
991
+ shape="_".join(map(str, att.shape)),
992
+ type_="cross",
993
+ )
994
+
995
+ attnlist = SELF_ATTN[-1]
996
+ print(i, "self", len(attnlist))
997
+ for j, att in enumerate(attnlist):
998
+ vis_attn(
999
+ att,
1000
+ out_path=self.out_path,
1001
+ step=i,
1002
+ layer=j,
1003
+ shape="_".join(map(str, att.shape)),
1004
+ type_="self",
1005
+ lines=False,
1006
+ )
1007
+
1008
+ if len(SELF_ATTN) % 10 == 0:
1009
+ SELF_ATTN = []
1010
+ MONITOR_ATTN = []
1011
+
1012
+ return out_uncond + (cfg_scale * (out_cond - out_uncond))
1013
+
1014
+
1015
+ if __name__ == "__main__":
1016
+
1017
+ device = "cuda:0"
1018
+ n_feats = 263
1019
+ num_frames = 196
1020
+ text_latent_dim = 256
1021
+ dim_mults = [2, 2, 2, 2]
1022
+ base_dim = 512
1023
+ model = MotionCLR(
1024
+ input_feats=n_feats,
1025
+ text_latent_dim=text_latent_dim,
1026
+ base_dim=base_dim,
1027
+ dim_mults=dim_mults,
1028
+ adagn=True,
1029
+ zero=True,
1030
+ dropout=0.1,
1031
+ no_eff=True,
1032
+ cond_mask_prob=0.1,
1033
+ self_attention=True,
1034
+ )
1035
+
1036
+ model = model.to(device)
1037
+ from utils.model_load import load_model_weights
1038
+
1039
+ checkpoint_path = "/comp_robot/chenlinghao/StableMoFusion/checkpoints/t2m/self_attn—fulllayer-ffn-drop0_1-lr1e4/model/latest.tar"
1040
+ new_state_dict = {}
1041
+ checkpoint = torch.load(checkpoint_path)
1042
+ ckpt2 = checkpoint.copy()
1043
+ ckpt2["model_ema"] = {}
1044
+ ckpt2["encoder"] = {}
1045
+
1046
+ for key, value in list(checkpoint["model_ema"].items()):
1047
+ new_key = key.replace(
1048
+ "cross_attn", "clr_attn"
1049
+ ) # Replace 'cross_attn' with 'clr_attn'
1050
+ ckpt2["model_ema"][new_key] = value
1051
+ for key, value in list(checkpoint["encoder"].items()):
1052
+ new_key = key.replace(
1053
+ "cross_attn", "clr_attn"
1054
+ ) # Replace 'cross_attn' with 'clr_attn'
1055
+ ckpt2["encoder"][new_key] = value
1056
+
1057
+ torch.save(
1058
+ ckpt2,
1059
+ "/comp_robot/chenlinghao/CLRpreview/checkpoints/t2m/release/model/latest.tar",
1060
+ )
1061
+
1062
+ dtype = torch.float32
1063
+ bs = 1
1064
+ x = torch.rand((bs, 196, 263), dtype=dtype).to(device)
1065
+ timesteps = torch.randint(low=0, high=1000, size=(bs,)).to(device)
1066
+ y = ["A man jumps to his left." for i in range(bs)]
1067
+ length = torch.randint(low=20, high=196, size=(bs,)).to(device)
1068
+
1069
+ out = model(x, timesteps, text=y)
1070
+ print(out.shape)
1071
+ model.eval()
1072
+ out = model.forward_with_cfg(x, timesteps, text=y)
1073
+ print(out.shape)
motion_loader/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .model_motion_loaders import get_motion_loader
2
+ from .dataset_motion_loaders import get_dataset_loader
motion_loader/dataset_motion_loaders.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from datasets import get_dataset
2
+ from torch.utils.data import DataLoader
3
+ from torch.utils.data._utils.collate import default_collate
4
+
5
+
6
+ def collate_fn(batch):
7
+ batch.sort(key=lambda x: x[3], reverse=True)
8
+ return default_collate(batch)
9
+
10
+
11
+ def get_dataset_loader(opt, batch_size, mode="eval", split="test", accelerator=None):
12
+ dataset = get_dataset(opt, split, mode, accelerator)
13
+ if mode in ["eval", "gt_eval"]:
14
+ dataloader = DataLoader(
15
+ dataset,
16
+ batch_size=batch_size,
17
+ shuffle=True,
18
+ num_workers=4,
19
+ drop_last=True,
20
+ collate_fn=collate_fn,
21
+ )
22
+ else:
23
+ dataloader = DataLoader(
24
+ dataset,
25
+ batch_size=batch_size,
26
+ shuffle=True,
27
+ num_workers=4,
28
+ drop_last=True,
29
+ persistent_workers=True,
30
+ )
31
+ return dataloader
motion_loader/model_motion_loaders.py ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from utils.word_vectorizer import WordVectorizer
3
+ from torch.utils.data import Dataset, DataLoader
4
+ from os.path import join as pjoin
5
+ from tqdm import tqdm
6
+ import numpy as np
7
+ from eval.evaluator_modules import *
8
+
9
+ from torch.utils.data._utils.collate import default_collate
10
+
11
+
12
+ class GeneratedDataset(Dataset):
13
+ """
14
+ opt.dataset_name
15
+ opt.max_motion_length
16
+ opt.unit_length
17
+ """
18
+
19
+ def __init__(
20
+ self, opt, pipeline, dataset, w_vectorizer, mm_num_samples, mm_num_repeats
21
+ ):
22
+ assert mm_num_samples < len(dataset)
23
+ self.dataset = dataset
24
+ dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True)
25
+ generated_motion = []
26
+ min_mov_length = 10 if opt.dataset_name == "t2m" else 6
27
+
28
+ # Pre-process all target captions
29
+ mm_generated_motions = []
30
+ if mm_num_samples > 0:
31
+ mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False)
32
+ mm_idxs = np.sort(mm_idxs)
33
+
34
+ all_caption = []
35
+ all_m_lens = []
36
+ all_data = []
37
+ with torch.no_grad():
38
+ for i, data in tqdm(enumerate(dataloader)):
39
+ word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
40
+ all_data.append(data)
41
+ tokens = tokens[0].split("_")
42
+ mm_num_now = len(mm_generated_motions)
43
+ is_mm = (
44
+ True
45
+ if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now]))
46
+ else False
47
+ )
48
+ repeat_times = mm_num_repeats if is_mm else 1
49
+ m_lens = max(
50
+ torch.div(m_lens, opt.unit_length, rounding_mode="trunc")
51
+ * opt.unit_length,
52
+ min_mov_length * opt.unit_length,
53
+ )
54
+ m_lens = min(m_lens, opt.max_motion_length)
55
+ if isinstance(m_lens, int):
56
+ m_lens = torch.LongTensor([m_lens]).to(opt.device)
57
+ else:
58
+ m_lens = m_lens.to(opt.device)
59
+ for t in range(repeat_times):
60
+ all_m_lens.append(m_lens)
61
+ all_caption.extend(caption)
62
+ if is_mm:
63
+ mm_generated_motions.append(0)
64
+ all_m_lens = torch.stack(all_m_lens)
65
+
66
+ # Generate all sequences
67
+ with torch.no_grad():
68
+ all_pred_motions, t_eval = pipeline.generate(all_caption, all_m_lens)
69
+ self.eval_generate_time = t_eval
70
+
71
+ cur_idx = 0
72
+ mm_generated_motions = []
73
+ with torch.no_grad():
74
+ for i, data_dummy in tqdm(enumerate(dataloader)):
75
+ data = all_data[i]
76
+ word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data
77
+ tokens = tokens[0].split("_")
78
+ mm_num_now = len(mm_generated_motions)
79
+ is_mm = (
80
+ True
81
+ if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now]))
82
+ else False
83
+ )
84
+ repeat_times = mm_num_repeats if is_mm else 1
85
+ mm_motions = []
86
+ for t in range(repeat_times):
87
+ pred_motions = all_pred_motions[cur_idx]
88
+ cur_idx += 1
89
+ if t == 0:
90
+ sub_dict = {
91
+ "motion": pred_motions.cpu().numpy(),
92
+ "length": pred_motions.shape[0], # m_lens[0].item(), #
93
+ "caption": caption[0],
94
+ "cap_len": cap_lens[0].item(),
95
+ "tokens": tokens,
96
+ }
97
+ generated_motion.append(sub_dict)
98
+
99
+ if is_mm:
100
+ mm_motions.append(
101
+ {
102
+ "motion": pred_motions.cpu().numpy(),
103
+ "length": pred_motions.shape[
104
+ 0
105
+ ], # m_lens[0].item(), #m_lens[0].item()
106
+ }
107
+ )
108
+ if is_mm:
109
+ mm_generated_motions.append(
110
+ {
111
+ "caption": caption[0],
112
+ "tokens": tokens,
113
+ "cap_len": cap_lens[0].item(),
114
+ "mm_motions": mm_motions,
115
+ }
116
+ )
117
+ self.generated_motion = generated_motion
118
+ self.mm_generated_motion = mm_generated_motions
119
+ self.opt = opt
120
+ self.w_vectorizer = w_vectorizer
121
+
122
+ def __len__(self):
123
+ return len(self.generated_motion)
124
+
125
+ def __getitem__(self, item):
126
+ data = self.generated_motion[item]
127
+ motion, m_length, caption, tokens = (
128
+ data["motion"],
129
+ data["length"],
130
+ data["caption"],
131
+ data["tokens"],
132
+ )
133
+ sent_len = data["cap_len"]
134
+
135
+ # This step is needed because T2M evaluators expect their norm convention
136
+ normed_motion = motion
137
+ denormed_motion = self.dataset.inv_transform(normed_motion)
138
+ renormed_motion = (
139
+ denormed_motion - self.dataset.mean_for_eval
140
+ ) / self.dataset.std_for_eval # according to T2M norms
141
+ motion = renormed_motion
142
+
143
+ pos_one_hots = []
144
+ word_embeddings = []
145
+ for token in tokens:
146
+ word_emb, pos_oh = self.w_vectorizer[token]
147
+ pos_one_hots.append(pos_oh[None, :])
148
+ word_embeddings.append(word_emb[None, :])
149
+ pos_one_hots = np.concatenate(pos_one_hots, axis=0)
150
+ word_embeddings = np.concatenate(word_embeddings, axis=0)
151
+ length = len(motion)
152
+ if length < self.opt.max_motion_length:
153
+ motion = np.concatenate(
154
+ [
155
+ motion,
156
+ np.zeros((self.opt.max_motion_length - length, motion.shape[1])),
157
+ ],
158
+ axis=0,
159
+ )
160
+ return (
161
+ word_embeddings,
162
+ pos_one_hots,
163
+ caption,
164
+ sent_len,
165
+ motion,
166
+ m_length,
167
+ "_".join(tokens),
168
+ )
169
+
170
+
171
+ def collate_fn(batch):
172
+ batch.sort(key=lambda x: x[3], reverse=True)
173
+ return default_collate(batch)
174
+
175
+
176
+ class MMGeneratedDataset(Dataset):
177
+ def __init__(self, opt, motion_dataset, w_vectorizer):
178
+ self.opt = opt
179
+ self.dataset = motion_dataset.mm_generated_motion
180
+ self.w_vectorizer = w_vectorizer
181
+
182
+ def __len__(self):
183
+ return len(self.dataset)
184
+
185
+ def __getitem__(self, item):
186
+ data = self.dataset[item]
187
+ mm_motions = data["mm_motions"]
188
+ m_lens = []
189
+ motions = []
190
+ for mm_motion in mm_motions:
191
+ m_lens.append(mm_motion["length"])
192
+ motion = mm_motion["motion"]
193
+ if len(motion) < self.opt.max_motion_length:
194
+ motion = np.concatenate(
195
+ [
196
+ motion,
197
+ np.zeros(
198
+ (self.opt.max_motion_length - len(motion), motion.shape[1])
199
+ ),
200
+ ],
201
+ axis=0,
202
+ )
203
+ motion = motion[None, :]
204
+ motions.append(motion)
205
+ m_lens = np.array(m_lens, dtype=np.int32)
206
+ motions = np.concatenate(motions, axis=0)
207
+ sort_indx = np.argsort(m_lens)[::-1].copy()
208
+
209
+ m_lens = m_lens[sort_indx]
210
+ motions = motions[sort_indx]
211
+ return motions, m_lens
212
+
213
+
214
+ def get_motion_loader(
215
+ opt, batch_size, pipeline, ground_truth_dataset, mm_num_samples, mm_num_repeats
216
+ ):
217
+
218
+ # Currently the configurations of two datasets are almost the same
219
+ if opt.dataset_name == "t2m" or opt.dataset_name == "kit":
220
+ w_vectorizer = WordVectorizer(opt.glove_dir, "our_vab")
221
+ else:
222
+ raise KeyError("Dataset not recognized!!")
223
+
224
+ dataset = GeneratedDataset(
225
+ opt,
226
+ pipeline,
227
+ ground_truth_dataset,
228
+ w_vectorizer,
229
+ mm_num_samples,
230
+ mm_num_repeats,
231
+ )
232
+ mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer)
233
+
234
+ motion_loader = DataLoader(
235
+ dataset,
236
+ batch_size=batch_size,
237
+ collate_fn=collate_fn,
238
+ drop_last=True,
239
+ num_workers=4,
240
+ )
241
+ mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1)
242
+
243
+ return motion_loader, mm_motion_loader, dataset.eval_generate_time
options/__init__.py ADDED
File without changes
options/edit.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # edit.yaml
2
+ reweighting_attn:
3
+ use: False
4
+ reweighting_attn_weight: 0.0 # the weight of reweighting attention for motion emphasizing and de-emphasizing
5
+ idx: -1 # the index of the word to be emphasized or de-emphasized (0 ~ 10)
6
+
7
+ erasing_motion:
8
+ use: False
9
+ erasing_motion_weight: 0.1 # the weight of motion erasing
10
+ time_start: 0.5 # the start time of motion erasing (0.0 ~ 1.0), ratio of the total time
11
+ time_end: 1.0 # the end time of motion erasing (0.0 ~ 1.0), ratio of the total time
12
+ idx: -1
13
+
14
+ manipulation: # motion manipulation means in-place motion replacement
15
+ use: False
16
+ manipulation_steps_start: 0 # the start step of motion manipulation, 0 ~ 10
17
+ manipulation_steps_end: 3 # the end step of motion manipulation, 0 ~ 10
18
+ manipulation_steps_end_crossattn: 3 # the end step of cross-attention for motion manipulation, 0 ~ 10
19
+ word_idx: 3 # the index of the word to be manipulated
20
+
21
+ time_shift:
22
+ use: False
23
+ time_shift_steps_start: 0 # the start step of time shifting, 0 ~ 10
24
+ time_shift_steps_end: 4 # the end step of time shifting, 0 ~ 10
25
+ time_shift_ratio: 0.5 # the ratio of time shifting, 0.0 ~ 1.0
26
+
27
+ example_based:
28
+ use: False
29
+ chunk_size: 20 # the size of the chunk for example-based generation
30
+ example_based_steps_end: 6 # the end step of example-based generation, 0 ~ 10
31
+ temp_seed: 200 # the inintial seed for example-based generation
32
+ temp_seed_bar: 15 # the the seed bar for example-based generation
33
+
34
+ style_tranfer:
35
+ use: False
36
+ style_transfer_steps_start: 0 # the start step of style transfer, 0 ~ 10
37
+ style_transfer_steps_end: 5 # the end step of style transfer, 0 ~ 10
38
+
39
+ grounded_generation:
40
+ use: False
options/evaluate_options.py ADDED
@@ -0,0 +1,53 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from .get_opt import get_opt
3
+ import yaml
4
+
5
+ class TestOptions():
6
+ def __init__(self):
7
+ self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
8
+ self.initialize()
9
+
10
+ def initialize(self):
11
+ self.parser.add_argument("--opt_path", type=str, default='./checkpoints/t2m/t2m_condunet1d_batch64/opt.txt',help='option file path for loading model')
12
+ self.parser.add_argument("--gpu_id", type=int, default=0, help='GPU id')
13
+
14
+ # evaluator
15
+ self.parser.add_argument("--evaluator_dir", type=str, default='./data/checkpoints', help='Directory path where save T2M evaluator\'s checkpoints')
16
+ self.parser.add_argument("--eval_meta_dir", type=str, default='./data', help='Directory path where save T2M evaluator\'s normalization data.')
17
+ self.parser.add_argument("--glove_dir",type=str,default='./data/glove', help='Directory path where save glove')
18
+
19
+ # inference
20
+ self.parser.add_argument("--num_inference_steps", type=int, default=10, help='Number of iterative denoising steps during inference.')
21
+ self.parser.add_argument("--which_ckpt", type=str, default='latest', help='name of checkpoint to load')
22
+ self.parser.add_argument("--diffuser_name", type=str, default='dpmsolver', help='sampler\'s scheduler class name in the diffuser library')
23
+ self.parser.add_argument("--no_ema", action="store_true", help='Where use EMA model in inference')
24
+ self.parser.add_argument("--no_fp16", action="store_true", help='Whether use FP16 in inference')
25
+ self.parser.add_argument('--debug', action="store_true", help='debug mode')
26
+ self.parser.add_argument('--self_attention', action="store_true", help='self_attention use or not')
27
+ self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient linear attention')
28
+ self.parser.add_argument('--vis_attn', action='store_true', help='vis attention value or not')
29
+ self.parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
30
+
31
+ # evaluation
32
+ self.parser.add_argument("--replication_times", type=int, default=1, help='Number of generation rounds for each text description')
33
+ self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size for eval')
34
+ self.parser.add_argument('--diversity_times', type=int, default=300, help='')
35
+ self.parser.add_argument('--mm_num_samples', type=int, default=100, help='Number of samples for evaluating multimodality')
36
+ self.parser.add_argument('--mm_num_repeats', type=int, default=30, help='Number of generation rounds for each text description when evaluating multimodality')
37
+ self.parser.add_argument('--mm_num_times', type=int, default=10, help='')
38
+ self.parser.add_argument('--edit_mode', action='store_true', help='editing mode')
39
+
40
+ def parse(self):
41
+ # load evaluation options
42
+ self.opt = self.parser.parse_args()
43
+ opt_dict = vars(self.opt)
44
+
45
+ # load the model options of T2m evaluator
46
+ with open('./config/evaluator.yaml', 'r') as yaml_file:
47
+ yaml_config = yaml.safe_load(yaml_file)
48
+ opt_dict.update(yaml_config)
49
+
50
+ # load the training options of the selected checkpoint
51
+ get_opt(self.opt, self.opt.opt_path)
52
+
53
+ return self.opt
options/generate_options.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from .get_opt import get_opt
3
+
4
+ class GenerateOptions():
5
+ def __init__(self, app=False):
6
+ self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
7
+ self.initialize()
8
+
9
+ def initialize(self):
10
+ self.parser.add_argument("--opt_path", type=str, default='./checkpoints/t2m/t2m_condunet1d_batch64/opt.txt', help='option file path for loading model')
11
+ self.parser.add_argument("--gpu_id", type=int, default=0, help='GPU id')
12
+ self.parser.add_argument("--output_dir", type=str, default='', help='Directory path to save generation result')
13
+ self.parser.add_argument("--footskate_cleanup", action="store_true", help='Where use footskate cleanup in inference')
14
+
15
+ # inference
16
+ self.parser.add_argument("--num_inference_steps", type=int, default=10, help='Number of iterative denoising steps during inference.')
17
+ self.parser.add_argument("--which_ckpt", type=str, default='latest', help='name of checkpoint to load')
18
+ self.parser.add_argument("--diffuser_name", type=str, default='dpmsolver', help='sampler\'s scheduler class name in the diffuser library')
19
+ self.parser.add_argument("--no_ema", action="store_true", help='Where use EMA model in inference')
20
+ self.parser.add_argument("--no_fp16", action="store_true", help='Whether use FP16 in inference')
21
+ self.parser.add_argument('--batch_size', type=int, default=1, help='Batch size for generate')
22
+ self.parser.add_argument("--seed", default=0, type=int, help="For fixing random seed.")
23
+ self.parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
24
+
25
+ # generate prompts
26
+ self.parser.add_argument('--text_prompt', type=str, default="", help='One text description pompt for motion generation')
27
+ self.parser.add_argument("--motion_length", default=4.0, type=float, help="The length of the generated motion [in seconds] when using prompts. Maximum is 9.8 for HumanML3D (text-to-motion)")
28
+ self.parser.add_argument('--input_text', type=str, default='', help='File path of texts when using multiple texts.')
29
+ self.parser.add_argument('--input_lens', type=str, default='', help='File path of expected motion frame lengths when using multitext.')
30
+ self.parser.add_argument("--num_samples", type=int, default=10, help='Number of samples for generate when using dataset.')
31
+ self.parser.add_argument('--debug', action="store_true", help='debug mode')
32
+ self.parser.add_argument('--self_attention', action="store_true", help='self_attention use or not')
33
+ self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient linear attention')
34
+ self.parser.add_argument('--vis_attn', action='store_true', help='vis attention value or not')
35
+ self.parser.add_argument('--edit_mode', action='store_true', help='editing mode')
36
+
37
+
38
+ def parse(self):
39
+ self.opt = self.parser.parse_args()
40
+ opt_path = self.opt.opt_path
41
+ get_opt(self.opt, opt_path)
42
+ return self.opt
43
+
44
+ def parse_app(self):
45
+ self.opt = self.parser.parse_args(
46
+ args=['--motion_length', '8', '--self_attention', '--no_eff', '--opt_path', './checkpoints/t2m/release/opt.txt', '--edit_mode']
47
+ )
48
+ opt_path = self.opt.opt_path
49
+ get_opt(self.opt, opt_path)
50
+ return self.opt
options/get_opt.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from argparse import Namespace
3
+ import re
4
+ from os.path import join as pjoin
5
+
6
+
7
+ def is_float(numStr):
8
+ flag = False
9
+ numStr = str(numStr).strip().lstrip("-").lstrip("+")
10
+ try:
11
+ reg = re.compile(r"^[-+]?[0-9]+\.[0-9]+$")
12
+ res = reg.match(str(numStr))
13
+ if res:
14
+ flag = True
15
+ except Exception as ex:
16
+ print("is_float() - error: " + str(ex))
17
+ return flag
18
+
19
+
20
+ def is_number(numStr):
21
+ flag = False
22
+ numStr = str(numStr).strip().lstrip("-").lstrip("+")
23
+ if str(numStr).isdigit():
24
+ flag = True
25
+ return flag
26
+
27
+
28
+ def get_opt(opt, opt_path):
29
+ opt_dict = vars(opt)
30
+
31
+ skip = (
32
+ "-------------- End ----------------",
33
+ "------------ Options -------------",
34
+ "\n",
35
+ )
36
+ print("Reading", opt_path)
37
+ with open(opt_path) as f:
38
+ for line in f:
39
+ if line.strip() not in skip:
40
+ print(line.strip())
41
+ key, value = line.strip().split(": ")
42
+ if getattr(opt, key, None) is not None:
43
+ continue
44
+ if value in ("True", "False"):
45
+ opt_dict[key] = True if value == "True" else False
46
+ elif is_float(value):
47
+ opt_dict[key] = float(value)
48
+ elif is_number(value):
49
+ opt_dict[key] = int(value)
50
+ elif "," in value:
51
+ value = value[1:-1].split(",")
52
+ opt_dict[key] = [int(i) for i in value]
53
+ else:
54
+ opt_dict[key] = str(value)
55
+
56
+ # opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
57
+ opt.save_root = os.path.dirname(opt_path)
58
+ opt.model_dir = pjoin(opt.save_root, "model")
59
+ opt.meta_dir = pjoin(opt.save_root, "meta")
60
+
61
+ if opt.dataset_name == "t2m" or opt.dataset_name == "humanml":
62
+ opt.joints_num = 22
63
+ opt.dim_pose = 263
64
+ opt.max_motion_length = 196
65
+ opt.radius = 4
66
+ opt.fps = 20
67
+ elif opt.dataset_name == "kit":
68
+ opt.joints_num = 21
69
+ opt.dim_pose = 251
70
+ opt.max_motion_length = 196
71
+ opt.radius = 240 * 8
72
+ opt.fps = 12.5
73
+ else:
74
+ raise KeyError("Dataset not recognized")
options/noedit.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # noedit.yaml
2
+ reweighting_attn:
3
+ use: False
4
+
5
+ erasing_motion:
6
+ use: False
7
+
8
+ manipulation:
9
+ use: False
10
+
11
+ time_shift:
12
+ use: False
13
+
14
+ example_based:
15
+ use: False
16
+
17
+ style_tranfer:
18
+ use: False
19
+
20
+ grounded_generation:
21
+ use: False
options/train_options.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from .get_opt import get_opt
3
+ from os.path import join as pjoin
4
+ import os
5
+
6
+ class TrainOptions():
7
+ def __init__(self):
8
+ self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
9
+ self.initialized = False
10
+
11
+ def initialize(self):
12
+ # base set
13
+ self.parser.add_argument('--name', type=str, default="test", help='Name of this trial')
14
+ self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name')
15
+ self.parser.add_argument('--feat_bias', type=float, default=5, help='Scales for global motion features and foot contact')
16
+ self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here')
17
+ self.parser.add_argument('--log_every', type=int, default=5, help='Frequency of printing training progress (by iteration)')
18
+ self.parser.add_argument('--save_interval', type=int, default=10_000, help='Frequency of evaluateing and saving models (by iteration)')
19
+
20
+
21
+ # network hyperparams
22
+ self.parser.add_argument('--num_layers', type=int, default=8, help='num_layers of transformer')
23
+ self.parser.add_argument('--latent_dim', type=int, default=512, help='latent_dim of transformer')
24
+ self.parser.add_argument('--text_latent_dim', type=int, default=256, help='latent_dim of text embeding')
25
+ self.parser.add_argument('--time_dim', type=int, default=512, help='latent_dim of timesteps')
26
+ self.parser.add_argument('--base_dim', type=int, default=512, help='Dimension of Unet base channel')
27
+ self.parser.add_argument('--dim_mults', type=int, default=[2,2,2,2], nargs='+', help='Unet channel multipliers.')
28
+ self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient linear attention')
29
+ self.parser.add_argument('--no_adagn', action='store_true', help='whether use adagn block')
30
+ self.parser.add_argument('--diffusion_steps', type=int, default=1000, help='diffusion_steps of transformer')
31
+ self.parser.add_argument('--prediction_type', type=str, default='sample', help='diffusion_steps of transformer')
32
+
33
+ # train hyperparams
34
+ self.parser.add_argument('--seed', type=int, default=0, help='seed for train')
35
+ self.parser.add_argument('--num_train_steps', type=int, default=50_000, help='Number of training iterations')
36
+ self.parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate')
37
+ self.parser.add_argument("--decay_rate", default=0.9, type=float, help="the decay rate of lr (0-1 default 0.9)")
38
+ self.parser.add_argument("--update_lr_steps", default=5_000, type=int, help="")
39
+ self.parser.add_argument("--cond_mask_prob", default=0.1, type=float,
40
+ help="The probability of masking the condition during training."
41
+ " For classifier-free guidance learning.")
42
+ self.parser.add_argument('--clip_grad_norm', type=float, default=1, help='Gradient clip')
43
+ self.parser.add_argument('--weight_decay', type=float, default=1e-2, help='Learning rate weight_decay')
44
+ self.parser.add_argument('--batch_size', type=int, default=64, help='Batch size per GPU')
45
+ self.parser.add_argument("--beta_schedule", default='linear', type=str, help="Types of beta in diffusion (e.g. linear, cosine)")
46
+ self.parser.add_argument('--dropout', type=float, default=0.1, help='dropout')
47
+
48
+ # continue training
49
+ self.parser.add_argument('--is_continue', action="store_true", help='Is this trail continued from previous trail?')
50
+ self.parser.add_argument('--continue_ckpt', type=str, default="latest.tar", help='previous trail to continue')
51
+ self.parser.add_argument("--opt_path", type=str, default='',help='option file path for loading model')
52
+ self.parser.add_argument('--debug', action="store_true", help='debug mode')
53
+ self.parser.add_argument('--self_attention', action="store_true", help='self_attention use or not')
54
+ self.parser.add_argument('--vis_attn', action='store_true', help='vis attention value or not')
55
+
56
+ self.parser.add_argument('--edit_mode', action='store_true', help='editing mode')
57
+
58
+ # EMA params
59
+ self.parser.add_argument(
60
+ "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters"
61
+ )
62
+ self.parser.add_argument(
63
+ "--model-ema-steps",
64
+ type=int,
65
+ default=32,
66
+ help="the number of iterations that controls how often to update the EMA model (default: 32)",
67
+ )
68
+ self.parser.add_argument(
69
+ "--model-ema-decay",
70
+ type=float,
71
+ default=0.9999,
72
+ help="decay factor for Exponential Moving Average of model parameters (default: 0.99988)",
73
+ )
74
+
75
+ self.initialized = True
76
+
77
+ def parse(self,accelerator):
78
+ if not self.initialized:
79
+ self.initialize()
80
+
81
+ self.opt = self.parser.parse_args()
82
+
83
+ if self.opt.is_continue:
84
+ assert self.opt.opt_path.endswith('.txt')
85
+ get_opt(self.opt, self.opt.opt_path)
86
+ self.opt.is_train = True
87
+ self.opt.is_continue=True
88
+ elif accelerator.is_main_process:
89
+ args = vars(self.opt)
90
+ accelerator.print('------------ Options -------------')
91
+ for k, v in sorted(args.items()):
92
+ accelerator.print('%s: %s' % (str(k), str(v)))
93
+ accelerator.print('-------------- End ----------------')
94
+ # save to the disk
95
+ expr_dir = pjoin(self.opt.checkpoints_dir, self.opt.dataset_name, self.opt.name)
96
+ os.makedirs(expr_dir,exist_ok=True)
97
+ file_name = pjoin(expr_dir, 'opt.txt')
98
+ with open(file_name, 'wt') as opt_file:
99
+ opt_file.write('------------ Options -------------\n')
100
+ for k, v in sorted(args.items()):
101
+ if k =='opt_path':
102
+ continue
103
+ opt_file.write('%s: %s\n' % (str(k), str(v)))
104
+ opt_file.write('-------------- End ----------------\n')
105
+
106
+
107
+ if self.opt.dataset_name == 't2m' or self.opt.dataset_name == 'humanml':
108
+ self.opt.joints_num = 22
109
+ self.opt.dim_pose = 263
110
+ self.opt.max_motion_length = 196
111
+ self.opt.radius = 4
112
+ self.opt.fps = 20
113
+ elif self.opt.dataset_name == 'kit':
114
+ self.opt.joints_num = 21
115
+ self.opt.dim_pose = 251
116
+ self.opt.max_motion_length = 196
117
+ self.opt.radius = 240 * 8
118
+ self.opt.fps = 12.5
119
+ else:
120
+ raise KeyError('Dataset not recognized')
121
+
122
+ self.opt.device = accelerator.device
123
+ self.opt.is_train = True
124
+ return self.opt
125
+
126
+
prepare/download_glove.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd ./data/
2
+
3
+ echo -e "Downloading glove (in use by the evaluators)"
4
+ gdown --fuzzy https://drive.google.com/file/d/1cmXKUT31pqd7_XpJAiWEo1K81TMYHA5n/view?usp=sharing
5
+ rm -rf glove
6
+
7
+ unzip glove.zip
8
+ echo -e "Cleaning\n"
9
+ rm glove.zip
10
+ cd ..
11
+
12
+ echo -e "Downloading done!"
prepare/download_t2m_evaluators.sh ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ mkdir -p data/
2
+ cd data/
3
+ mkdir -p checkpoints/
4
+ cd checkpoints/
5
+
6
+ echo "The t2m evaluators will be stored in the './deps' folder"
7
+
8
+ echo "Downloading"
9
+ gdown --fuzzy https://drive.google.com/file/d/16hyR4XlEyksVyNVjhIWK684Lrm_7_pvX/view?usp=sharing
10
+ echo "Extracting"
11
+ unzip t2m.zip
12
+ echo "Cleaning"
13
+ rm t2m.zip
14
+
15
+ cd ../..
16
+
17
+ echo "Downloading done!"
prepare/prepare_clip.sh ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ mkdir -p deps/
2
+ cd deps/
3
+ git lfs install
4
+ git clone https://huggingface.co/openai/clip-vit-large-patch14
5
+ cd ..
requirements.txt ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ tqdm
2
+ opencv-python
3
+ scipy
4
+ matplotlib==3.3.1
5
+ spacy
6
+ accelerate
7
+ transformers
8
+ einops
9
+ diffusers
10
+ panda3d
11
+ numpy==1.23.0
12
+ git+https://github.com/openai/CLIP.git
13
+ diffusers==0.30.3
14
+ transformers==4.45.2
15
+
16
+ # for train
17
+ tensorboard
18
+ accelerate==1.0.1
19
+ smplx
20
+ python-box
scripts/__init__.py ADDED
File without changes
scripts/evaluation.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import torch
3
+ from motion_loader import get_dataset_loader, get_motion_loader
4
+ from datasets import get_dataset
5
+ from models import build_models
6
+ from eval import EvaluatorModelWrapper, evaluation
7
+ from utils.utils import *
8
+ from utils.model_load import load_model_weights
9
+ import os
10
+ from os.path import join as pjoin
11
+
12
+ from models.gaussian_diffusion import DiffusePipeline
13
+ from accelerate.utils import set_seed
14
+
15
+ from options.evaluate_options import TestOptions
16
+
17
+ import yaml
18
+ from box import Box
19
+
20
+
21
+ def yaml_to_box(yaml_file):
22
+ with open(yaml_file, "r") as file:
23
+ yaml_data = yaml.safe_load(file)
24
+
25
+ return Box(yaml_data)
26
+
27
+
28
+ if __name__ == "__main__":
29
+ parser = TestOptions()
30
+ opt = parser.parse()
31
+ set_seed(0)
32
+
33
+ if opt.edit_mode:
34
+ edit_config = yaml_to_box("options/edit.yaml")
35
+ else:
36
+ edit_config = yaml_to_box("options/noedit.yaml")
37
+
38
+ device_id = opt.gpu_id
39
+ device = torch.device("cuda:%d" % device_id if torch.cuda.is_available() else "cpu")
40
+ torch.cuda.set_device(device)
41
+ opt.device = device
42
+
43
+ # load evaluator
44
+ eval_wrapper = EvaluatorModelWrapper(opt)
45
+
46
+ # load dataset
47
+ gt_loader = get_dataset_loader(opt, opt.batch_size, mode="gt_eval", split="test")
48
+ gen_dataset = get_dataset(opt, mode="eval", split="test")
49
+
50
+ # load model
51
+ model = build_models(opt, edit_config=edit_config)
52
+ ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + ".tar")
53
+ load_model_weights(model, ckpt_path, use_ema=not opt.no_ema, device=device)
54
+
55
+ # Create a pipeline for generation in diffusion model framework
56
+ pipeline = DiffusePipeline(
57
+ opt=opt,
58
+ model=model,
59
+ diffuser_name=opt.diffuser_name,
60
+ device=device,
61
+ num_inference_steps=opt.num_inference_steps,
62
+ torch_dtype=torch.float32 if opt.no_fp16 else torch.float16,
63
+ )
64
+
65
+ eval_motion_loaders = {
66
+ "text2motion": lambda: get_motion_loader(
67
+ opt,
68
+ opt.batch_size,
69
+ pipeline,
70
+ gen_dataset,
71
+ opt.mm_num_samples,
72
+ opt.mm_num_repeats,
73
+ )
74
+ }
75
+
76
+ save_dir = pjoin(opt.save_root, "eval")
77
+ os.makedirs(save_dir, exist_ok=True)
78
+ if opt.no_ema:
79
+ log_file = (
80
+ pjoin(save_dir, opt.diffuser_name)
81
+ + f"_{str(opt.num_inference_steps)}setps.log"
82
+ )
83
+ else:
84
+ log_file = (
85
+ pjoin(save_dir, opt.diffuser_name)
86
+ + f"_{str(opt.num_inference_steps)}steps_ema.log"
87
+ )
88
+
89
+ if not os.path.exists(log_file):
90
+ config_dict = dict(pipeline.scheduler.config)
91
+ config_dict["no_ema"] = opt.no_ema
92
+ with open(log_file, "wt") as f:
93
+ f.write("------------ Options -------------\n")
94
+ for k, v in sorted(config_dict.items()):
95
+ f.write("%s: %s\n" % (str(k), str(v)))
96
+ f.write("-------------- End ----------------\n")
97
+
98
+ all_metrics = evaluation(
99
+ eval_wrapper,
100
+ gt_loader,
101
+ eval_motion_loaders,
102
+ log_file,
103
+ opt.replication_times,
104
+ opt.diversity_times,
105
+ opt.mm_num_times,
106
+ run_mm=True,
107
+ )
scripts/generate.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ import torch
4
+ import numpy as np
5
+ from os.path import join as pjoin
6
+ import utils.paramUtil as paramUtil
7
+ from utils.plot_script import *
8
+
9
+ from utils.utils import *
10
+ from utils.motion_process import recover_from_ric
11
+ from accelerate.utils import set_seed
12
+ from models.gaussian_diffusion import DiffusePipeline
13
+ from options.generate_options import GenerateOptions
14
+ from utils.model_load import load_model_weights
15
+ from motion_loader import get_dataset_loader
16
+ from models import build_models
17
+ import yaml
18
+ from box import Box
19
+
20
+
21
+ def yaml_to_box(yaml_file):
22
+ with open(yaml_file, "r") as file:
23
+ yaml_data = yaml.safe_load(file)
24
+ return Box(yaml_data)
25
+
26
+
27
+ if __name__ == "__main__":
28
+ parser = GenerateOptions()
29
+ opt = parser.parse()
30
+ set_seed(opt.seed)
31
+ device_id = opt.gpu_id
32
+ device = torch.device("cuda:%d" % device_id if torch.cuda.is_available() else "cpu")
33
+ opt.device = device
34
+
35
+ assert opt.dataset_name == "t2m" or "kit"
36
+
37
+ # Using a text prompt for generation
38
+ if opt.text_prompt != "":
39
+ texts = [opt.text_prompt]
40
+ opt.num_samples = 1
41
+ motion_lens = [opt.motion_length * opt.fps]
42
+
43
+ # Or using texts (in .txt file) for generation
44
+ elif opt.input_text != "":
45
+ with open(opt.input_text, "r") as fr:
46
+ texts = [line.strip() for line in fr.readlines()]
47
+ opt.num_samples = len(texts)
48
+ if opt.input_lens != "":
49
+ with open(opt.input_lens, "r") as fr:
50
+ motion_lens = [int(line.strip()) for line in fr.readlines()]
51
+ assert len(texts) == len(
52
+ motion_lens
53
+ ), f"Please ensure that the motion length in {opt.input_lens} corresponds to the text in {opt.input_text}."
54
+ else:
55
+ motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)]
56
+
57
+ # Or usining texts in dataset
58
+ else:
59
+ gen_datasetloader = get_dataset_loader(
60
+ opt, opt.num_samples, mode="hml_gt", split="test"
61
+ )
62
+ texts, _, motion_lens = next(iter(gen_datasetloader))
63
+
64
+ # edit mode
65
+ if opt.edit_mode:
66
+ edit_config = yaml_to_box("options/edit.yaml")
67
+ else:
68
+ edit_config = yaml_to_box("options/noedit.yaml")
69
+ print(edit_config)
70
+
71
+ ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + ".tar")
72
+ checkpoint = torch.load(ckpt_path,map_location={'cuda:0': str(device)})
73
+ niter = checkpoint.get('total_it', 0)
74
+ # make save dir
75
+ out_path = opt.output_dir
76
+ if out_path == "":
77
+ out_path = pjoin(opt.save_root, "samples_iter{}_seed{}".format(niter, opt.seed))
78
+ if opt.text_prompt != "":
79
+ out_path += "_" + opt.text_prompt.replace(" ", "_").replace(".", "")
80
+ elif opt.input_text != "":
81
+ out_path += "_" + os.path.basename(opt.input_text).replace(
82
+ ".txt", ""
83
+ ).replace(" ", "_").replace(".", "")
84
+ os.makedirs(out_path, exist_ok=True)
85
+
86
+ # load model
87
+ model = build_models(opt, edit_config=edit_config, out_path=out_path)
88
+ niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema)
89
+
90
+ # Create a pipeline for generation in diffusion model framework
91
+ pipeline = DiffusePipeline(
92
+ opt=opt,
93
+ model=model,
94
+ diffuser_name=opt.diffuser_name,
95
+ device=device,
96
+ num_inference_steps=opt.num_inference_steps,
97
+ torch_dtype=torch.float16,
98
+ )
99
+
100
+ # generate
101
+ pred_motions, _ = pipeline.generate(
102
+ texts, torch.LongTensor([int(x) for x in motion_lens])
103
+ )
104
+
105
+ # Convert the generated motion representaion into 3D joint coordinates and save as npy file
106
+ npy_dir = pjoin(out_path, "joints_npy")
107
+ root_dir = pjoin(out_path, "root_npy")
108
+ os.makedirs(npy_dir, exist_ok=True)
109
+ os.makedirs(root_dir, exist_ok=True)
110
+ print(f"saving results npy file (3d joints) to [{npy_dir}]")
111
+ mean = np.load(pjoin(opt.meta_dir, "mean.npy"))
112
+ std = np.load(pjoin(opt.meta_dir, "std.npy"))
113
+ samples = []
114
+
115
+ root_list = []
116
+ for i, motion in enumerate(pred_motions):
117
+ motion = motion.cpu().numpy() * std + mean
118
+ np.save(pjoin(npy_dir, f"raw_{i:02}.npy"), motion)
119
+ npy_name = f"{i:02}.npy"
120
+ # 1. recover 3d joints representation by ik
121
+ motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num)
122
+ # 2. put on Floor (Y axis)
123
+ floor_height = motion.min(dim=0)[0].min(dim=0)[0][1]
124
+ motion[:, :, 1] -= floor_height
125
+ motion = motion.numpy()
126
+ # 3. remove jitter
127
+ motion = motion_temporal_filter(motion, sigma=1)
128
+
129
+ # save root trajectory (Y axis)
130
+ root_trajectory = motion[:, 0, :]
131
+ root_list.append(root_trajectory)
132
+ np.save(pjoin(root_dir, f"root_{i:02}.npy"), root_trajectory)
133
+ y = root_trajectory[:, 1]
134
+
135
+ plt.figure()
136
+ plt.plot(y)
137
+
138
+ plt.legend()
139
+
140
+ plt.title("Root Joint Trajectory")
141
+ plt.xlabel("Frame")
142
+ plt.ylabel("Position")
143
+
144
+ plt.savefig("./root_trajectory_xyz.png")
145
+ np.save(pjoin(npy_dir, npy_name), motion)
146
+ samples.append(motion)
147
+
148
+ root_list_res = np.concatenate(root_list, axis=0)
149
+ np.save("root_list.npy", root_list_res)
150
+
151
+ # save the text and length conditions used for this generation
152
+ with open(pjoin(out_path, "results.txt"), "w") as fw:
153
+ fw.write("\n".join(texts))
154
+ with open(pjoin(out_path, "results_lens.txt"), "w") as fw:
155
+ fw.write("\n".join([str(l) for l in motion_lens]))
156
+
157
+ # skeletal animation visualization
158
+ print(f"saving motion videos to [{out_path}]...")
159
+ for i, title in enumerate(texts):
160
+ motion = samples[i]
161
+ fname = f"{i:02}.mp4"
162
+ kinematic_tree = (
163
+ paramUtil.t2m_kinematic_chain
164
+ if (opt.dataset_name == "t2m")
165
+ else paramUtil.kit_kinematic_chain
166
+ )
167
+ plot_3d_motion(
168
+ pjoin(out_path, fname),
169
+ kinematic_tree,
170
+ motion,
171
+ title=title,
172
+ fps=opt.fps,
173
+ radius=opt.radius,
174
+ )
scripts/train.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ from os.path import join as pjoin
4
+ from options.train_options import TrainOptions
5
+ from utils.plot_script import *
6
+
7
+ from models import build_models
8
+ from utils.ema import ExponentialMovingAverage
9
+ from trainers import DDPMTrainer
10
+ from motion_loader import get_dataset_loader
11
+
12
+ from accelerate.utils import set_seed
13
+ from accelerate import Accelerator
14
+ import torch
15
+
16
+ import yaml
17
+ from box import Box
18
+
19
+ def yaml_to_box(yaml_file):
20
+ with open(yaml_file, 'r') as file:
21
+ yaml_data = yaml.safe_load(file)
22
+
23
+ return Box(yaml_data)
24
+
25
+ if __name__ == '__main__':
26
+ accelerator = Accelerator()
27
+
28
+ parser = TrainOptions()
29
+ opt = parser.parse(accelerator)
30
+ set_seed(opt.seed)
31
+ torch.autograd.set_detect_anomaly(True)
32
+
33
+ opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name)
34
+ opt.model_dir = pjoin(opt.save_root, 'model')
35
+ opt.meta_dir = pjoin(opt.save_root, 'meta')
36
+
37
+ if opt.edit_mode:
38
+ edit_config = yaml_to_box('options/edit.yaml')
39
+ else:
40
+ edit_config = yaml_to_box('options/noedit.yaml')
41
+
42
+ if accelerator.is_main_process:
43
+ os.makedirs(opt.model_dir, exist_ok=True)
44
+ os.makedirs(opt.meta_dir, exist_ok=True)
45
+
46
+ train_datasetloader = get_dataset_loader(opt, batch_size = opt.batch_size, split='train', accelerator=accelerator, mode='train') # 7169
47
+
48
+
49
+ accelerator.print('\nInitializing model ...' )
50
+ encoder = build_models(opt, edit_config=edit_config)
51
+ model_ema = None
52
+ if opt.model_ema:
53
+ # Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at:
54
+ # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123
55
+ adjust = 106_667 * opt.model_ema_steps / opt.num_train_steps
56
+ alpha = 1.0 - opt.model_ema_decay
57
+ alpha = min(1.0, alpha * adjust)
58
+ print('EMA alpha:',alpha)
59
+ model_ema = ExponentialMovingAverage(encoder, decay=1.0 - alpha)
60
+ accelerator.print('Finish building Model.\n')
61
+
62
+ trainer = DDPMTrainer(opt, encoder,accelerator, model_ema)
63
+
64
+ trainer.train(train_datasetloader)
65
+
66
+
trainers/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .ddpm_trainer import DDPMTrainer
2
+
3
+
4
+ __all__ = ['DDPMTrainer']
trainers/ddpm_trainer.py ADDED
@@ -0,0 +1,245 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import time
3
+ import torch.optim as optim
4
+ from collections import OrderedDict
5
+ from utils.utils import print_current_loss
6
+ from os.path import join as pjoin
7
+
8
+ from diffusers import DDPMScheduler
9
+ from torch.utils.tensorboard import SummaryWriter
10
+ import time
11
+ import pdb
12
+ import sys
13
+ import os
14
+ from torch.optim.lr_scheduler import ExponentialLR
15
+
16
+
17
+ class DDPMTrainer(object):
18
+
19
+ def __init__(self, args, model, accelerator, model_ema=None):
20
+ self.opt = args
21
+ self.accelerator = accelerator
22
+ self.device = self.accelerator.device
23
+ self.model = model
24
+ self.diffusion_steps = args.diffusion_steps
25
+ self.noise_scheduler = DDPMScheduler(
26
+ num_train_timesteps=self.diffusion_steps,
27
+ beta_schedule=args.beta_schedule,
28
+ variance_type="fixed_small",
29
+ prediction_type=args.prediction_type,
30
+ clip_sample=False,
31
+ )
32
+ self.model_ema = model_ema
33
+ if args.is_train:
34
+ self.mse_criterion = torch.nn.MSELoss(reduction="none")
35
+
36
+ accelerator.print("Diffusion_config:\n", self.noise_scheduler.config)
37
+
38
+ if self.accelerator.is_main_process:
39
+ starttime = time.strftime("%Y-%m-%d_%H:%M:%S")
40
+ print("Start experiment:", starttime)
41
+ self.writer = SummaryWriter(
42
+ log_dir=pjoin(args.save_root, "logs_") + starttime[:16],
43
+ comment=starttime[:16],
44
+ flush_secs=60,
45
+ )
46
+ self.accelerator.wait_for_everyone()
47
+
48
+ self.optimizer = optim.AdamW(
49
+ self.model.parameters(), lr=self.opt.lr, weight_decay=self.opt.weight_decay
50
+ )
51
+ self.scheduler = (
52
+ ExponentialLR(self.optimizer, gamma=args.decay_rate)
53
+ if args.decay_rate > 0
54
+ else None
55
+ )
56
+
57
+ @staticmethod
58
+ def zero_grad(opt_list):
59
+ for opt in opt_list:
60
+ opt.zero_grad()
61
+
62
+ def clip_norm(self, network_list):
63
+ for network in network_list:
64
+ self.accelerator.clip_grad_norm_(
65
+ network.parameters(), self.opt.clip_grad_norm
66
+ ) # 0.5 -> 1
67
+
68
+ @staticmethod
69
+ def step(opt_list):
70
+ for opt in opt_list:
71
+ opt.step()
72
+
73
+ def forward(self, batch_data):
74
+ caption, motions, m_lens = batch_data
75
+ motions = motions.detach().float()
76
+
77
+ x_start = motions
78
+ B, T = x_start.shape[:2]
79
+ cur_len = torch.LongTensor([min(T, m_len) for m_len in m_lens]).to(self.device)
80
+ self.src_mask = self.generate_src_mask(T, cur_len).to(x_start.device)
81
+
82
+ # 1. Sample noise that we'll add to the motion
83
+ real_noise = torch.randn_like(x_start)
84
+
85
+ # 2. Sample a random timestep for each motion
86
+ t = torch.randint(0, self.diffusion_steps, (B,), device=self.device)
87
+ self.timesteps = t
88
+
89
+ # 3. Add noise to the motion according to the noise magnitude at each timestep
90
+ # (this is the forward diffusion process)
91
+ x_t = self.noise_scheduler.add_noise(x_start, real_noise, t)
92
+
93
+ # 4. network prediction
94
+ self.prediction = self.model(x_t, t, text=caption)
95
+
96
+ if self.opt.prediction_type == "sample":
97
+ self.target = x_start
98
+ elif self.opt.prediction_type == "epsilon":
99
+ self.target = real_noise
100
+ elif self.opt.prediction_type == "v_prediction":
101
+ self.target = self.noise_scheduler.get_velocity(x_start, real_noise, t)
102
+
103
+ def masked_l2(self, a, b, mask, weights):
104
+
105
+ loss = self.mse_criterion(a, b).mean(dim=-1) # (bath_size, motion_length)
106
+
107
+ loss = (loss * mask).sum(-1) / mask.sum(-1) # (batch_size, )
108
+
109
+ loss = (loss * weights).mean()
110
+
111
+ return loss
112
+
113
+ def backward_G(self):
114
+ loss_logs = OrderedDict({})
115
+ mse_loss_weights = torch.ones_like(self.timesteps)
116
+ loss_logs["loss_mot_rec"] = self.masked_l2(
117
+ self.prediction, self.target, self.src_mask, mse_loss_weights
118
+ )
119
+
120
+ self.loss = loss_logs["loss_mot_rec"]
121
+
122
+ return loss_logs
123
+
124
+ def update(self):
125
+ self.zero_grad([self.optimizer])
126
+ loss_logs = self.backward_G()
127
+ self.accelerator.backward(self.loss)
128
+ self.clip_norm([self.model])
129
+ self.step([self.optimizer])
130
+
131
+ return loss_logs
132
+
133
+ def generate_src_mask(self, T, length):
134
+ B = len(length)
135
+ src_mask = torch.ones(B, T)
136
+ for i in range(B):
137
+ for j in range(length[i], T):
138
+ src_mask[i, j] = 0
139
+ return src_mask
140
+
141
+ def train_mode(self):
142
+ self.model.train()
143
+ if self.model_ema:
144
+ self.model_ema.train()
145
+
146
+ def eval_mode(self):
147
+ self.model.eval()
148
+ if self.model_ema:
149
+ self.model_ema.eval()
150
+
151
+ def save(self, file_name, total_it):
152
+ state = {
153
+ "opt_encoder": self.optimizer.state_dict(),
154
+ "total_it": total_it,
155
+ "encoder": self.accelerator.unwrap_model(self.model).state_dict(),
156
+ }
157
+ if self.model_ema:
158
+ state["model_ema"] = self.accelerator.unwrap_model(
159
+ self.model_ema
160
+ ).module.state_dict()
161
+ torch.save(state, file_name)
162
+ return
163
+
164
+ def load(self, model_dir):
165
+ checkpoint = torch.load(model_dir, map_location=self.device)
166
+ self.optimizer.load_state_dict(checkpoint["opt_encoder"])
167
+ if self.model_ema:
168
+ self.model_ema.load_state_dict(checkpoint["model_ema"], strict=True)
169
+ self.model.load_state_dict(checkpoint["encoder"], strict=True)
170
+
171
+ return checkpoint.get("total_it", 0)
172
+
173
+ def train(self, train_loader):
174
+
175
+ it = 0
176
+ if self.opt.is_continue:
177
+ model_path = pjoin(self.opt.model_dir, self.opt.continue_ckpt)
178
+ it = self.load(model_path)
179
+ self.accelerator.print(f"continue train from {it} iters in {model_path}")
180
+ start_time = time.time()
181
+
182
+ logs = OrderedDict()
183
+ self.dataset = train_loader.dataset
184
+ self.model, self.mse_criterion, self.optimizer, train_loader, self.model_ema = (
185
+ self.accelerator.prepare(
186
+ self.model,
187
+ self.mse_criterion,
188
+ self.optimizer,
189
+ train_loader,
190
+ self.model_ema,
191
+ )
192
+ )
193
+
194
+ num_epochs = (self.opt.num_train_steps - it) // len(train_loader) + 1
195
+ self.accelerator.print(f"need to train for {num_epochs} epochs....")
196
+
197
+ for epoch in range(0, num_epochs):
198
+ self.train_mode()
199
+ for i, batch_data in enumerate(train_loader):
200
+ self.forward(batch_data)
201
+ log_dict = self.update()
202
+ it += 1
203
+
204
+ if self.model_ema and it % self.opt.model_ema_steps == 0:
205
+ self.accelerator.unwrap_model(self.model_ema).update_parameters(
206
+ self.model
207
+ )
208
+
209
+ # update logger
210
+ for k, v in log_dict.items():
211
+ if k not in logs:
212
+ logs[k] = v
213
+ else:
214
+ logs[k] += v
215
+
216
+ if it % self.opt.log_every == 0:
217
+ mean_loss = OrderedDict({})
218
+ for tag, value in logs.items():
219
+ mean_loss[tag] = value / self.opt.log_every
220
+ logs = OrderedDict()
221
+ print_current_loss(
222
+ self.accelerator, start_time, it, mean_loss, epoch, inner_iter=i
223
+ )
224
+ if self.accelerator.is_main_process:
225
+ self.writer.add_scalar("loss", mean_loss["loss_mot_rec"], it)
226
+ self.accelerator.wait_for_everyone()
227
+
228
+ if (
229
+ it % self.opt.save_interval == 0
230
+ and self.accelerator.is_main_process
231
+ ): # Save model
232
+ self.save(pjoin(self.opt.model_dir, "latest.tar").format(it), it)
233
+ self.accelerator.wait_for_everyone()
234
+
235
+ if (self.scheduler is not None) and (
236
+ it % self.opt.update_lr_steps == 0
237
+ ):
238
+ self.scheduler.step()
239
+
240
+ # Save the last checkpoint if it wasn't already saved.
241
+ if it % self.opt.save_interval != 0 and self.accelerator.is_main_process:
242
+ self.save(pjoin(self.opt.model_dir, "latest.tar"), it)
243
+
244
+ self.accelerator.wait_for_everyone()
245
+ self.accelerator.print("FINISH")
utils/__init__.py ADDED
File without changes
utils/constants.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ SMPL_FOOT_R = [8, 11]
2
+ SMPL_FOOT_L = [7, 10]
3
+ SMPL_FACE_FORWARD_JOINTS = [2, 1, 17, 16]
4
+
5
+ # Define a kinematic tree for the skeletal struture
6
+ SMPL_BODY_CHAIN = [
7
+ [0, 2, 5, 8, 11],
8
+ [0, 1, 4, 7, 10],
9
+ [0, 3, 6, 9, 12, 15],
10
+ [9, 14, 17, 19, 21],
11
+ [9, 13, 16, 18, 20],
12
+ ]
13
+ SMPL_LEFT_HAND_CHAIN = [
14
+ [20, 22, 23, 24],
15
+ [20, 34, 35, 36],
16
+ [20, 25, 26, 27],
17
+ [20, 31, 32, 33],
18
+ [20, 28, 29, 30],
19
+ ]
20
+ SMPL_RIGHT_HAND_CHAIN = [
21
+ [21, 43, 44, 45],
22
+ [21, 46, 47, 48],
23
+ [21, 40, 41, 42],
24
+ [21, 37, 38, 39],
25
+ [21, 49, 50, 51],
26
+ ]
27
+
28
+ SMPL_BODY_BONES = [
29
+ -0.0018,
30
+ -0.2233,
31
+ 0.0282,
32
+ 0.0695,
33
+ -0.0914,
34
+ -0.0068,
35
+ -0.0677,
36
+ -0.0905,
37
+ -0.0043,
38
+ -0.0025,
39
+ 0.1090,
40
+ -0.0267,
41
+ 0.0343,
42
+ -0.3752,
43
+ -0.0045,
44
+ -0.0383,
45
+ -0.3826,
46
+ -0.0089,
47
+ 0.0055,
48
+ 0.1352,
49
+ 0.0011,
50
+ -0.0136,
51
+ -0.3980,
52
+ -0.0437,
53
+ 0.0158,
54
+ -0.3984,
55
+ -0.0423,
56
+ 0.0015,
57
+ 0.0529,
58
+ 0.0254,
59
+ 0.0264,
60
+ -0.0558,
61
+ 0.1193,
62
+ -0.0254,
63
+ -0.0481,
64
+ 0.1233,
65
+ -0.0028,
66
+ 0.2139,
67
+ -0.0429,
68
+ 0.0788,
69
+ 0.1217,
70
+ -0.0341,
71
+ -0.0818,
72
+ 0.1188,
73
+ -0.0386,
74
+ 0.0052,
75
+ 0.0650,
76
+ 0.0513,
77
+ 0.0910,
78
+ 0.0305,
79
+ -0.0089,
80
+ -0.0960,
81
+ 0.0326,
82
+ -0.0091,
83
+ 0.2596,
84
+ -0.0128,
85
+ -0.0275,
86
+ -0.2537,
87
+ -0.0133,
88
+ -0.0214,
89
+ 0.2492,
90
+ 0.0090,
91
+ -0.0012,
92
+ -0.2553,
93
+ 0.0078,
94
+ -0.0056,
95
+ 0.0840,
96
+ -0.0082,
97
+ -0.0149,
98
+ -0.0846,
99
+ -0.0061,
100
+ -0.0103,
101
+ ]
102
+
103
+ SMPL_HYBRIK = [
104
+ 0,
105
+ 0,
106
+ 0,
107
+ 0,
108
+ 0,
109
+ 0,
110
+ 0,
111
+ 0,
112
+ 0,
113
+ 0,
114
+ 0,
115
+ 0,
116
+ 1,
117
+ 1,
118
+ 1,
119
+ 1,
120
+ 0,
121
+ 0,
122
+ 0,
123
+ 0,
124
+ 0,
125
+ 0,
126
+ ]
127
+
128
+ SMPL_BODY_PARENTS = [
129
+ 0,
130
+ 0,
131
+ 0,
132
+ 0,
133
+ 1,
134
+ 2,
135
+ 3,
136
+ 4,
137
+ 5,
138
+ 6,
139
+ 7,
140
+ 8,
141
+ 9,
142
+ 9,
143
+ 9,
144
+ 12,
145
+ 13,
146
+ 14,
147
+ 16,
148
+ 17,
149
+ 18,
150
+ 19,
151
+ ]
152
+
153
+ SMPL_BODY_CHILDS = [
154
+ -1,
155
+ 4,
156
+ 5,
157
+ 6,
158
+ 7,
159
+ 8,
160
+ 9,
161
+ 10,
162
+ 11,
163
+ -1,
164
+ -2,
165
+ -2,
166
+ 15,
167
+ 16,
168
+ 17,
169
+ -2,
170
+ 18,
171
+ 19,
172
+ 20,
173
+ 21,
174
+ -2,
175
+ -2,
176
+ ]
utils/ema.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+ class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel):
4
+ """Maintains moving averages of model parameters using an exponential decay.
5
+ ``ema_avg = decay * avg_model_param + (1 - decay) * model_param``
6
+ `torch.optim.swa_utils.AveragedModel <https://pytorch.org/docs/stable/optim.html#custom-averaging-strategies>`_
7
+ is used to compute the EMA.
8
+ """
9
+
10
+ def __init__(self, model, decay, device="cpu"):
11
+ def ema_avg(avg_model_param, model_param, num_averaged):
12
+ return decay * avg_model_param + (1 - decay) * model_param
13
+ super().__init__(model, device, ema_avg)
utils/eval_humanml.py ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ from datetime import datetime
3
+ import numpy as np
4
+ import torch
5
+ from utils.metrics import *
6
+ from collections import OrderedDict
7
+ import os
8
+
9
+
10
+
11
+ def evaluate_matching_score(eval_wrapper,motion_loaders, file):
12
+ match_score_dict = OrderedDict({})
13
+ R_precision_dict = OrderedDict({})
14
+ activation_dict = OrderedDict({})
15
+ # print(motion_loaders.keys())
16
+ print('========== Evaluating Matching Score ==========')
17
+ for motion_loader_name, motion_loader in motion_loaders.items():
18
+ all_motion_embeddings = []
19
+ score_list = []
20
+ all_size = 0
21
+ matching_score_sum = 0
22
+ top_k_count = 0
23
+ # print(motion_loader_name)
24
+ with torch.no_grad():
25
+ for idx, batch in enumerate(motion_loader):
26
+ word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch
27
+ text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings(
28
+ word_embs=word_embeddings,
29
+ pos_ohot=pos_one_hots,
30
+ cap_lens=sent_lens,
31
+ motions=motions,
32
+ m_lens=m_lens
33
+ )
34
+ dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(),
35
+ motion_embeddings.cpu().numpy())
36
+ matching_score_sum += dist_mat.trace()
37
+ # import pdb;pdb.set_trace()
38
+
39
+ argsmax = np.argsort(dist_mat, axis=1)
40
+ top_k_mat = calculate_top_k(argsmax, top_k=3)
41
+ top_k_count += top_k_mat.sum(axis=0)
42
+
43
+ all_size += text_embeddings.shape[0]
44
+
45
+ all_motion_embeddings.append(motion_embeddings.cpu().numpy())
46
+
47
+ all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0)
48
+ # import pdb;pdb.set_trace()
49
+ matching_score = matching_score_sum / all_size
50
+ R_precision = top_k_count / all_size
51
+ match_score_dict[motion_loader_name] = matching_score
52
+ R_precision_dict[motion_loader_name] = R_precision
53
+ activation_dict[motion_loader_name] = all_motion_embeddings
54
+
55
+ print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}')
56
+ print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True)
57
+
58
+ line = f'---> [{motion_loader_name}] R_precision: '
59
+ for i in range(len(R_precision)):
60
+ line += '(top %d): %.4f ' % (i+1, R_precision[i])
61
+ print(line)
62
+ print(line, file=file, flush=True)
63
+
64
+ return match_score_dict, R_precision_dict, activation_dict
65
+
66
+
67
+ def evaluate_fid(eval_wrapper,groundtruth_loader, activation_dict, file):
68
+ eval_dict = OrderedDict({})
69
+ gt_motion_embeddings = []
70
+ print('========== Evaluating FID ==========')
71
+ with torch.no_grad():
72
+ for idx, batch in enumerate(groundtruth_loader):
73
+ _, _, _, sent_lens, motions, m_lens, _ = batch
74
+ motion_embeddings = eval_wrapper.get_motion_embeddings(
75
+ motions=motions,
76
+ m_lens=m_lens
77
+ )
78
+ gt_motion_embeddings.append(motion_embeddings.cpu().numpy())
79
+ gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0)
80
+ gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings)
81
+
82
+ for model_name, motion_embeddings in activation_dict.items():
83
+ mu, cov = calculate_activation_statistics(motion_embeddings)
84
+ # print(mu)
85
+ fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov)
86
+ print(f'---> [{model_name}] FID: {fid:.4f}')
87
+ print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True)
88
+ eval_dict[model_name] = fid
89
+ return eval_dict
90
+
91
+
92
+ def evaluate_diversity(activation_dict, file, diversity_times):
93
+ eval_dict = OrderedDict({})
94
+ print('========== Evaluating Diversity ==========')
95
+ for model_name, motion_embeddings in activation_dict.items():
96
+ diversity = calculate_diversity(motion_embeddings, diversity_times)
97
+ eval_dict[model_name] = diversity
98
+ print(f'---> [{model_name}] Diversity: {diversity:.4f}')
99
+ print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True)
100
+ return eval_dict
101
+
102
+
103
+ def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times):
104
+ eval_dict = OrderedDict({})
105
+ print('========== Evaluating MultiModality ==========')
106
+ for model_name, mm_motion_loader in mm_motion_loaders.items():
107
+ mm_motion_embeddings = []
108
+ with torch.no_grad():
109
+ for idx, batch in enumerate(mm_motion_loader):
110
+ # (1, mm_replications, dim_pos)
111
+ motions, m_lens = batch
112
+ motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0])
113
+ mm_motion_embeddings.append(motion_embedings.unsqueeze(0))
114
+ if len(mm_motion_embeddings) == 0:
115
+ multimodality = 0
116
+ else:
117
+ mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy()
118
+ multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times)
119
+ print(f'---> [{model_name}] Multimodality: {multimodality:.4f}')
120
+ print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True)
121
+ eval_dict[model_name] = multimodality
122
+ return eval_dict
123
+
124
+
125
+ def get_metric_statistics(values, replication_times):
126
+ mean = np.mean(values, axis=0)
127
+ std = np.std(values, axis=0)
128
+ conf_interval = 1.96 * std / np.sqrt(replication_times)
129
+ return mean, conf_interval
130
+
131
+
132
+ def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=False):
133
+ with open(log_file, 'a') as f:
134
+ all_metrics = OrderedDict({'Matching Score': OrderedDict({}),
135
+ 'R_precision': OrderedDict({}),
136
+ 'FID': OrderedDict({}),
137
+ 'Diversity': OrderedDict({}),
138
+ 'MultiModality': OrderedDict({})})
139
+
140
+ for replication in range(replication_times):
141
+ print(f'Time: {datetime.now()}')
142
+ print(f'Time: {datetime.now()}', file=f, flush=True)
143
+ motion_loaders = {}
144
+ motion_loaders['ground truth'] = gt_loader
145
+ mm_motion_loaders = {}
146
+ # motion_loaders['ground truth'] = gt_loader
147
+ for motion_loader_name, motion_loader_getter in eval_motion_loaders.items():
148
+ motion_loader, mm_motion_loader,eval_generate_time = motion_loader_getter()
149
+ print(f'---> [{motion_loader_name}] batch_generate_time: {eval_generate_time}s', file=f, flush=True)
150
+ motion_loaders[motion_loader_name] = motion_loader
151
+ mm_motion_loaders[motion_loader_name] = mm_motion_loader
152
+
153
+ if replication_times>1:
154
+ print(f'==================== Replication {replication} ====================')
155
+ print(f'==================== Replication {replication} ====================', file=f, flush=True)
156
+
157
+
158
+ mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f)
159
+
160
+ fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f)
161
+
162
+ div_score_dict = evaluate_diversity(acti_dict, f, diversity_times)
163
+
164
+ if run_mm:
165
+ mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times)
166
+
167
+ print(f'!!! DONE !!!')
168
+ print(f'!!! DONE !!!', file=f, flush=True)
169
+
170
+ for key, item in mat_score_dict.items():
171
+ if key not in all_metrics['Matching Score']:
172
+ all_metrics['Matching Score'][key] = [item]
173
+ else:
174
+ all_metrics['Matching Score'][key] += [item]
175
+
176
+ for key, item in R_precision_dict.items():
177
+ if key not in all_metrics['R_precision']:
178
+ all_metrics['R_precision'][key] = [item]
179
+ else:
180
+ all_metrics['R_precision'][key] += [item]
181
+
182
+ for key, item in fid_score_dict.items():
183
+ if key not in all_metrics['FID']:
184
+ all_metrics['FID'][key] = [item]
185
+ else:
186
+ all_metrics['FID'][key] += [item]
187
+
188
+ for key, item in div_score_dict.items():
189
+ if key not in all_metrics['Diversity']:
190
+ all_metrics['Diversity'][key] = [item]
191
+ else:
192
+ all_metrics['Diversity'][key] += [item]
193
+
194
+ for key, item in mm_score_dict.items():
195
+ if key not in all_metrics['MultiModality']:
196
+ all_metrics['MultiModality'][key] = [item]
197
+ else:
198
+ all_metrics['MultiModality'][key] += [item]
199
+
200
+
201
+ mean_dict = {}
202
+ if replication_times>1:
203
+ for metric_name, metric_dict in all_metrics.items():
204
+ print('========== %s Summary ==========' % metric_name)
205
+ print('========== %s Summary ==========' % metric_name, file=f, flush=True)
206
+
207
+ for model_name, values in metric_dict.items():
208
+ # print(metric_name, model_name)
209
+ mean, conf_interval = get_metric_statistics(np.array(values),replication_times)
210
+ mean_dict[metric_name + '_' + model_name] = mean
211
+ # print(mean, mean.dtype)
212
+ if isinstance(mean, np.float64) or isinstance(mean, np.float32):
213
+ print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
214
+ print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
215
+ elif isinstance(mean, np.ndarray):
216
+ line = f'---> [{model_name}]'
217
+ for i in range(len(mean)):
218
+ line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
219
+ print(line)
220
+ print(line, file=f, flush=True)
221
+ return mean_dict
222
+ else:
223
+ return all_metrics
224
+
225
+
226
+ def distributed_evaluation(eval_wrapper, gt_loader, eval_motion_loader, log_file, replication_times, diversity_times):
227
+ with open(log_file, 'a') as f:
228
+ all_metrics = OrderedDict({'Matching Score': OrderedDict({}),
229
+ 'R_precision': OrderedDict({}),
230
+ 'FID': OrderedDict({}),
231
+ 'Diversity': OrderedDict({}),
232
+ 'MultiModality': OrderedDict({})})
233
+
234
+ for replication in range(replication_times):
235
+ print(f'Time: {datetime.now()}')
236
+ print(f'Time: {datetime.now()}', file=f, flush=True)
237
+ motion_loaders = {'test':eval_motion_loader}
238
+
239
+ if replication_times>1:
240
+ print(f'==================== Replication {replication} ====================')
241
+ print(f'==================== Replication {replication} ====================', file=f, flush=True)
242
+
243
+
244
+ mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f)
245
+
246
+ fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f)
247
+
248
+ div_score_dict = evaluate_diversity(acti_dict, f, diversity_times)
249
+
250
+
251
+ print(f'!!! DONE !!!')
252
+ print(f'!!! DONE !!!', file=f, flush=True)
253
+
254
+ for key, item in mat_score_dict.items():
255
+ if key not in all_metrics['Matching Score']:
256
+ all_metrics['Matching Score'][key] = [item]
257
+ else:
258
+ all_metrics['Matching Score'][key] += [item]
259
+
260
+ for key, item in R_precision_dict.items():
261
+ if key not in all_metrics['R_precision']:
262
+ all_metrics['R_precision'][key] = [item]
263
+ else:
264
+ all_metrics['R_precision'][key] += [item]
265
+
266
+ for key, item in fid_score_dict.items():
267
+ if key not in all_metrics['FID']:
268
+ all_metrics['FID'][key] = [item]
269
+ else:
270
+ all_metrics['FID'][key] += [item]
271
+
272
+ for key, item in div_score_dict.items():
273
+ if key not in all_metrics['Diversity']:
274
+ all_metrics['Diversity'][key] = [item]
275
+ else:
276
+ all_metrics['Diversity'][key] += [item]
277
+
278
+ mean_dict = {}
279
+ for metric_name, metric_dict in all_metrics.items():
280
+ print('========== %s Summary ==========' % metric_name)
281
+ print('========== %s Summary ==========' % metric_name, file=f, flush=True)
282
+
283
+ for model_name, values in metric_dict.items():
284
+ # print(metric_name, model_name)
285
+ mean, conf_interval = get_metric_statistics(np.array(values),replication_times)
286
+ mean_dict[metric_name + '_' + model_name] = mean
287
+ # print(mean, mean.dtype)
288
+ if isinstance(mean, np.float64) or isinstance(mean, np.float32):
289
+ print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}')
290
+ print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True)
291
+ elif isinstance(mean, np.ndarray):
292
+ line = f'---> [{model_name}]'
293
+ for i in range(len(mean)):
294
+ line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i])
295
+ print(line)
296
+ print(line, file=f, flush=True)
297
+ return mean_dict
298
+
utils/kinematics.py ADDED
@@ -0,0 +1,350 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+
3
+ import numpy as np
4
+
5
+ import utils.constants as constants
6
+ import torch
7
+
8
+ class HybrIKJointsToRotmat:
9
+ def __init__(self):
10
+ self.naive_hybrik = constants.SMPL_HYBRIK
11
+ self.num_nodes = 22
12
+ self.parents = constants.SMPL_BODY_PARENTS
13
+ self.child = constants.SMPL_BODY_CHILDS
14
+ self.bones = np.array(constants.SMPL_BODY_BONES).reshape(24, 3)[
15
+ : self.num_nodes
16
+ ]
17
+
18
+ def multi_child_rot(
19
+ self, t: np.ndarray, p: np.ndarray, pose_global_parent: np.ndarray
20
+ ) -> Tuple[np.ndarray]:
21
+ """
22
+ t: B x 3 x child_num
23
+ p: B x 3 x child_num
24
+ pose_global_parent: B x 3 x 3
25
+ """
26
+ m = np.matmul(
27
+ t, np.transpose(np.matmul(np.linalg.inv(pose_global_parent), p), [0, 2, 1])
28
+ )
29
+ u, s, vt = np.linalg.svd(m)
30
+ r = np.matmul(np.transpose(vt, [0, 2, 1]), np.transpose(u, [0, 2, 1]))
31
+ err_det_mask = (np.linalg.det(r) < 0.0).reshape(-1, 1, 1)
32
+ id_fix = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]).reshape(
33
+ 1, 3, 3
34
+ )
35
+ r_fix = np.matmul(
36
+ np.transpose(vt, [0, 2, 1]), np.matmul(id_fix, np.transpose(u, [0, 2, 1]))
37
+ )
38
+ r = r * (1.0 - err_det_mask) + r_fix * err_det_mask
39
+ return r, np.matmul(pose_global_parent, r)
40
+
41
+ def single_child_rot(
42
+ self,
43
+ t: np.ndarray,
44
+ p: np.ndarray,
45
+ pose_global_parent: np.ndarray,
46
+ twist: np.ndarray = None,
47
+ ) -> Tuple[np.ndarray]:
48
+ """
49
+ t: B x 3 x 1
50
+ p: B x 3 x 1
51
+ pose_global_parent: B x 3 x 3
52
+ twist: B x 2 if given, default to None
53
+ """
54
+ p_rot = np.matmul(np.linalg.inv(pose_global_parent), p)
55
+ cross = np.cross(t, p_rot, axisa=1, axisb=1, axisc=1)
56
+ sina = np.linalg.norm(cross, axis=1, keepdims=True) / (
57
+ np.linalg.norm(t, axis=1, keepdims=True)
58
+ * np.linalg.norm(p_rot, axis=1, keepdims=True)
59
+ )
60
+ cross = cross / np.linalg.norm(cross, axis=1, keepdims=True)
61
+ cosa = np.sum(t * p_rot, axis=1, keepdims=True) / (
62
+ np.linalg.norm(t, axis=1, keepdims=True)
63
+ * np.linalg.norm(p_rot, axis=1, keepdims=True)
64
+ )
65
+ sina = sina.reshape(-1, 1, 1)
66
+ cosa = cosa.reshape(-1, 1, 1)
67
+ skew_sym_t = np.stack(
68
+ [
69
+ 0.0 * cross[:, 0],
70
+ -cross[:, 2],
71
+ cross[:, 1],
72
+ cross[:, 2],
73
+ 0.0 * cross[:, 0],
74
+ -cross[:, 0],
75
+ -cross[:, 1],
76
+ cross[:, 0],
77
+ 0.0 * cross[:, 0],
78
+ ],
79
+ 1,
80
+ )
81
+ skew_sym_t = skew_sym_t.reshape(-1, 3, 3)
82
+ dsw_rotmat = (
83
+ np.eye(3).reshape(1, 3, 3)
84
+ + sina * skew_sym_t
85
+ + (1.0 - cosa) * np.matmul(skew_sym_t, skew_sym_t)
86
+ )
87
+ if twist is not None:
88
+ skew_sym_t = np.stack(
89
+ [
90
+ 0.0 * t[:, 0],
91
+ -t[:, 2],
92
+ t[:, 1],
93
+ t[:, 2],
94
+ 0.0 * t[:, 0],
95
+ -t[:, 0],
96
+ -t[:, 1],
97
+ t[:, 0],
98
+ 0.0 * t[:, 0],
99
+ ],
100
+ 1,
101
+ )
102
+ skew_sym_t = skew_sym_t.reshape(-1, 3, 3)
103
+ sina = twist[:, 1].reshape(-1, 1, 1)
104
+ cosa = twist[:, 0].reshape(-1, 1, 1)
105
+ dtw_rotmat = (
106
+ np.eye(3).reshape([1, 3, 3])
107
+ + sina * skew_sym_t
108
+ + (1.0 - cosa) * np.matmul(skew_sym_t, skew_sym_t)
109
+ )
110
+ dsw_rotmat = np.matmul(dsw_rotmat, dtw_rotmat)
111
+ return dsw_rotmat, np.matmul(pose_global_parent, dsw_rotmat)
112
+
113
+ def __call__(self, joints: np.ndarray, twist: np.ndarray = None) -> np.ndarray:
114
+ """
115
+ joints: B x N x 3
116
+ twist: B x N x 2 if given, default to None
117
+ """
118
+ expand_dim = False
119
+ if len(joints.shape) == 2:
120
+ expand_dim = True
121
+ joints = np.expand_dims(joints, 0)
122
+ if twist is not None:
123
+ twist = np.expand_dims(twist, 0)
124
+ assert len(joints.shape) == 3
125
+ batch_size = np.shape(joints)[0]
126
+ joints_rel = joints - joints[:, self.parents]
127
+ joints_hybrik = 0.0 * joints_rel
128
+ pose_global = np.zeros([batch_size, self.num_nodes, 3, 3])
129
+ pose = np.zeros([batch_size, self.num_nodes, 3, 3])
130
+ for i in range(self.num_nodes):
131
+ if i == 0:
132
+ joints_hybrik[:, 0] = joints[:, 0]
133
+ else:
134
+ joints_hybrik[:, i] = (
135
+ np.matmul(
136
+ pose_global[:, self.parents[i]],
137
+ self.bones[i].reshape(1, 3, 1),
138
+ ).reshape(-1, 3)
139
+ + joints_hybrik[:, self.parents[i]]
140
+ )
141
+ if self.child[i] == -2:
142
+ pose[:, i] = pose[:, i] + np.eye(3).reshape(1, 3, 3)
143
+ pose_global[:, i] = pose_global[:, self.parents[i]]
144
+ continue
145
+ if i == 0:
146
+ r, rg = self.multi_child_rot(
147
+ np.transpose(self.bones[[1, 2, 3]].reshape(1, 3, 3), [0, 2, 1]),
148
+ np.transpose(joints_rel[:, [1, 2, 3]], [0, 2, 1]),
149
+ np.eye(3).reshape(1, 3, 3),
150
+ )
151
+
152
+ elif i == 9:
153
+ r, rg = self.multi_child_rot(
154
+ np.transpose(self.bones[[12, 13, 14]].reshape(1, 3, 3), [0, 2, 1]),
155
+ np.transpose(joints_rel[:, [12, 13, 14]], [0, 2, 1]),
156
+ pose_global[:, self.parents[9]],
157
+ )
158
+ else:
159
+ p = joints_rel[:, self.child[i]]
160
+ if self.naive_hybrik[i] == 0:
161
+ p = joints[:, self.child[i]] - joints_hybrik[:, i]
162
+ twi = None
163
+ if twist is not None:
164
+ twi = twist[:, i]
165
+ r, rg = self.single_child_rot(
166
+ self.bones[self.child[i]].reshape(1, 3, 1),
167
+ p.reshape(-1, 3, 1),
168
+ pose_global[:, self.parents[i]],
169
+ twi,
170
+ )
171
+ pose[:, i] = r
172
+ pose_global[:, i] = rg
173
+ if expand_dim:
174
+ pose = pose[0]
175
+ return pose
176
+
177
+ class HybrIKJointsToRotmat_Tensor:
178
+ def __init__(self):
179
+ self.naive_hybrik = constants.SMPL_HYBRIK
180
+ self.num_nodes = 22
181
+ self.parents = constants.SMPL_BODY_PARENTS
182
+ self.child = constants.SMPL_BODY_CHILDS
183
+ self.bones = torch.tensor(constants.SMPL_BODY_BONES).reshape(24, 3)[:self.num_nodes]
184
+
185
+ def multi_child_rot(self, t, p, pose_global_parent):
186
+ """
187
+ t: B x 3 x child_num
188
+ p: B x 3 x child_num
189
+ pose_global_parent: B x 3 x 3
190
+ """
191
+ m = torch.matmul(
192
+ t, torch.transpose(torch.matmul(torch.inverse(pose_global_parent), p), 1, 2)
193
+ )
194
+ u, s, vt = torch.linalg.svd(m)
195
+ r = torch.matmul(torch.transpose(vt, 1, 2), torch.transpose(u, 1, 2))
196
+ err_det_mask = (torch.det(r) < 0.0).reshape(-1, 1, 1)
197
+ id_fix = torch.tensor([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, -1.0]]).reshape(1, 3, 3)
198
+ r_fix = torch.matmul(
199
+ torch.transpose(vt, 1, 2), torch.matmul(id_fix, torch.transpose(u, 1, 2))
200
+ )
201
+ r = r * (~err_det_mask) + r_fix * err_det_mask
202
+ return r, torch.matmul(pose_global_parent, r)
203
+
204
+ def single_child_rot(
205
+ self,
206
+ t,
207
+ p,
208
+ pose_global_parent,
209
+ twist = None,
210
+ ) -> Tuple[torch.Tensor]:
211
+ """
212
+ t: B x 3 x 1
213
+ p: B x 3 x 1
214
+ pose_global_parent: B x 3 x 3
215
+ twist: B x 2 if given, default to None
216
+ """
217
+ t_tensor = t.clone().detach()#torch.tensor(t)
218
+ p_tensor = p.clone().detach()#torch.tensor(p)
219
+ pose_global_parent_tensor = pose_global_parent.clone().detach()#torch.tensor(pose_global_parent)
220
+
221
+ p_rot = torch.matmul(torch.linalg.inv(pose_global_parent_tensor), p_tensor)
222
+ cross = torch.cross(t_tensor, p_rot, dim=1)
223
+ sina = torch.linalg.norm(cross, dim=1, keepdim=True) / (
224
+ torch.linalg.norm(t_tensor, dim=1, keepdim=True)
225
+ * torch.linalg.norm(p_rot, dim=1, keepdim=True)
226
+ )
227
+ cross = cross / torch.linalg.norm(cross, dim=1, keepdim=True)
228
+ cosa = torch.sum(t_tensor * p_rot, dim=1, keepdim=True) / (
229
+ torch.linalg.norm(t_tensor, dim=1, keepdim=True)
230
+ * torch.linalg.norm(p_rot, dim=1, keepdim=True)
231
+ )
232
+ sina = sina.reshape(-1, 1, 1)
233
+ cosa = cosa.reshape(-1, 1, 1)
234
+ skew_sym_t = torch.stack(
235
+ [
236
+ 0.0 * cross[:, 0],
237
+ -cross[:, 2],
238
+ cross[:, 1],
239
+ cross[:, 2],
240
+ 0.0 * cross[:, 0],
241
+ -cross[:, 0],
242
+ -cross[:, 1],
243
+ cross[:, 0],
244
+ 0.0 * cross[:, 0],
245
+ ],
246
+ 1,
247
+ )
248
+ skew_sym_t = skew_sym_t.reshape(-1, 3, 3)
249
+ dsw_rotmat = (
250
+ torch.eye(3).reshape(1, 3, 3)
251
+ + sina * skew_sym_t
252
+ + (1.0 - cosa) * torch.matmul(skew_sym_t, skew_sym_t)
253
+ )
254
+ if twist is not None:
255
+ twist_tensor = torch.tensor(twist)
256
+ skew_sym_t = torch.stack(
257
+ [
258
+ 0.0 * t_tensor[:, 0],
259
+ -t_tensor[:, 2],
260
+ t_tensor[:, 1],
261
+ t_tensor[:, 2],
262
+ 0.0 * t_tensor[:, 0],
263
+ -t_tensor[:, 0],
264
+ -t_tensor[:, 1],
265
+ t_tensor[:, 0],
266
+ 0.0 * t_tensor[:, 0],
267
+ ],
268
+ 1,
269
+ )
270
+ skew_sym_t = skew_sym_t.reshape(-1, 3, 3)
271
+ sina = twist_tensor[:, 1].reshape(-1, 1, 1)
272
+ cosa = twist_tensor[:, 0].reshape(-1, 1, 1)
273
+ dtw_rotmat = (
274
+ torch.eye(3).reshape([1, 3, 3])
275
+ + sina * skew_sym_t
276
+ + (1.0 - cosa) * torch.matmul(skew_sym_t, skew_sym_t)
277
+ )
278
+ dsw_rotmat = torch.matmul(dsw_rotmat, dtw_rotmat)
279
+
280
+ return dsw_rotmat, torch.matmul(pose_global_parent_tensor, dsw_rotmat)
281
+
282
+ def __call__(self, joints, twist = None) -> torch.Tensor:
283
+ """
284
+ joints: B x N x 3
285
+ twist: B x N x 2 if given, default to None
286
+ """
287
+ expand_dim = False
288
+ if len(joints.shape) == 2:
289
+ expand_dim = True
290
+ joints = joints.unsqueeze(0)
291
+ if twist is not None:
292
+ twist = twist.unsqueeze(0)
293
+ assert len(joints.shape) == 3
294
+ batch_size = joints.shape[0]
295
+ joints_rel = joints - joints[:, self.parents]
296
+ joints_hybrik = torch.zeros_like(joints_rel)
297
+ pose_global = torch.zeros([batch_size, self.num_nodes, 3, 3])
298
+ pose = torch.zeros([batch_size, self.num_nodes, 3, 3])
299
+ for i in range(self.num_nodes):
300
+ if i == 0:
301
+ joints_hybrik[:, 0] = joints[:, 0]
302
+ else:
303
+ joints_hybrik[:, i] = (
304
+ torch.matmul(
305
+ pose_global[:, self.parents[i]],
306
+ self.bones[i].reshape(1, 3, 1),
307
+ ).reshape(-1, 3)
308
+ + joints_hybrik[:, self.parents[i]]
309
+ )
310
+ if self.child[i] == -2:
311
+ pose[:, i] = pose[:, i] + torch.eye(3).reshape(1, 3, 3)
312
+ pose_global[:, i] = pose_global[:, self.parents[i]]
313
+ continue
314
+ if i == 0:
315
+ t = self.bones[[1, 2, 3]].reshape(1, 3, 3).permute(0, 2, 1)
316
+ p = joints_rel[:, [1, 2, 3]].permute(0, 2, 1)
317
+ pose_global_parent = torch.eye(3).reshape(1, 3, 3)
318
+ r, rg = self.multi_child_rot(t, p, pose_global_parent)
319
+ elif i == 9:
320
+ t = self.bones[[12, 13, 14]].reshape(1, 3, 3).permute(0, 2, 1)
321
+ p = joints_rel[:, [12, 13, 14]].permute(0, 2, 1)
322
+ r, rg = self.multi_child_rot(t, p, pose_global[:, self.parents[9]],)
323
+ else:
324
+ p = joints_rel[:, self.child[i]]
325
+ if self.naive_hybrik[i] == 0:
326
+ p = joints[:, self.child[i]] - joints_hybrik[:, i]
327
+ twi = None
328
+ if twist is not None:
329
+ twi = twist[:, i]
330
+ t = self.bones[self.child[i]].reshape(-1, 3, 1)
331
+ p = p.reshape(-1, 3, 1)
332
+ nframes, _, _ = p.shape
333
+ t = t.repeat(nframes, 1, 1)
334
+ r, rg = self.single_child_rot(t, p, pose_global[:, self.parents[i]], twi)
335
+ pose[:, i] = r
336
+ pose_global[:, i] = rg
337
+ if expand_dim:
338
+ pose = pose[0]
339
+ return pose
340
+
341
+
342
+ if __name__ == "__main__":
343
+ jts2rot_hybrik = HybrIKJointsToRotmat_Tensor()
344
+ joints = torch.tensor(constants.SMPL_BODY_BONES).reshape(1, 24, 3)[:, :22]
345
+ parents = [0, 0, 0, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9, 9, 12, 13, 14, 16, 17, 18, 19]
346
+ for i in range(1, 22):
347
+ joints[:, i] = joints[:, i] + joints[:, parents[i]]
348
+ print(joints.shape)
349
+ pose = jts2rot_hybrik(joints)
350
+ print(pose.shape)
utils/metrics.py ADDED
@@ -0,0 +1,142 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from scipy import linalg
3
+
4
+ def euclidean_distance_matrix(matrix1, matrix2):
5
+ """
6
+ Params:
7
+ -- matrix1: N1 x D
8
+ -- matrix2: N2 x D
9
+ Returns:
10
+ -- dist: N1 x N2
11
+ dist[i, j] == distance(matrix1[i], matrix2[j])
12
+ """
13
+ assert matrix1.shape[1] == matrix2.shape[1]
14
+ d1 = -2 * np.dot(matrix1, matrix2.T) # shape (num_test, num_train)
15
+ d2 = np.sum(np.square(matrix1), axis=1, keepdims=True) # shape (num_test, 1)
16
+ d3 = np.sum(np.square(matrix2), axis=1) # shape (num_train, )
17
+ dists = np.sqrt(d1 + d2 + d3) # broadcasting
18
+ return dists
19
+
20
+ def calculate_top_k(mat, top_k):
21
+ size = mat.shape[0]
22
+ gt_mat = np.expand_dims(np.arange(size), 1).repeat(size, 1)
23
+ bool_mat = (mat == gt_mat)
24
+ correct_vec = False
25
+ top_k_list = []
26
+ for i in range(top_k):
27
+ correct_vec = (correct_vec | bool_mat[:, i])
28
+ top_k_list.append(correct_vec[:, None])
29
+ top_k_mat = np.concatenate(top_k_list, axis=1)
30
+ return top_k_mat
31
+
32
+
33
+ def calculate_R_precision(embedding1, embedding2, top_k, sum_all=False):
34
+ dist_mat = euclidean_distance_matrix(embedding1, embedding2)
35
+ argmax = np.argsort(dist_mat, axis=1)
36
+ top_k_mat = calculate_top_k(argmax, top_k)
37
+ if sum_all:
38
+ return top_k_mat.sum(axis=0)
39
+ else:
40
+ return top_k_mat
41
+
42
+
43
+ def calculate_matching_score(embedding1, embedding2, sum_all=False):
44
+ assert len(embedding1.shape) == 2
45
+ assert embedding1.shape[0] == embedding2.shape[0]
46
+ assert embedding1.shape[1] == embedding2.shape[1]
47
+
48
+ dist = linalg.norm(embedding1 - embedding2, axis=1)
49
+ if sum_all:
50
+ return dist.sum(axis=0)
51
+ else:
52
+ return dist
53
+
54
+
55
+
56
+ def calculate_activation_statistics(activations):
57
+ """
58
+ Params:
59
+ -- activation: num_samples x dim_feat
60
+ Returns:
61
+ -- mu: dim_feat
62
+ -- sigma: dim_feat x dim_feat
63
+ """
64
+ mu = np.mean(activations, axis=0)
65
+ cov = np.cov(activations, rowvar=False)
66
+ return mu, cov
67
+
68
+
69
+ def calculate_diversity(activation, diversity_times):
70
+ assert len(activation.shape) == 2
71
+ assert activation.shape[0] > diversity_times
72
+ num_samples = activation.shape[0]
73
+
74
+ first_indices = np.random.choice(num_samples, diversity_times, replace=False)
75
+ second_indices = np.random.choice(num_samples, diversity_times, replace=False)
76
+ dist = linalg.norm(activation[first_indices] - activation[second_indices], axis=1)
77
+ return dist.mean()
78
+
79
+
80
+ def calculate_multimodality(activation, multimodality_times):
81
+ assert len(activation.shape) == 3
82
+ assert activation.shape[1] > multimodality_times
83
+ num_per_sent = activation.shape[1]
84
+
85
+ first_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
86
+ second_dices = np.random.choice(num_per_sent, multimodality_times, replace=False)
87
+ dist = linalg.norm(activation[:, first_dices] - activation[:, second_dices], axis=2)
88
+ return dist.mean()
89
+
90
+
91
+ def calculate_frechet_distance(mu1, sigma1, mu2, sigma2, eps=1e-6):
92
+ """Numpy implementation of the Frechet Distance.
93
+ The Frechet distance between two multivariate Gaussians X_1 ~ N(mu_1, C_1)
94
+ and X_2 ~ N(mu_2, C_2) is
95
+ d^2 = ||mu_1 - mu_2||^2 + Tr(C_1 + C_2 - 2*sqrt(C_1*C_2)).
96
+ Stable version by Dougal J. Sutherland.
97
+ Params:
98
+ -- mu1 : Numpy array containing the activations of a layer of the
99
+ inception net (like returned by the function 'get_predictions')
100
+ for generated samples.
101
+ -- mu2 : The sample mean over activations, precalculated on an
102
+ representative data set.
103
+ -- sigma1: The covariance matrix over activations for generated samples.
104
+ -- sigma2: The covariance matrix over activations, precalculated on an
105
+ representative data set.
106
+ Returns:
107
+ -- : The Frechet Distance.
108
+ """
109
+
110
+ mu1 = np.atleast_1d(mu1)
111
+ mu2 = np.atleast_1d(mu2)
112
+
113
+ sigma1 = np.atleast_2d(sigma1)
114
+ sigma2 = np.atleast_2d(sigma2)
115
+
116
+ assert mu1.shape == mu2.shape, \
117
+ 'Training and test mean vectors have different lengths'
118
+ assert sigma1.shape == sigma2.shape, \
119
+ 'Training and test covariances have different dimensions'
120
+
121
+ diff = mu1 - mu2
122
+
123
+ # Product might be almost singular
124
+ covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
125
+ if not np.isfinite(covmean).all():
126
+ msg = ('fid calculation produces singular product; '
127
+ 'adding %s to diagonal of cov estimates') % eps
128
+ print(msg)
129
+ offset = np.eye(sigma1.shape[0]) * eps
130
+ covmean = linalg.sqrtm((sigma1 + offset).dot(sigma2 + offset))
131
+
132
+ # Numerical error might give slight imaginary component
133
+ if np.iscomplexobj(covmean):
134
+ if not np.allclose(np.diagonal(covmean).imag, 0, atol=1e-3):
135
+ m = np.max(np.abs(covmean.imag))
136
+ raise ValueError('Imaginary component {}'.format(m))
137
+ covmean = covmean.real
138
+
139
+ tr_covmean = np.trace(covmean)
140
+
141
+ return (diff.dot(diff) + np.trace(sigma1) +
142
+ np.trace(sigma2) - 2 * tr_covmean)
utils/model_load.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from .ema import ExponentialMovingAverage
3
+
4
+ def load_model_weights(model, ckpt_path, use_ema=True, device='cuda:0'):
5
+ """
6
+ Load weights of a model from a checkpoint file.
7
+
8
+ Args:
9
+ model (torch.nn.Module): The model to load weights into.
10
+ ckpt_path (str): Path to the checkpoint file.
11
+ use_ema (bool): Whether to use Exponential Moving Average (EMA) weights if available.
12
+ """
13
+ checkpoint = torch.load(ckpt_path,map_location={'cuda:0': str(device)})
14
+ total_iter = checkpoint.get('total_it', 0)
15
+
16
+ if "model_ema" in checkpoint and use_ema:
17
+ ema_key = next(iter(checkpoint["model_ema"]))
18
+ if ('module' in ema_key) or ('n_averaged' in ema_key):
19
+ model = ExponentialMovingAverage(model, decay=1.0)
20
+
21
+ model.load_state_dict(checkpoint["model_ema"], strict=True)
22
+ if ('module' in ema_key) or ('n_averaged' in ema_key):
23
+ model = model.module
24
+ print(f'\nLoading EMA module model from {ckpt_path} with {total_iter} iterations')
25
+ else:
26
+ print(f'\nLoading EMA model from {ckpt_path} with {total_iter} iterations')
27
+ else:
28
+ model.load_state_dict(checkpoint['encoder'], strict=True)
29
+ print(f'\nLoading model from {ckpt_path} with {total_iter} iterations')
30
+
31
+ return total_iter
utils/motion_process.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This file code from T2M(https://github.com/EricGuo5513/text-to-motion), licensed under the https://github.com/EricGuo5513/text-to-motion/blob/main/LICENSE.
2
+ # Copyright (c) 2022 Chuan Guo
3
+ from os.path import join as pjoin
4
+ from typing import Union
5
+ import numpy as np
6
+ import os
7
+ from utils.quaternion import *
8
+ from utils.skeleton import Skeleton
9
+ from utils.paramUtil import *
10
+
11
+ import torch
12
+ from tqdm import tqdm
13
+
14
+ # positions (batch, joint_num, 3)
15
+ def uniform_skeleton(positions, target_offset):
16
+ src_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
17
+ src_offset = src_skel.get_offsets_joints(torch.from_numpy(positions[0]))
18
+ src_offset = src_offset.numpy()
19
+ tgt_offset = target_offset.numpy()
20
+
21
+ '''Calculate Scale Ratio as the ratio of legs'''
22
+ src_leg_len = np.abs(src_offset[l_idx1]).max() + np.abs(src_offset[l_idx2]).max()
23
+ tgt_leg_len = np.abs(tgt_offset[l_idx1]).max() + np.abs(tgt_offset[l_idx2]).max()
24
+
25
+ scale_rt = tgt_leg_len / src_leg_len
26
+ src_root_pos = positions[:, 0]
27
+ tgt_root_pos = src_root_pos * scale_rt
28
+
29
+ '''Inverse Kinematics'''
30
+ quat_params = src_skel.inverse_kinematics_np(positions, face_joint_indx)
31
+
32
+ '''Forward Kinematics'''
33
+ src_skel.set_offset(target_offset)
34
+ new_joints = src_skel.forward_kinematics_np(quat_params, tgt_root_pos)
35
+ return new_joints
36
+
37
+
38
+ def extract_features(positions, feet_thre, n_raw_offsets, kinematic_chain, face_joint_indx, fid_r, fid_l):
39
+ global_positions = positions.copy()
40
+ """ Get Foot Contacts """
41
+
42
+ def foot_detect(positions, thres):
43
+ velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])
44
+
45
+ feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
46
+ feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2
47
+ feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
48
+
49
+ feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float)
50
+
51
+ feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
52
+ feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
53
+ feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
54
+
55
+ feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float)
56
+ return feet_l, feet_r
57
+
58
+ feet_l, feet_r = foot_detect(positions, feet_thre)
59
+
60
+ '''Quaternion and Cartesian representation'''
61
+ r_rot = None
62
+
63
+ def get_rifke(positions):
64
+ '''Local pose'''
65
+ positions[..., 0] -= positions[:, 0:1, 0]
66
+ positions[..., 2] -= positions[:, 0:1, 2]
67
+ '''All pose face Z+'''
68
+ positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)
69
+ return positions
70
+
71
+ def get_quaternion(positions):
72
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
73
+ # (seq_len, joints_num, 4)
74
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False)
75
+
76
+ '''Fix Quaternion Discontinuity'''
77
+ quat_params = qfix(quat_params)
78
+ # (seq_len, 4)
79
+ r_rot = quat_params[:, 0].copy()
80
+ # print(r_rot[0])
81
+ '''Root Linear Velocity'''
82
+ # (seq_len - 1, 3)
83
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
84
+ # print(r_rot.shape, velocity.shape)
85
+ velocity = qrot_np(r_rot[1:], velocity)
86
+ '''Root Angular Velocity'''
87
+ # (seq_len - 1, 4)
88
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
89
+ quat_params[1:, 0] = r_velocity
90
+ # (seq_len, joints_num, 4)
91
+ return quat_params, r_velocity, velocity, r_rot
92
+
93
+ def get_cont6d_params(positions):
94
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
95
+ # (seq_len, joints_num, 4)
96
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)
97
+
98
+ '''Quaternion to continuous 6D'''
99
+ cont_6d_params = quaternion_to_cont6d_np(quat_params)
100
+
101
+ # (seq_len, 4)
102
+ r_rot = quat_params[:, 0].copy()
103
+
104
+ '''Root Linear Velocity'''
105
+ # (seq_len - 1, 3)
106
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
107
+
108
+ velocity = qrot_np(r_rot[1:], velocity)
109
+
110
+ '''Root Angular Velocity'''
111
+ # (seq_len - 1, 4)
112
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
113
+ # (seq_len, joints_num, 4)
114
+ return cont_6d_params, r_velocity, velocity, r_rot
115
+
116
+ cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)
117
+ positions = get_rifke(positions)
118
+
119
+ '''Root height'''
120
+ root_y = positions[:, 0, 1:2]
121
+
122
+ '''Root rotation and linear velocity'''
123
+ # (seq_len-1, 1) rotation velocity along y-axis
124
+ # (seq_len-1, 2) linear velovity on xz plane
125
+ r_velocity = np.arcsin(r_velocity[:, 2:3])
126
+ l_velocity = velocity[:, [0, 2]]
127
+ # print(r_velocity.shape, l_velocity.shape, root_y.shape)
128
+ root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)
129
+
130
+ '''Get Joint Rotation Representation'''
131
+ # (seq_len, (joints_num-1) *6) quaternion for skeleton joints
132
+ rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)
133
+
134
+ '''Get Joint Rotation Invariant Position Represention'''
135
+ # (seq_len, (joints_num-1)*3) local joint position
136
+ ric_data = positions[:, 1:].reshape(len(positions), -1)
137
+
138
+ '''Get Joint Velocity Representation'''
139
+ # (seq_len-1, joints_num*3)
140
+ local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1),
141
+ global_positions[1:] - global_positions[:-1])
142
+ local_vel = local_vel.reshape(len(local_vel), -1)
143
+
144
+ data = root_data
145
+ data = np.concatenate([data, ric_data[:-1]], axis=-1)
146
+ data = np.concatenate([data, rot_data[:-1]], axis=-1)
147
+ # print(data.shape, local_vel.shape)
148
+ data = np.concatenate([data, local_vel], axis=-1)
149
+ data = np.concatenate([data, feet_l, feet_r], axis=-1)
150
+
151
+ return data
152
+
153
+
154
+ def process_file(positions, feet_thre):
155
+ '''Uniform Skeleton'''
156
+ positions = uniform_skeleton(positions, tgt_offsets)
157
+
158
+ '''Put on Floor'''
159
+ floor_height = positions.min(axis=0).min(axis=0)[1]
160
+ positions[:, :, 1] -= floor_height
161
+
162
+ '''XZ at origin'''
163
+ root_pos_init = positions[0]
164
+ root_pose_init_xz = root_pos_init[0] * np.array([1, 0, 1])
165
+ positions = positions - root_pose_init_xz
166
+
167
+ # '''Move the first pose to origin '''
168
+ # root_pos_init = positions[0]
169
+ # positions = positions - root_pos_init[0]
170
+
171
+ '''All initially face Z+'''
172
+ r_hip, l_hip, sdr_r, sdr_l = face_joint_indx
173
+ across1 = root_pos_init[r_hip] - root_pos_init[l_hip]
174
+ across2 = root_pos_init[sdr_r] - root_pos_init[sdr_l]
175
+ across = across1 + across2
176
+ across = across / np.sqrt((across ** 2).sum(axis=-1))[..., np.newaxis]
177
+
178
+ # forward (3,), rotate around y-axis
179
+ forward_init = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
180
+ # forward (3,)
181
+ forward_init = forward_init / np.sqrt((forward_init ** 2).sum(axis=-1))[..., np.newaxis]
182
+
183
+ # print(forward_init)
184
+
185
+ target = np.array([[0, 0, 1]])
186
+ root_quat_init = qbetween_np(forward_init, target)
187
+ root_quat_init = np.ones(positions.shape[:-1] + (4,)) * root_quat_init
188
+
189
+ positions_b = positions.copy()
190
+
191
+ positions = qrot_np(root_quat_init, positions)
192
+
193
+ '''New ground truth positions'''
194
+ global_positions = positions.copy()
195
+
196
+ """ Get Foot Contacts """
197
+
198
+ def foot_detect(positions, thres):
199
+ velfactor, heightfactor = np.array([thres, thres]), np.array([3.0, 2.0])
200
+
201
+ feet_l_x = (positions[1:, fid_l, 0] - positions[:-1, fid_l, 0]) ** 2
202
+ feet_l_y = (positions[1:, fid_l, 1] - positions[:-1, fid_l, 1]) ** 2
203
+ feet_l_z = (positions[1:, fid_l, 2] - positions[:-1, fid_l, 2]) ** 2
204
+ # feet_l_h = positions[:-1,fid_l,1]
205
+ # feet_l = (((feet_l_x + feet_l_y + feet_l_z) < velfactor) & (feet_l_h < heightfactor)).astype(np.float)
206
+ feet_l = ((feet_l_x + feet_l_y + feet_l_z) < velfactor).astype(np.float)
207
+
208
+ feet_r_x = (positions[1:, fid_r, 0] - positions[:-1, fid_r, 0]) ** 2
209
+ feet_r_y = (positions[1:, fid_r, 1] - positions[:-1, fid_r, 1]) ** 2
210
+ feet_r_z = (positions[1:, fid_r, 2] - positions[:-1, fid_r, 2]) ** 2
211
+ # feet_r_h = positions[:-1,fid_r,1]
212
+ # feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor) & (feet_r_h < heightfactor)).astype(np.float)
213
+ feet_r = (((feet_r_x + feet_r_y + feet_r_z) < velfactor)).astype(np.float)
214
+ return feet_l, feet_r
215
+
216
+ feet_l, feet_r = foot_detect(positions, feet_thre)
217
+
218
+ '''Quaternion and Cartesian representation'''
219
+ r_rot = None
220
+
221
+ def get_rifke(positions):
222
+ '''Local pose'''
223
+ positions[..., 0] -= positions[:, 0:1, 0]
224
+ positions[..., 2] -= positions[:, 0:1, 2]
225
+ '''All pose face Z+'''
226
+ positions = qrot_np(np.repeat(r_rot[:, None], positions.shape[1], axis=1), positions)
227
+ return positions
228
+
229
+ def get_quaternion(positions):
230
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
231
+ # (seq_len, joints_num, 4)
232
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=False)
233
+
234
+ '''Fix Quaternion Discontinuity'''
235
+ quat_params = qfix(quat_params)
236
+ # (seq_len, 4)
237
+ r_rot = quat_params[:, 0].copy()
238
+ # print(r_rot[0])
239
+ '''Root Linear Velocity'''
240
+ # (seq_len - 1, 3)
241
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
242
+ # print(r_rot.shape, velocity.shape)
243
+ velocity = qrot_np(r_rot[1:], velocity)
244
+ '''Root Angular Velocity'''
245
+ # (seq_len - 1, 4)
246
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
247
+ quat_params[1:, 0] = r_velocity
248
+ # (seq_len, joints_num, 4)
249
+ return quat_params, r_velocity, velocity, r_rot
250
+
251
+ def get_cont6d_params(positions):
252
+ skel = Skeleton(n_raw_offsets, kinematic_chain, "cpu")
253
+ # (seq_len, joints_num, 4)
254
+ quat_params = skel.inverse_kinematics_np(positions, face_joint_indx, smooth_forward=True)
255
+
256
+ '''Quaternion to continuous 6D'''
257
+ cont_6d_params = quaternion_to_cont6d_np(quat_params)
258
+ # (seq_len, 4)
259
+ r_rot = quat_params[:, 0].copy()
260
+ # print(r_rot[0])
261
+ '''Root Linear Velocity'''
262
+ # (seq_len - 1, 3)
263
+ velocity = (positions[1:, 0] - positions[:-1, 0]).copy()
264
+ # print(r_rot.shape, velocity.shape)
265
+ velocity = qrot_np(r_rot[1:], velocity)
266
+ '''Root Angular Velocity'''
267
+ # (seq_len - 1, 4)
268
+ r_velocity = qmul_np(r_rot[1:], qinv_np(r_rot[:-1]))
269
+ # (seq_len, joints_num, 4)
270
+ return cont_6d_params, r_velocity, velocity, r_rot
271
+
272
+ cont_6d_params, r_velocity, velocity, r_rot = get_cont6d_params(positions)
273
+ positions = get_rifke(positions)
274
+
275
+ '''Root height'''
276
+ root_y = positions[:, 0, 1:2]
277
+
278
+ '''Root rotation and linear velocity'''
279
+ # (seq_len-1, 1) rotation velocity along y-axis
280
+ # (seq_len-1, 2) linear velovity on xz plane
281
+ r_velocity = np.arcsin(r_velocity[:, 2:3])
282
+ l_velocity = velocity[:, [0, 2]]
283
+ # print(r_velocity.shape, l_velocity.shape, root_y.shape)
284
+ root_data = np.concatenate([r_velocity, l_velocity, root_y[:-1]], axis=-1)
285
+
286
+ '''Get Joint Rotation Representation'''
287
+ # (seq_len, (joints_num-1) *6) quaternion for skeleton joints
288
+ rot_data = cont_6d_params[:, 1:].reshape(len(cont_6d_params), -1)
289
+
290
+ '''Get Joint Rotation Invariant Position Represention'''
291
+ # (seq_len, (joints_num-1)*3) local joint position
292
+ ric_data = positions[:, 1:].reshape(len(positions), -1)
293
+
294
+ '''Get Joint Velocity Representation'''
295
+ # (seq_len-1, joints_num*3)
296
+ local_vel = qrot_np(np.repeat(r_rot[:-1, None], global_positions.shape[1], axis=1),
297
+ global_positions[1:] - global_positions[:-1])
298
+ local_vel = local_vel.reshape(len(local_vel), -1)
299
+
300
+ data = root_data
301
+ data = np.concatenate([data, ric_data[:-1]], axis=-1)
302
+ data = np.concatenate([data, rot_data[:-1]], axis=-1)
303
+
304
+ data = np.concatenate([data, local_vel], axis=-1)
305
+ data = np.concatenate([data, feet_l, feet_r], axis=-1)
306
+
307
+ return data, global_positions, positions, l_velocity
308
+
309
+
310
+ # Recover global angle and positions for rotation data
311
+ # root_rot_velocity (B, seq_len, 1)
312
+ # root_linear_velocity (B, seq_len, 2)
313
+ # root_y (B, seq_len, 1)
314
+ # ric_data (B, seq_len, (joint_num - 1)*3)
315
+ # rot_data (B, seq_len, (joint_num - 1)*6)
316
+ # local_velocity (B, seq_len, joint_num*3)
317
+ # foot contact (B, seq_len, 4)
318
+ def recover_root_rot_pos(data):
319
+ rot_vel = data[..., 0]
320
+ r_rot_ang = torch.zeros_like(rot_vel).to(data.device)
321
+ '''Get Y-axis rotation from rotation velocity'''
322
+ r_rot_ang[..., 1:] = rot_vel[..., :-1]
323
+ r_rot_ang = torch.cumsum(r_rot_ang, dim=-1)
324
+
325
+ r_rot_quat = torch.zeros(data.shape[:-1] + (4,)).to(data.device)
326
+ r_rot_quat[..., 0] = torch.cos(r_rot_ang)
327
+ r_rot_quat[..., 2] = torch.sin(r_rot_ang)
328
+
329
+ r_pos = torch.zeros(data.shape[:-1] + (3,)).to(data.device)
330
+ r_pos[..., 1:, [0, 2]] = data[..., :-1, 1:3]
331
+ '''Add Y-axis rotation to root position'''
332
+ r_pos = qrot(qinv(r_rot_quat), r_pos)
333
+
334
+ r_pos = torch.cumsum(r_pos, dim=-2)
335
+
336
+ r_pos[..., 1] = data[..., 3]
337
+ return r_rot_quat, r_pos
338
+
339
+
340
+ def recover_from_rot(data, joints_num, skeleton):
341
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
342
+
343
+ r_rot_cont6d = quaternion_to_cont6d(r_rot_quat)
344
+
345
+ start_indx = 1 + 2 + 1 + (joints_num - 1) * 3
346
+ end_indx = start_indx + (joints_num - 1) * 6
347
+ cont6d_params = data[..., start_indx:end_indx]
348
+ # print(r_rot_cont6d.shape, cont6d_params.shape, r_pos.shape)
349
+ cont6d_params = torch.cat([r_rot_cont6d, cont6d_params], dim=-1)
350
+ cont6d_params = cont6d_params.view(-1, joints_num, 6)
351
+
352
+ positions = skeleton.forward_kinematics_cont6d(cont6d_params, r_pos)
353
+
354
+ return positions
355
+
356
+
357
+ # NOTE: Expand input data types (torch.Tensor -> Union[torch.Tensor, np.array])
358
+ def recover_from_ric(
359
+ data: Union[torch.Tensor, np.array], joints_num: int
360
+ ) -> Union[torch.Tensor, np.array]:
361
+ if isinstance(data, np.ndarray):
362
+ data = torch.from_numpy(data).float()
363
+ dtype = "numpy"
364
+ else:
365
+ data = data.float()
366
+ dtype = "tensor"
367
+ r_rot_quat, r_pos = recover_root_rot_pos(data)
368
+ positions = data[..., 4:(joints_num - 1) * 3 + 4]
369
+ positions = positions.view(positions.shape[:-1] + (-1, 3))
370
+
371
+ '''Add Y-axis rotation to local joints'''
372
+ positions = qrot(qinv(r_rot_quat[..., None, :]).expand(positions.shape[:-1] + (4,)), positions)
373
+
374
+ '''Add root XZ to joints'''
375
+ positions[..., 0] += r_pos[..., 0:1]
376
+ positions[..., 2] += r_pos[..., 2:3]
377
+
378
+ '''Concate root and joints'''
379
+ positions = torch.cat([r_pos.unsqueeze(-2), positions], dim=-2)
380
+
381
+ if dtype == "numpy":
382
+ positions = positions.numpy()
383
+
384
+ return positions
385
+ '''
386
+ For Text2Motion Dataset
387
+ '''
388
+ '''
389
+ if __name__ == "__main__":
390
+ example_id = "000021"
391
+ # Lower legs
392
+ l_idx1, l_idx2 = 5, 8
393
+ # Right/Left foot
394
+ fid_r, fid_l = [8, 11], [7, 10]
395
+ # Face direction, r_hip, l_hip, sdr_r, sdr_l
396
+ face_joint_indx = [2, 1, 17, 16]
397
+ # l_hip, r_hip
398
+ r_hip, l_hip = 2, 1
399
+ joints_num = 22
400
+ # ds_num = 8
401
+ data_dir = '../dataset/pose_data_raw/joints/'
402
+ save_dir1 = '../dataset/pose_data_raw/new_joints/'
403
+ save_dir2 = '../dataset/pose_data_raw/new_joint_vecs/'
404
+
405
+ n_raw_offsets = torch.from_numpy(t2m_raw_offsets)
406
+ kinematic_chain = t2m_kinematic_chain
407
+
408
+ # Get offsets of target skeleton
409
+ example_data = np.load(os.path.join(data_dir, example_id + '.npy'))
410
+ example_data = example_data.reshape(len(example_data), -1, 3)
411
+ example_data = torch.from_numpy(example_data)
412
+ tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
413
+ # (joints_num, 3)
414
+ tgt_offsets = tgt_skel.get_offsets_joints(example_data[0])
415
+ # print(tgt_offsets)
416
+
417
+ source_list = os.listdir(data_dir)
418
+ frame_num = 0
419
+ for source_file in tqdm(source_list):
420
+ source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num]
421
+ try:
422
+ data, ground_positions, positions, l_velocity = process_file(source_data, 0.002)
423
+ rec_ric_data = recover_from_ric(torch.from_numpy(data).unsqueeze(0).float(), joints_num)
424
+ np.save(pjoin(save_dir1, source_file), rec_ric_data.squeeze().numpy())
425
+ np.save(pjoin(save_dir2, source_file), data)
426
+ frame_num += data.shape[0]
427
+ except Exception as e:
428
+ print(source_file)
429
+ print(e)
430
+
431
+ print('Total clips: %d, Frames: %d, Duration: %fm' %
432
+ (len(source_list), frame_num, frame_num / 20 / 60))
433
+ '''
434
+
435
+ if __name__ == "__main__":
436
+ example_id = "03950_gt"
437
+ # Lower legs
438
+ l_idx1, l_idx2 = 17, 18
439
+ # Right/Left foot
440
+ fid_r, fid_l = [14, 15], [19, 20]
441
+ # Face direction, r_hip, l_hip, sdr_r, sdr_l
442
+ face_joint_indx = [11, 16, 5, 8]
443
+ # l_hip, r_hip
444
+ r_hip, l_hip = 11, 16
445
+ joints_num = 21
446
+ # ds_num = 8
447
+ data_dir = '../dataset/kit_mocap_dataset/joints/'
448
+ save_dir1 = '../dataset/kit_mocap_dataset/new_joints/'
449
+ save_dir2 = '../dataset/kit_mocap_dataset/new_joint_vecs/'
450
+
451
+ n_raw_offsets = torch.from_numpy(kit_raw_offsets)
452
+ kinematic_chain = kit_kinematic_chain
453
+
454
+ '''Get offsets of target skeleton'''
455
+ example_data = np.load(os.path.join(data_dir, example_id + '.npy'))
456
+ example_data = example_data.reshape(len(example_data), -1, 3)
457
+ example_data = torch.from_numpy(example_data)
458
+ tgt_skel = Skeleton(n_raw_offsets, kinematic_chain, 'cpu')
459
+ # (joints_num, 3)
460
+ tgt_offsets = tgt_skel.get_offsets_joints(example_data[0])
461
+ # print(tgt_offsets)
462
+
463
+ source_list = os.listdir(data_dir)
464
+ frame_num = 0
465
+ '''Read source data'''
466
+ for source_file in tqdm(source_list):
467
+ source_data = np.load(os.path.join(data_dir, source_file))[:, :joints_num]
468
+ try:
469
+ name = ''.join(source_file[:-7].split('_')) + '.npy'
470
+ data, ground_positions, positions, l_velocity = process_file(source_data, 0.05)
471
+ rec_ric_data = recover_from_ric(torch.from_numpy(data).unsqueeze(0).float(), joints_num)
472
+ if np.isnan(rec_ric_data.numpy()).any():
473
+ print(source_file)
474
+ continue
475
+ np.save(pjoin(save_dir1, name), rec_ric_data.squeeze().numpy())
476
+ np.save(pjoin(save_dir2, name), data)
477
+ frame_num += data.shape[0]
478
+ except Exception as e:
479
+ print(source_file)
480
+ print(e)
481
+
482
+ print('Total clips: %d, Frames: %d, Duration: %fm' %
483
+ (len(source_list), frame_num, frame_num / 12.5 / 60))
utils/paramUtil.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ # Define a kinematic tree for the skeletal struture
4
+ kit_kinematic_chain = [[0, 11, 12, 13, 14, 15], [0, 16, 17, 18, 19, 20], [0, 1, 2, 3, 4], [3, 5, 6, 7], [3, 8, 9, 10]]
5
+
6
+ kit_raw_offsets = np.array(
7
+ [
8
+ [0, 0, 0],
9
+ [0, 1, 0],
10
+ [0, 1, 0],
11
+ [0, 1, 0],
12
+ [0, 1, 0],
13
+ [1, 0, 0],
14
+ [0, -1, 0],
15
+ [0, -1, 0],
16
+ [-1, 0, 0],
17
+ [0, -1, 0],
18
+ [0, -1, 0],
19
+ [1, 0, 0],
20
+ [0, -1, 0],
21
+ [0, -1, 0],
22
+ [0, 0, 1],
23
+ [0, 0, 1],
24
+ [-1, 0, 0],
25
+ [0, -1, 0],
26
+ [0, -1, 0],
27
+ [0, 0, 1],
28
+ [0, 0, 1]
29
+ ]
30
+ )
31
+
32
+ t2m_raw_offsets = np.array([[0,0,0],
33
+ [1,0,0],
34
+ [-1,0,0],
35
+ [0,1,0],
36
+ [0,-1,0],
37
+ [0,-1,0],
38
+ [0,1,0],
39
+ [0,-1,0],
40
+ [0,-1,0],
41
+ [0,1,0],
42
+ [0,0,1],
43
+ [0,0,1],
44
+ [0,1,0],
45
+ [1,0,0],
46
+ [-1,0,0],
47
+ [0,0,1],
48
+ [0,-1,0],
49
+ [0,-1,0],
50
+ [0,-1,0],
51
+ [0,-1,0],
52
+ [0,-1,0],
53
+ [0,-1,0]])
54
+
55
+ t2m_kinematic_chain = [[0, 2, 5, 8, 11], [0, 1, 4, 7, 10], [0, 3, 6, 9, 12, 15], [9, 14, 17, 19, 21], [9, 13, 16, 18, 20]]
56
+ t2m_left_hand_chain = [[20, 22, 23, 24], [20, 34, 35, 36], [20, 25, 26, 27], [20, 31, 32, 33], [20, 28, 29, 30]]
57
+ t2m_right_hand_chain = [[21, 43, 44, 45], [21, 46, 47, 48], [21, 40, 41, 42], [21, 37, 38, 39], [21, 49, 50, 51]]
58
+
59
+ kit_kinematic_tree = [
60
+ [0, 1, 2, 3, 4], # body
61
+ [3, 5, 6, 7], # right arm
62
+ [3, 8, 9, 10], # left arm
63
+ [0, 11, 12, 13, 14, 15], # right leg
64
+ [0, 16, 17, 18, 19, 20], # left leg
65
+ ]
66
+
67
+ humanml3d_kinematic_tree = [
68
+ [0, 3, 6, 9, 12, 15], # body
69
+ [9, 14, 17, 19, 21], # right arm
70
+ [9, 13, 16, 18, 20], # left arm
71
+ [0, 2, 5, 8, 11], # right leg
72
+ [0, 1, 4, 7, 10], # left leg
73
+ ]
74
+
75
+
76
+ kit_tgt_skel_id = '03950'
77
+
78
+ t2m_tgt_skel_id = '000021'
79
+
80
+ KIT_JOINT_NAMES = [
81
+ "pelvis", # 0
82
+ "spine_1", # 1
83
+ "spine_3", # 2
84
+ "neck", # 3
85
+ "head", # 4
86
+ "left_shoulder", # 5
87
+ "left_elbow", # 6
88
+ "left_wrist", # 7
89
+ "right_shoulder", # 8
90
+ "right_elbow", # 9
91
+ "right_wrist", # 10
92
+ "left_hip", # 11
93
+ "left_knee", # 12
94
+ "left_ankle", # 13
95
+ "left_heel", # 14
96
+ "left_foot", # 15
97
+ "right_hip", # 16
98
+ "right_knee", # 17
99
+ "right_ankle", # 18
100
+ "right_heel", # 19
101
+ "right_foot", # 20
102
+ ]
103
+
104
+ HumanML3D_JOINT_NAMES = [
105
+ "pelvis", # 0: root
106
+ "left_hip", # 1: lhip
107
+ "right_hip", # 2: rhip
108
+ "spine_1", # 3: belly
109
+ "left_knee", # 4: lknee
110
+ "right_knee", # 5: rknee
111
+ "spine_2", # 6: spine
112
+ "left_ankle", # 7: lankle
113
+ "right_ankle", # 8: rankle
114
+ "spine_3", # 9: chest
115
+ "left_foot", # 10: ltoes
116
+ "right_foot", # 11: rtoes
117
+ "neck", # 12: neck
118
+ "left_clavicle", # 13: linshoulder
119
+ "right_clavicle", # 14: rinshoulder
120
+ "head", # 15: head
121
+ "left_shoulder", # 16: lshoulder
122
+ "right_shoulder", # 17: rshoulder
123
+ "left_elbow", # 18: lelbow
124
+ "right_elbow", # 19: relbow
125
+ "left_wrist", # 20: lwrist
126
+ "right_wrist", # 21: rwrist
127
+ ]
128
+
129
+ HumanML3D2KIT = {
130
+ 0:0, # pelvis--pelvis
131
+ 1:16, # left_hip--right_hip
132
+ 2:11, # right_hip--left_hip
133
+ 3:1, # spine_1--spine_1
134
+ 4:17, # left_knee--right_knee
135
+ 5:12, # right_knee--left_knee
136
+ 6:1, # spine_2--spine_1
137
+ 7:18, # left_ankle--right_ankle
138
+ 8:13, # right_ankle--left_ankle
139
+ 9:2, # spine_3--spine_3
140
+ 10:20, # left_foot--right_foot
141
+ 11:15, # right_foot--left_foot
142
+ 12:3, # neck--neck
143
+ 13:8, # left_clavicle--right_shoulder
144
+ 14:5, # right_clavicle--left_shoulder
145
+ 15:4, # head--head
146
+ 16:8, # left_shoulder--right_shoulder
147
+ 17:5, # right_shoulder--left_shoulder
148
+ 18:9, # left_elbow--right_elbow
149
+ 19:6, # right_elbow--left_elbow
150
+ 20:10, # left_wrist--right_wrist
151
+ 21:7, # right_wrist--left_wrist
152
+ }
utils/plot_script.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import matplotlib
4
+ import matplotlib.pyplot as plt
5
+ from mpl_toolkits.mplot3d import Axes3D
6
+ from matplotlib.animation import FuncAnimation, FFMpegFileWriter
7
+ from mpl_toolkits.mplot3d.art3d import Poly3DCollection
8
+ import mpl_toolkits.mplot3d.axes3d as p3
9
+ # import cv2
10
+
11
+
12
+ def list_cut_average(ll, intervals):
13
+ if intervals == 1:
14
+ return ll
15
+
16
+ bins = math.ceil(len(ll) * 1.0 / intervals)
17
+ ll_new = []
18
+ for i in range(bins):
19
+ l_low = intervals * i
20
+ l_high = l_low + intervals
21
+ l_high = l_high if l_high < len(ll) else len(ll)
22
+ ll_new.append(np.mean(ll[l_low:l_high]))
23
+ return ll_new
24
+
25
+
26
+ def plot_3d_motion(save_path, kinematic_tree, joints, title, figsize=(10, 10), fps=120, radius=4):
27
+ matplotlib.use('Agg')
28
+
29
+ title_sp = title.split(' ')
30
+ if len(title_sp) > 20:
31
+ title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:20]), ' '.join(title_sp[20:])])
32
+ elif len(title_sp) > 10:
33
+ title = '\n'.join([' '.join(title_sp[:10]), ' '.join(title_sp[10:])])
34
+
35
+ def init():
36
+ ax.set_xlim3d([-radius / 4, radius / 4])
37
+ ax.set_ylim3d([0, radius / 2])
38
+ ax.set_zlim3d([0, radius / 2])
39
+ # print(title)
40
+ fig.suptitle(title, fontsize=20)
41
+ ax.grid(b=False)
42
+
43
+ def plot_xzPlane(minx, maxx, miny, minz, maxz):
44
+ ## Plot a plane XZ
45
+ verts = [
46
+ [minx, miny, minz],
47
+ [minx, miny, maxz],
48
+ [maxx, miny, maxz],
49
+ [maxx, miny, minz]
50
+ ]
51
+ xz_plane = Poly3DCollection([verts])
52
+ xz_plane.set_facecolor((0.5, 0.5, 0.5, 0.5))
53
+ ax.add_collection3d(xz_plane)
54
+
55
+ # return ax
56
+
57
+ # (seq_len, joints_num, 3)
58
+ data = joints.copy().reshape(len(joints), -1, 3)
59
+ fig = plt.figure(figsize=figsize)
60
+ ax = p3.Axes3D(fig)
61
+ init()
62
+ MINS = data.min(axis=0).min(axis=0)
63
+ MAXS = data.max(axis=0).max(axis=0)
64
+ colors = ['red', 'blue', 'black', 'red', 'blue',
65
+ 'darkblue', 'darkblue', 'darkblue', 'darkblue', 'darkblue',
66
+ 'darkred', 'darkred', 'darkred', 'darkred', 'darkred']
67
+ frame_number = data.shape[0]
68
+ # print(data.shape)
69
+
70
+ height_offset = MINS[1]
71
+ data[:, :, 1] -= height_offset
72
+ trajec = data[:, 0, [0, 2]]
73
+
74
+ data[..., 0] -= data[:, 0:1, 0]
75
+ data[..., 2] -= data[:, 0:1, 2]
76
+
77
+ # print(trajec.shape)
78
+
79
+ def update(index):
80
+ # print(index)
81
+ ax.lines = []
82
+ ax.collections = []
83
+ ax.view_init(elev=120, azim=-90)
84
+ ax.dist = 7.5
85
+ # ax =
86
+ plot_xzPlane(MINS[0] - trajec[index, 0], MAXS[0] - trajec[index, 0], 0, MINS[2] - trajec[index, 1],
87
+ MAXS[2] - trajec[index, 1])
88
+ # ax.scatter(data[index, :22, 0], data[index, :22, 1], data[index, :22, 2], color='black', s=3)
89
+
90
+ # if index > 1:
91
+ # ax.plot3D(trajec[:index, 0] - trajec[index, 0], np.zeros_like(trajec[:index, 0]),
92
+ # trajec[:index, 1] - trajec[index, 1], linewidth=1.0,
93
+ # color='blue')
94
+ # ax = plot_xzPlane(ax, MINS[0], MAXS[0], 0, MINS[2], MAXS[2])
95
+
96
+ for i, (chain, color) in enumerate(zip(kinematic_tree, colors)):
97
+ # print(color)
98
+ if i < 5:
99
+ linewidth = 4.0
100
+ else:
101
+ linewidth = 2.0
102
+ ax.plot3D(data[index, chain, 0], data[index, chain, 1], data[index, chain, 2], linewidth=linewidth,
103
+ color=color)
104
+ # print(trajec[:index, 0].shape)
105
+
106
+ plt.axis('off')
107
+ ax.set_xticklabels([])
108
+ ax.set_yticklabels([])
109
+ ax.set_zticklabels([])
110
+
111
+ ani = FuncAnimation(fig, update, frames=frame_number, interval=1000 / fps, repeat=False)
112
+
113
+ from matplotlib import rcParams
114
+ rcParams['animation.ffmpeg_path'] = '/home/chenlinghao/software/miniconda3/envs/py10/bin/ffmpeg'
115
+ writer = FFMpegFileWriter(fps=fps)
116
+ ani.save(save_path, writer=writer)
117
+ plt.close()
utils/quaternion.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2018-present, Facebook, Inc.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ #
7
+
8
+ import torch
9
+ import numpy as np
10
+
11
+ _EPS4 = np.finfo(float).eps * 4.0
12
+
13
+ _FLOAT_EPS = np.finfo(np.float32).eps
14
+
15
+ # PyTorch-backed implementations
16
+ def qinv(q):
17
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
18
+ mask = torch.ones_like(q)
19
+ mask[..., 1:] = -mask[..., 1:]
20
+ return q * mask
21
+
22
+
23
+ def qinv_np(q):
24
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
25
+ return qinv(torch.from_numpy(q).float()).numpy()
26
+
27
+
28
+ def qnormalize(q):
29
+ assert q.shape[-1] == 4, 'q must be a tensor of shape (*, 4)'
30
+ return q / torch.norm(q, dim=-1, keepdim=True)
31
+
32
+
33
+ def qmul(q, r):
34
+ """
35
+ Multiply quaternion(s) q with quaternion(s) r.
36
+ Expects two equally-sized tensors of shape (*, 4), where * denotes any number of dimensions.
37
+ Returns q*r as a tensor of shape (*, 4).
38
+ """
39
+ assert q.shape[-1] == 4
40
+ assert r.shape[-1] == 4
41
+
42
+ original_shape = q.shape
43
+
44
+ # Compute outer product
45
+ terms = torch.bmm(r.view(-1, 4, 1), q.view(-1, 1, 4))
46
+
47
+ w = terms[:, 0, 0] - terms[:, 1, 1] - terms[:, 2, 2] - terms[:, 3, 3]
48
+ x = terms[:, 0, 1] + terms[:, 1, 0] - terms[:, 2, 3] + terms[:, 3, 2]
49
+ y = terms[:, 0, 2] + terms[:, 1, 3] + terms[:, 2, 0] - terms[:, 3, 1]
50
+ z = terms[:, 0, 3] - terms[:, 1, 2] + terms[:, 2, 1] + terms[:, 3, 0]
51
+ return torch.stack((w, x, y, z), dim=1).view(original_shape)
52
+
53
+
54
+ def qrot(q, v):
55
+ """
56
+ Rotate vector(s) v about the rotation described by quaternion(s) q.
57
+ Expects a tensor of shape (*, 4) for q and a tensor of shape (*, 3) for v,
58
+ where * denotes any number of dimensions.
59
+ Returns a tensor of shape (*, 3).
60
+ """
61
+ assert q.shape[-1] == 4
62
+ assert v.shape[-1] == 3
63
+ assert q.shape[:-1] == v.shape[:-1]
64
+
65
+ original_shape = list(v.shape)
66
+ # print(q.shape)
67
+ q = q.contiguous().view(-1, 4)
68
+ v = v.contiguous().view(-1, 3)
69
+
70
+ qvec = q[:, 1:]
71
+ uv = torch.cross(qvec, v, dim=1)
72
+ uuv = torch.cross(qvec, uv, dim=1)
73
+ return (v + 2 * (q[:, :1] * uv + uuv)).view(original_shape)
74
+
75
+
76
+ def qeuler(q, order, epsilon=0, deg=True):
77
+ """
78
+ Convert quaternion(s) q to Euler angles.
79
+ Expects a tensor of shape (*, 4), where * denotes any number of dimensions.
80
+ Returns a tensor of shape (*, 3).
81
+ """
82
+ assert q.shape[-1] == 4
83
+
84
+ original_shape = list(q.shape)
85
+ original_shape[-1] = 3
86
+ q = q.view(-1, 4)
87
+
88
+ q0 = q[:, 0]
89
+ q1 = q[:, 1]
90
+ q2 = q[:, 2]
91
+ q3 = q[:, 3]
92
+
93
+ if order == 'xyz':
94
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
95
+ y = torch.asin(torch.clamp(2 * (q1 * q3 + q0 * q2), -1 + epsilon, 1 - epsilon))
96
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
97
+ elif order == 'yzx':
98
+ x = torch.atan2(2 * (q0 * q1 - q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
99
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
100
+ z = torch.asin(torch.clamp(2 * (q1 * q2 + q0 * q3), -1 + epsilon, 1 - epsilon))
101
+ elif order == 'zxy':
102
+ x = torch.asin(torch.clamp(2 * (q0 * q1 + q2 * q3), -1 + epsilon, 1 - epsilon))
103
+ y = torch.atan2(2 * (q0 * q2 - q1 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
104
+ z = torch.atan2(2 * (q0 * q3 - q1 * q2), 1 - 2 * (q1 * q1 + q3 * q3))
105
+ elif order == 'xzy':
106
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
107
+ y = torch.atan2(2 * (q0 * q2 + q1 * q3), 1 - 2 * (q2 * q2 + q3 * q3))
108
+ z = torch.asin(torch.clamp(2 * (q0 * q3 - q1 * q2), -1 + epsilon, 1 - epsilon))
109
+ elif order == 'yxz':
110
+ x = torch.asin(torch.clamp(2 * (q0 * q1 - q2 * q3), -1 + epsilon, 1 - epsilon))
111
+ y = torch.atan2(2 * (q1 * q3 + q0 * q2), 1 - 2 * (q1 * q1 + q2 * q2))
112
+ z = torch.atan2(2 * (q1 * q2 + q0 * q3), 1 - 2 * (q1 * q1 + q3 * q3))
113
+ elif order == 'zyx':
114
+ x = torch.atan2(2 * (q0 * q1 + q2 * q3), 1 - 2 * (q1 * q1 + q2 * q2))
115
+ y = torch.asin(torch.clamp(2 * (q0 * q2 - q1 * q3), -1 + epsilon, 1 - epsilon))
116
+ z = torch.atan2(2 * (q0 * q3 + q1 * q2), 1 - 2 * (q2 * q2 + q3 * q3))
117
+ else:
118
+ raise
119
+
120
+ if deg:
121
+ return torch.stack((x, y, z), dim=1).view(original_shape) * 180 / np.pi
122
+ else:
123
+ return torch.stack((x, y, z), dim=1).view(original_shape)
124
+
125
+
126
+ # Numpy-backed implementations
127
+
128
+ def qmul_np(q, r):
129
+ q = torch.from_numpy(q).contiguous().float()
130
+ r = torch.from_numpy(r).contiguous().float()
131
+ return qmul(q, r).numpy()
132
+
133
+
134
+ def qrot_np(q, v):
135
+ q = torch.from_numpy(q).contiguous().float()
136
+ v = torch.from_numpy(v).contiguous().float()
137
+ return qrot(q, v).numpy()
138
+
139
+
140
+ def qeuler_np(q, order, epsilon=0, use_gpu=False):
141
+ if use_gpu:
142
+ q = torch.from_numpy(q).cuda().float()
143
+ return qeuler(q, order, epsilon).cpu().numpy()
144
+ else:
145
+ q = torch.from_numpy(q).contiguous().float()
146
+ return qeuler(q, order, epsilon).numpy()
147
+
148
+
149
+ def qfix(q):
150
+ """
151
+ Enforce quaternion continuity across the time dimension by selecting
152
+ the representation (q or -q) with minimal distance (or, equivalently, maximal dot product)
153
+ between two consecutive frames.
154
+
155
+ Expects a tensor of shape (L, J, 4), where L is the sequence length and J is the number of joints.
156
+ Returns a tensor of the same shape.
157
+ """
158
+ assert len(q.shape) == 3
159
+ assert q.shape[-1] == 4
160
+
161
+ result = q.copy()
162
+ dot_products = np.sum(q[1:] * q[:-1], axis=2)
163
+ mask = dot_products < 0
164
+ mask = (np.cumsum(mask, axis=0) % 2).astype(bool)
165
+ result[1:][mask] *= -1
166
+ return result
167
+
168
+
169
+ def euler2quat(e, order, deg=True):
170
+ """
171
+ Convert Euler angles to quaternions.
172
+ """
173
+ assert e.shape[-1] == 3
174
+
175
+ original_shape = list(e.shape)
176
+ original_shape[-1] = 4
177
+
178
+ e = e.view(-1, 3)
179
+
180
+ ## if euler angles in degrees
181
+ if deg:
182
+ e = e * np.pi / 180.
183
+
184
+ x = e[:, 0]
185
+ y = e[:, 1]
186
+ z = e[:, 2]
187
+
188
+ rx = torch.stack((torch.cos(x / 2), torch.sin(x / 2), torch.zeros_like(x), torch.zeros_like(x)), dim=1)
189
+ ry = torch.stack((torch.cos(y / 2), torch.zeros_like(y), torch.sin(y / 2), torch.zeros_like(y)), dim=1)
190
+ rz = torch.stack((torch.cos(z / 2), torch.zeros_like(z), torch.zeros_like(z), torch.sin(z / 2)), dim=1)
191
+
192
+ result = None
193
+ for coord in order:
194
+ if coord == 'x':
195
+ r = rx
196
+ elif coord == 'y':
197
+ r = ry
198
+ elif coord == 'z':
199
+ r = rz
200
+ else:
201
+ raise
202
+ if result is None:
203
+ result = r
204
+ else:
205
+ result = qmul(result, r)
206
+
207
+ # Reverse antipodal representation to have a non-negative "w"
208
+ if order in ['xyz', 'yzx', 'zxy']:
209
+ result *= -1
210
+
211
+ return result.view(original_shape)
212
+
213
+
214
+ def expmap_to_quaternion(e):
215
+ """
216
+ Convert axis-angle rotations (aka exponential maps) to quaternions.
217
+ Stable formula from "Practical Parameterization of Rotations Using the Exponential Map".
218
+ Expects a tensor of shape (*, 3), where * denotes any number of dimensions.
219
+ Returns a tensor of shape (*, 4).
220
+ """
221
+ assert e.shape[-1] == 3
222
+
223
+ original_shape = list(e.shape)
224
+ original_shape[-1] = 4
225
+ e = e.reshape(-1, 3)
226
+
227
+ theta = np.linalg.norm(e, axis=1).reshape(-1, 1)
228
+ w = np.cos(0.5 * theta).reshape(-1, 1)
229
+ xyz = 0.5 * np.sinc(0.5 * theta / np.pi) * e
230
+ return np.concatenate((w, xyz), axis=1).reshape(original_shape)
231
+
232
+
233
+ def euler_to_quaternion(e, order):
234
+ """
235
+ Convert Euler angles to quaternions.
236
+ """
237
+ assert e.shape[-1] == 3
238
+
239
+ original_shape = list(e.shape)
240
+ original_shape[-1] = 4
241
+
242
+ e = e.reshape(-1, 3)
243
+
244
+ x = e[:, 0]
245
+ y = e[:, 1]
246
+ z = e[:, 2]
247
+
248
+ rx = np.stack((np.cos(x / 2), np.sin(x / 2), np.zeros_like(x), np.zeros_like(x)), axis=1)
249
+ ry = np.stack((np.cos(y / 2), np.zeros_like(y), np.sin(y / 2), np.zeros_like(y)), axis=1)
250
+ rz = np.stack((np.cos(z / 2), np.zeros_like(z), np.zeros_like(z), np.sin(z / 2)), axis=1)
251
+
252
+ result = None
253
+ for coord in order:
254
+ if coord == 'x':
255
+ r = rx
256
+ elif coord == 'y':
257
+ r = ry
258
+ elif coord == 'z':
259
+ r = rz
260
+ else:
261
+ raise
262
+ if result is None:
263
+ result = r
264
+ else:
265
+ result = qmul_np(result, r)
266
+
267
+ # Reverse antipodal representation to have a non-negative "w"
268
+ if order in ['xyz', 'yzx', 'zxy']:
269
+ result *= -1
270
+
271
+ return result.reshape(original_shape)
272
+
273
+
274
+ def quaternion_to_matrix(quaternions):
275
+ """
276
+ Convert rotations given as quaternions to rotation matrices.
277
+ Args:
278
+ quaternions: quaternions with real part first,
279
+ as tensor of shape (..., 4).
280
+ Returns:
281
+ Rotation matrices as tensor of shape (..., 3, 3).
282
+ """
283
+ r, i, j, k = torch.unbind(quaternions, -1)
284
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
285
+
286
+ o = torch.stack(
287
+ (
288
+ 1 - two_s * (j * j + k * k),
289
+ two_s * (i * j - k * r),
290
+ two_s * (i * k + j * r),
291
+ two_s * (i * j + k * r),
292
+ 1 - two_s * (i * i + k * k),
293
+ two_s * (j * k - i * r),
294
+ two_s * (i * k - j * r),
295
+ two_s * (j * k + i * r),
296
+ 1 - two_s * (i * i + j * j),
297
+ ),
298
+ -1,
299
+ )
300
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
301
+
302
+
303
+ def quaternion_to_matrix_np(quaternions):
304
+ q = torch.from_numpy(quaternions).contiguous().float()
305
+ return quaternion_to_matrix(q).numpy()
306
+
307
+
308
+ def quaternion_to_cont6d_np(quaternions):
309
+ rotation_mat = quaternion_to_matrix_np(quaternions)
310
+ cont_6d = np.concatenate([rotation_mat[..., 0], rotation_mat[..., 1]], axis=-1)
311
+ return cont_6d
312
+
313
+
314
+ def quaternion_to_cont6d(quaternions):
315
+ rotation_mat = quaternion_to_matrix(quaternions)
316
+ cont_6d = torch.cat([rotation_mat[..., 0], rotation_mat[..., 1]], dim=-1)
317
+ return cont_6d
318
+
319
+
320
+ def cont6d_to_matrix(cont6d):
321
+ assert cont6d.shape[-1] == 6, "The last dimension must be 6"
322
+ x_raw = cont6d[..., 0:3]
323
+ y_raw = cont6d[..., 3:6]
324
+
325
+ x = x_raw / torch.norm(x_raw, dim=-1, keepdim=True)
326
+ z = torch.cross(x, y_raw, dim=-1)
327
+ z = z / torch.norm(z, dim=-1, keepdim=True)
328
+
329
+ y = torch.cross(z, x, dim=-1)
330
+
331
+ x = x[..., None]
332
+ y = y[..., None]
333
+ z = z[..., None]
334
+
335
+ mat = torch.cat([x, y, z], dim=-1)
336
+ return mat
337
+
338
+
339
+ def cont6d_to_matrix_np(cont6d):
340
+ q = torch.from_numpy(cont6d).contiguous().float()
341
+ return cont6d_to_matrix(q).numpy()
342
+
343
+
344
+ def qpow(q0, t, dtype=torch.float):
345
+ ''' q0 : tensor of quaternions
346
+ t: tensor of powers
347
+ '''
348
+ q0 = qnormalize(q0)
349
+ theta0 = torch.acos(q0[..., 0])
350
+
351
+ ## if theta0 is close to zero, add epsilon to avoid NaNs
352
+ mask = (theta0 <= 10e-10) * (theta0 >= -10e-10)
353
+ theta0 = (1 - mask) * theta0 + mask * 10e-10
354
+ v0 = q0[..., 1:] / torch.sin(theta0).view(-1, 1)
355
+
356
+ if isinstance(t, torch.Tensor):
357
+ q = torch.zeros(t.shape + q0.shape)
358
+ theta = t.view(-1, 1) * theta0.view(1, -1)
359
+ else: ## if t is a number
360
+ q = torch.zeros(q0.shape)
361
+ theta = t * theta0
362
+
363
+ q[..., 0] = torch.cos(theta)
364
+ q[..., 1:] = v0 * torch.sin(theta).unsqueeze(-1)
365
+
366
+ return q.to(dtype)
367
+
368
+
369
+ def qslerp(q0, q1, t):
370
+ '''
371
+ q0: starting quaternion
372
+ q1: ending quaternion
373
+ t: array of points along the way
374
+
375
+ Returns:
376
+ Tensor of Slerps: t.shape + q0.shape
377
+ '''
378
+
379
+ q0 = qnormalize(q0)
380
+ q1 = qnormalize(q1)
381
+ q_ = qpow(qmul(q1, qinv(q0)), t)
382
+
383
+ return qmul(q_,
384
+ q0.contiguous().view(torch.Size([1] * len(t.shape)) + q0.shape).expand(t.shape + q0.shape).contiguous())
385
+
386
+
387
+ def qbetween(v0, v1):
388
+ '''
389
+ find the quaternion used to rotate v0 to v1
390
+ '''
391
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
392
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
393
+
394
+ v = torch.cross(v0, v1)
395
+ w = torch.sqrt((v0 ** 2).sum(dim=-1, keepdim=True) * (v1 ** 2).sum(dim=-1, keepdim=True)) + (v0 * v1).sum(dim=-1,
396
+ keepdim=True)
397
+ return qnormalize(torch.cat([w, v], dim=-1))
398
+
399
+
400
+ def qbetween_np(v0, v1):
401
+ '''
402
+ find the quaternion used to rotate v0 to v1
403
+ '''
404
+ assert v0.shape[-1] == 3, 'v0 must be of the shape (*, 3)'
405
+ assert v1.shape[-1] == 3, 'v1 must be of the shape (*, 3)'
406
+
407
+ v0 = torch.from_numpy(v0).float()
408
+ v1 = torch.from_numpy(v1).float()
409
+ return qbetween(v0, v1).numpy()
410
+
411
+
412
+ def lerp(p0, p1, t):
413
+ if not isinstance(t, torch.Tensor):
414
+ t = torch.Tensor([t])
415
+
416
+ new_shape = t.shape + p0.shape
417
+ new_view_t = t.shape + torch.Size([1] * len(p0.shape))
418
+ new_view_p = torch.Size([1] * len(t.shape)) + p0.shape
419
+ p0 = p0.view(new_view_p).expand(new_shape)
420
+ p1 = p1.view(new_view_p).expand(new_shape)
421
+ t = t.view(new_view_t).expand(new_shape)
422
+
423
+ return p0 + t * (p1 - p0)
utils/skeleton.py ADDED
@@ -0,0 +1,199 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from utils.quaternion import *
2
+ import scipy.ndimage.filters as filters
3
+
4
+ class Skeleton(object):
5
+ def __init__(self, offset, kinematic_tree, device):
6
+ self.device = device
7
+ self._raw_offset_np = offset.numpy()
8
+ self._raw_offset = offset.clone().detach().to(device).float()
9
+ self._kinematic_tree = kinematic_tree
10
+ self._offset = None
11
+ self._parents = [0] * len(self._raw_offset)
12
+ self._parents[0] = -1
13
+ for chain in self._kinematic_tree:
14
+ for j in range(1, len(chain)):
15
+ self._parents[chain[j]] = chain[j-1]
16
+
17
+ def njoints(self):
18
+ return len(self._raw_offset)
19
+
20
+ def offset(self):
21
+ return self._offset
22
+
23
+ def set_offset(self, offsets):
24
+ self._offset = offsets.clone().detach().to(self.device).float()
25
+
26
+ def kinematic_tree(self):
27
+ return self._kinematic_tree
28
+
29
+ def parents(self):
30
+ return self._parents
31
+
32
+ # joints (batch_size, joints_num, 3)
33
+ def get_offsets_joints_batch(self, joints):
34
+ assert len(joints.shape) == 3
35
+ _offsets = self._raw_offset.expand(joints.shape[0], -1, -1).clone()
36
+ for i in range(1, self._raw_offset.shape[0]):
37
+ _offsets[:, i] = torch.norm(joints[:, i] - joints[:, self._parents[i]], p=2, dim=1)[:, None] * _offsets[:, i]
38
+
39
+ self._offset = _offsets.detach()
40
+ return _offsets
41
+
42
+ # joints (joints_num, 3)
43
+ def get_offsets_joints(self, joints):
44
+ assert len(joints.shape) == 2
45
+ _offsets = self._raw_offset.clone()
46
+ for i in range(1, self._raw_offset.shape[0]):
47
+ # print(joints.shape)
48
+ _offsets[i] = torch.norm(joints[i] - joints[self._parents[i]], p=2, dim=0) * _offsets[i]
49
+
50
+ self._offset = _offsets.detach()
51
+ return _offsets
52
+
53
+ # face_joint_idx should follow the order of right hip, left hip, right shoulder, left shoulder
54
+ # joints (batch_size, joints_num, 3)
55
+ def inverse_kinematics_np(self, joints, face_joint_idx, smooth_forward=False):
56
+ assert len(face_joint_idx) == 4
57
+ '''Get Forward Direction'''
58
+ l_hip, r_hip, sdr_r, sdr_l = face_joint_idx
59
+ across1 = joints[:, r_hip] - joints[:, l_hip]
60
+ across2 = joints[:, sdr_r] - joints[:, sdr_l]
61
+ across = across1 + across2
62
+ across = across / np.sqrt((across**2).sum(axis=-1))[:, np.newaxis]
63
+ # print(across1.shape, across2.shape)
64
+
65
+ # forward (batch_size, 3)
66
+ forward = np.cross(np.array([[0, 1, 0]]), across, axis=-1)
67
+ if smooth_forward:
68
+ forward = filters.gaussian_filter1d(forward, 20, axis=0, mode='nearest')
69
+ # forward (batch_size, 3)
70
+ forward = forward / np.sqrt((forward**2).sum(axis=-1))[..., np.newaxis]
71
+
72
+ '''Get Root Rotation'''
73
+ target = np.array([[0,0,1]]).repeat(len(forward), axis=0)
74
+ root_quat = qbetween_np(forward, target)
75
+
76
+ '''Inverse Kinematics'''
77
+ # quat_params (batch_size, joints_num, 4)
78
+ # print(joints.shape[:-1])
79
+ quat_params = np.zeros(joints.shape[:-1] + (4,))
80
+ # print(quat_params.shape)
81
+ root_quat[0] = np.array([[1.0, 0.0, 0.0, 0.0]])
82
+ quat_params[:, 0] = root_quat
83
+ # quat_params[0, 0] = np.array([[1.0, 0.0, 0.0, 0.0]])
84
+ for chain in self._kinematic_tree:
85
+ R = root_quat
86
+ for j in range(len(chain) - 1):
87
+ # (batch, 3)
88
+ u = self._raw_offset_np[chain[j+1]][np.newaxis,...].repeat(len(joints), axis=0)
89
+ # print(u.shape)
90
+ # (batch, 3)
91
+ v = joints[:, chain[j+1]] - joints[:, chain[j]]
92
+ v = v / np.sqrt((v**2).sum(axis=-1))[:, np.newaxis]
93
+ # print(u.shape, v.shape)
94
+ rot_u_v = qbetween_np(u, v)
95
+
96
+ R_loc = qmul_np(qinv_np(R), rot_u_v)
97
+
98
+ quat_params[:,chain[j + 1], :] = R_loc
99
+ R = qmul_np(R, R_loc)
100
+
101
+ return quat_params
102
+
103
+ # Be sure root joint is at the beginning of kinematic chains
104
+ def forward_kinematics(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
105
+ # quat_params (batch_size, joints_num, 4)
106
+ # joints (batch_size, joints_num, 3)
107
+ # root_pos (batch_size, 3)
108
+ if skel_joints is not None:
109
+ offsets = self.get_offsets_joints_batch(skel_joints)
110
+ if len(self._offset.shape) == 2:
111
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
112
+ joints = torch.zeros(quat_params.shape[:-1] + (3,)).to(self.device)
113
+ joints[:, 0] = root_pos
114
+ for chain in self._kinematic_tree:
115
+ if do_root_R:
116
+ R = quat_params[:, 0]
117
+ else:
118
+ R = torch.tensor([[1.0, 0.0, 0.0, 0.0]]).expand(len(quat_params), -1).detach().to(self.device)
119
+ for i in range(1, len(chain)):
120
+ R = qmul(R, quat_params[:, chain[i]])
121
+ offset_vec = offsets[:, chain[i]]
122
+ joints[:, chain[i]] = qrot(R, offset_vec) + joints[:, chain[i-1]]
123
+ return joints
124
+
125
+ # Be sure root joint is at the beginning of kinematic chains
126
+ def forward_kinematics_np(self, quat_params, root_pos, skel_joints=None, do_root_R=True):
127
+ # quat_params (batch_size, joints_num, 4)
128
+ # joints (batch_size, joints_num, 3)
129
+ # root_pos (batch_size, 3)
130
+ if skel_joints is not None:
131
+ skel_joints = torch.from_numpy(skel_joints)
132
+ offsets = self.get_offsets_joints_batch(skel_joints)
133
+ if len(self._offset.shape) == 2:
134
+ offsets = self._offset.expand(quat_params.shape[0], -1, -1)
135
+ offsets = offsets.numpy()
136
+ joints = np.zeros(quat_params.shape[:-1] + (3,))
137
+ joints[:, 0] = root_pos
138
+ for chain in self._kinematic_tree:
139
+ if do_root_R:
140
+ R = quat_params[:, 0]
141
+ else:
142
+ R = np.array([[1.0, 0.0, 0.0, 0.0]]).repeat(len(quat_params), axis=0)
143
+ for i in range(1, len(chain)):
144
+ R = qmul_np(R, quat_params[:, chain[i]])
145
+ offset_vec = offsets[:, chain[i]]
146
+ joints[:, chain[i]] = qrot_np(R, offset_vec) + joints[:, chain[i - 1]]
147
+ return joints
148
+
149
+ def forward_kinematics_cont6d_np(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
150
+ # cont6d_params (batch_size, joints_num, 6)
151
+ # joints (batch_size, joints_num, 3)
152
+ # root_pos (batch_size, 3)
153
+ if skel_joints is not None:
154
+ skel_joints = torch.from_numpy(skel_joints)
155
+ offsets = self.get_offsets_joints_batch(skel_joints)
156
+ if len(self._offset.shape) == 2:
157
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
158
+ offsets = offsets.numpy()
159
+ joints = np.zeros(cont6d_params.shape[:-1] + (3,))
160
+ joints[:, 0] = root_pos
161
+ for chain in self._kinematic_tree:
162
+ if do_root_R:
163
+ matR = cont6d_to_matrix_np(cont6d_params[:, 0])
164
+ else:
165
+ matR = np.eye(3)[np.newaxis, :].repeat(len(cont6d_params), axis=0)
166
+ for i in range(1, len(chain)):
167
+ matR = np.matmul(matR, cont6d_to_matrix_np(cont6d_params[:, chain[i]]))
168
+ offset_vec = offsets[:, chain[i]][..., np.newaxis]
169
+ # print(matR.shape, offset_vec.shape)
170
+ joints[:, chain[i]] = np.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
171
+ return joints
172
+
173
+ def forward_kinematics_cont6d(self, cont6d_params, root_pos, skel_joints=None, do_root_R=True):
174
+ # cont6d_params (batch_size, joints_num, 6)
175
+ # joints (batch_size, joints_num, 3)
176
+ # root_pos (batch_size, 3)
177
+ if skel_joints is not None:
178
+ # skel_joints = torch.from_numpy(skel_joints)
179
+ offsets = self.get_offsets_joints_batch(skel_joints)
180
+ if len(self._offset.shape) == 2:
181
+ offsets = self._offset.expand(cont6d_params.shape[0], -1, -1)
182
+ joints = torch.zeros(cont6d_params.shape[:-1] + (3,)).to(cont6d_params.device)
183
+ joints[..., 0, :] = root_pos
184
+ for chain in self._kinematic_tree:
185
+ if do_root_R:
186
+ matR = cont6d_to_matrix(cont6d_params[:, 0])
187
+ else:
188
+ matR = torch.eye(3).expand((len(cont6d_params), -1, -1)).detach().to(cont6d_params.device)
189
+ for i in range(1, len(chain)):
190
+ matR = torch.matmul(matR, cont6d_to_matrix(cont6d_params[:, chain[i]]))
191
+ offset_vec = offsets[:, chain[i]].unsqueeze(-1)
192
+ # print(matR.shape, offset_vec.shape)
193
+ joints[:, chain[i]] = torch.matmul(matR, offset_vec).squeeze(-1) + joints[:, chain[i-1]]
194
+ return joints
195
+
196
+
197
+
198
+
199
+
utils/smpl.py ADDED
@@ -0,0 +1,343 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ from tqdm import tqdm
4
+
5
+ import numpy as np
6
+ from utils.transforms import *
7
+
8
+ import pickle
9
+ from typing import Optional
10
+ # import smplx
11
+ # from smplx.lbs import vertices2joints
12
+ import os
13
+ # from smplx import SMPL as _SMPL
14
+ # from smplx.body_models import ModelOutput
15
+
16
+ smpl_joints = [
17
+ "root", # 0
18
+ "lhip", # 1
19
+ "rhip", # 2
20
+ "belly", # 3
21
+ "lknee", # 4
22
+ "rknee", # 5
23
+ "spine", # 6
24
+ "lankle",# 7
25
+ "rankle",# 8
26
+ "chest", # 9
27
+ "ltoes", # 10
28
+ "rtoes", # 11
29
+ "neck", # 12
30
+ "linshoulder", # 13
31
+ "rinshoulder", # 14
32
+ "head", # 15
33
+ "lshoulder", # 16
34
+ "rshoulder", # 17
35
+ "lelbow", # 18
36
+ "relbow", # 19
37
+ "lwrist", # 20
38
+ "rwrist", # 21
39
+ # "lhand", # 22
40
+ # "rhand", # 23
41
+ ]
42
+
43
+ smpl_parents = [
44
+ -1,
45
+ 0,
46
+ 0,
47
+ 0,
48
+ 1,
49
+ 2,
50
+ 3,
51
+ 4,
52
+ 5,
53
+ 6,
54
+ 7,
55
+ 8,
56
+ 9,
57
+ 9,
58
+ 9,
59
+ 12,
60
+ 13,
61
+ 14,
62
+ 16,
63
+ 17,
64
+ 18,
65
+ 19,
66
+ # 20,
67
+ # 21,
68
+ ]
69
+
70
+ smpl_offsets = [
71
+ [0.0, 0.0, 0.0],
72
+ [0.05858135, -0.08228004, -0.01766408],
73
+ [-0.06030973, -0.09051332, -0.01354254],
74
+ [0.00443945, 0.12440352, -0.03838522],
75
+ [0.04345142, -0.38646945, 0.008037],
76
+ [-0.04325663, -0.38368791, -0.00484304],
77
+ [0.00448844, 0.1379564, 0.02682033],
78
+ [-0.01479032, -0.42687458, -0.037428],
79
+ [0.01905555, -0.4200455, -0.03456167],
80
+ [-0.00226458, 0.05603239, 0.00285505],
81
+ [0.04105436, -0.06028581, 0.12204243],
82
+ [-0.03483987, -0.06210566, 0.13032329],
83
+ [-0.0133902, 0.21163553, -0.03346758],
84
+ [0.07170245, 0.11399969, -0.01889817],
85
+ [-0.08295366, 0.11247234, -0.02370739],
86
+ [0.01011321, 0.08893734, 0.05040987],
87
+ [0.12292141, 0.04520509, -0.019046],
88
+ [-0.11322832, 0.04685326, -0.00847207],
89
+ [0.2553319, -0.01564902, -0.02294649],
90
+ [-0.26012748, -0.01436928, -0.03126873],
91
+ [0.26570925, 0.01269811, -0.00737473],
92
+ [-0.26910836, 0.00679372, -0.00602676],
93
+ # [0.08669055, -0.01063603, -0.01559429],
94
+ # [-0.0887537, -0.00865157, -0.01010708],
95
+ ]
96
+
97
+
98
+ def set_line_data_3d(line, x):
99
+ line.set_data(x[:, :2].T)
100
+ line.set_3d_properties(x[:, 2])
101
+
102
+
103
+ def set_scatter_data_3d(scat, x, c):
104
+ scat.set_offsets(x[:, :2])
105
+ scat.set_3d_properties(x[:, 2], "z")
106
+ scat.set_facecolors([c])
107
+
108
+
109
+ def get_axrange(poses):
110
+ pose = poses[0]
111
+ x_min = pose[:, 0].min()
112
+ x_max = pose[:, 0].max()
113
+
114
+ y_min = pose[:, 1].min()
115
+ y_max = pose[:, 1].max()
116
+
117
+ z_min = pose[:, 2].min()
118
+ z_max = pose[:, 2].max()
119
+
120
+ xdiff = x_max - x_min
121
+ ydiff = y_max - y_min
122
+ zdiff = z_max - z_min
123
+
124
+ biggestdiff = max([xdiff, ydiff, zdiff])
125
+ return biggestdiff
126
+
127
+
128
+ def plot_single_pose(num, poses, lines, ax, axrange, scat, contact):
129
+ pose = poses[num]
130
+ static = contact[num]
131
+ indices = [7, 8, 10, 11]
132
+
133
+ for i, (point, idx) in enumerate(zip(scat, indices)):
134
+ position = pose[idx : idx + 1]
135
+ color = "r" if static[i] else "g"
136
+ set_scatter_data_3d(point, position, color)
137
+
138
+ for i, (p, line) in enumerate(zip(smpl_parents, lines)):
139
+ # don't plot root
140
+ if i == 0:
141
+ continue
142
+ # stack to create a line
143
+ data = np.stack((pose[i], pose[p]), axis=0)
144
+ set_line_data_3d(line, data)
145
+
146
+ if num == 0:
147
+ if isinstance(axrange, int):
148
+ axrange = (axrange, axrange, axrange)
149
+ xcenter, ycenter, zcenter = 0, 0, 2.5
150
+ stepx, stepy, stepz = axrange[0] / 2, axrange[1] / 2, axrange[2] / 2
151
+
152
+ x_min, x_max = xcenter - stepx, xcenter + stepx
153
+ y_min, y_max = ycenter - stepy, ycenter + stepy
154
+ z_min, z_max = zcenter - stepz, zcenter + stepz
155
+
156
+ ax.set_xlim(x_min, x_max)
157
+ ax.set_ylim(y_min, y_max)
158
+ ax.set_zlim(z_min, z_max)
159
+
160
+
161
+ class SMPLSkeleton:
162
+ def __init__(
163
+ self, device=None,
164
+ ):
165
+ offsets = smpl_offsets
166
+ parents = smpl_parents
167
+ assert len(offsets) == len(parents)
168
+
169
+ self._offsets = torch.Tensor(offsets).to(device)
170
+ self._parents = np.array(parents)
171
+ self._compute_metadata()
172
+
173
+ def _compute_metadata(self):
174
+ self._has_children = np.zeros(len(self._parents)).astype(bool)
175
+ for i, parent in enumerate(self._parents):
176
+ if parent != -1:
177
+ self._has_children[parent] = True
178
+
179
+ self._children = []
180
+ for i, parent in enumerate(self._parents):
181
+ self._children.append([])
182
+ for i, parent in enumerate(self._parents):
183
+ if parent != -1:
184
+ self._children[parent].append(i)
185
+
186
+ def forward(self, rotations, root_positions):
187
+ """
188
+ Perform forward kinematics using the given trajectory and local rotations.
189
+ Arguments (where N = batch size, L = sequence length, J = number of joints):
190
+ -- rotations: (N, L, J, 3) tensor of axis-angle rotations describing the local rotations of each joint.
191
+ -- root_positions: (N, L, 3) tensor describing the root joint positions.
192
+ """
193
+ assert len(rotations.shape) == 4
194
+ assert len(root_positions.shape) == 3
195
+ # transform from axis angle to quaternion
196
+ rotations = axis_angle_to_quaternion(rotations)
197
+
198
+ positions_world = []
199
+ rotations_world = []
200
+
201
+ expanded_offsets = self._offsets.expand(
202
+ rotations.shape[0],
203
+ rotations.shape[1],
204
+ self._offsets.shape[0],
205
+ self._offsets.shape[1],
206
+ )
207
+
208
+ # Parallelize along the batch and time dimensions
209
+ for i in range(self._offsets.shape[0]):
210
+ if self._parents[i] == -1:
211
+ positions_world.append(root_positions)
212
+ rotations_world.append(rotations[:, :, 0])
213
+ else:
214
+ positions_world.append(
215
+ quaternion_apply(
216
+ rotations_world[self._parents[i]], expanded_offsets[:, :, i]
217
+ )
218
+ + positions_world[self._parents[i]]
219
+ )
220
+ if self._has_children[i]:
221
+ rotations_world.append(
222
+ quaternion_multiply(
223
+ rotations_world[self._parents[i]], rotations[:, :, i]
224
+ )
225
+ )
226
+ else:
227
+ # This joint is a terminal node -> it would be useless to compute the transformation
228
+ rotations_world.append(None)
229
+
230
+ return torch.stack(positions_world, dim=3).permute(0, 1, 3, 2)
231
+
232
+
233
+ # class SMPL_old(smplx.SMPLLayer):
234
+ # def __init__(self, *args, joint_regressor_extra: Optional[str] = None, update_hips: bool = False, **kwargs):
235
+ # """
236
+ # Extension of the official SMPL implementation to support more joints.
237
+ # Args:
238
+ # Same as SMPLLayer.
239
+ # joint_regressor_extra (str): Path to extra joint regressor.
240
+ # """
241
+ # super(SMPL, self).__init__(*args, **kwargs)
242
+ # smpl_to_openpose = [24, 12, 17, 19, 21, 16, 18, 20, 0, 2, 5, 8, 1, 4,
243
+ # 7, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34]
244
+
245
+ # if joint_regressor_extra is not None:
246
+ # self.register_buffer('joint_regressor_extra', torch.tensor(pickle.load(open(joint_regressor_extra, 'rb'), encoding='latin1'), dtype=torch.float32))
247
+ # self.register_buffer('joint_map', torch.tensor(smpl_to_openpose, dtype=torch.long))
248
+ # self.update_hips = update_hips
249
+
250
+ # def forward(self, *args, **kwargs):
251
+ # """
252
+ # Run forward pass. Same as SMPL and also append an extra set of joints if joint_regressor_extra is specified.
253
+ # """
254
+ # smpl_output = super(SMPL, self).forward(*args, **kwargs)
255
+ # joints = smpl_output.joints[:, self.joint_map, :]
256
+ # if self.update_hips:
257
+ # joints[:,[9,12]] = joints[:,[9,12]] + \
258
+ # 0.25*(joints[:,[9,12]]-joints[:,[12,9]]) + \
259
+ # 0.5*(joints[:,[8]] - 0.5*(joints[:,[9,12]] + joints[:,[12,9]]))
260
+ # if hasattr(self, 'joint_regressor_extra'):
261
+ # extra_joints = vertices2joints(self.joint_regressor_extra, smpl_output.vertices)
262
+ # joints = torch.cat([joints, extra_joints], dim=1)
263
+ # smpl_output.joints = joints
264
+ # return smpl_output
265
+
266
+ # Map joints to SMPL joints
267
+ JOINT_MAP = {
268
+ 'OP Nose': 24, 'OP Neck': 12, 'OP RShoulder': 17,
269
+ 'OP RElbow': 19, 'OP RWrist': 21, 'OP LShoulder': 16,
270
+ 'OP LElbow': 18, 'OP LWrist': 20, 'OP MidHip': 0,
271
+ 'OP RHip': 2, 'OP RKnee': 5, 'OP RAnkle': 8,
272
+ 'OP LHip': 1, 'OP LKnee': 4, 'OP LAnkle': 7,
273
+ 'OP REye': 25, 'OP LEye': 26, 'OP REar': 27,
274
+ 'OP LEar': 28, 'OP LBigToe': 29, 'OP LSmallToe': 30,
275
+ 'OP LHeel': 31, 'OP RBigToe': 32, 'OP RSmallToe': 33, 'OP RHeel': 34,
276
+ 'Right Ankle': 8, 'Right Knee': 5, 'Right Hip': 45,
277
+ 'Left Hip': 46, 'Left Knee': 4, 'Left Ankle': 7,
278
+ 'Right Wrist': 21, 'Right Elbow': 19, 'Right Shoulder': 17,
279
+ 'Left Shoulder': 16, 'Left Elbow': 18, 'Left Wrist': 20,
280
+ 'Neck (LSP)': 47, 'Top of Head (LSP)': 48,
281
+ 'Pelvis (MPII)': 49, 'Thorax (MPII)': 50,
282
+ 'Spine (H36M)': 51, 'Jaw (H36M)': 52,
283
+ 'Head (H36M)': 53, 'Nose': 24, 'Left Eye': 26,
284
+ 'Right Eye': 25, 'Left Ear': 28, 'Right Ear': 27
285
+ }
286
+ JOINT_NAMES = [
287
+ 'OP Nose', 'OP Neck', 'OP RShoulder',
288
+ 'OP RElbow', 'OP RWrist', 'OP LShoulder',
289
+ 'OP LElbow', 'OP LWrist', 'OP MidHip',
290
+ 'OP RHip', 'OP RKnee', 'OP RAnkle',
291
+ 'OP LHip', 'OP LKnee', 'OP LAnkle',
292
+ 'OP REye', 'OP LEye', 'OP REar',
293
+ 'OP LEar', 'OP LBigToe', 'OP LSmallToe',
294
+ 'OP LHeel', 'OP RBigToe', 'OP RSmallToe', 'OP RHeel',
295
+ 'Right Ankle', 'Right Knee', 'Right Hip',
296
+ 'Left Hip', 'Left Knee', 'Left Ankle',
297
+ 'Right Wrist', 'Right Elbow', 'Right Shoulder',
298
+ 'Left Shoulder', 'Left Elbow', 'Left Wrist',
299
+ 'Neck (LSP)', 'Top of Head (LSP)',
300
+ 'Pelvis (MPII)', 'Thorax (MPII)',
301
+ 'Spine (H36M)', 'Jaw (H36M)',
302
+ 'Head (H36M)', 'Nose', 'Left Eye',
303
+ 'Right Eye', 'Left Ear', 'Right Ear'
304
+ ]
305
+ BASE_DATA_DIR = "/data2/TSMC_data/base_data"
306
+ JOINT_IDS = {JOINT_NAMES[i]: i for i in range(len(JOINT_NAMES))}
307
+ JOINT_REGRESSOR_TRAIN_EXTRA = os.path.join(BASE_DATA_DIR, 'J_regressor_extra.npy')
308
+ SMPL_MEAN_PARAMS = os.path.join(BASE_DATA_DIR, 'smpl_mean_params.npz')
309
+ SMPL_MODEL_DIR = BASE_DATA_DIR
310
+ H36M_TO_J17 = [6, 5, 4, 1, 2, 3, 16, 15, 14, 11, 12, 13, 8, 10, 0, 7, 9]
311
+ H36M_TO_J14 = H36M_TO_J17[:14]
312
+
313
+
314
+ # class SMPL(_SMPL):
315
+ # """ Extension of the official SMPL implementation to support more joints """
316
+
317
+ # def __init__(self, *args, **kwargs):
318
+ # super(SMPL, self).__init__(*args, **kwargs)
319
+ # joints = [JOINT_MAP[i] for i in JOINT_NAMES]
320
+ # J_regressor_extra = np.load(JOINT_REGRESSOR_TRAIN_EXTRA)
321
+ # self.register_buffer('J_regressor_extra', torch.tensor(J_regressor_extra, dtype=torch.float32))
322
+ # self.joint_map = torch.tensor(joints, dtype=torch.long)
323
+
324
+
325
+ # def forward(self, *args, **kwargs):
326
+ # kwargs['get_skin'] = True
327
+ # smpl_output = super(SMPL, self).forward(*args, **kwargs)
328
+ # extra_joints = vertices2joints(self.J_regressor_extra, smpl_output.vertices)
329
+ # joints = torch.cat([smpl_output.joints, extra_joints], dim=1)
330
+ # joints = joints[:, self.joint_map, :]
331
+ # output = ModelOutput(vertices=smpl_output.vertices,
332
+ # global_orient=smpl_output.global_orient,
333
+ # body_pose=smpl_output.body_pose,
334
+ # joints=joints,
335
+ # betas=smpl_output.betas,
336
+ # full_pose=smpl_output.full_pose)
337
+ # return output
338
+
339
+
340
+ # def get_smpl_faces():
341
+ # print("Get SMPL faces")
342
+ # smpl = SMPL(SMPL_MODEL_DIR, batch_size=1, create_transl=False)
343
+ # return smpl.faces
utils/transforms.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # This code is based on https://github.com/Mathux/ACTOR.git
2
+ # Copyright (c) Facebook, Inc. and its affiliates. All rights reserved.
3
+ # Check PYTORCH3D_LICENCE before use
4
+
5
+ import functools
6
+ from typing import Optional
7
+
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import numpy as np
11
+
12
+ """
13
+ The transformation matrices returned from the functions in this file assume
14
+ the points on which the transformation will be applied are column vectors.
15
+ i.e. the R matrix is structured as
16
+
17
+ R = [
18
+ [Rxx, Rxy, Rxz],
19
+ [Ryx, Ryy, Ryz],
20
+ [Rzx, Rzy, Rzz],
21
+ ] # (3, 3)
22
+
23
+ This matrix can be applied to column vectors by post multiplication
24
+ by the points e.g.
25
+
26
+ points = [[0], [1], [2]] # (3 x 1) xyz coordinates of a point
27
+ transformed_points = R * points
28
+
29
+ To apply the same matrix to points which are row vectors, the R matrix
30
+ can be transposed and pre multiplied by the points:
31
+
32
+ e.g.
33
+ points = [[0, 1, 2]] # (1 x 3) xyz coordinates of a point
34
+ transformed_points = points * R.transpose(1, 0)
35
+ """
36
+
37
+
38
+ def quaternion_to_matrix(quaternions):
39
+ """
40
+ Convert rotations given as quaternions to rotation matrices.
41
+
42
+ Args:
43
+ quaternions: quaternions with real part first,
44
+ as tensor of shape (..., 4).
45
+
46
+ Returns:
47
+ Rotation matrices as tensor of shape (..., 3, 3).
48
+ """
49
+ r, i, j, k = torch.unbind(quaternions, -1)
50
+ two_s = 2.0 / (quaternions * quaternions).sum(-1)
51
+
52
+ o = torch.stack(
53
+ (
54
+ 1 - two_s * (j * j + k * k),
55
+ two_s * (i * j - k * r),
56
+ two_s * (i * k + j * r),
57
+ two_s * (i * j + k * r),
58
+ 1 - two_s * (i * i + k * k),
59
+ two_s * (j * k - i * r),
60
+ two_s * (i * k - j * r),
61
+ two_s * (j * k + i * r),
62
+ 1 - two_s * (i * i + j * j),
63
+ ),
64
+ -1,
65
+ )
66
+ return o.reshape(quaternions.shape[:-1] + (3, 3))
67
+
68
+
69
+ def _copysign(a, b):
70
+ """
71
+ Return a tensor where each element has the absolute value taken from the,
72
+ corresponding element of a, with sign taken from the corresponding
73
+ element of b. This is like the standard copysign floating-point operation,
74
+ but is not careful about negative 0 and NaN.
75
+
76
+ Args:
77
+ a: source tensor.
78
+ b: tensor whose signs will be used, of the same shape as a.
79
+
80
+ Returns:
81
+ Tensor of the same shape as a with the signs of b.
82
+ """
83
+ signs_differ = (a < 0) != (b < 0)
84
+ return torch.where(signs_differ, -a, a)
85
+
86
+
87
+ def _sqrt_positive_part(x):
88
+ """
89
+ Returns torch.sqrt(torch.max(0, x))
90
+ but with a zero subgradient where x is 0.
91
+ """
92
+ ret = torch.zeros_like(x)
93
+ positive_mask = x > 0
94
+ ret[positive_mask] = torch.sqrt(x[positive_mask])
95
+ return ret
96
+
97
+
98
+ def matrix_to_quaternion(matrix):
99
+ """
100
+ Convert rotations given as rotation matrices to quaternions.
101
+
102
+ Args:
103
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
104
+
105
+ Returns:
106
+ quaternions with real part first, as tensor of shape (..., 4).
107
+ """
108
+ # print("matrix.size:", matrix.size())
109
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
110
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
111
+ m00 = matrix[..., 0, 0]
112
+ m11 = matrix[..., 1, 1]
113
+ m22 = matrix[..., 2, 2]
114
+ o0 = 0.5 * _sqrt_positive_part(1 + m00 + m11 + m22)
115
+ x = 0.5 * _sqrt_positive_part(1 + m00 - m11 - m22)
116
+ y = 0.5 * _sqrt_positive_part(1 - m00 + m11 - m22)
117
+ z = 0.5 * _sqrt_positive_part(1 - m00 - m11 + m22)
118
+ o1 = _copysign(x, matrix[..., 2, 1] - matrix[..., 1, 2])
119
+ o2 = _copysign(y, matrix[..., 0, 2] - matrix[..., 2, 0])
120
+ o3 = _copysign(z, matrix[..., 1, 0] - matrix[..., 0, 1])
121
+ return torch.stack((o0, o1, o2, o3), -1)
122
+
123
+
124
+ def _axis_angle_rotation(axis: str, angle):
125
+ """
126
+ Return the rotation matrices for one of the rotations about an axis
127
+ of which Euler angles describe, for each value of the angle given.
128
+
129
+ Args:
130
+ axis: Axis label "X" or "Y or "Z".
131
+ angle: any shape tensor of Euler angles in radians
132
+
133
+ Returns:
134
+ Rotation matrices as tensor of shape (..., 3, 3).
135
+ """
136
+
137
+ cos = torch.cos(angle)
138
+ sin = torch.sin(angle)
139
+ one = torch.ones_like(angle)
140
+ zero = torch.zeros_like(angle)
141
+
142
+ if axis == "X":
143
+ R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
144
+ if axis == "Y":
145
+ R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
146
+ if axis == "Z":
147
+ R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
148
+
149
+ return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
150
+
151
+
152
+ def euler_angles_to_matrix(euler_angles, convention: str):
153
+ """
154
+ Convert rotations given as Euler angles in radians to rotation matrices.
155
+
156
+ Args:
157
+ euler_angles: Euler angles in radians as tensor of shape (..., 3).
158
+ convention: Convention string of three uppercase letters from
159
+ {"X", "Y", and "Z"}.
160
+
161
+ Returns:
162
+ Rotation matrices as tensor of shape (..., 3, 3).
163
+ """
164
+ if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
165
+ raise ValueError("Invalid input euler angles.")
166
+ if len(convention) != 3:
167
+ raise ValueError("Convention must have 3 letters.")
168
+ if convention[1] in (convention[0], convention[2]):
169
+ raise ValueError(f"Invalid convention {convention}.")
170
+ for letter in convention:
171
+ if letter not in ("X", "Y", "Z"):
172
+ raise ValueError(f"Invalid letter {letter} in convention string.")
173
+ matrices = map(_axis_angle_rotation, convention, torch.unbind(euler_angles, -1))
174
+ return functools.reduce(torch.matmul, matrices)
175
+
176
+
177
+ def _angle_from_tan(
178
+ axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool
179
+ ):
180
+ """
181
+ Extract the first or third Euler angle from the two members of
182
+ the matrix which are positive constant times its sine and cosine.
183
+
184
+ Args:
185
+ axis: Axis label "X" or "Y or "Z" for the angle we are finding.
186
+ other_axis: Axis label "X" or "Y or "Z" for the middle axis in the
187
+ convention.
188
+ data: Rotation matrices as tensor of shape (..., 3, 3).
189
+ horizontal: Whether we are looking for the angle for the third axis,
190
+ which means the relevant entries are in the same row of the
191
+ rotation matrix. If not, they are in the same column.
192
+ tait_bryan: Whether the first and third axes in the convention differ.
193
+
194
+ Returns:
195
+ Euler Angles in radians for each matrix in dataset as a tensor
196
+ of shape (...).
197
+ """
198
+
199
+ i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis]
200
+ if horizontal:
201
+ i2, i1 = i1, i2
202
+ even = (axis + other_axis) in ["XY", "YZ", "ZX"]
203
+ if horizontal == even:
204
+ return torch.atan2(data[..., i1], data[..., i2])
205
+ if tait_bryan:
206
+ return torch.atan2(-data[..., i2], data[..., i1])
207
+ return torch.atan2(data[..., i2], -data[..., i1])
208
+
209
+
210
+ def _index_from_letter(letter: str):
211
+ if letter == "X":
212
+ return 0
213
+ if letter == "Y":
214
+ return 1
215
+ if letter == "Z":
216
+ return 2
217
+
218
+
219
+ def matrix_to_euler_angles(matrix, convention: str):
220
+ """
221
+ Convert rotations given as rotation matrices to Euler angles in radians.
222
+
223
+ Args:
224
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
225
+ convention: Convention string of three uppercase letters.
226
+
227
+ Returns:
228
+ Euler angles in radians as tensor of shape (..., 3).
229
+ """
230
+ if len(convention) != 3:
231
+ raise ValueError("Convention must have 3 letters.")
232
+ if convention[1] in (convention[0], convention[2]):
233
+ raise ValueError(f"Invalid convention {convention}.")
234
+ for letter in convention:
235
+ if letter not in ("X", "Y", "Z"):
236
+ raise ValueError(f"Invalid letter {letter} in convention string.")
237
+ if matrix.size(-1) != 3 or matrix.size(-2) != 3:
238
+ raise ValueError(f"Invalid rotation matrix shape f{matrix.shape}.")
239
+ i0 = _index_from_letter(convention[0])
240
+ i2 = _index_from_letter(convention[2])
241
+ tait_bryan = i0 != i2
242
+ if tait_bryan:
243
+ central_angle = torch.asin(
244
+ matrix[..., i0, i2] * (-1.0 if i0 - i2 in [-1, 2] else 1.0)
245
+ )
246
+ else:
247
+ central_angle = torch.acos(matrix[..., i0, i0])
248
+
249
+ o = (
250
+ _angle_from_tan(
251
+ convention[0], convention[1], matrix[..., i2], False, tait_bryan
252
+ ),
253
+ central_angle,
254
+ _angle_from_tan(
255
+ convention[2], convention[1], matrix[..., i0, :], True, tait_bryan
256
+ ),
257
+ )
258
+ return torch.stack(o, -1)
259
+
260
+
261
+ def random_quaternions(
262
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
263
+ ):
264
+ """
265
+ Generate random quaternions representing rotations,
266
+ i.e. versors with nonnegative real part.
267
+
268
+ Args:
269
+ n: Number of quaternions in a batch to return.
270
+ dtype: Type to return.
271
+ device: Desired device of returned tensor. Default:
272
+ uses the current device for the default tensor type.
273
+ requires_grad: Whether the resulting tensor should have the gradient
274
+ flag set.
275
+
276
+ Returns:
277
+ Quaternions as tensor of shape (N, 4).
278
+ """
279
+ o = torch.randn((n, 4), dtype=dtype, device=device, requires_grad=requires_grad)
280
+ s = (o * o).sum(1)
281
+ o = o / _copysign(torch.sqrt(s), o[:, 0])[:, None]
282
+ return o
283
+
284
+
285
+ def random_rotations(
286
+ n: int, dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
287
+ ):
288
+ """
289
+ Generate random rotations as 3x3 rotation matrices.
290
+
291
+ Args:
292
+ n: Number of rotation matrices in a batch to return.
293
+ dtype: Type to return.
294
+ device: Device of returned tensor. Default: if None,
295
+ uses the current device for the default tensor type.
296
+ requires_grad: Whether the resulting tensor should have the gradient
297
+ flag set.
298
+
299
+ Returns:
300
+ Rotation matrices as tensor of shape (n, 3, 3).
301
+ """
302
+ quaternions = random_quaternions(
303
+ n, dtype=dtype, device=device, requires_grad=requires_grad
304
+ )
305
+ return quaternion_to_matrix(quaternions)
306
+
307
+
308
+ def random_rotation(
309
+ dtype: Optional[torch.dtype] = None, device=None, requires_grad=False
310
+ ):
311
+ """
312
+ Generate a single random 3x3 rotation matrix.
313
+
314
+ Args:
315
+ dtype: Type to return
316
+ device: Device of returned tensor. Default: if None,
317
+ uses the current device for the default tensor type
318
+ requires_grad: Whether the resulting tensor should have the gradient
319
+ flag set
320
+
321
+ Returns:
322
+ Rotation matrix as tensor of shape (3, 3).
323
+ """
324
+ return random_rotations(1, dtype, device, requires_grad)[0]
325
+
326
+
327
+ def standardize_quaternion(quaternions):
328
+ """
329
+ Convert a unit quaternion to a standard form: one in which the real
330
+ part is non negative.
331
+
332
+ Args:
333
+ quaternions: Quaternions with real part first,
334
+ as tensor of shape (..., 4).
335
+
336
+ Returns:
337
+ Standardized quaternions as tensor of shape (..., 4).
338
+ """
339
+ return torch.where(quaternions[..., 0:1] < 0, -quaternions, quaternions)
340
+
341
+
342
+ def quaternion_raw_multiply(a, b):
343
+ """
344
+ Multiply two quaternions.
345
+ Usual torch rules for broadcasting apply.
346
+
347
+ Args:
348
+ a: Quaternions as tensor of shape (..., 4), real part first.
349
+ b: Quaternions as tensor of shape (..., 4), real part first.
350
+
351
+ Returns:
352
+ The product of a and b, a tensor of quaternions shape (..., 4).
353
+ """
354
+ aw, ax, ay, az = torch.unbind(a, -1)
355
+ bw, bx, by, bz = torch.unbind(b, -1)
356
+ ow = aw * bw - ax * bx - ay * by - az * bz
357
+ ox = aw * bx + ax * bw + ay * bz - az * by
358
+ oy = aw * by - ax * bz + ay * bw + az * bx
359
+ oz = aw * bz + ax * by - ay * bx + az * bw
360
+ return torch.stack((ow, ox, oy, oz), -1)
361
+
362
+
363
+ def quaternion_multiply(a, b):
364
+ """
365
+ Multiply two quaternions representing rotations, returning the quaternion
366
+ representing their composition, i.e. the versor with nonnegative real part.
367
+ Usual torch rules for broadcasting apply.
368
+
369
+ Args:
370
+ a: Quaternions as tensor of shape (..., 4), real part first.
371
+ b: Quaternions as tensor of shape (..., 4), real part first.
372
+
373
+ Returns:
374
+ The product of a and b, a tensor of quaternions of shape (..., 4).
375
+ """
376
+ ab = quaternion_raw_multiply(a, b)
377
+ return standardize_quaternion(ab)
378
+
379
+
380
+ def quaternion_invert(quaternion):
381
+ """
382
+ Given a quaternion representing rotation, get the quaternion representing
383
+ its inverse.
384
+
385
+ Args:
386
+ quaternion: Quaternions as tensor of shape (..., 4), with real part
387
+ first, which must be versors (unit quaternions).
388
+
389
+ Returns:
390
+ The inverse, a tensor of quaternions of shape (..., 4).
391
+ """
392
+
393
+ return quaternion * quaternion.new_tensor([1, -1, -1, -1])
394
+
395
+
396
+ def quaternion_apply(quaternion, point):
397
+ """
398
+ Apply the rotation given by a quaternion to a 3D point.
399
+ Usual torch rules for broadcasting apply.
400
+
401
+ Args:
402
+ quaternion: Tensor of quaternions, real part first, of shape (..., 4).
403
+ point: Tensor of 3D points of shape (..., 3).
404
+
405
+ Returns:
406
+ Tensor of rotated points of shape (..., 3).
407
+ """
408
+ if point.size(-1) != 3:
409
+ raise ValueError(f"Points are not in 3D, f{point.shape}.")
410
+ real_parts = point.new_zeros(point.shape[:-1] + (1,))
411
+ point_as_quaternion = torch.cat((real_parts, point), -1)
412
+ out = quaternion_raw_multiply(
413
+ quaternion_raw_multiply(quaternion, point_as_quaternion),
414
+ quaternion_invert(quaternion),
415
+ )
416
+ return out[..., 1:]
417
+
418
+
419
+ def axis_angle_to_matrix(axis_angle):
420
+ """
421
+ Convert rotations given as axis/angle to rotation matrices.
422
+
423
+ Args:
424
+ axis_angle: Rotations given as a vector in axis angle form,
425
+ as a tensor of shape (..., 3), where the magnitude is
426
+ the angle turned anticlockwise in radians around the
427
+ vector's direction.
428
+
429
+ Returns:
430
+ Rotation matrices as tensor of shape (..., 3, 3).
431
+ """
432
+ return quaternion_to_matrix(axis_angle_to_quaternion(axis_angle))
433
+
434
+
435
+ def matrix_to_axis_angle(matrix):
436
+ """
437
+ Convert rotations given as rotation matrices to axis/angle.
438
+
439
+ Args:
440
+ matrix: Rotation matrices as tensor of shape (..., 3, 3).
441
+
442
+ Returns:
443
+ Rotations given as a vector in axis angle form, as a tensor
444
+ of shape (..., 3), where the magnitude is the angle
445
+ turned anticlockwise in radians around the vector's
446
+ direction.
447
+ """
448
+ return quaternion_to_axis_angle(matrix_to_quaternion(matrix))
449
+
450
+
451
+ def axis_angle_to_quaternion(axis_angle):
452
+ """
453
+ Convert rotations given as axis/angle to quaternions.
454
+
455
+ Args:
456
+ axis_angle: Rotations given as a vector in axis angle form,
457
+ as a tensor of shape (..., 3), where the magnitude is
458
+ the angle turned anticlockwise in radians around the
459
+ vector's direction.
460
+
461
+ Returns:
462
+ quaternions with real part first, as tensor of shape (..., 4).
463
+ """
464
+ angles = torch.norm(axis_angle, p=2, dim=-1, keepdim=True)
465
+ half_angles = 0.5 * angles
466
+ eps = 1e-6
467
+ small_angles = angles.abs() < eps
468
+ sin_half_angles_over_angles = torch.empty_like(angles)
469
+ sin_half_angles_over_angles[~small_angles] = (
470
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
471
+ )
472
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
473
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
474
+ sin_half_angles_over_angles[small_angles] = (
475
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
476
+ )
477
+ quaternions = torch.cat(
478
+ [torch.cos(half_angles), axis_angle * sin_half_angles_over_angles], dim=-1
479
+ )
480
+ return quaternions
481
+
482
+
483
+ def quaternion_to_axis_angle(quaternions):
484
+ """
485
+ Convert rotations given as quaternions to axis/angle.
486
+
487
+ Args:
488
+ quaternions: quaternions with real part first,
489
+ as tensor of shape (..., 4).
490
+
491
+ Returns:
492
+ Rotations given as a vector in axis angle form, as a tensor
493
+ of shape (..., 3), where the magnitude is the angle
494
+ turned anticlockwise in radians around the vector's
495
+ direction.
496
+ """
497
+ norms = torch.norm(quaternions[..., 1:], p=2, dim=-1, keepdim=True)
498
+ half_angles = torch.atan2(norms, quaternions[..., :1])
499
+ angles = 2 * half_angles
500
+ eps = 1e-6
501
+ small_angles = angles.abs() < eps
502
+ sin_half_angles_over_angles = torch.empty_like(angles)
503
+ sin_half_angles_over_angles[~small_angles] = (
504
+ torch.sin(half_angles[~small_angles]) / angles[~small_angles]
505
+ )
506
+ # for x small, sin(x/2) is about x/2 - (x/2)^3/6
507
+ # so sin(x/2)/x is about 1/2 - (x*x)/48
508
+ sin_half_angles_over_angles[small_angles] = (
509
+ 0.5 - (angles[small_angles] * angles[small_angles]) / 48
510
+ )
511
+ return quaternions[..., 1:] / sin_half_angles_over_angles
512
+
513
+
514
+ def rotation_6d_to_matrix(d6: torch.Tensor) -> torch.Tensor:
515
+ """
516
+ Converts 6D rotation representation by Zhou et al. [1] to rotation matrix
517
+ using Gram--Schmidt orthogonalisation per Section B of [1].
518
+ Args:
519
+ d6: 6D rotation representation, of size (*, 6)
520
+
521
+ Returns:
522
+ batch of rotation matrices of size (*, 3, 3)
523
+
524
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
525
+ On the Continuity of Rotation Representations in Neural Networks.
526
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
527
+ Retrieved from http://arxiv.org/abs/1812.07035
528
+ """
529
+
530
+ a1, a2 = d6[..., :3], d6[..., 3:]
531
+ b1 = F.normalize(a1, dim=-1)
532
+ b2 = a2 - (b1 * a2).sum(-1, keepdim=True) * b1
533
+ b2 = F.normalize(b2, dim=-1)
534
+ b3 = torch.cross(b1, b2, dim=-1)
535
+ return torch.stack((b1, b2, b3), dim=-2)
536
+
537
+
538
+ def matrix_to_rotation_6d(matrix: torch.Tensor) -> torch.Tensor:
539
+ """
540
+ Converts rotation matrices to 6D rotation representation by Zhou et al. [1]
541
+ by dropping the last row. Note that 6D representation is not unique.
542
+ Args:
543
+ matrix: batch of rotation matrices of size (*, 3, 3)
544
+
545
+ Returns:
546
+ 6D rotation representation, of size (*, 6)
547
+
548
+ [1] Zhou, Y., Barnes, C., Lu, J., Yang, J., & Li, H.
549
+ On the Continuity of Rotation Representations in Neural Networks.
550
+ IEEE Conference on Computer Vision and Pattern Recognition, 2019.
551
+ Retrieved from http://arxiv.org/abs/1812.07035
552
+ """
553
+ return matrix[..., :2, :].clone().reshape(*matrix.size()[:-2], 6)
554
+
555
+ def axis_angle_to_rotation_6d(axis_angle):
556
+ matrix = axis_angle_to_matrix(axis_angle)
557
+ return matrix_to_rotation_6d(matrix)
558
+
559
+ def rotateXYZ(mesh_v, Rxyz):
560
+ for rot in Rxyz:
561
+ angle = np.radians(rot[0])
562
+ rx = np.array([
563
+ [1., 0., 0. ],
564
+ [0., np.cos(angle), -np.sin(angle)],
565
+ [0., np.sin(angle), np.cos(angle) ]
566
+ ])
567
+
568
+ angle = np.radians(rot[1])
569
+ ry = np.array([
570
+ [np.cos(angle), 0., np.sin(angle)],
571
+ [0., 1., 0. ],
572
+ [-np.sin(angle), 0., np.cos(angle)]
573
+ ])
574
+
575
+ angle = np.radians(rot[2])
576
+ rz = np.array([
577
+ [np.cos(angle), -np.sin(angle), 0. ],
578
+ [np.sin(angle), np.cos(angle), 0. ],
579
+ [0., 0., 1. ]
580
+ ])
581
+ # return rotateZ(rotateY(rotateX(mesh_v, Rxyz[0]), Rxyz[1]), Rxyz[2])
582
+ mesh_v = rz.dot(ry.dot(rx.dot(mesh_v.T))).T
583
+ # return rx.dot(mesh_v.T).T
584
+ return mesh_v
585
+
586
+ def rotate_trans(trans_3d, rot_body=None):
587
+ if rot_body is not None:
588
+ trans_3d = rotateXYZ(trans_3d, rot_body)
589
+ trans_3d = torch.from_numpy(trans_3d)
590
+ return trans_3d
591
+
592
+ def rotate_root(pose_6d, rot_body=None):
593
+ root_6d = pose_6d[:, :6]
594
+ root_6d = rotation_6d_to_matrix(root_6d)
595
+ if rot_body is not None:
596
+ root_6d = rotateXYZ(root_6d, rot_body)
597
+ root_6d = torch.from_numpy(root_6d)
598
+ root_6d = matrix_to_rotation_6d(root_6d)
599
+ pose_6d[:, :6] = root_6d
600
+ return pose_6d