diff --git a/1gpu.yaml b/1gpu.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0219628eed01926b40ced235c104f48038871f31 --- /dev/null +++ b/1gpu.yaml @@ -0,0 +1,16 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: 'NO' +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: no +num_machines: 1 +num_processes: 1 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +main_process_port: 21000 \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..a2887ea66ae5a34ab84df77f0e633cb514702de4 --- /dev/null +++ b/LICENSE @@ -0,0 +1,57 @@ +#-------------------------------------------- +# MotionCLR +# Copyright (c) 2024 IDEA. All Rights Reserved. +# Licensed under the IDEA License, Version 1.0 [see LICENSE for details] +#-------------------------------------------- + +IDEA License 1.0 + +This License Agreement (as may be amended in accordance with this License Agreement, “License”), between you, or your employer or other entity (if you are entering into this agreement on behalf of your employer or other entity) (“Licensee” or “you”) and the International Digital Economy Academy (“IDEA” or “we”) applies to your use of any computer program, algorithm, source code, object code, or software that is made available by IDEA under this License (“Software”) and any specifications, manuals, documentation, and other written information provided by IDEA related to the Software (“Documentation”). + +By downloading the Software or by using the Software, you agree to the terms of this License. If you do not agree to this License, then you do not have any rights to use the Software or Documentation (collectively, the “Software Products”), and you must immediately cease using the Software Products. If you are agreeing to be bound by the terms of this License on behalf of your employer or other entity, you represent and warrant to IDEA that you have full legal authority to bind your employer or such entity to this License. If you do not have the requisite authority, you may not accept the License or access the Software Products on behalf of your employer or other entity. + +1. LICENSE GRANT + +a. You are granted a non-exclusive, worldwide, transferable, sublicensable, irrevocable, royalty free and limited license under IDEA’s copyright interests to use, reproduce, distribute, copy, create derivative works of, and make modifications to the Software solely for your non-commercial research purposes. + +b. The grant of rights expressly set forth in this Section 1 (License Grant) are the complete grant of rights to you in the Software Products, and no other licenses are granted, whether by waiver, estoppel, implication, equity or otherwise. IDEA and its licensors reserve all rights not expressly granted by this License. + +c. If you intend to use the Software Products for any commercial purposes, you must request a license from IDEA, which IDEA may grant to you in its sole discretion. + +2. REDISTRIBUTION AND USE + +a. If you distribute or make the Software Products, or any derivative works thereof, available to a third party, you shall provide a copy of this Agreement to such third party. + +b. You must retain in all copies of the Software Products that you distribute the following attribution notice: "MotionCLR is licensed under the IDEA License 1.0, Copyright (c) IDEA. All Rights Reserved." + +d. Your use of the Software Products must comply with applicable laws and regulations (including trade compliance laws and regulations). + +e. You will not, and will not permit, assist or cause any third party to use, modify, copy, reproduce, create derivative works of, or distribute the Software Products (or any derivative works thereof, works incorporating the Software Products, or any data produced by the Software), in whole or in part, for in any manner that infringes, misappropriates, or otherwise violates any third-party rights. + +3. DISCLAIMER OF WARRANTY + +UNLESS REQUIRED BY APPLICABLE LAW, THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS THEREFROM ARE PROVIDED ON AN "AS IS" BASIS, WITHOUT WARRANTIES OF ANY KIND, EITHER EXPRESS OR IMPLIED, INCLUDING, WITHOUT LIMITATION, ANY WARRANTIES OF TITLE, NON-INFRINGEMENT, MERCHANTABILITY, OR FITNESS FOR A PARTICULAR PURPOSE. YOU ARE SOLELY RESPONSIBLE FOR DETERMINING THE APPROPRIATENESS OF USING OR REDISTRIBUTING THE SOFTWARE PRODUCTS AND ASSUME ANY RISKS ASSOCIATED WITH YOUR USE OF THE SOFTWARE PRODUCTS AND ANY OUTPUT AND RESULTS. + +4. LIMITATION OF LIABILITY + +IN NO EVENT WILL IDEA OR ITS AFFILIATES BE LIABLE UNDER ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, TORT, NEGLIGENCE, PRODUCTS LIABILITY, OR OTHERWISE, ARISING OUT OF THIS AGREEMENT, FOR ANY LOST PROFITS OR ANY INDIRECT, SPECIAL, CONSEQUENTIAL, INCIDENTAL, EXEMPLARY OR PUNITIVE DAMAGES, EVEN IF IDEA OR ITS AFFILIATES HAVE BEEN ADVISED OF THE POSSIBILITY OF ANY OF THE FOREGOING. + +5. INDEMNIFICATION + +You will indemnify, defend and hold harmless IDEA and our subsidiaries and affiliates, and each of our respective shareholders, directors, officers, employees, agents, successors, and assigns (collectively, the “IDEA Parties”) from and against any losses, liabilities, damages, fines, penalties, and expenses (including reasonable attorneys’ fees) incurred by any IDEA Party in connection with any claim, demand, allegation, lawsuit, proceeding, or investigation (collectively, “Claims”) arising out of or related to: (a) your access to or use of the Software Products (as well as any results or data generated from such access or use); (b) your violation of this License; or (c) your violation, misappropriation or infringement of any rights of another (including intellectual property or other proprietary rights and privacy rights). You will promptly notify the IDEA Parties of any such Claims, and cooperate with IDEA Parties in defending such Claims. You will also grant the IDEA Parties sole control of the defense or settlement, at IDEA’s sole option, of any Claims. This indemnity is in addition to, and not in lieu of, any other indemnities or remedies set forth in a written agreement between you and IDEA or the other IDEA Parties. + +6. TERMINATION; SURVIVAL + +a. This License will automatically terminate upon any breach by you of the terms of this License. + +b. If you institute litigation or other proceedings against IDEA or any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Software Products, or any portion of any of the foregoing, constitutes infringement of intellectual property or other rights owned or licensable by you, then any licenses granted to you under this Agreement shall terminate as of the date such litigation or claim is filed or instituted. + +c. The following sections survive termination of this License: 2 (Redistribution and use), 3 (Disclaimers of Warranty), 4 (Limitation of Liability), 5 (Indemnification), 6 (Termination; Survival), 7 (Trademarks) and 8 (Applicable Law; Dispute Resolution). + +7. TRADEMARKS + +Licensee has not been granted any trademark license as part of this License and may not use any name or mark associated with IDEA without the prior written permission of IDEA, except to the extent necessary to make the reference required by the attribution notice of this Agreement. + +8. APPLICABLE LAW; DISPUTE RESOLUTION + +This License will be governed and construed under the laws of the People’s Republic of China without regard to conflicts of law provisions. The parties expressly agree that the United Nations Convention on Contracts for the International Sale of Goods will not apply. Any suit or proceeding arising out of or relating to this License will be brought in the courts, as applicable, in Shenzhen, Guangdong, and each party irrevocably submits to the jurisdiction and venue of such courts. \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..c8c3e1d610028bb3b173e046daa74da1ece988c9 --- /dev/null +++ b/app.py @@ -0,0 +1,540 @@ +import spaces +import gradio as gr +import sys +import os +import torch +import numpy as np +from os.path import join as pjoin +import utils.paramUtil as paramUtil +from utils.plot_script import * +from utils.utils import * +from utils.motion_process import recover_from_ric +from accelerate.utils import set_seed +from models.gaussian_diffusion import DiffusePipeline +from options.generate_options import GenerateOptions +from utils.model_load import load_model_weights +from motion_loader import get_dataset_loader +from models import build_models +import yaml +import time +from box import Box +import hashlib +from huggingface_hub import hf_hub_download + +ckptdir = './checkpoints/t2m/release' +os.makedirs(ckptdir, exist_ok=True) + + +mean_path = hf_hub_download( + repo_id="EvanTHU/MotionCLR", + filename="meta/mean.npy", + local_dir=ckptdir, + local_dir_use_symlinks=False +) + +std_path = hf_hub_download( + repo_id="EvanTHU/MotionCLR", + filename="meta/std.npy", + local_dir=ckptdir, + local_dir_use_symlinks=False +) + +model_path = hf_hub_download( + repo_id="EvanTHU/MotionCLR", + filename="model/latest.tar", + local_dir=ckptdir, + local_dir_use_symlinks=False +) + +opt_path = hf_hub_download( + repo_id="EvanTHU/MotionCLR", + filename="opt.txt", + local_dir=ckptdir, + local_dir_use_symlinks=False +) + + + +os.makedirs("tmp", exist_ok=True) +os.environ['GRADIO_TEMP_DIR'] = './tmp' + +def generate_md5(input_string): + # Encode the string and compute the MD5 hash + md5_hash = hashlib.md5(input_string.encode()) + # Return the hexadecimal representation of the hash + return md5_hash.hexdigest() + +def set_all_use_to_false(data): + for key, value in data.items(): + if isinstance(value, Box): + set_all_use_to_false(value) + elif key == 'use': + data[key] = False + return data + +def yaml_to_box(yaml_file): + with open(yaml_file, 'r') as file: + yaml_data = yaml.safe_load(file) + + return Box(yaml_data) + +HEAD = """
Content Reference
""" + global edit_config + edit_config = set_all_use_to_false(edit_config) + return video_dis, video_dis, video_dis, video_dis, style_dis, video_dis, gr.update(visible=True) + +def reweighting(text, idx, weight, opt, pipeline): + global edit_config + edit_config.reweighting_attn.use = True + edit_config.reweighting_attn.idx = idx + edit_config.reweighting_attn.reweighting_attn_weight = weight + + + gr.Info("Loading Configurations...", duration = 3) + model = build_models(opt, edit_config=edit_config) + ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar') + niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema) + + pipeline = DiffusePipeline( + opt = opt, + model = model, + diffuser_name = opt.diffuser_name, + device=opt.device, + num_inference_steps=opt.num_inference_steps, + torch_dtype=torch.float16, + ) + + print(edit_config) + + width = 500 + height = 500 + texts = [text, text] + motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)] + + save_dir = './tmp/gen/' + filenames = [generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4"] + save_paths = [pjoin(save_dir, str(filenames[0])), pjoin(save_dir, str(filenames[1]))] + os.makedirs(save_dir, exist_ok=True) + + start_time = time.perf_counter() + gr.Info("Generating motion...", duration = 3) + pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens])) + end_time = time.perf_counter() + exc = end_time - start_time + gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3) + start_time = time.perf_counter() + mean = np.load(pjoin(opt.meta_dir, 'mean.npy')) + std = np.load(pjoin(opt.meta_dir, 'std.npy')) + + + samples = [] + + root_list = [] + for i, motion in enumerate(pred_motions): + motion = motion.cpu().numpy() * std + mean + # 1. recover 3d joints representation by ik + motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num) + # 2. put on Floor (Y axis) + floor_height = motion.min(dim=0)[0].min(dim=0)[0][1] + motion[:, :, 1] -= floor_height + motion = motion.numpy() + # 3. remove jitter + motion = motion_temporal_filter(motion, sigma=1) + + samples.append(motion) + + i = 1 + title = texts[i] + motion = samples[i] + kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain + plot_3d_motion(save_paths[1], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius) + + + gr.Info("Rendered motion...", duration = 3) + end_time = time.perf_counter() + exc = end_time - start_time + gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3) + + video_dis = f'' + + + edit_config = set_all_use_to_false(edit_config) + return video_dis + +def generate_example_based_motion(text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion, opt, pipeline): + global edit_config + edit_config.example_based.use = True + edit_config.example_based.chunk_size = chunk_size + edit_config.example_based.example_based_steps_end = example_based_steps_end + edit_config.example_based.temp_seed = temp_seed + edit_config.example_based.temp_seed_bar = temp_seed_bar + + + gr.Info("Loading Configurations...", duration = 3) + model = build_models(opt, edit_config=edit_config) + ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar') + niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema) + + pipeline = DiffusePipeline( + opt = opt, + model = model, + diffuser_name = opt.diffuser_name, + device=opt.device, + num_inference_steps=opt.num_inference_steps, + torch_dtype=torch.float16, + ) + + width = 500 + height = 500 + texts = [text for _ in range(num_motion)] + motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)] + + save_dir = './tmp/gen/' + filenames = [generate_md5(str(time.time())) + ".mp4" for _ in range(num_motion)] + save_paths = [pjoin(save_dir, str(filenames[i])) for i in range(num_motion)] + os.makedirs(save_dir, exist_ok=True) + + start_time = time.perf_counter() + gr.Info("Generating motion...", duration = 3) + pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens])) + end_time = time.perf_counter() + exc = end_time - start_time + gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3) + start_time = time.perf_counter() + mean = np.load(pjoin(opt.meta_dir, 'mean.npy')) + std = np.load(pjoin(opt.meta_dir, 'std.npy')) + + + samples = [] + + root_list = [] + progress=gr.Progress() + progress(0, desc="Starting...") + for i, motion in enumerate(pred_motions): + motion = motion.cpu().numpy() * std + mean + # 1. recover 3d joints representation by ik + motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num) + # 2. put on Floor (Y axis) + floor_height = motion.min(dim=0)[0].min(dim=0)[0][1] + motion[:, :, 1] -= floor_height + motion = motion.numpy() + # 3. remove jitter + motion = motion_temporal_filter(motion, sigma=1) + + samples.append(motion) + + video_dis = [] + i = 0 + for title in progress.tqdm(texts): + print(save_paths[i]) + title = texts[i] + motion = samples[i] + kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain + plot_3d_motion(save_paths[i], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius) + video_html = f''' + + ''' + video_dis.append(video_html) + i += 1 + + for _ in range(24 - num_motion): + video_dis.append(None) + gr.Info("Rendered motion...", duration = 3) + end_time = time.perf_counter() + exc = end_time - start_time + gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3) + + edit_config = set_all_use_to_false(edit_config) + return video_dis + +def transfer_style(text, style_text, style_transfer_steps_end, opt, pipeline): + global edit_config + edit_config.style_tranfer.use = True + edit_config.style_tranfer.style_transfer_steps_end = style_transfer_steps_end + + gr.Info("Loading Configurations...", duration = 3) + model = build_models(opt, edit_config=edit_config) + ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar') + niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema) + + pipeline = DiffusePipeline( + opt = opt, + model = model, + diffuser_name = opt.diffuser_name, + device=opt.device, + num_inference_steps=opt.num_inference_steps, + torch_dtype=torch.float16, + ) + + print(edit_config) + + width = 500 + height = 500 + texts = [style_text, text, text] + motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)] + + save_dir = './tmp/gen/' + filenames = [generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4", generate_md5(str(time.time())) + ".mp4"] + save_paths = [pjoin(save_dir, str(filenames[0])), pjoin(save_dir, str(filenames[1])), pjoin(save_dir, str(filenames[2]))] + os.makedirs(save_dir, exist_ok=True) + + start_time = time.perf_counter() + gr.Info("Generating motion...", duration = 3) + pred_motions, _ = pipeline.generate(texts, torch.LongTensor([int(x) for x in motion_lens])) + end_time = time.perf_counter() + exc = end_time - start_time + gr.Info(f"Generating time cost: {exc:.2f} s, rendering starts...", duration = 3) + start_time = time.perf_counter() + mean = np.load(pjoin(opt.meta_dir, 'mean.npy')) + std = np.load(pjoin(opt.meta_dir, 'std.npy')) + + samples = [] + + root_list = [] + for i, motion in enumerate(pred_motions): + motion = motion.cpu().numpy() * std + mean + # 1. recover 3d joints representation by ik + motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num) + # 2. put on Floor (Y axis) + floor_height = motion.min(dim=0)[0].min(dim=0)[0][1] + motion[:, :, 1] -= floor_height + motion = motion.numpy() + # 3. remove jitter + motion = motion_temporal_filter(motion, sigma=1) + + samples.append(motion) + + for i,title in enumerate(texts): + title = texts[i] + motion = samples[i] + kinematic_tree = paramUtil.t2m_kinematic_chain if (opt.dataset_name == 't2m') else paramUtil.kit_kinematic_chain + plot_3d_motion(save_paths[i], kinematic_tree, motion, title=title, fps=opt.fps, radius=opt.radius) + + gr.Info("Rendered motion...", duration = 3) + end_time = time.perf_counter() + exc = end_time - start_time + gr.Info(f"Rendering time cost: {exc:.2f} s", duration = 3) + + video_dis0 = f"""Style Reference
""" + video_dis1 = f"""Content Reference
""" + video_dis2 = f"""Transfered Result
""" + + edit_config = set_all_use_to_false(edit_config) + return video_dis0, video_dis2 + + +@spaces.GPU +def main(): + parser = GenerateOptions() + opt = parser.parse_app() + set_seed(opt.seed) + device_id = opt.gpu_id + device = torch.device('cuda:%d' % device_id if torch.cuda.is_available() else 'cpu') + opt.device = device + + + # load model + model = build_models(opt, edit_config=edit_config) + ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + '.tar') + niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema) + + pipeline = DiffusePipeline( + opt = opt, + model = model, + diffuser_name = opt.diffuser_name, + device=device, + num_inference_steps=opt.num_inference_steps, + torch_dtype=torch.float16, + ) + + with gr.Blocks() as demo: + gr.Markdown(HEAD) + with gr.Row(): + with gr.Column(scale=7): + text_input = gr.Textbox(label="Input the text prompt to generate motion...") + with gr.Column(scale=3): + sequence_length = gr.Slider(minimum=1, maximum=9.6, step=0.1, label="Motion length", value=8) + with gr.Row(): + generate_button = gr.Button("Generate motion") + + with gr.Row(): + video_display = gr.HTML(label="生成的视频", visible=True) + + + tabs = gr.Tabs(visible=True) + with tabs: + with gr.Tab("Motion (de-)emphasizing"): + with gr.Row(): + int_input = gr.Number(label="Editing word index", minimum=0, maximum=70) + weight_input = gr.Slider(minimum=-1, maximum=1, step=0.01, label="Input weight for (de-)emphasizing [-1, 1]", value=0) + + trim_button = gr.Button("Edit reweighting") + + with gr.Row(): + original_video1 = gr.HTML(label="before editing", visible=False) + edited_video = gr.HTML(label="after editing") + + trim_button.click( + fn=lambda x, int_input, weight_input : reweighting(x, int_input, weight_input, opt, pipeline), + inputs=[text_input, int_input, weight_input], + outputs=edited_video, + ) + + with gr.Tab("Example-based motion genration"): + with gr.Row(): + with gr.Column(scale=4): + chunk_size = gr.Number(minimum=10, maximum=20, step=10,label="Chunk size (#frames)", value=20) + example_based_steps_end = gr.Number(minimum=0, maximum=9,label="Ending step of manipulation", value=6) + with gr.Column(scale=3): + temp_seed = gr.Number(label="Seed for random", value=200, minimum=0) + temp_seed_bar = gr.Slider(minimum=0, maximum=100, step=1, label="Seed for random bar", value=15) + with gr.Column(scale=3): + num_motion = gr.Radio(choices=[4, 8, 12, 16, 24], value=8, label="Select number of motions") + + gen_button = gr.Button("Generate example-based motion") + + + example_video_display = [] + for _ in range(6): + with gr.Row(): + for _ in range(4): + video = gr.HTML(label="Example-based motion", visible=True) + example_video_display.append(video) + + gen_button.click( + fn=lambda text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion: generate_example_based_motion(text, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion, opt, pipeline), + inputs=[text_input, chunk_size, example_based_steps_end, temp_seed, temp_seed_bar, num_motion], + outputs=example_video_display + ) + + with gr.Tab("Style transfer"): + with gr.Row(): + style_text = gr.Textbox(label="Reference prompt (e.g. 'a man walks.')", value="a man walks.") + style_transfer_steps_end = gr.Number(label="The end step of diffusion (0~9)", minimum=0, maximum=9, value=5) + + style_transfer_button = gr.Button("Transfer style") + + with gr.Row(): + style_reference = gr.HTML(label="style reference") + original_video4 = gr.HTML(label="before style transfer", visible=False) + styled_video = gr.HTML(label="after style transfer") + + style_transfer_button.click( + fn=lambda text, style_text, style_transfer_steps_end: transfer_style(text, style_text, style_transfer_steps_end, opt, pipeline), + inputs=[text_input, style_text, style_transfer_steps_end], + outputs=[style_reference, styled_video], + ) + + def update_motion_length(sequence_length): + opt.motion_length = sequence_length + + def on_generate(text, length, pipeline): + update_motion_length(length) + return generate_video_from_text(text, opt, pipeline) + + + generate_button.click( + fn=lambda text, length: on_generate(text, length, pipeline), + inputs=[text_input, sequence_length], + outputs=[ + video_display, + original_video1, + original_video4, + tabs, + ], + show_progress=True + ) + + generate_button.click( + fn=lambda: [gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True), gr.update(visible=True)], + inputs=None, + outputs=[video_display, original_video1, original_video4] + ) + + demo.launch() + + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/assets/motion_lens.txt b/assets/motion_lens.txt new file mode 100644 index 0000000000000000000000000000000000000000..9da06a18339c312235ed2a16c054f2bb70981398 --- /dev/null +++ b/assets/motion_lens.txt @@ -0,0 +1 @@ +160 \ No newline at end of file diff --git a/assets/prompts.txt b/assets/prompts.txt new file mode 100644 index 0000000000000000000000000000000000000000..96f0234f56528241c2bcc21d7703f10ecb18f9c3 --- /dev/null +++ b/assets/prompts.txt @@ -0,0 +1 @@ +a man jumps. \ No newline at end of file diff --git a/config/diffuser_params.yaml b/config/diffuser_params.yaml new file mode 100644 index 0000000000000000000000000000000000000000..614a72e401e213695c5914cc0bc04bebd2030326 --- /dev/null +++ b/config/diffuser_params.yaml @@ -0,0 +1,26 @@ +dpmsolver: + scheduler_class: DPMSolverMultistepScheduler + additional_params: + algorithm_type: sde-dpmsolver++ + use_karras_sigmas: true + +ddpm: + scheduler_class: DDPMScheduler + additional_params: + variance_type: fixed_small + clip_sample: false + +ddim: + scheduler_class: DDIMScheduler + additional_params: + clip_sample: false + +deis: + scheduler_class: DEISMultistepScheduler + additional_params: + num_train_timesteps: 1000 + +pndm: + scheduler_class: PNDMScheduler + additional_params: + num_train_timesteps: 1000 diff --git a/config/evaluator.yaml b/config/evaluator.yaml new file mode 100644 index 0000000000000000000000000000000000000000..6021113fe989a79f40c9c3df7c6c4f841ca34bd5 --- /dev/null +++ b/config/evaluator.yaml @@ -0,0 +1,14 @@ +unit_length: 4 +max_text_len: 20 +text_enc_mod: bigru +estimator_mod: bigru +dim_text_hidden: 512 +dim_att_vec: 512 +dim_z: 128 +dim_movement_enc_hidden: 512 +dim_movement_dec_hidden: 512 +dim_movement_latent: 512 +dim_word: 300 +dim_pos_ohot: 15 +dim_motion_hidden: 1024 +dim_coemb_hidden: 512 \ No newline at end of file diff --git a/datasets/__init__.py b/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..bba1613ad0e958b2fa110b8f610ecab5892079ae --- /dev/null +++ b/datasets/__init__.py @@ -0,0 +1,22 @@ + +from .t2m_dataset import HumanML3D,KIT + +from os.path import join as pjoin +__all__ = [ + 'HumanML3D', 'KIT', 'get_dataset',] + +def get_dataset(opt, split='train', mode='train', accelerator=None): + if opt.dataset_name == 't2m' : + dataset = HumanML3D(opt, split, mode, accelerator) + elif opt.dataset_name == 'kit' : + dataset = KIT(opt,split, mode, accelerator) + else: + raise KeyError('Dataset Does Not Exist') + + if accelerator: + accelerator.print('Completing loading %s dataset' % opt.dataset_name) + else: + print('Completing loading %s dataset' % opt.dataset_name) + + return dataset + diff --git a/datasets/t2m_dataset.py b/datasets/t2m_dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce74097e879df495a785707c6addcc728ad049d --- /dev/null +++ b/datasets/t2m_dataset.py @@ -0,0 +1,304 @@ +import torch +from torch.utils import data +import numpy as np +from os.path import join as pjoin +import random +import codecs as cs +from tqdm.auto import tqdm +from utils.word_vectorizer import WordVectorizer, POS_enumerator +from utils.motion_process import recover_from_ric + + +class Text2MotionDataset(data.Dataset): + """ + Dataset for Text2Motion generation task. + """ + + data_root = "" + min_motion_len = 40 + joints_num = None + dim_pose = None + max_motion_length = 196 + + def __init__(self, opt, split, mode="train", accelerator=None): + self.max_text_len = getattr(opt, "max_text_len", 20) + self.unit_length = getattr(opt, "unit_length", 4) + self.mode = mode + motion_dir = pjoin(self.data_root, "new_joint_vecs") + text_dir = pjoin(self.data_root, "texts") + + if mode not in ["train", "eval", "gt_eval", "xyz_gt", "hml_gt"]: + raise ValueError( + f"Mode '{mode}' is not supported. Please use one of: 'train', 'eval', 'gt_eval', 'xyz_gt','hml_gt'." + ) + + mean, std = None, None + if mode == "gt_eval": + print(pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_std.npy")) + # used by T2M models (including evaluators) + mean = np.load(pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_mean.npy")) + std = np.load(pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_std.npy")) + elif mode in ["eval"]: + print(pjoin(opt.meta_dir, "std.npy")) + # used by our models during inference + mean = np.load(pjoin(opt.meta_dir, "mean.npy")) + std = np.load(pjoin(opt.meta_dir, "std.npy")) + else: + # used by our models during train + mean = np.load(pjoin(self.data_root, "Mean.npy")) + std = np.load(pjoin(self.data_root, "Std.npy")) + + if mode == "eval": + # used by T2M models (including evaluators) + # this is to translate ours norms to theirs + self.mean_for_eval = np.load( + pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_mean.npy") + ) + self.std_for_eval = np.load( + pjoin(opt.eval_meta_dir, f"{opt.dataset_name}_std.npy") + ) + if mode in ["gt_eval", "eval"]: + self.w_vectorizer = WordVectorizer(opt.glove_dir, "our_vab") + + data_dict = {} + id_list = [] + split_file = pjoin(self.data_root, f"{split}.txt") + with cs.open(split_file, "r") as f: + for line in f.readlines(): + id_list.append(line.strip()) + + if opt.debug == True: + id_list = id_list[:1000] + + new_name_list = [] + length_list = [] + for name in tqdm( + id_list, + disable=( + not accelerator.is_local_main_process + if accelerator is not None + else False + ), + ): + motion = np.load(pjoin(motion_dir, name + ".npy")) + if (len(motion)) < self.min_motion_len or (len(motion) >= 200): + continue + text_data = [] + flag = False + with cs.open(pjoin(text_dir, name + ".txt")) as f: + for line in f.readlines(): + text_dict = {} + line_split = line.strip().split("#") + caption = line_split[0] + try: + tokens = line_split[1].split(" ") + f_tag = float(line_split[2]) + to_tag = float(line_split[3]) + f_tag = 0.0 if np.isnan(f_tag) else f_tag + to_tag = 0.0 if np.isnan(to_tag) else to_tag + except: + tokens = ["a/NUM", "a/NUM"] + f_tag = 0.0 + to_tag = 8.0 + text_dict["caption"] = caption + text_dict["tokens"] = tokens + if f_tag == 0.0 and to_tag == 0.0: + flag = True + text_data.append(text_dict) + else: + n_motion = motion[int(f_tag * 20) : int(to_tag * 20)] + if (len(n_motion)) < self.min_motion_len or ( + len(n_motion) >= 200 + ): + continue + new_name = random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name + while new_name in data_dict: + new_name = ( + random.choice("ABCDEFGHIJKLMNOPQRSTUVW") + "_" + name + ) + data_dict[new_name] = { + "motion": n_motion, + "length": len(n_motion), + "text": [text_dict], + } + new_name_list.append(new_name) + length_list.append(len(n_motion)) + if flag: + data_dict[name] = { + "motion": motion, + "length": len(motion), + "text": text_data, + } + new_name_list.append(name) + length_list.append(len(motion)) + + name_list, length_list = zip( + *sorted(zip(new_name_list, length_list), key=lambda x: x[1]) + ) + + if mode == "train": + if opt.dataset_name != "amass": + joints_num = self.joints_num + # root_rot_velocity (B, seq_len, 1) + std[0:1] = std[0:1] / opt.feat_bias + # root_linear_velocity (B, seq_len, 2) + std[1:3] = std[1:3] / opt.feat_bias + # root_y (B, seq_len, 1) + std[3:4] = std[3:4] / opt.feat_bias + # ric_data (B, seq_len, (joint_num - 1)*3) + std[4 : 4 + (joints_num - 1) * 3] = ( + std[4 : 4 + (joints_num - 1) * 3] / 1.0 + ) + # rot_data (B, seq_len, (joint_num - 1)*6) + std[4 + (joints_num - 1) * 3 : 4 + (joints_num - 1) * 9] = ( + std[4 + (joints_num - 1) * 3 : 4 + (joints_num - 1) * 9] / 1.0 + ) + # local_velocity (B, seq_len, joint_num*3) + std[ + 4 + (joints_num - 1) * 9 : 4 + (joints_num - 1) * 9 + joints_num * 3 + ] = ( + std[ + 4 + + (joints_num - 1) * 9 : 4 + + (joints_num - 1) * 9 + + joints_num * 3 + ] + / 1.0 + ) + # foot contact (B, seq_len, 4) + std[4 + (joints_num - 1) * 9 + joints_num * 3 :] = ( + std[4 + (joints_num - 1) * 9 + joints_num * 3 :] / opt.feat_bias + ) + + assert 4 + (joints_num - 1) * 9 + joints_num * 3 + 4 == mean.shape[-1] + + if accelerator is not None and accelerator.is_main_process: + np.save(pjoin(opt.meta_dir, "mean.npy"), mean) + np.save(pjoin(opt.meta_dir, "std.npy"), std) + + self.mean = mean + self.std = std + self.data_dict = data_dict + self.name_list = name_list + + def inv_transform(self, data): + return data * self.std + self.mean + + def __len__(self): + return len(self.data_dict) + + def __getitem__(self, idx): + data = self.data_dict[self.name_list[idx]] + motion, m_length, text_list = data["motion"], data["length"], data["text"] + + # Randomly select a caption + text_data = random.choice(text_list) + caption = text_data["caption"] + + "Z Normalization" + if self.mode not in ["xyz_gt", "hml_gt"]: + motion = (motion - self.mean) / self.std + + "crop motion" + if self.mode in ["eval", "gt_eval"]: + # Crop the motions in to times of 4, and introduce small variations + if self.unit_length < 10: + coin2 = np.random.choice(["single", "single", "double"]) + else: + coin2 = "single" + if coin2 == "double": + m_length = (m_length // self.unit_length - 1) * self.unit_length + elif coin2 == "single": + m_length = (m_length // self.unit_length) * self.unit_length + idx = random.randint(0, len(motion) - m_length) + motion = motion[idx : idx + m_length] + elif m_length >= self.max_motion_length: + idx = random.randint(0, len(motion) - self.max_motion_length) + motion = motion[idx : idx + self.max_motion_length] + m_length = self.max_motion_length + + "pad motion" + if m_length < self.max_motion_length: + motion = np.concatenate( + [ + motion, + np.zeros((self.max_motion_length - m_length, motion.shape[1])), + ], + axis=0, + ) + assert len(motion) == self.max_motion_length + + if self.mode in ["gt_eval", "eval"]: + "word embedding for text-to-motion evaluation" + tokens = text_data["tokens"] + if len(tokens) < self.max_text_len: + # pad with "unk" + tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"] + sent_len = len(tokens) + tokens = tokens + ["unk/OTHER"] * (self.max_text_len + 2 - sent_len) + else: + # crop + tokens = tokens[: self.max_text_len] + tokens = ["sos/OTHER"] + tokens + ["eos/OTHER"] + sent_len = len(tokens) + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + return ( + word_embeddings, + pos_one_hots, + caption, + sent_len, + motion, + m_length, + "_".join(tokens), + ) + elif self.mode in ["xyz_gt"]: + "Convert motion hml representation to skeleton points xyz" + # 1. Use kn to get the keypoints position (the padding position after kn is all zero) + motion = torch.from_numpy(motion).float() + pred_joints = recover_from_ric( + motion, self.joints_num + ) # (nframe, njoints, 3) + + # 2. Put on Floor (Y axis) + floor_height = pred_joints.min(dim=0)[0].min(dim=0)[0][1] + pred_joints[:, :, 1] -= floor_height + return pred_joints + + return caption, motion, m_length + + +class HumanML3D(Text2MotionDataset): + def __init__(self, opt, split="train", mode="train", accelerator=None): + self.data_root = "./data/HumanML3D" + self.min_motion_len = 40 + self.joints_num = 22 + self.dim_pose = 263 + self.max_motion_length = 196 + if accelerator: + accelerator.print( + "\n Loading %s mode HumanML3D %s dataset ..." % (mode, split) + ) + else: + print("\n Loading %s mode HumanML3D dataset ..." % mode) + super(HumanML3D, self).__init__(opt, split, mode, accelerator) + + +class KIT(Text2MotionDataset): + def __init__(self, opt, split="train", mode="train", accelerator=None): + self.data_root = "./data/KIT-ML" + self.min_motion_len = 24 + self.joints_num = 21 + self.dim_pose = 251 + self.max_motion_length = 196 + if accelerator: + accelerator.print("\n Loading %s mode KIT %s dataset ..." % (mode, split)) + else: + print("\n Loading %s mode KIT dataset ..." % mode) + super(KIT, self).__init__(opt, split, mode, accelerator) diff --git a/eval/__init__.py b/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..20beab320965ac72a1ccec12ae94bef4b87eb8f9 --- /dev/null +++ b/eval/__init__.py @@ -0,0 +1,2 @@ +from .evaluator_wrapper import EvaluatorModelWrapper +from .eval_t2m import evaluation \ No newline at end of file diff --git a/eval/eval_t2m.py b/eval/eval_t2m.py new file mode 100644 index 0000000000000000000000000000000000000000..0929d929697df084bc7ad6b34e66a50bfe750cb7 --- /dev/null +++ b/eval/eval_t2m.py @@ -0,0 +1,222 @@ +# This file code from T2M(https://github.com/EricGuo5513/text-to-motion), licensed under the https://github.com/EricGuo5513/text-to-motion/blob/main/LICENSE. +# Copyright (c) 2022 Chuan Guo +from datetime import datetime +import numpy as np +import torch +from utils.metrics import * +from collections import OrderedDict + + +def evaluate_matching_score(eval_wrapper,motion_loaders, file): + match_score_dict = OrderedDict({}) + R_precision_dict = OrderedDict({}) + activation_dict = OrderedDict({}) + # print(motion_loaders.keys()) + print('========== Evaluating Matching Score ==========') + for motion_loader_name, motion_loader in motion_loaders.items(): + all_motion_embeddings = [] + score_list = [] + all_size = 0 + matching_score_sum = 0 + top_k_count = 0 + # print(motion_loader_name) + with torch.no_grad(): + for idx, batch in enumerate(motion_loader): + word_embeddings, pos_one_hots, _, sent_lens, motions, m_lens, _ = batch + text_embeddings, motion_embeddings = eval_wrapper.get_co_embeddings( + word_embs=word_embeddings, + pos_ohot=pos_one_hots, + cap_lens=sent_lens, + motions=motions, + m_lens=m_lens + ) + dist_mat = euclidean_distance_matrix(text_embeddings.cpu().numpy(), + motion_embeddings.cpu().numpy()) + matching_score_sum += dist_mat.trace() + # import pdb;pdb.set_trace() + + argsmax = np.argsort(dist_mat, axis=1) + top_k_mat = calculate_top_k(argsmax, top_k=3) + top_k_count += top_k_mat.sum(axis=0) + + all_size += text_embeddings.shape[0] + + all_motion_embeddings.append(motion_embeddings.cpu().numpy()) + + all_motion_embeddings = np.concatenate(all_motion_embeddings, axis=0) + # import pdb;pdb.set_trace() + matching_score = matching_score_sum / all_size + R_precision = top_k_count / all_size + match_score_dict[motion_loader_name] = matching_score + R_precision_dict[motion_loader_name] = R_precision + activation_dict[motion_loader_name] = all_motion_embeddings + + print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}') + print(f'---> [{motion_loader_name}] Matching Score: {matching_score:.4f}', file=file, flush=True) + + line = f'---> [{motion_loader_name}] R_precision: ' + for i in range(len(R_precision)): + line += '(top %d): %.4f ' % (i+1, R_precision[i]) + print(line) + print(line, file=file, flush=True) + + return match_score_dict, R_precision_dict, activation_dict + + +def evaluate_fid(eval_wrapper,groundtruth_loader, activation_dict, file): + eval_dict = OrderedDict({}) + gt_motion_embeddings = [] + print('========== Evaluating FID ==========') + with torch.no_grad(): + for idx, batch in enumerate(groundtruth_loader): + _, _, _, sent_lens, motions, m_lens, _ = batch + motion_embeddings = eval_wrapper.get_motion_embeddings( + motions=motions, + m_lens=m_lens + ) + gt_motion_embeddings.append(motion_embeddings.cpu().numpy()) + gt_motion_embeddings = np.concatenate(gt_motion_embeddings, axis=0) + gt_mu, gt_cov = calculate_activation_statistics(gt_motion_embeddings) + + for model_name, motion_embeddings in activation_dict.items(): + mu, cov = calculate_activation_statistics(motion_embeddings) + # print(mu) + fid = calculate_frechet_distance(gt_mu, gt_cov, mu, cov) + print(f'---> [{model_name}] FID: {fid:.4f}') + print(f'---> [{model_name}] FID: {fid:.4f}', file=file, flush=True) + eval_dict[model_name] = fid + return eval_dict + + +def evaluate_diversity(activation_dict, file, diversity_times): + eval_dict = OrderedDict({}) + print('========== Evaluating Diversity ==========') + for model_name, motion_embeddings in activation_dict.items(): + diversity = calculate_diversity(motion_embeddings, diversity_times) + eval_dict[model_name] = diversity + print(f'---> [{model_name}] Diversity: {diversity:.4f}') + print(f'---> [{model_name}] Diversity: {diversity:.4f}', file=file, flush=True) + return eval_dict + + +def evaluate_multimodality(eval_wrapper, mm_motion_loaders, file, mm_num_times): + eval_dict = OrderedDict({}) + print('========== Evaluating MultiModality ==========') + for model_name, mm_motion_loader in mm_motion_loaders.items(): + mm_motion_embeddings = [] + with torch.no_grad(): + for idx, batch in enumerate(mm_motion_loader): + # (1, mm_replications, dim_pos) + motions, m_lens = batch + motion_embedings = eval_wrapper.get_motion_embeddings(motions[0], m_lens[0]) + mm_motion_embeddings.append(motion_embedings.unsqueeze(0)) + if len(mm_motion_embeddings) == 0: + multimodality = 0 + else: + mm_motion_embeddings = torch.cat(mm_motion_embeddings, dim=0).cpu().numpy() + multimodality = calculate_multimodality(mm_motion_embeddings, mm_num_times) + print(f'---> [{model_name}] Multimodality: {multimodality:.4f}') + print(f'---> [{model_name}] Multimodality: {multimodality:.4f}', file=file, flush=True) + eval_dict[model_name] = multimodality + return eval_dict + + +def get_metric_statistics(values, replication_times): + mean = np.mean(values, axis=0) + std = np.std(values, axis=0) + conf_interval = 1.96 * std / np.sqrt(replication_times) + return mean, conf_interval + + +def evaluation(eval_wrapper, gt_loader, eval_motion_loaders, log_file, replication_times, diversity_times, mm_num_times, run_mm=False): + with open(log_file, 'a') as f: + all_metrics = OrderedDict({'Matching Score': OrderedDict({}), + 'R_precision': OrderedDict({}), + 'FID': OrderedDict({}), + 'Diversity': OrderedDict({}), + 'MultiModality': OrderedDict({})}) + + for replication in range(replication_times): + print(f'Time: {datetime.now()}') + print(f'Time: {datetime.now()}', file=f, flush=True) + motion_loaders = {} + motion_loaders['ground truth'] = gt_loader + mm_motion_loaders = {} + # motion_loaders['ground truth'] = gt_loader + for motion_loader_name, motion_loader_getter in eval_motion_loaders.items(): + motion_loader, mm_motion_loader,eval_generate_time = motion_loader_getter() + print(f'---> [{motion_loader_name}] batch_generate_time: {eval_generate_time}s', file=f, flush=True) + motion_loaders[motion_loader_name] = motion_loader + mm_motion_loaders[motion_loader_name] = mm_motion_loader + + if replication_times>1: + print(f'==================== Replication {replication} ====================') + print(f'==================== Replication {replication} ====================', file=f, flush=True) + + + mat_score_dict, R_precision_dict, acti_dict = evaluate_matching_score(eval_wrapper, motion_loaders, f) + + fid_score_dict = evaluate_fid(eval_wrapper, gt_loader, acti_dict, f) + + div_score_dict = evaluate_diversity(acti_dict, f, diversity_times) + + if run_mm: + mm_score_dict = evaluate_multimodality(eval_wrapper, mm_motion_loaders, f, mm_num_times) + + print(f'!!! DONE !!!') + print(f'!!! DONE !!!', file=f, flush=True) + + for key, item in mat_score_dict.items(): + if key not in all_metrics['Matching Score']: + all_metrics['Matching Score'][key] = [item] + else: + all_metrics['Matching Score'][key] += [item] + + for key, item in R_precision_dict.items(): + if key not in all_metrics['R_precision']: + all_metrics['R_precision'][key] = [item] + else: + all_metrics['R_precision'][key] += [item] + + for key, item in fid_score_dict.items(): + if key not in all_metrics['FID']: + all_metrics['FID'][key] = [item] + else: + all_metrics['FID'][key] += [item] + + for key, item in div_score_dict.items(): + if key not in all_metrics['Diversity']: + all_metrics['Diversity'][key] = [item] + else: + all_metrics['Diversity'][key] += [item] + + for key, item in mm_score_dict.items(): + if key not in all_metrics['MultiModality']: + all_metrics['MultiModality'][key] = [item] + else: + all_metrics['MultiModality'][key] += [item] + + + mean_dict = {} + if replication_times>1: + for metric_name, metric_dict in all_metrics.items(): + print('========== %s Summary ==========' % metric_name) + print('========== %s Summary ==========' % metric_name, file=f, flush=True) + + for model_name, values in metric_dict.items(): + # print(metric_name, model_name) + mean, conf_interval = get_metric_statistics(np.array(values),replication_times) + mean_dict[metric_name + '_' + model_name] = mean + # print(mean, mean.dtype) + if isinstance(mean, np.float64) or isinstance(mean, np.float32): + print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}') + print(f'---> [{model_name}] Mean: {mean:.4f} CInterval: {conf_interval:.4f}', file=f, flush=True) + elif isinstance(mean, np.ndarray): + line = f'---> [{model_name}]' + for i in range(len(mean)): + line += '(top %d) Mean: %.4f CInt: %.4f;' % (i+1, mean[i], conf_interval[i]) + print(line) + print(line, file=f, flush=True) + return mean_dict + else: + return all_metrics diff --git a/eval/evaluator_modules.py b/eval/evaluator_modules.py new file mode 100644 index 0000000000000000000000000000000000000000..b553599941e4f7e50471f28b81581fc9291ec09e --- /dev/null +++ b/eval/evaluator_modules.py @@ -0,0 +1,436 @@ +# This file code from T2M(https://github.com/EricGuo5513/text-to-motion), licensed under the https://github.com/EricGuo5513/text-to-motion/blob/main/LICENSE. +# Copyright (c) 2022 Chuan Guo +import torch +import torch.nn as nn +import numpy as np +import time +import math +from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence +import torch.nn.functional as F + + +class ContrastiveLoss(torch.nn.Module): + """ + Contrastive loss function. + Based on: http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf + """ + def __init__(self, margin=3.0): + super(ContrastiveLoss, self).__init__() + self.margin = margin + + def forward(self, output1, output2, label): + euclidean_distance = F.pairwise_distance(output1, output2, keepdim=True) + loss_contrastive = torch.mean((1-label) * torch.pow(euclidean_distance, 2) + + (label) * torch.pow(torch.clamp(self.margin - euclidean_distance, min=0.0), 2)) + return loss_contrastive + + +def init_weight(m): + if isinstance(m, nn.Conv1d) or isinstance(m, nn.Linear) or isinstance(m, nn.ConvTranspose1d): + nn.init.xavier_normal_(m.weight) + # m.bias.data.fill_(0.01) + if m.bias is not None: + nn.init.constant_(m.bias, 0) + + +def reparameterize(mu, logvar): + s_var = logvar.mul(0.5).exp_() + eps = s_var.data.new(s_var.size()).normal_() + return eps.mul(s_var).add_(mu) + + +# batch_size, dimension and position +# output: (batch_size, dim) +def positional_encoding(batch_size, dim, pos): + assert batch_size == pos.shape[0] + positions_enc = np.array([ + [pos[j] / np.power(10000, (i-i%2)/dim) for i in range(dim)] + for j in range(batch_size) + ], dtype=np.float32) + positions_enc[:, 0::2] = np.sin(positions_enc[:, 0::2]) + positions_enc[:, 1::2] = np.cos(positions_enc[:, 1::2]) + return torch.from_numpy(positions_enc).float() + + +def get_padding_mask(batch_size, seq_len, cap_lens): + cap_lens = cap_lens.data.tolist() + mask_2d = torch.ones((batch_size, seq_len, seq_len), dtype=torch.float32) + for i, cap_len in enumerate(cap_lens): + mask_2d[i, :, :cap_len] = 0 + return mask_2d.bool(), 1 - mask_2d[:, :, 0].clone() + + +class PositionalEncoding(nn.Module): + + def __init__(self, d_model, max_len=300): + super(PositionalEncoding, self).__init__() + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model)) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + # pe = pe.unsqueeze(0).transpose(0, 1) + self.register_buffer('pe', pe) + + def forward(self, pos): + return self.pe[pos] + + +class MovementConvEncoder(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MovementConvEncoder, self).__init__() + self.main = nn.Sequential( + nn.Conv1d(input_size, hidden_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.Conv1d(hidden_size, output_size, 4, 2, 1), + nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ) + self.out_net = nn.Linear(output_size, output_size) + self.main.apply(init_weight) + self.out_net.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + # print(outputs.shape) + return self.out_net(outputs) + + +class MovementConvDecoder(nn.Module): + def __init__(self, input_size, hidden_size, output_size): + super(MovementConvDecoder, self).__init__() + self.main = nn.Sequential( + nn.ConvTranspose1d(input_size, hidden_size, 4, 2, 1), + # nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + nn.ConvTranspose1d(hidden_size, output_size, 4, 2, 1), + # nn.Dropout(0.2, inplace=True), + nn.LeakyReLU(0.2, inplace=True), + ) + self.out_net = nn.Linear(output_size, output_size) + + self.main.apply(init_weight) + self.out_net.apply(init_weight) + + def forward(self, inputs): + inputs = inputs.permute(0, 2, 1) + outputs = self.main(inputs).permute(0, 2, 1) + return self.out_net(outputs) + + +class TextVAEDecoder(nn.Module): + def __init__(self, text_size, input_size, output_size, hidden_size, n_layers): + super(TextVAEDecoder, self).__init__() + self.input_size = input_size + self.output_size = output_size + self.hidden_size = hidden_size + self.n_layers = n_layers + self.emb = nn.Sequential( + nn.Linear(input_size, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True)) + + self.z2init = nn.Linear(text_size, hidden_size * n_layers) + self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)]) + self.positional_encoder = PositionalEncoding(hidden_size) + + + self.output = nn.Sequential( + nn.Linear(hidden_size, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size) + ) + + # + # self.output = nn.Sequential( + # nn.Linear(hidden_size, hidden_size), + # nn.LayerNorm(hidden_size), + # nn.LeakyReLU(0.2, inplace=True), + # nn.Linear(hidden_size, output_size-4) + # ) + + # self.contact_net = nn.Sequential( + # nn.Linear(output_size-4, 64), + # nn.LayerNorm(64), + # nn.LeakyReLU(0.2, inplace=True), + # nn.Linear(64, 4) + # ) + + self.output.apply(init_weight) + self.emb.apply(init_weight) + self.z2init.apply(init_weight) + # self.contact_net.apply(init_weight) + + def get_init_hidden(self, latent): + hidden = self.z2init(latent) + hidden = torch.split(hidden, self.hidden_size, dim=-1) + return list(hidden) + + def forward(self, inputs, last_pred, hidden, p): + h_in = self.emb(inputs) + pos_enc = self.positional_encoder(p).to(inputs.device).detach() + h_in = h_in + pos_enc + for i in range(self.n_layers): + # print(h_in.shape) + hidden[i] = self.gru[i](h_in, hidden[i]) + h_in = hidden[i] + pose_pred = self.output(h_in) + # pose_pred = self.output(h_in) + last_pred.detach() + # contact = self.contact_net(pose_pred) + # return torch.cat([pose_pred, contact], dim=-1), hidden + return pose_pred, hidden + + +class TextDecoder(nn.Module): + def __init__(self, text_size, input_size, output_size, hidden_size, n_layers): + super(TextDecoder, self).__init__() + self.input_size = input_size + self.output_size = output_size + self.hidden_size = hidden_size + self.n_layers = n_layers + self.emb = nn.Sequential( + nn.Linear(input_size, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True)) + + self.gru = nn.ModuleList([nn.GRUCell(hidden_size, hidden_size) for i in range(self.n_layers)]) + self.z2init = nn.Linear(text_size, hidden_size * n_layers) + self.positional_encoder = PositionalEncoding(hidden_size) + + self.mu_net = nn.Linear(hidden_size, output_size) + self.logvar_net = nn.Linear(hidden_size, output_size) + + self.emb.apply(init_weight) + self.z2init.apply(init_weight) + self.mu_net.apply(init_weight) + self.logvar_net.apply(init_weight) + + def get_init_hidden(self, latent): + + hidden = self.z2init(latent) + hidden = torch.split(hidden, self.hidden_size, dim=-1) + + return list(hidden) + + def forward(self, inputs, hidden, p): + # print(inputs.shape) + x_in = self.emb(inputs) + pos_enc = self.positional_encoder(p).to(inputs.device).detach() + x_in = x_in + pos_enc + + for i in range(self.n_layers): + hidden[i] = self.gru[i](x_in, hidden[i]) + h_in = hidden[i] + mu = self.mu_net(h_in) + logvar = self.logvar_net(h_in) + z = reparameterize(mu, logvar) + return z, mu, logvar, hidden + +class AttLayer(nn.Module): + def __init__(self, query_dim, key_dim, value_dim): + super(AttLayer, self).__init__() + self.W_q = nn.Linear(query_dim, value_dim) + self.W_k = nn.Linear(key_dim, value_dim, bias=False) + self.W_v = nn.Linear(key_dim, value_dim) + + self.softmax = nn.Softmax(dim=1) + self.dim = value_dim + + self.W_q.apply(init_weight) + self.W_k.apply(init_weight) + self.W_v.apply(init_weight) + + def forward(self, query, key_mat): + ''' + query (batch, query_dim) + key (batch, seq_len, key_dim) + ''' + # print(query.shape) + query_vec = self.W_q(query).unsqueeze(-1) # (batch, value_dim, 1) + val_set = self.W_v(key_mat) # (batch, seq_len, value_dim) + key_set = self.W_k(key_mat) # (batch, seq_len, value_dim) + + weights = torch.matmul(key_set, query_vec) / np.sqrt(self.dim) + + co_weights = self.softmax(weights) # (batch, seq_len, 1) + values = val_set * co_weights # (batch, seq_len, value_dim) + pred = values.sum(dim=1) # (batch, value_dim) + return pred, co_weights + + def short_cut(self, querys, keys): + return self.W_q(querys), self.W_k(keys) + + +class TextEncoderBiGRU(nn.Module): + def __init__(self, word_size, pos_size, hidden_size, device): + super(TextEncoderBiGRU, self).__init__() + self.device = device + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + # self.linear2 = nn.Linear(hidden_size, output_size) + + self.input_emb.apply(init_weight) + self.pos_emb.apply(init_weight) + # self.linear2.apply(init_weight) + # self.batch_size = batch_size + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, word_embs, pos_onehot, cap_lens): + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + gru_seq = pad_packed_sequence(gru_seq, batch_first=True)[0] + forward_seq = gru_seq[..., :self.hidden_size] + backward_seq = gru_seq[..., self.hidden_size:].clone() + + # Concate the forward and backward word embeddings + for i, length in enumerate(cap_lens): + backward_seq[i:i+1, :length] = torch.flip(backward_seq[i:i+1, :length].clone(), dims=[1]) + gru_seq = torch.cat([forward_seq, backward_seq], dim=-1) + + return gru_seq, gru_last + + +class TextEncoderBiGRUCo(nn.Module): + def __init__(self, word_size, pos_size, hidden_size, output_size, device): + super(TextEncoderBiGRUCo, self).__init__() + self.device = device + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_net = nn.Sequential( + nn.Linear(hidden_size * 2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size) + ) + + self.input_emb.apply(init_weight) + self.pos_emb.apply(init_weight) + self.output_net.apply(init_weight) + # self.linear2.apply(init_weight) + # self.batch_size = batch_size + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, word_embs, pos_onehot, cap_lens): + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) + + +class MotionEncoderBiGRUCo(nn.Module): + def __init__(self, input_size, hidden_size, output_size, device): + super(MotionEncoderBiGRUCo, self).__init__() + self.device = device + + self.input_emb = nn.Linear(input_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + self.output_net = nn.Sequential( + nn.Linear(hidden_size*2, hidden_size), + nn.LayerNorm(hidden_size), + nn.LeakyReLU(0.2, inplace=True), + nn.Linear(hidden_size, output_size) + ) + + self.input_emb.apply(init_weight) + self.output_net.apply(init_weight) + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, inputs, m_lens): + num_samples = inputs.shape[0] + + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = m_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output_net(gru_last) + + +class MotionLenEstimatorBiGRU(nn.Module): + def __init__(self, word_size, pos_size, hidden_size, output_size): + super(MotionLenEstimatorBiGRU, self).__init__() + + self.pos_emb = nn.Linear(pos_size, word_size) + self.input_emb = nn.Linear(word_size, hidden_size) + self.gru = nn.GRU(hidden_size, hidden_size, batch_first=True, bidirectional=True) + nd = 512 + self.output = nn.Sequential( + nn.Linear(hidden_size*2, nd), + nn.LayerNorm(nd), + nn.LeakyReLU(0.2, inplace=True), + + nn.Linear(nd, nd // 2), + nn.LayerNorm(nd // 2), + nn.LeakyReLU(0.2, inplace=True), + + nn.Linear(nd // 2, nd // 4), + nn.LayerNorm(nd // 4), + nn.LeakyReLU(0.2, inplace=True), + + nn.Linear(nd // 4, output_size) + ) + + self.input_emb.apply(init_weight) + self.pos_emb.apply(init_weight) + self.output.apply(init_weight) + self.hidden_size = hidden_size + self.hidden = nn.Parameter(torch.randn((2, 1, self.hidden_size), requires_grad=True)) + + # input(batch_size, seq_len, dim) + def forward(self, word_embs, pos_onehot, cap_lens): + num_samples = word_embs.shape[0] + + pos_embs = self.pos_emb(pos_onehot) + inputs = word_embs + pos_embs + input_embs = self.input_emb(inputs) + hidden = self.hidden.repeat(1, num_samples, 1) + + cap_lens = cap_lens.data.tolist() + emb = pack_padded_sequence(input_embs, cap_lens, batch_first=True) + + gru_seq, gru_last = self.gru(emb, hidden) + + gru_last = torch.cat([gru_last[0], gru_last[1]], dim=-1) + + return self.output(gru_last) diff --git a/eval/evaluator_wrapper.py b/eval/evaluator_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..25d63d5dd16e1d55de62d5a52074483dd8c406b4 --- /dev/null +++ b/eval/evaluator_wrapper.py @@ -0,0 +1,78 @@ +# This file code from T2M(https://github.com/EricGuo5513/text-to-motion), licensed under the https://github.com/EricGuo5513/text-to-motion/blob/main/LICENSE. +# Copyright (c) 2022 Chuan Guo +import torch +from os.path import join as pjoin +import numpy as np +from .evaluator_modules import * + + +def build_models(opt): + movement_enc = MovementConvEncoder(opt.dim_pose-4, opt.dim_movement_enc_hidden, opt.dim_movement_latent) + text_enc = TextEncoderBiGRUCo(word_size=opt.dim_word, + pos_size=opt.dim_pos_ohot, + hidden_size=opt.dim_text_hidden, + output_size=opt.dim_coemb_hidden, + device=opt.device) + + motion_enc = MotionEncoderBiGRUCo(input_size=opt.dim_movement_latent, + hidden_size=opt.dim_motion_hidden, + output_size=opt.dim_coemb_hidden, + device=opt.device) + checkpoint = torch.load(pjoin(opt.evaluator_dir, opt.dataset_name, 'text_mot_match', 'model', 'finest.tar'), + map_location=opt.device) + movement_enc.load_state_dict(checkpoint['movement_encoder']) + text_enc.load_state_dict(checkpoint['text_encoder']) + motion_enc.load_state_dict(checkpoint['motion_encoder']) + print('\nLoading Evaluation Model Wrapper (Epoch %d) Completed!!' % (checkpoint['epoch'])) + return text_enc, motion_enc, movement_enc + +class EvaluatorModelWrapper(object): + + def __init__(self, opt): + self.text_encoder, self.motion_encoder, self.movement_encoder = build_models(opt) + self.opt = opt + self.device = opt.device + + self.text_encoder.to(opt.device) + self.motion_encoder.to(opt.device) + self.movement_encoder.to(opt.device) + + self.text_encoder.eval() + self.motion_encoder.eval() + self.movement_encoder.eval() + + # Please note that the results does not following the order of inputs + def get_co_embeddings(self, word_embs, pos_ohot, cap_lens, motions, m_lens): + with torch.no_grad(): + word_embs = word_embs.detach().to(self.device).float() + pos_ohot = pos_ohot.detach().to(self.device).float() + motions = motions.detach().to(self.device).float() + + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + m_lens = m_lens[align_idx] + + '''Movement Encoding''' + movements = self.movement_encoder(motions[..., :-4]).detach() + m_lens = torch.div(m_lens, self.opt.unit_length, rounding_mode='trunc') + motion_embedding = self.motion_encoder(movements, m_lens) + + '''Text Encoding''' + text_embedding = self.text_encoder(word_embs, pos_ohot, cap_lens) + text_embedding = text_embedding[align_idx] + return text_embedding, motion_embedding + + # Please note that the results does not following the order of inputs + def get_motion_embeddings(self, motions, m_lens): + with torch.no_grad(): + motions = motions.detach().to(self.device).float() + + align_idx = np.argsort(m_lens.data.tolist())[::-1].copy() + motions = motions[align_idx] + m_lens = m_lens[align_idx] + + '''Movement Encoding''' + movements = self.movement_encoder(motions[..., :-4]).detach() + m_lens = torch.div(m_lens, self.opt.unit_length, rounding_mode='trunc') + motion_embedding = self.motion_encoder(movements, m_lens) + return motion_embedding diff --git a/models/__init__.py b/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..64b1e5a31ece560a43a80ce4ea5850dc1749a0e9 --- /dev/null +++ b/models/__init__.py @@ -0,0 +1,26 @@ +from .unet import MotionCLR + + +__all__ = ["MotionCLR"] + + +def build_models(opt, edit_config=None, out_path=None): + print("\nInitializing model ...") + model = MotionCLR( + input_feats=opt.dim_pose, + text_latent_dim=opt.text_latent_dim, + base_dim=opt.base_dim, + dim_mults=opt.dim_mults, + time_dim=opt.time_dim, + adagn=not opt.no_adagn, + zero=True, + dropout=opt.dropout, + no_eff=opt.no_eff, + cond_mask_prob=getattr(opt, "cond_mask_prob", 0.0), + self_attention=opt.self_attention, + vis_attn=opt.vis_attn, + edit_config=edit_config, + out_path=out_path, + ) + + return model diff --git a/models/gaussian_diffusion.py b/models/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..34a2b001c4436cc0cf6182476728d612412285b8 --- /dev/null +++ b/models/gaussian_diffusion.py @@ -0,0 +1,134 @@ +from diffusers import ( + DPMSolverMultistepScheduler, + DDPMScheduler, + DDIMScheduler, + PNDMScheduler, + DEISMultistepScheduler, +) +import torch +import yaml +import math +import tqdm +import time + + +class DiffusePipeline(object): + + def __init__( + self, + opt, + model, + diffuser_name, + num_inference_steps, + device, + torch_dtype=torch.float16, + ): + self.device = device + self.torch_dtype = torch_dtype + self.diffuser_name = diffuser_name + self.num_inference_steps = num_inference_steps + if self.torch_dtype == torch.float16: + model = model.half() + self.model = model.to(device) + self.opt = opt + + # Load parameters from YAML file + with open("config/diffuser_params.yaml", "r") as yaml_file: + diffuser_params = yaml.safe_load(yaml_file) + + # Select diffusion'parameters based on diffuser_name + if diffuser_name in diffuser_params: + params = diffuser_params[diffuser_name] + scheduler_class_name = params["scheduler_class"] + additional_params = params["additional_params"] + + # align training parameters + additional_params["num_train_timesteps"] = opt.diffusion_steps + additional_params["beta_schedule"] = opt.beta_schedule + additional_params["prediction_type"] = opt.prediction_type + + try: + scheduler_class = globals()[scheduler_class_name] + except KeyError: + raise ValueError(f"Class '{scheduler_class_name}' not found.") + + self.scheduler = scheduler_class(**additional_params) + else: + raise ValueError(f"Unsupported diffuser_name: {diffuser_name}") + + def generate_batch(self, caption, m_lens): + B = len(caption) + T = m_lens.max() + shape = (B, T, self.model.input_feats) + + # random sampling noise x_T + sample = torch.randn(shape, device=self.device, dtype=self.torch_dtype) + + # set timesteps + self.scheduler.set_timesteps(self.num_inference_steps, self.device) + timesteps = [ + torch.tensor([t] * B, device=self.device).long() + for t in self.scheduler.timesteps + ] + + # cache text_embedded + enc_text = self.model.encode_text(caption, self.device) + + for i, t in enumerate(timesteps): + # 1. model predict + with torch.no_grad(): + if getattr(self.model, "cond_mask_prob", 0) > 0: + predict = self.model.forward_with_cfg(sample, t, enc_text=enc_text) + else: + + predict = self.model(sample, t, enc_text=enc_text) + + # 2. compute less noisy motion and set x_t -> x_t-1 + sample = self.scheduler.step(predict, t[0], sample).prev_sample + + return sample + + def generate(self, caption, m_lens, batch_size=32): + N = len(caption) + infer_mode = "" + if getattr(self.model, "cond_mask_prob", 0) > 0: + infer_mode = "classifier-free-guidance" + print( + f"\nUsing {self.diffuser_name} diffusion scheduler to {infer_mode} generate {N} motions, sampling {self.num_inference_steps} steps." + ) + self.model.eval() + + all_output = [] + t_sum = 0 + cur_idx = 0 + for bacth_idx in tqdm.tqdm(range(math.ceil(N / batch_size))): + if cur_idx + batch_size >= N: + batch_caption = caption[cur_idx:] + batch_m_lens = m_lens[cur_idx:] + else: + batch_caption = caption[cur_idx : cur_idx + batch_size] + batch_m_lens = m_lens[cur_idx : cur_idx + batch_size] + torch.cuda.synchronize() + start_time = time.time() + output = self.generate_batch(batch_caption, batch_m_lens) + torch.cuda.synchronize() + now_time = time.time() + + # The average inference time is calculated after GPU warm-up in the first 50 steps. + if (bacth_idx + 1) * self.num_inference_steps >= 50: + t_sum += now_time - start_time + + # Crop motion with gt/predicted motion length + B = output.shape[0] + for i in range(B): + all_output.append(output[i, : batch_m_lens[i]]) + + cur_idx += batch_size + + # calcalate average inference time + t_eval = t_sum / (bacth_idx - 1) + print( + "The average generation time of a batch motion (bs=%d) is %f seconds" + % (batch_size, t_eval) + ) + return all_output, t_eval diff --git a/models/unet.py b/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..0afc83f69c819eaf23263940cb7c4de5ab88352c --- /dev/null +++ b/models/unet.py @@ -0,0 +1,1073 @@ +import clip +import math + +import torch +import torch.nn.functional as F +from torch import nn +import numpy as np +from einops.layers.torch import Rearrange +from einops import rearrange +import matplotlib.pyplot as plt +import os + + +MONITOR_ATTN = [] +SELF_ATTN = [] + + +def vis_attn(att, out_path, step, layer, shape, type_="self", lines=True): + if lines: + plt.figure(figsize=(10, 3)) + for token_index in range(att.shape[1]): + plt.plot(att[:, token_index], label=f"Token {token_index}") + + plt.title("Attention Values for Each Token") + plt.xlabel("time") + plt.ylabel("Attention Value") + plt.legend(loc="upper right", bbox_to_anchor=(1.15, 1)) + + # save image + savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_lines_{shape}.png") + os.makedirs(os.path.dirname(savepath), exist_ok=True) + plt.savefig(savepath, bbox_inches="tight") + np.save(savepath.replace(".png", ".npy"), att) + else: + plt.figure(figsize=(10, 10)) + plt.imshow(att.transpose(), cmap="viridis", aspect="auto") + plt.colorbar() + plt.title("Attention Matrix Heatmap") + plt.ylabel("time") + plt.xlabel("time") + + # save image + savepath = os.path.join(out_path, f"vis-{type_}/step{str(step)}/layer{str(layer)}_heatmap_{shape}.png") + os.makedirs(os.path.dirname(savepath), exist_ok=True) + plt.savefig(savepath, bbox_inches="tight") + np.save(savepath.replace(".png", ".npy"), att) + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +class FFN(nn.Module): + + def __init__(self, latent_dim, ffn_dim, dropout): + super().__init__() + self.linear1 = nn.Linear(latent_dim, ffn_dim) + self.linear2 = zero_module(nn.Linear(ffn_dim, latent_dim)) + self.activation = nn.GELU() + self.dropout = nn.Dropout(dropout) + + def forward(self, x): + y = self.linear2(self.dropout(self.activation(self.linear1(x)))) + y = x + y + return y + + +class Conv1dAdaGNBlock(nn.Module): + """ + Conv1d --> GroupNorm --> scale,shift --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4): + super().__init__() + self.out_channels = out_channels + self.block = nn.Conv1d( + inp_channels, out_channels, kernel_size, padding=kernel_size // 2 + ) + self.group_norm = nn.GroupNorm(n_groups, out_channels) + self.avtication = nn.Mish() + + def forward(self, x, scale, shift): + """ + Args: + x: [bs, nfeat, nframes] + scale: [bs, out_feat, 1] + shift: [bs, out_feat, 1] + """ + x = self.block(x) + + batch_size, channels, horizon = x.size() + x = rearrange( + x, "batch channels horizon -> (batch horizon) channels" + ) # [bs*seq, nfeats] + x = self.group_norm(x) + x = rearrange( + x.reshape(batch_size, horizon, channels), + "batch horizon channels -> batch channels horizon", + ) + x = ada_shift_scale(x, shift, scale) + + return self.avtication(x) + + +class SelfAttention(nn.Module): + + def __init__( + self, + latent_dim, + text_latent_dim, + num_heads: int = 8, + dropout: float = 0.0, + log_attn=False, + edit_config=None, + ): + super().__init__() + self.num_head = num_heads + self.norm = nn.LayerNorm(latent_dim) + self.query = nn.Linear(latent_dim, latent_dim) + self.key = nn.Linear(latent_dim, latent_dim) + self.value = nn.Linear(latent_dim, latent_dim) + self.dropout = nn.Dropout(dropout) + + self.edit_config = edit_config + self.log_attn = log_attn + + def forward(self, x): + """ + x: B, T, D + xf: B, N, L + """ + B, T, D = x.shape + N = x.shape[1] + assert N == T + H = self.num_head + + # B, T, 1, D + query = self.query(self.norm(x)).unsqueeze(2) + # B, 1, N, D + key = self.key(self.norm(x)).unsqueeze(1) + query = query.view(B, T, H, -1) + key = key.view(B, N, H, -1) + + # style transfer motion editing + style_tranfer = self.edit_config.style_tranfer.use + if style_tranfer: + if ( + len(SELF_ATTN) + <= self.edit_config.style_tranfer.style_transfer_steps_end + ): + query[1] = query[0] + + # example based motion generation + example_based = self.edit_config.example_based.use + if example_based: + if len(SELF_ATTN) == self.edit_config.example_based.example_based_steps_end: + + temp_seed = self.edit_config.example_based.temp_seed + for id_ in range(query.shape[0] - 1): + with torch.random.fork_rng(): + torch.manual_seed(temp_seed) + tensor = query[0] + chunks = torch.split( + tensor, self.edit_config.example_based.chunk_size, dim=0 + ) + shuffled_indices = torch.randperm(len(chunks)) + shuffled_chunks = [chunks[i] for i in shuffled_indices] + shuffled_tensor = torch.cat(shuffled_chunks, dim=0) + query[1 + id_] = shuffled_tensor + temp_seed += self.edit_config.example_based.temp_seed_bar + + # time shift motion editing (q, k) + time_shift = self.edit_config.time_shift.use + if time_shift: + if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end: + part1 = int( + key.shape[1] * self.edit_config.time_shift.time_shift_ratio // 1 + ) + part2 = int( + key.shape[1] + * (1 - self.edit_config.time_shift.time_shift_ratio) + // 1 + ) + q_front_part = query[0, :part1, :, :] + q_back_part = query[0, -part2:, :, :] + + new_q = torch.cat((q_back_part, q_front_part), dim=0) + query[1] = new_q + + k_front_part = key[0, :part1, :, :] + k_back_part = key[0, -part2:, :, :] + new_k = torch.cat((k_back_part, k_front_part), dim=0) + key[1] = new_k + + # B, T, N, H + attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H) + weight = self.dropout(F.softmax(attention, dim=2)) + + # for counting the step and logging attention maps + try: + attention_matrix = ( + weight[0, :, :].mean(dim=-1).detach().cpu().numpy().astype(float) + ) + SELF_ATTN[-1].append(attention_matrix) + except: + pass + + # attention manipulation for replacement + attention_manipulation = self.edit_config.manipulation.use + if attention_manipulation: + if len(SELF_ATTN) <= self.edit_config.manipulation.manipulation_steps_end: + weight[1, :, :, :] = weight[0, :, :, :] + + value = self.value(self.norm(x)).view(B, N, H, -1) + + # time shift motion editing (v) + if time_shift: + if len(MONITOR_ATTN) <= self.edit_config.time_shift.time_shift_steps_end: + v_front_part = value[0, :part1, :, :] + v_back_part = value[0, -part2:, :, :] + new_v = torch.cat((v_back_part, v_front_part), dim=0) + value[1] = new_v + y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D) + return y + + +class TimestepEmbedder(nn.Module): + def __init__(self, d_model, max_len=5000): + super(TimestepEmbedder, self).__init__() + + pe = torch.zeros(max_len, d_model) + position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) + div_term = torch.exp( + torch.arange(0, d_model, 2).float() * (-np.log(10000.0) / d_model) + ) + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + self.register_buffer("pe", pe) + + def forward(self, x): + return self.pe[x] + + +class Downsample1d(nn.Module): + def __init__(self, dim): + super().__init__() + self.conv = nn.Conv1d(dim, dim, 3, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Upsample1d(nn.Module): + def __init__(self, dim_in, dim_out=None): + super().__init__() + dim_out = dim_out or dim_in + self.conv = nn.ConvTranspose1d(dim_in, dim_out, 4, 2, 1) + + def forward(self, x): + return self.conv(x) + + +class Conv1dBlock(nn.Module): + """ + Conv1d --> GroupNorm --> Mish + """ + + def __init__(self, inp_channels, out_channels, kernel_size, n_groups=4, zero=False): + super().__init__() + self.out_channels = out_channels + self.block = nn.Conv1d( + inp_channels, out_channels, kernel_size, padding=kernel_size // 2 + ) + self.norm = nn.GroupNorm(n_groups, out_channels) + self.activation = nn.Mish() + + if zero: + # zero init the convolution + nn.init.zeros_(self.block.weight) + nn.init.zeros_(self.block.bias) + + def forward(self, x): + """ + Args: + x: [bs, nfeat, nframes] + """ + x = self.block(x) + + batch_size, channels, horizon = x.size() + x = rearrange( + x, "batch channels horizon -> (batch horizon) channels" + ) # [bs*seq, nfeats] + x = self.norm(x) + x = rearrange( + x.reshape(batch_size, horizon, channels), + "batch horizon channels -> batch channels horizon", + ) + + return self.activation(x) + + +def ada_shift_scale(x, shift, scale): + return x * (1 + scale) + shift + + +class ResidualTemporalBlock(nn.Module): + def __init__( + self, + inp_channels, + out_channels, + embed_dim, + kernel_size=5, + zero=True, + n_groups=8, + dropout: float = 0.1, + adagn=True, + ): + super().__init__() + self.adagn = adagn + + self.blocks = nn.ModuleList( + [ + # adagn only the first conv (following guided-diffusion) + ( + Conv1dAdaGNBlock(inp_channels, out_channels, kernel_size, n_groups) + if adagn + else Conv1dBlock(inp_channels, out_channels, kernel_size) + ), + Conv1dBlock( + out_channels, out_channels, kernel_size, n_groups, zero=zero + ), + ] + ) + + self.time_mlp = nn.Sequential( + nn.Mish(), + # adagn = scale and shift + nn.Linear(embed_dim, out_channels * 2 if adagn else out_channels), + Rearrange("batch t -> batch t 1"), + ) + self.dropout = nn.Dropout(dropout) + if zero: + nn.init.zeros_(self.time_mlp[1].weight) + nn.init.zeros_(self.time_mlp[1].bias) + + self.residual_conv = ( + nn.Conv1d(inp_channels, out_channels, 1) + if inp_channels != out_channels + else nn.Identity() + ) + + def forward(self, x, time_embeds=None): + """ + x : [ batch_size x inp_channels x nframes ] + t : [ batch_size x embed_dim ] + returns: [ batch_size x out_channels x nframes ] + """ + if self.adagn: + scale, shift = self.time_mlp(time_embeds).chunk(2, dim=1) + out = self.blocks[0](x, scale, shift) + else: + out = self.blocks[0](x) + self.time_mlp(time_embeds) + out = self.blocks[1](out) + out = self.dropout(out) + return out + self.residual_conv(x) + + +class CrossAttention(nn.Module): + + def __init__( + self, + latent_dim, + text_latent_dim, + num_heads: int = 8, + dropout: float = 0.0, + log_attn=False, + edit_config=None, + ): + super().__init__() + self.num_head = num_heads + self.norm = nn.LayerNorm(latent_dim) + self.text_norm = nn.LayerNorm(text_latent_dim) + self.query = nn.Linear(latent_dim, latent_dim) + self.key = nn.Linear(text_latent_dim, latent_dim) + self.value = nn.Linear(text_latent_dim, latent_dim) + self.dropout = nn.Dropout(dropout) + + self.edit_config = edit_config + self.log_attn = log_attn + + def forward(self, x, xf): + """ + x: B, T, D + xf: B, N, L + """ + B, T, D = x.shape + N = xf.shape[1] + H = self.num_head + # B, T, 1, D + query = self.query(self.norm(x)).unsqueeze(2) + # B, 1, N, D + key = self.key(self.text_norm(xf)).unsqueeze(1) + query = query.view(B, T, H, -1) + key = key.view(B, N, H, -1) + # B, T, N, H + attention = torch.einsum("bnhd,bmhd->bnmh", query, key) / math.sqrt(D // H) + weight = self.dropout(F.softmax(attention, dim=2)) + + # attention reweighting for (de)-emphasizing motion + if self.edit_config.reweighting_attn.use: + reweighting_attn = self.edit_config.reweighting_attn.reweighting_attn_weight + if self.edit_config.reweighting_attn.idx == -1: + # read idxs from txt file + with open("./assets/reweighting_idx.txt", "r") as f: + idxs = f.readlines() + else: + # gradio demo mode + idxs = [0, self.edit_config.reweighting_attn.idx] + idxs = [int(idx) for idx in idxs] + for i in range(len(idxs)): + weight[i, :, 1 + idxs[i]] = weight[i, :, 1 + idxs[i]] + reweighting_attn + weight[i, :, 1 + idxs[i] + 1] = ( + weight[i, :, 1 + idxs[i] + 1] + reweighting_attn + ) + + # for counting the step and logging attention maps + try: + attention_matrix = ( + weight[0, :, 1 : 1 + 3] + .mean(dim=-1) + .detach() + .cpu() + .numpy() + .astype(float) + ) + MONITOR_ATTN[-1].append(attention_matrix) + except: + pass + + # erasing motion (autually is the deemphasizing motion) + erasing_motion = self.edit_config.erasing_motion.use + if erasing_motion: + reweighting_attn = self.edit_config.erasing_motion.erasing_motion_weight + begin = self.edit_config.erasing_motion.time_start + end = self.edit_config.erasing_motion.time_end + idx = self.edit_config.erasing_motion.idx + if reweighting_attn > 0.01 or reweighting_attn < -0.01: + weight[1, int(T * begin) : int(T * end), idx] = ( + weight[1, int(T * begin) : int(T * end) :, idx] * reweighting_attn + ) + weight[1, int(T * begin) : int(T * end), idx + 1] = ( + weight[1, int(T * begin) : int(T * end), idx + 1] * reweighting_attn + ) + + # attention manipulation for motion replacement + manipulation = self.edit_config.manipulation.use + if manipulation: + if ( + len(MONITOR_ATTN) + <= self.edit_config.manipulation.manipulation_steps_end_crossattn + ): + word_idx = self.edit_config.manipulation.word_idx + weight[1, :, : 1 + word_idx, :] = weight[0, :, : 1 + word_idx, :] + weight[1, :, 1 + word_idx + 1 :, :] = weight[ + 0, :, 1 + word_idx + 1 :, : + ] + + value = self.value(self.text_norm(xf)).view(B, N, H, -1) + y = torch.einsum("bnmh,bmhd->bnhd", weight, value).reshape(B, T, D) + return y + + +class ResidualCLRAttentionLayer(nn.Module): + def __init__( + self, + dim1, + dim2, + num_heads: int = 8, + dropout: float = 0.1, + no_eff: bool = False, + self_attention: bool = False, + log_attn=False, + edit_config=None, + ): + super(ResidualCLRAttentionLayer, self).__init__() + self.dim1 = dim1 + self.dim2 = dim2 + self.num_heads = num_heads + + # Multi-Head Attention Layer + if no_eff: + self.cross_attention = CrossAttention( + latent_dim=dim1, + text_latent_dim=dim2, + num_heads=num_heads, + dropout=dropout, + log_attn=log_attn, + edit_config=edit_config, + ) + else: + self.cross_attention = LinearCrossAttention( + latent_dim=dim1, + text_latent_dim=dim2, + num_heads=num_heads, + dropout=dropout, + log_attn=log_attn, + ) + if self_attention: + self.self_attn_use = True + self.self_attention = SelfAttention( + latent_dim=dim1, + text_latent_dim=dim2, + num_heads=num_heads, + dropout=dropout, + log_attn=log_attn, + edit_config=edit_config, + ) + else: + self.self_attn_use = False + + def forward(self, input_tensor, condition_tensor, cond_indices): + """ + input_tensor :B, D, L + condition_tensor: B, L, D + """ + if cond_indices.numel() == 0: + return input_tensor + + # self attention + if self.self_attn_use: + x = input_tensor + x = x.permute(0, 2, 1) # (batch_size, seq_length, feat_dim) + x = self.self_attention(x) + x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length) + input_tensor = input_tensor + x + x = input_tensor + + # cross attention + x = x[cond_indices].permute(0, 2, 1) # (batch_size, seq_length, feat_dim) + x = self.cross_attention(x, condition_tensor[cond_indices]) + x = x.permute(0, 2, 1) # (batch_size, feat_dim, seq_length) + + input_tensor[cond_indices] = input_tensor[cond_indices] + x + + return input_tensor + + +class CLRBlock(nn.Module): + def __init__( + self, + dim_in, + dim_out, + cond_dim, + time_dim, + adagn=True, + zero=True, + no_eff=False, + self_attention=False, + dropout: float = 0.1, + log_attn=False, + edit_config=None, + ) -> None: + super().__init__() + self.conv1d = ResidualTemporalBlock( + dim_in, dim_out, embed_dim=time_dim, adagn=adagn, zero=zero, dropout=dropout + ) + self.clr_attn = ResidualCLRAttentionLayer( + dim1=dim_out, + dim2=cond_dim, + no_eff=no_eff, + dropout=dropout, + self_attention=self_attention, + log_attn=log_attn, + edit_config=edit_config, + ) + # import pdb; pdb.set_trace() + self.ffn = FFN(dim_out, dim_out * 4, dropout=dropout) + + def forward(self, x, t, cond, cond_indices=None): + x = self.conv1d(x, t) + x = self.clr_attn(x, cond, cond_indices) + x = self.ffn(x.permute(0, 2, 1)).permute(0, 2, 1) + return x + + +class CondUnet1D(nn.Module): + """ + Diffusion's style UNET with 1D convolution and adaptive group normalization for motion suquence denoising, + cross-attention to introduce conditional prompts (like text). + """ + + def __init__( + self, + input_dim, + cond_dim, + dim=128, + dim_mults=(1, 2, 4, 8), + dims=None, + time_dim=512, + adagn=True, + zero=True, + dropout=0.1, + no_eff=False, + self_attention=False, + log_attn=False, + edit_config=None, + ): + super().__init__() + if not dims: + dims = [input_dim, *map(lambda m: int(dim * m), dim_mults)] ##[d, d,2d,4d] + print("dims: ", dims, "mults: ", dim_mults) + in_out = list(zip(dims[:-1], dims[1:])) + + self.time_mlp = nn.Sequential( + TimestepEmbedder(time_dim), + nn.Linear(time_dim, time_dim * 4), + nn.Mish(), + nn.Linear(time_dim * 4, time_dim), + ) + + self.downs = nn.ModuleList([]) + self.ups = nn.ModuleList([]) + + for ind, (dim_in, dim_out) in enumerate(in_out): + self.downs.append( + nn.ModuleList( + [ + CLRBlock( + dim_in, + dim_out, + cond_dim, + time_dim, + adagn=adagn, + zero=zero, + no_eff=no_eff, + dropout=dropout, + self_attention=self_attention, + log_attn=log_attn, + edit_config=edit_config, + ), + CLRBlock( + dim_out, + dim_out, + cond_dim, + time_dim, + adagn=adagn, + zero=zero, + no_eff=no_eff, + dropout=dropout, + self_attention=self_attention, + log_attn=log_attn, + edit_config=edit_config, + ), + Downsample1d(dim_out), + ] + ) + ) + + mid_dim = dims[-1] + self.mid_block1 = CLRBlock( + dim_in=mid_dim, + dim_out=mid_dim, + cond_dim=cond_dim, + time_dim=time_dim, + adagn=adagn, + zero=zero, + no_eff=no_eff, + dropout=dropout, + self_attention=self_attention, + log_attn=log_attn, + edit_config=edit_config, + ) + self.mid_block2 = CLRBlock( + dim_in=mid_dim, + dim_out=mid_dim, + cond_dim=cond_dim, + time_dim=time_dim, + adagn=adagn, + zero=zero, + no_eff=no_eff, + dropout=dropout, + self_attention=self_attention, + log_attn=log_attn, + edit_config=edit_config, + ) + + last_dim = mid_dim + for ind, dim_out in enumerate(reversed(dims[1:])): + self.ups.append( + nn.ModuleList( + [ + Upsample1d(last_dim, dim_out), + CLRBlock( + dim_out * 2, + dim_out, + cond_dim, + time_dim, + adagn=adagn, + zero=zero, + no_eff=no_eff, + dropout=dropout, + self_attention=self_attention, + log_attn=log_attn, + edit_config=edit_config, + ), + CLRBlock( + dim_out, + dim_out, + cond_dim, + time_dim, + adagn=adagn, + zero=zero, + no_eff=no_eff, + dropout=dropout, + self_attention=self_attention, + log_attn=log_attn, + edit_config=edit_config, + ), + ] + ) + ) + last_dim = dim_out + self.final_conv = nn.Conv1d(dim_out, input_dim, 1) + + if zero: + nn.init.zeros_(self.final_conv.weight) + nn.init.zeros_(self.final_conv.bias) + + def forward( + self, + x, + t, + cond, + cond_indices, + ): + temb = self.time_mlp(t) + + h = [] + for block1, block2, downsample in self.downs: + x = block1(x, temb, cond, cond_indices) + x = block2(x, temb, cond, cond_indices) + h.append(x) + x = downsample(x) + + x = self.mid_block1(x, temb, cond, cond_indices) + x = self.mid_block2(x, temb, cond, cond_indices) + + for upsample, block1, block2 in self.ups: + x = upsample(x) + x = torch.cat((x, h.pop()), dim=1) + x = block1(x, temb, cond, cond_indices) + x = block2(x, temb, cond, cond_indices) + + x = self.final_conv(x) + return x + + +class MotionCLR(nn.Module): + """ + Diffuser's style UNET for text-to-motion task. + """ + + def __init__( + self, + input_feats, + base_dim=128, + dim_mults=(1, 2, 2, 2), + dims=None, + adagn=True, + zero=True, + dropout=0.1, + no_eff=False, + time_dim=512, + latent_dim=256, + cond_mask_prob=0.1, + clip_dim=512, + clip_version="ViT-B/32", + text_latent_dim=256, + text_ff_size=2048, + text_num_heads=4, + activation="gelu", + num_text_layers=4, + self_attention=False, + vis_attn=False, + edit_config=None, + out_path=None, + ): + super().__init__() + self.input_feats = input_feats + self.dim_mults = dim_mults + self.base_dim = base_dim + self.latent_dim = latent_dim + self.cond_mask_prob = cond_mask_prob + self.vis_attn = vis_attn + self.counting_map = [] + self.out_path = out_path + + print( + f"The T2M Unet mask the text prompt by {self.cond_mask_prob} prob. in training" + ) + + # text encoder + self.embed_text = nn.Linear(clip_dim, text_latent_dim) + self.clip_version = clip_version + self.clip_model = self.load_and_freeze_clip(clip_version) + textTransEncoderLayer = nn.TransformerEncoderLayer( + d_model=text_latent_dim, + nhead=text_num_heads, + dim_feedforward=text_ff_size, + dropout=dropout, + activation=activation, + ) + self.textTransEncoder = nn.TransformerEncoder( + textTransEncoderLayer, num_layers=num_text_layers + ) + self.text_ln = nn.LayerNorm(text_latent_dim) + + self.unet = CondUnet1D( + input_dim=self.input_feats, + cond_dim=text_latent_dim, + dim=self.base_dim, + dim_mults=self.dim_mults, + adagn=adagn, + zero=zero, + dropout=dropout, + no_eff=no_eff, + dims=dims, + time_dim=time_dim, + self_attention=self_attention, + log_attn=self.vis_attn, + edit_config=edit_config, + ) + + def encode_text(self, raw_text, device): + with torch.no_grad(): + texts = clip.tokenize(raw_text, truncate=True).to( + device + ) # [bs, context_length] # if n_tokens > 77 -> will truncate + x = self.clip_model.token_embedding(texts).type( + self.clip_model.dtype + ) # [batch_size, n_ctx, d_model] + x = x + self.clip_model.positional_embedding.type(self.clip_model.dtype) + x = x.permute(1, 0, 2) # NLD -> LND + x = self.clip_model.transformer(x) + x = self.clip_model.ln_final(x).type( + self.clip_model.dtype + ) # [len, batch_size, 512] + + x = self.embed_text(x) # [len, batch_size, 256] + x = self.textTransEncoder(x) + x = self.text_ln(x) + + # T, B, D -> B, T, D + xf_out = x.permute(1, 0, 2) + + ablation_text = False + if ablation_text: + xf_out[:, 1:, :] = xf_out[:, 0, :].unsqueeze(1) + return xf_out + + def load_and_freeze_clip(self, clip_version): + clip_model, _ = clip.load( # clip_model.dtype=float32 + clip_version, device="cpu", jit=False + ) # Must set jit=False for training + + # Freeze CLIP weights + clip_model.eval() + for p in clip_model.parameters(): + p.requires_grad = False + + return clip_model + + def mask_cond(self, bs, force_mask=False): + """ + mask motion condition , return contitional motion index in the batch + """ + if force_mask: + cond_indices = torch.empty(0) + elif self.training and self.cond_mask_prob > 0.0: + mask = torch.bernoulli( + torch.ones( + bs, + ) + * self.cond_mask_prob + ) # 1-> use null_cond, 0-> use real cond + mask = 1.0 - mask + cond_indices = torch.nonzero(mask).squeeze(-1) + else: + cond_indices = torch.arange(bs) + + return cond_indices + + def forward( + self, + x, + timesteps, + text=None, + uncond=False, + enc_text=None, + ): + """ + Args: + x: [batch_size, nframes, nfeats], + timesteps: [batch_size] (int) + text: list (batch_size length) of strings with input text prompts + uncond: whethere using text condition + + Returns: [batch_size, seq_length, nfeats] + """ + B, T, _ = x.shape + x = x.transpose(1, 2) # [bs, nfeats, nframes] + + if enc_text is None: + enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim] + + cond_indices = self.mask_cond(x.shape[0], force_mask=uncond) + + # NOTE: need to pad to be the multiplier of 8 for the unet + PADDING_NEEEDED = (16 - (T % 16)) % 16 + + padding = (0, PADDING_NEEEDED) + x = F.pad(x, padding, value=0) + + x = self.unet( + x, + t=timesteps, + cond=enc_text, + cond_indices=cond_indices, + ) # [bs, nfeats,, nframes] + + x = x[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,] + + return x + + def forward_with_cfg(self, x, timesteps, text=None, enc_text=None, cfg_scale=2.5): + """ + Args: + x: [batch_size, nframes, nfeats], + timesteps: [batch_size] (int) + text: list (batch_size length) of strings with input text prompts + + Returns: [batch_size, max_frames, nfeats] + """ + global SELF_ATTN + global MONITOR_ATTN + MONITOR_ATTN.append([]) + SELF_ATTN.append([]) + + B, T, _ = x.shape + x = x.transpose(1, 2) # [bs, nfeats, nframes] + if enc_text is None: + enc_text = self.encode_text(text, x.device) # [bs, seqlen, text_dim] + + cond_indices = self.mask_cond(B) + + # NOTE: need to pad to be the multiplier of 8 for the unet + PADDING_NEEEDED = (16 - (T % 16)) % 16 + + padding = (0, PADDING_NEEEDED) + x = F.pad(x, padding, value=0) + + combined_x = torch.cat([x, x], dim=0) + combined_t = torch.cat([timesteps, timesteps], dim=0) + out = self.unet( + x=combined_x, + t=combined_t, + cond=enc_text, + cond_indices=cond_indices, + ) # [bs, nfeats, nframes] + + out = out[:, :, :T].transpose(1, 2) # [bs, nframes, nfeats,] + + out_cond, out_uncond = torch.split(out, len(out) // 2, dim=0) + + if self.vis_attn == True: + i = len(MONITOR_ATTN) + attnlist = MONITOR_ATTN[-1] + print(i, "cross", len(attnlist)) + for j, att in enumerate(attnlist): + vis_attn( + att, + out_path=self.out_path, + step=i, + layer=j, + shape="_".join(map(str, att.shape)), + type_="cross", + ) + + attnlist = SELF_ATTN[-1] + print(i, "self", len(attnlist)) + for j, att in enumerate(attnlist): + vis_attn( + att, + out_path=self.out_path, + step=i, + layer=j, + shape="_".join(map(str, att.shape)), + type_="self", + lines=False, + ) + + if len(SELF_ATTN) % 10 == 0: + SELF_ATTN = [] + MONITOR_ATTN = [] + + return out_uncond + (cfg_scale * (out_cond - out_uncond)) + + +if __name__ == "__main__": + + device = "cuda:0" + n_feats = 263 + num_frames = 196 + text_latent_dim = 256 + dim_mults = [2, 2, 2, 2] + base_dim = 512 + model = MotionCLR( + input_feats=n_feats, + text_latent_dim=text_latent_dim, + base_dim=base_dim, + dim_mults=dim_mults, + adagn=True, + zero=True, + dropout=0.1, + no_eff=True, + cond_mask_prob=0.1, + self_attention=True, + ) + + model = model.to(device) + from utils.model_load import load_model_weights + + checkpoint_path = "/comp_robot/chenlinghao/StableMoFusion/checkpoints/t2m/self_attn—fulllayer-ffn-drop0_1-lr1e4/model/latest.tar" + new_state_dict = {} + checkpoint = torch.load(checkpoint_path) + ckpt2 = checkpoint.copy() + ckpt2["model_ema"] = {} + ckpt2["encoder"] = {} + + for key, value in list(checkpoint["model_ema"].items()): + new_key = key.replace( + "cross_attn", "clr_attn" + ) # Replace 'cross_attn' with 'clr_attn' + ckpt2["model_ema"][new_key] = value + for key, value in list(checkpoint["encoder"].items()): + new_key = key.replace( + "cross_attn", "clr_attn" + ) # Replace 'cross_attn' with 'clr_attn' + ckpt2["encoder"][new_key] = value + + torch.save( + ckpt2, + "/comp_robot/chenlinghao/CLRpreview/checkpoints/t2m/release/model/latest.tar", + ) + + dtype = torch.float32 + bs = 1 + x = torch.rand((bs, 196, 263), dtype=dtype).to(device) + timesteps = torch.randint(low=0, high=1000, size=(bs,)).to(device) + y = ["A man jumps to his left." for i in range(bs)] + length = torch.randint(low=20, high=196, size=(bs,)).to(device) + + out = model(x, timesteps, text=y) + print(out.shape) + model.eval() + out = model.forward_with_cfg(x, timesteps, text=y) + print(out.shape) diff --git a/motion_loader/__init__.py b/motion_loader/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c78ad0cd5a722a98939f9daceb029de920f05adb --- /dev/null +++ b/motion_loader/__init__.py @@ -0,0 +1,2 @@ +from .model_motion_loaders import get_motion_loader +from .dataset_motion_loaders import get_dataset_loader diff --git a/motion_loader/dataset_motion_loaders.py b/motion_loader/dataset_motion_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..c4d8ae699351045a1b2e39f0f86997876ca8207f --- /dev/null +++ b/motion_loader/dataset_motion_loaders.py @@ -0,0 +1,31 @@ +from datasets import get_dataset +from torch.utils.data import DataLoader +from torch.utils.data._utils.collate import default_collate + + +def collate_fn(batch): + batch.sort(key=lambda x: x[3], reverse=True) + return default_collate(batch) + + +def get_dataset_loader(opt, batch_size, mode="eval", split="test", accelerator=None): + dataset = get_dataset(opt, split, mode, accelerator) + if mode in ["eval", "gt_eval"]: + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=4, + drop_last=True, + collate_fn=collate_fn, + ) + else: + dataloader = DataLoader( + dataset, + batch_size=batch_size, + shuffle=True, + num_workers=4, + drop_last=True, + persistent_workers=True, + ) + return dataloader diff --git a/motion_loader/model_motion_loaders.py b/motion_loader/model_motion_loaders.py new file mode 100644 index 0000000000000000000000000000000000000000..e9d6ce31f446f18fd8f03f5a152814b0c1456f37 --- /dev/null +++ b/motion_loader/model_motion_loaders.py @@ -0,0 +1,243 @@ +import torch +from utils.word_vectorizer import WordVectorizer +from torch.utils.data import Dataset, DataLoader +from os.path import join as pjoin +from tqdm import tqdm +import numpy as np +from eval.evaluator_modules import * + +from torch.utils.data._utils.collate import default_collate + + +class GeneratedDataset(Dataset): + """ + opt.dataset_name + opt.max_motion_length + opt.unit_length + """ + + def __init__( + self, opt, pipeline, dataset, w_vectorizer, mm_num_samples, mm_num_repeats + ): + assert mm_num_samples < len(dataset) + self.dataset = dataset + dataloader = DataLoader(dataset, batch_size=1, num_workers=1, shuffle=True) + generated_motion = [] + min_mov_length = 10 if opt.dataset_name == "t2m" else 6 + + # Pre-process all target captions + mm_generated_motions = [] + if mm_num_samples > 0: + mm_idxs = np.random.choice(len(dataset), mm_num_samples, replace=False) + mm_idxs = np.sort(mm_idxs) + + all_caption = [] + all_m_lens = [] + all_data = [] + with torch.no_grad(): + for i, data in tqdm(enumerate(dataloader)): + word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data + all_data.append(data) + tokens = tokens[0].split("_") + mm_num_now = len(mm_generated_motions) + is_mm = ( + True + if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) + else False + ) + repeat_times = mm_num_repeats if is_mm else 1 + m_lens = max( + torch.div(m_lens, opt.unit_length, rounding_mode="trunc") + * opt.unit_length, + min_mov_length * opt.unit_length, + ) + m_lens = min(m_lens, opt.max_motion_length) + if isinstance(m_lens, int): + m_lens = torch.LongTensor([m_lens]).to(opt.device) + else: + m_lens = m_lens.to(opt.device) + for t in range(repeat_times): + all_m_lens.append(m_lens) + all_caption.extend(caption) + if is_mm: + mm_generated_motions.append(0) + all_m_lens = torch.stack(all_m_lens) + + # Generate all sequences + with torch.no_grad(): + all_pred_motions, t_eval = pipeline.generate(all_caption, all_m_lens) + self.eval_generate_time = t_eval + + cur_idx = 0 + mm_generated_motions = [] + with torch.no_grad(): + for i, data_dummy in tqdm(enumerate(dataloader)): + data = all_data[i] + word_emb, pos_ohot, caption, cap_lens, motions, m_lens, tokens = data + tokens = tokens[0].split("_") + mm_num_now = len(mm_generated_motions) + is_mm = ( + True + if ((mm_num_now < mm_num_samples) and (i == mm_idxs[mm_num_now])) + else False + ) + repeat_times = mm_num_repeats if is_mm else 1 + mm_motions = [] + for t in range(repeat_times): + pred_motions = all_pred_motions[cur_idx] + cur_idx += 1 + if t == 0: + sub_dict = { + "motion": pred_motions.cpu().numpy(), + "length": pred_motions.shape[0], # m_lens[0].item(), # + "caption": caption[0], + "cap_len": cap_lens[0].item(), + "tokens": tokens, + } + generated_motion.append(sub_dict) + + if is_mm: + mm_motions.append( + { + "motion": pred_motions.cpu().numpy(), + "length": pred_motions.shape[ + 0 + ], # m_lens[0].item(), #m_lens[0].item() + } + ) + if is_mm: + mm_generated_motions.append( + { + "caption": caption[0], + "tokens": tokens, + "cap_len": cap_lens[0].item(), + "mm_motions": mm_motions, + } + ) + self.generated_motion = generated_motion + self.mm_generated_motion = mm_generated_motions + self.opt = opt + self.w_vectorizer = w_vectorizer + + def __len__(self): + return len(self.generated_motion) + + def __getitem__(self, item): + data = self.generated_motion[item] + motion, m_length, caption, tokens = ( + data["motion"], + data["length"], + data["caption"], + data["tokens"], + ) + sent_len = data["cap_len"] + + # This step is needed because T2M evaluators expect their norm convention + normed_motion = motion + denormed_motion = self.dataset.inv_transform(normed_motion) + renormed_motion = ( + denormed_motion - self.dataset.mean_for_eval + ) / self.dataset.std_for_eval # according to T2M norms + motion = renormed_motion + + pos_one_hots = [] + word_embeddings = [] + for token in tokens: + word_emb, pos_oh = self.w_vectorizer[token] + pos_one_hots.append(pos_oh[None, :]) + word_embeddings.append(word_emb[None, :]) + pos_one_hots = np.concatenate(pos_one_hots, axis=0) + word_embeddings = np.concatenate(word_embeddings, axis=0) + length = len(motion) + if length < self.opt.max_motion_length: + motion = np.concatenate( + [ + motion, + np.zeros((self.opt.max_motion_length - length, motion.shape[1])), + ], + axis=0, + ) + return ( + word_embeddings, + pos_one_hots, + caption, + sent_len, + motion, + m_length, + "_".join(tokens), + ) + + +def collate_fn(batch): + batch.sort(key=lambda x: x[3], reverse=True) + return default_collate(batch) + + +class MMGeneratedDataset(Dataset): + def __init__(self, opt, motion_dataset, w_vectorizer): + self.opt = opt + self.dataset = motion_dataset.mm_generated_motion + self.w_vectorizer = w_vectorizer + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, item): + data = self.dataset[item] + mm_motions = data["mm_motions"] + m_lens = [] + motions = [] + for mm_motion in mm_motions: + m_lens.append(mm_motion["length"]) + motion = mm_motion["motion"] + if len(motion) < self.opt.max_motion_length: + motion = np.concatenate( + [ + motion, + np.zeros( + (self.opt.max_motion_length - len(motion), motion.shape[1]) + ), + ], + axis=0, + ) + motion = motion[None, :] + motions.append(motion) + m_lens = np.array(m_lens, dtype=np.int32) + motions = np.concatenate(motions, axis=0) + sort_indx = np.argsort(m_lens)[::-1].copy() + + m_lens = m_lens[sort_indx] + motions = motions[sort_indx] + return motions, m_lens + + +def get_motion_loader( + opt, batch_size, pipeline, ground_truth_dataset, mm_num_samples, mm_num_repeats +): + + # Currently the configurations of two datasets are almost the same + if opt.dataset_name == "t2m" or opt.dataset_name == "kit": + w_vectorizer = WordVectorizer(opt.glove_dir, "our_vab") + else: + raise KeyError("Dataset not recognized!!") + + dataset = GeneratedDataset( + opt, + pipeline, + ground_truth_dataset, + w_vectorizer, + mm_num_samples, + mm_num_repeats, + ) + mm_dataset = MMGeneratedDataset(opt, dataset, w_vectorizer) + + motion_loader = DataLoader( + dataset, + batch_size=batch_size, + collate_fn=collate_fn, + drop_last=True, + num_workers=4, + ) + mm_motion_loader = DataLoader(mm_dataset, batch_size=1, num_workers=1) + + return motion_loader, mm_motion_loader, dataset.eval_generate_time diff --git a/options/__init__.py b/options/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/options/edit.yaml b/options/edit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..471226ca819d2312f9f7c7838795180f53c8d74b --- /dev/null +++ b/options/edit.yaml @@ -0,0 +1,40 @@ +# edit.yaml +reweighting_attn: + use: False + reweighting_attn_weight: 0.0 # the weight of reweighting attention for motion emphasizing and de-emphasizing + idx: -1 # the index of the word to be emphasized or de-emphasized (0 ~ 10) + +erasing_motion: + use: False + erasing_motion_weight: 0.1 # the weight of motion erasing + time_start: 0.5 # the start time of motion erasing (0.0 ~ 1.0), ratio of the total time + time_end: 1.0 # the end time of motion erasing (0.0 ~ 1.0), ratio of the total time + idx: -1 + +manipulation: # motion manipulation means in-place motion replacement + use: False + manipulation_steps_start: 0 # the start step of motion manipulation, 0 ~ 10 + manipulation_steps_end: 3 # the end step of motion manipulation, 0 ~ 10 + manipulation_steps_end_crossattn: 3 # the end step of cross-attention for motion manipulation, 0 ~ 10 + word_idx: 3 # the index of the word to be manipulated + +time_shift: + use: False + time_shift_steps_start: 0 # the start step of time shifting, 0 ~ 10 + time_shift_steps_end: 4 # the end step of time shifting, 0 ~ 10 + time_shift_ratio: 0.5 # the ratio of time shifting, 0.0 ~ 1.0 + +example_based: + use: False + chunk_size: 20 # the size of the chunk for example-based generation + example_based_steps_end: 6 # the end step of example-based generation, 0 ~ 10 + temp_seed: 200 # the inintial seed for example-based generation + temp_seed_bar: 15 # the the seed bar for example-based generation + +style_tranfer: + use: False + style_transfer_steps_start: 0 # the start step of style transfer, 0 ~ 10 + style_transfer_steps_end: 5 # the end step of style transfer, 0 ~ 10 + +grounded_generation: + use: False \ No newline at end of file diff --git a/options/evaluate_options.py b/options/evaluate_options.py new file mode 100644 index 0000000000000000000000000000000000000000..302cdffa49e060f835b445e235d1335a08389ffa --- /dev/null +++ b/options/evaluate_options.py @@ -0,0 +1,53 @@ +import argparse +from .get_opt import get_opt +import yaml + +class TestOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + self.initialize() + + def initialize(self): + self.parser.add_argument("--opt_path", type=str, default='./checkpoints/t2m/t2m_condunet1d_batch64/opt.txt',help='option file path for loading model') + self.parser.add_argument("--gpu_id", type=int, default=0, help='GPU id') + + # evaluator + self.parser.add_argument("--evaluator_dir", type=str, default='./data/checkpoints', help='Directory path where save T2M evaluator\'s checkpoints') + self.parser.add_argument("--eval_meta_dir", type=str, default='./data', help='Directory path where save T2M evaluator\'s normalization data.') + self.parser.add_argument("--glove_dir",type=str,default='./data/glove', help='Directory path where save glove') + + # inference + self.parser.add_argument("--num_inference_steps", type=int, default=10, help='Number of iterative denoising steps during inference.') + self.parser.add_argument("--which_ckpt", type=str, default='latest', help='name of checkpoint to load') + self.parser.add_argument("--diffuser_name", type=str, default='dpmsolver', help='sampler\'s scheduler class name in the diffuser library') + self.parser.add_argument("--no_ema", action="store_true", help='Where use EMA model in inference') + self.parser.add_argument("--no_fp16", action="store_true", help='Whether use FP16 in inference') + self.parser.add_argument('--debug', action="store_true", help='debug mode') + self.parser.add_argument('--self_attention', action="store_true", help='self_attention use or not') + self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient linear attention') + self.parser.add_argument('--vis_attn', action='store_true', help='vis attention value or not') + self.parser.add_argument('--dropout', type=float, default=0.1, help='dropout') + + # evaluation + self.parser.add_argument("--replication_times", type=int, default=1, help='Number of generation rounds for each text description') + self.parser.add_argument('--batch_size', type=int, default=32, help='Batch size for eval') + self.parser.add_argument('--diversity_times', type=int, default=300, help='') + self.parser.add_argument('--mm_num_samples', type=int, default=100, help='Number of samples for evaluating multimodality') + self.parser.add_argument('--mm_num_repeats', type=int, default=30, help='Number of generation rounds for each text description when evaluating multimodality') + self.parser.add_argument('--mm_num_times', type=int, default=10, help='') + self.parser.add_argument('--edit_mode', action='store_true', help='editing mode') + + def parse(self): + # load evaluation options + self.opt = self.parser.parse_args() + opt_dict = vars(self.opt) + + # load the model options of T2m evaluator + with open('./config/evaluator.yaml', 'r') as yaml_file: + yaml_config = yaml.safe_load(yaml_file) + opt_dict.update(yaml_config) + + # load the training options of the selected checkpoint + get_opt(self.opt, self.opt.opt_path) + + return self.opt \ No newline at end of file diff --git a/options/generate_options.py b/options/generate_options.py new file mode 100644 index 0000000000000000000000000000000000000000..900efa6894d53ff0dc1abc62bff43aa6fc79252c --- /dev/null +++ b/options/generate_options.py @@ -0,0 +1,50 @@ +import argparse +from .get_opt import get_opt + +class GenerateOptions(): + def __init__(self, app=False): + self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + self.initialize() + + def initialize(self): + self.parser.add_argument("--opt_path", type=str, default='./checkpoints/t2m/t2m_condunet1d_batch64/opt.txt', help='option file path for loading model') + self.parser.add_argument("--gpu_id", type=int, default=0, help='GPU id') + self.parser.add_argument("--output_dir", type=str, default='', help='Directory path to save generation result') + self.parser.add_argument("--footskate_cleanup", action="store_true", help='Where use footskate cleanup in inference') + + # inference + self.parser.add_argument("--num_inference_steps", type=int, default=10, help='Number of iterative denoising steps during inference.') + self.parser.add_argument("--which_ckpt", type=str, default='latest', help='name of checkpoint to load') + self.parser.add_argument("--diffuser_name", type=str, default='dpmsolver', help='sampler\'s scheduler class name in the diffuser library') + self.parser.add_argument("--no_ema", action="store_true", help='Where use EMA model in inference') + self.parser.add_argument("--no_fp16", action="store_true", help='Whether use FP16 in inference') + self.parser.add_argument('--batch_size', type=int, default=1, help='Batch size for generate') + self.parser.add_argument("--seed", default=0, type=int, help="For fixing random seed.") + self.parser.add_argument('--dropout', type=float, default=0.1, help='dropout') + + # generate prompts + self.parser.add_argument('--text_prompt', type=str, default="", help='One text description pompt for motion generation') + self.parser.add_argument("--motion_length", default=4.0, type=float, help="The length of the generated motion [in seconds] when using prompts. Maximum is 9.8 for HumanML3D (text-to-motion)") + self.parser.add_argument('--input_text', type=str, default='', help='File path of texts when using multiple texts.') + self.parser.add_argument('--input_lens', type=str, default='', help='File path of expected motion frame lengths when using multitext.') + self.parser.add_argument("--num_samples", type=int, default=10, help='Number of samples for generate when using dataset.') + self.parser.add_argument('--debug', action="store_true", help='debug mode') + self.parser.add_argument('--self_attention', action="store_true", help='self_attention use or not') + self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient linear attention') + self.parser.add_argument('--vis_attn', action='store_true', help='vis attention value or not') + self.parser.add_argument('--edit_mode', action='store_true', help='editing mode') + + + def parse(self): + self.opt = self.parser.parse_args() + opt_path = self.opt.opt_path + get_opt(self.opt, opt_path) + return self.opt + + def parse_app(self): + self.opt = self.parser.parse_args( + args=['--motion_length', '8', '--self_attention', '--no_eff', '--opt_path', './checkpoints/t2m/release/opt.txt', '--edit_mode'] + ) + opt_path = self.opt.opt_path + get_opt(self.opt, opt_path) + return self.opt \ No newline at end of file diff --git a/options/get_opt.py b/options/get_opt.py new file mode 100644 index 0000000000000000000000000000000000000000..61a7ee06d631827e93b33164174695980c144861 --- /dev/null +++ b/options/get_opt.py @@ -0,0 +1,74 @@ +import os +from argparse import Namespace +import re +from os.path import join as pjoin + + +def is_float(numStr): + flag = False + numStr = str(numStr).strip().lstrip("-").lstrip("+") + try: + reg = re.compile(r"^[-+]?[0-9]+\.[0-9]+$") + res = reg.match(str(numStr)) + if res: + flag = True + except Exception as ex: + print("is_float() - error: " + str(ex)) + return flag + + +def is_number(numStr): + flag = False + numStr = str(numStr).strip().lstrip("-").lstrip("+") + if str(numStr).isdigit(): + flag = True + return flag + + +def get_opt(opt, opt_path): + opt_dict = vars(opt) + + skip = ( + "-------------- End ----------------", + "------------ Options -------------", + "\n", + ) + print("Reading", opt_path) + with open(opt_path) as f: + for line in f: + if line.strip() not in skip: + print(line.strip()) + key, value = line.strip().split(": ") + if getattr(opt, key, None) is not None: + continue + if value in ("True", "False"): + opt_dict[key] = True if value == "True" else False + elif is_float(value): + opt_dict[key] = float(value) + elif is_number(value): + opt_dict[key] = int(value) + elif "," in value: + value = value[1:-1].split(",") + opt_dict[key] = [int(i) for i in value] + else: + opt_dict[key] = str(value) + + # opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) + opt.save_root = os.path.dirname(opt_path) + opt.model_dir = pjoin(opt.save_root, "model") + opt.meta_dir = pjoin(opt.save_root, "meta") + + if opt.dataset_name == "t2m" or opt.dataset_name == "humanml": + opt.joints_num = 22 + opt.dim_pose = 263 + opt.max_motion_length = 196 + opt.radius = 4 + opt.fps = 20 + elif opt.dataset_name == "kit": + opt.joints_num = 21 + opt.dim_pose = 251 + opt.max_motion_length = 196 + opt.radius = 240 * 8 + opt.fps = 12.5 + else: + raise KeyError("Dataset not recognized") diff --git a/options/noedit.yaml b/options/noedit.yaml new file mode 100644 index 0000000000000000000000000000000000000000..d9149181586fd6c009e21efe5078a4730f1571b0 --- /dev/null +++ b/options/noedit.yaml @@ -0,0 +1,21 @@ +# noedit.yaml +reweighting_attn: + use: False + +erasing_motion: + use: False + +manipulation: + use: False + +time_shift: + use: False + +example_based: + use: False + +style_tranfer: + use: False + +grounded_generation: + use: False \ No newline at end of file diff --git a/options/train_options.py b/options/train_options.py new file mode 100644 index 0000000000000000000000000000000000000000..cb10d03da8aafaf751454ff912b423b5a5cad48e --- /dev/null +++ b/options/train_options.py @@ -0,0 +1,126 @@ +import argparse +from .get_opt import get_opt +from os.path import join as pjoin +import os + +class TrainOptions(): + def __init__(self): + self.parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + self.initialized = False + + def initialize(self): + # base set + self.parser.add_argument('--name', type=str, default="test", help='Name of this trial') + self.parser.add_argument('--dataset_name', type=str, default='t2m', help='Dataset Name') + self.parser.add_argument('--feat_bias', type=float, default=5, help='Scales for global motion features and foot contact') + self.parser.add_argument('--checkpoints_dir', type=str, default='./checkpoints', help='models are saved here') + self.parser.add_argument('--log_every', type=int, default=5, help='Frequency of printing training progress (by iteration)') + self.parser.add_argument('--save_interval', type=int, default=10_000, help='Frequency of evaluateing and saving models (by iteration)') + + + # network hyperparams + self.parser.add_argument('--num_layers', type=int, default=8, help='num_layers of transformer') + self.parser.add_argument('--latent_dim', type=int, default=512, help='latent_dim of transformer') + self.parser.add_argument('--text_latent_dim', type=int, default=256, help='latent_dim of text embeding') + self.parser.add_argument('--time_dim', type=int, default=512, help='latent_dim of timesteps') + self.parser.add_argument('--base_dim', type=int, default=512, help='Dimension of Unet base channel') + self.parser.add_argument('--dim_mults', type=int, default=[2,2,2,2], nargs='+', help='Unet channel multipliers.') + self.parser.add_argument('--no_eff', action='store_true', help='whether use efficient linear attention') + self.parser.add_argument('--no_adagn', action='store_true', help='whether use adagn block') + self.parser.add_argument('--diffusion_steps', type=int, default=1000, help='diffusion_steps of transformer') + self.parser.add_argument('--prediction_type', type=str, default='sample', help='diffusion_steps of transformer') + + # train hyperparams + self.parser.add_argument('--seed', type=int, default=0, help='seed for train') + self.parser.add_argument('--num_train_steps', type=int, default=50_000, help='Number of training iterations') + self.parser.add_argument('--lr', type=float, default=2e-4, help='Learning rate') + self.parser.add_argument("--decay_rate", default=0.9, type=float, help="the decay rate of lr (0-1 default 0.9)") + self.parser.add_argument("--update_lr_steps", default=5_000, type=int, help="") + self.parser.add_argument("--cond_mask_prob", default=0.1, type=float, + help="The probability of masking the condition during training." + " For classifier-free guidance learning.") + self.parser.add_argument('--clip_grad_norm', type=float, default=1, help='Gradient clip') + self.parser.add_argument('--weight_decay', type=float, default=1e-2, help='Learning rate weight_decay') + self.parser.add_argument('--batch_size', type=int, default=64, help='Batch size per GPU') + self.parser.add_argument("--beta_schedule", default='linear', type=str, help="Types of beta in diffusion (e.g. linear, cosine)") + self.parser.add_argument('--dropout', type=float, default=0.1, help='dropout') + + # continue training + self.parser.add_argument('--is_continue', action="store_true", help='Is this trail continued from previous trail?') + self.parser.add_argument('--continue_ckpt', type=str, default="latest.tar", help='previous trail to continue') + self.parser.add_argument("--opt_path", type=str, default='',help='option file path for loading model') + self.parser.add_argument('--debug', action="store_true", help='debug mode') + self.parser.add_argument('--self_attention', action="store_true", help='self_attention use or not') + self.parser.add_argument('--vis_attn', action='store_true', help='vis attention value or not') + + self.parser.add_argument('--edit_mode', action='store_true', help='editing mode') + + # EMA params + self.parser.add_argument( + "--model-ema", action="store_true", help="enable tracking Exponential Moving Average of model parameters" + ) + self.parser.add_argument( + "--model-ema-steps", + type=int, + default=32, + help="the number of iterations that controls how often to update the EMA model (default: 32)", + ) + self.parser.add_argument( + "--model-ema-decay", + type=float, + default=0.9999, + help="decay factor for Exponential Moving Average of model parameters (default: 0.99988)", + ) + + self.initialized = True + + def parse(self,accelerator): + if not self.initialized: + self.initialize() + + self.opt = self.parser.parse_args() + + if self.opt.is_continue: + assert self.opt.opt_path.endswith('.txt') + get_opt(self.opt, self.opt.opt_path) + self.opt.is_train = True + self.opt.is_continue=True + elif accelerator.is_main_process: + args = vars(self.opt) + accelerator.print('------------ Options -------------') + for k, v in sorted(args.items()): + accelerator.print('%s: %s' % (str(k), str(v))) + accelerator.print('-------------- End ----------------') + # save to the disk + expr_dir = pjoin(self.opt.checkpoints_dir, self.opt.dataset_name, self.opt.name) + os.makedirs(expr_dir,exist_ok=True) + file_name = pjoin(expr_dir, 'opt.txt') + with open(file_name, 'wt') as opt_file: + opt_file.write('------------ Options -------------\n') + for k, v in sorted(args.items()): + if k =='opt_path': + continue + opt_file.write('%s: %s\n' % (str(k), str(v))) + opt_file.write('-------------- End ----------------\n') + + + if self.opt.dataset_name == 't2m' or self.opt.dataset_name == 'humanml': + self.opt.joints_num = 22 + self.opt.dim_pose = 263 + self.opt.max_motion_length = 196 + self.opt.radius = 4 + self.opt.fps = 20 + elif self.opt.dataset_name == 'kit': + self.opt.joints_num = 21 + self.opt.dim_pose = 251 + self.opt.max_motion_length = 196 + self.opt.radius = 240 * 8 + self.opt.fps = 12.5 + else: + raise KeyError('Dataset not recognized') + + self.opt.device = accelerator.device + self.opt.is_train = True + return self.opt + + diff --git a/prepare/download_glove.sh b/prepare/download_glove.sh new file mode 100644 index 0000000000000000000000000000000000000000..eff947429dfc80e2e03a139aabb179374402bb18 --- /dev/null +++ b/prepare/download_glove.sh @@ -0,0 +1,12 @@ +cd ./data/ + +echo -e "Downloading glove (in use by the evaluators)" +gdown --fuzzy https://drive.google.com/file/d/1cmXKUT31pqd7_XpJAiWEo1K81TMYHA5n/view?usp=sharing +rm -rf glove + +unzip glove.zip +echo -e "Cleaning\n" +rm glove.zip +cd .. + +echo -e "Downloading done!" \ No newline at end of file diff --git a/prepare/download_t2m_evaluators.sh b/prepare/download_t2m_evaluators.sh new file mode 100644 index 0000000000000000000000000000000000000000..efa62b6a01c16697d5538d96c13a7f55374266e9 --- /dev/null +++ b/prepare/download_t2m_evaluators.sh @@ -0,0 +1,17 @@ +mkdir -p data/ +cd data/ +mkdir -p checkpoints/ +cd checkpoints/ + +echo "The t2m evaluators will be stored in the './deps' folder" + +echo "Downloading" +gdown --fuzzy https://drive.google.com/file/d/16hyR4XlEyksVyNVjhIWK684Lrm_7_pvX/view?usp=sharing +echo "Extracting" +unzip t2m.zip +echo "Cleaning" +rm t2m.zip + +cd ../.. + +echo "Downloading done!" diff --git a/prepare/prepare_clip.sh b/prepare/prepare_clip.sh new file mode 100644 index 0000000000000000000000000000000000000000..bf786225bdcd891623a69fea6a16f86f7c155cf7 --- /dev/null +++ b/prepare/prepare_clip.sh @@ -0,0 +1,5 @@ +mkdir -p deps/ +cd deps/ +git lfs install +git clone https://huggingface.co/openai/clip-vit-large-patch14 +cd .. diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..9a2386dc8fa457f28ffce1e38302bd9b3b8852fe --- /dev/null +++ b/requirements.txt @@ -0,0 +1,20 @@ +tqdm +opencv-python +scipy +matplotlib==3.3.1 +spacy +accelerate +transformers +einops +diffusers +panda3d +numpy==1.23.0 +git+https://github.com/openai/CLIP.git +diffusers==0.30.3 +transformers==4.45.2 + +# for train +tensorboard +accelerate==1.0.1 +smplx +python-box \ No newline at end of file diff --git a/scripts/__init__.py b/scripts/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/scripts/evaluation.py b/scripts/evaluation.py new file mode 100644 index 0000000000000000000000000000000000000000..2820fbc30333d55021e2025aa77c21d4a2dd4e9d --- /dev/null +++ b/scripts/evaluation.py @@ -0,0 +1,107 @@ +import sys +import torch +from motion_loader import get_dataset_loader, get_motion_loader +from datasets import get_dataset +from models import build_models +from eval import EvaluatorModelWrapper, evaluation +from utils.utils import * +from utils.model_load import load_model_weights +import os +from os.path import join as pjoin + +from models.gaussian_diffusion import DiffusePipeline +from accelerate.utils import set_seed + +from options.evaluate_options import TestOptions + +import yaml +from box import Box + + +def yaml_to_box(yaml_file): + with open(yaml_file, "r") as file: + yaml_data = yaml.safe_load(file) + + return Box(yaml_data) + + +if __name__ == "__main__": + parser = TestOptions() + opt = parser.parse() + set_seed(0) + + if opt.edit_mode: + edit_config = yaml_to_box("options/edit.yaml") + else: + edit_config = yaml_to_box("options/noedit.yaml") + + device_id = opt.gpu_id + device = torch.device("cuda:%d" % device_id if torch.cuda.is_available() else "cpu") + torch.cuda.set_device(device) + opt.device = device + + # load evaluator + eval_wrapper = EvaluatorModelWrapper(opt) + + # load dataset + gt_loader = get_dataset_loader(opt, opt.batch_size, mode="gt_eval", split="test") + gen_dataset = get_dataset(opt, mode="eval", split="test") + + # load model + model = build_models(opt, edit_config=edit_config) + ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + ".tar") + load_model_weights(model, ckpt_path, use_ema=not opt.no_ema, device=device) + + # Create a pipeline for generation in diffusion model framework + pipeline = DiffusePipeline( + opt=opt, + model=model, + diffuser_name=opt.diffuser_name, + device=device, + num_inference_steps=opt.num_inference_steps, + torch_dtype=torch.float32 if opt.no_fp16 else torch.float16, + ) + + eval_motion_loaders = { + "text2motion": lambda: get_motion_loader( + opt, + opt.batch_size, + pipeline, + gen_dataset, + opt.mm_num_samples, + opt.mm_num_repeats, + ) + } + + save_dir = pjoin(opt.save_root, "eval") + os.makedirs(save_dir, exist_ok=True) + if opt.no_ema: + log_file = ( + pjoin(save_dir, opt.diffuser_name) + + f"_{str(opt.num_inference_steps)}setps.log" + ) + else: + log_file = ( + pjoin(save_dir, opt.diffuser_name) + + f"_{str(opt.num_inference_steps)}steps_ema.log" + ) + + if not os.path.exists(log_file): + config_dict = dict(pipeline.scheduler.config) + config_dict["no_ema"] = opt.no_ema + with open(log_file, "wt") as f: + f.write("------------ Options -------------\n") + for k, v in sorted(config_dict.items()): + f.write("%s: %s\n" % (str(k), str(v))) + f.write("-------------- End ----------------\n") + + all_metrics = evaluation( + eval_wrapper, + gt_loader, + eval_motion_loaders, + log_file, + opt.replication_times, + opt.diversity_times, + opt.mm_num_times, + run_mm=True, + ) diff --git a/scripts/generate.py b/scripts/generate.py new file mode 100644 index 0000000000000000000000000000000000000000..4fa4ba61ded2ab2ba6dbc950ca2612cdc685cc10 --- /dev/null +++ b/scripts/generate.py @@ -0,0 +1,174 @@ +import sys +import os +import torch +import numpy as np +from os.path import join as pjoin +import utils.paramUtil as paramUtil +from utils.plot_script import * + +from utils.utils import * +from utils.motion_process import recover_from_ric +from accelerate.utils import set_seed +from models.gaussian_diffusion import DiffusePipeline +from options.generate_options import GenerateOptions +from utils.model_load import load_model_weights +from motion_loader import get_dataset_loader +from models import build_models +import yaml +from box import Box + + +def yaml_to_box(yaml_file): + with open(yaml_file, "r") as file: + yaml_data = yaml.safe_load(file) + return Box(yaml_data) + + +if __name__ == "__main__": + parser = GenerateOptions() + opt = parser.parse() + set_seed(opt.seed) + device_id = opt.gpu_id + device = torch.device("cuda:%d" % device_id if torch.cuda.is_available() else "cpu") + opt.device = device + + assert opt.dataset_name == "t2m" or "kit" + + # Using a text prompt for generation + if opt.text_prompt != "": + texts = [opt.text_prompt] + opt.num_samples = 1 + motion_lens = [opt.motion_length * opt.fps] + + # Or using texts (in .txt file) for generation + elif opt.input_text != "": + with open(opt.input_text, "r") as fr: + texts = [line.strip() for line in fr.readlines()] + opt.num_samples = len(texts) + if opt.input_lens != "": + with open(opt.input_lens, "r") as fr: + motion_lens = [int(line.strip()) for line in fr.readlines()] + assert len(texts) == len( + motion_lens + ), f"Please ensure that the motion length in {opt.input_lens} corresponds to the text in {opt.input_text}." + else: + motion_lens = [opt.motion_length * opt.fps for _ in range(opt.num_samples)] + + # Or usining texts in dataset + else: + gen_datasetloader = get_dataset_loader( + opt, opt.num_samples, mode="hml_gt", split="test" + ) + texts, _, motion_lens = next(iter(gen_datasetloader)) + + # edit mode + if opt.edit_mode: + edit_config = yaml_to_box("options/edit.yaml") + else: + edit_config = yaml_to_box("options/noedit.yaml") + print(edit_config) + + ckpt_path = pjoin(opt.model_dir, opt.which_ckpt + ".tar") + checkpoint = torch.load(ckpt_path,map_location={'cuda:0': str(device)}) + niter = checkpoint.get('total_it', 0) + # make save dir + out_path = opt.output_dir + if out_path == "": + out_path = pjoin(opt.save_root, "samples_iter{}_seed{}".format(niter, opt.seed)) + if opt.text_prompt != "": + out_path += "_" + opt.text_prompt.replace(" ", "_").replace(".", "") + elif opt.input_text != "": + out_path += "_" + os.path.basename(opt.input_text).replace( + ".txt", "" + ).replace(" ", "_").replace(".", "") + os.makedirs(out_path, exist_ok=True) + + # load model + model = build_models(opt, edit_config=edit_config, out_path=out_path) + niter = load_model_weights(model, ckpt_path, use_ema=not opt.no_ema) + + # Create a pipeline for generation in diffusion model framework + pipeline = DiffusePipeline( + opt=opt, + model=model, + diffuser_name=opt.diffuser_name, + device=device, + num_inference_steps=opt.num_inference_steps, + torch_dtype=torch.float16, + ) + + # generate + pred_motions, _ = pipeline.generate( + texts, torch.LongTensor([int(x) for x in motion_lens]) + ) + + # Convert the generated motion representaion into 3D joint coordinates and save as npy file + npy_dir = pjoin(out_path, "joints_npy") + root_dir = pjoin(out_path, "root_npy") + os.makedirs(npy_dir, exist_ok=True) + os.makedirs(root_dir, exist_ok=True) + print(f"saving results npy file (3d joints) to [{npy_dir}]") + mean = np.load(pjoin(opt.meta_dir, "mean.npy")) + std = np.load(pjoin(opt.meta_dir, "std.npy")) + samples = [] + + root_list = [] + for i, motion in enumerate(pred_motions): + motion = motion.cpu().numpy() * std + mean + np.save(pjoin(npy_dir, f"raw_{i:02}.npy"), motion) + npy_name = f"{i:02}.npy" + # 1. recover 3d joints representation by ik + motion = recover_from_ric(torch.from_numpy(motion).float(), opt.joints_num) + # 2. put on Floor (Y axis) + floor_height = motion.min(dim=0)[0].min(dim=0)[0][1] + motion[:, :, 1] -= floor_height + motion = motion.numpy() + # 3. remove jitter + motion = motion_temporal_filter(motion, sigma=1) + + # save root trajectory (Y axis) + root_trajectory = motion[:, 0, :] + root_list.append(root_trajectory) + np.save(pjoin(root_dir, f"root_{i:02}.npy"), root_trajectory) + y = root_trajectory[:, 1] + + plt.figure() + plt.plot(y) + + plt.legend() + + plt.title("Root Joint Trajectory") + plt.xlabel("Frame") + plt.ylabel("Position") + + plt.savefig("./root_trajectory_xyz.png") + np.save(pjoin(npy_dir, npy_name), motion) + samples.append(motion) + + root_list_res = np.concatenate(root_list, axis=0) + np.save("root_list.npy", root_list_res) + + # save the text and length conditions used for this generation + with open(pjoin(out_path, "results.txt"), "w") as fw: + fw.write("\n".join(texts)) + with open(pjoin(out_path, "results_lens.txt"), "w") as fw: + fw.write("\n".join([str(l) for l in motion_lens])) + + # skeletal animation visualization + print(f"saving motion videos to [{out_path}]...") + for i, title in enumerate(texts): + motion = samples[i] + fname = f"{i:02}.mp4" + kinematic_tree = ( + paramUtil.t2m_kinematic_chain + if (opt.dataset_name == "t2m") + else paramUtil.kit_kinematic_chain + ) + plot_3d_motion( + pjoin(out_path, fname), + kinematic_tree, + motion, + title=title, + fps=opt.fps, + radius=opt.radius, + ) diff --git a/scripts/train.py b/scripts/train.py new file mode 100644 index 0000000000000000000000000000000000000000..4b4d927ea12975fcfef780536054946dfb99e42d --- /dev/null +++ b/scripts/train.py @@ -0,0 +1,66 @@ +import sys +import os +from os.path import join as pjoin +from options.train_options import TrainOptions +from utils.plot_script import * + +from models import build_models +from utils.ema import ExponentialMovingAverage +from trainers import DDPMTrainer +from motion_loader import get_dataset_loader + +from accelerate.utils import set_seed +from accelerate import Accelerator +import torch + +import yaml +from box import Box + +def yaml_to_box(yaml_file): + with open(yaml_file, 'r') as file: + yaml_data = yaml.safe_load(file) + + return Box(yaml_data) + +if __name__ == '__main__': + accelerator = Accelerator() + + parser = TrainOptions() + opt = parser.parse(accelerator) + set_seed(opt.seed) + torch.autograd.set_detect_anomaly(True) + + opt.save_root = pjoin(opt.checkpoints_dir, opt.dataset_name, opt.name) + opt.model_dir = pjoin(opt.save_root, 'model') + opt.meta_dir = pjoin(opt.save_root, 'meta') + + if opt.edit_mode: + edit_config = yaml_to_box('options/edit.yaml') + else: + edit_config = yaml_to_box('options/noedit.yaml') + + if accelerator.is_main_process: + os.makedirs(opt.model_dir, exist_ok=True) + os.makedirs(opt.meta_dir, exist_ok=True) + + train_datasetloader = get_dataset_loader(opt, batch_size = opt.batch_size, split='train', accelerator=accelerator, mode='train') # 7169 + + + accelerator.print('\nInitializing model ...' ) + encoder = build_models(opt, edit_config=edit_config) + model_ema = None + if opt.model_ema: + # Decay adjustment that aims to keep the decay independent of other hyper-parameters originally proposed at: + # https://github.com/facebookresearch/pycls/blob/f8cd9627/pycls/core/net.py#L123 + adjust = 106_667 * opt.model_ema_steps / opt.num_train_steps + alpha = 1.0 - opt.model_ema_decay + alpha = min(1.0, alpha * adjust) + print('EMA alpha:',alpha) + model_ema = ExponentialMovingAverage(encoder, decay=1.0 - alpha) + accelerator.print('Finish building Model.\n') + + trainer = DDPMTrainer(opt, encoder,accelerator, model_ema) + + trainer.train(train_datasetloader) + + diff --git a/trainers/__init__.py b/trainers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e35c648260b478cf5c9a8c6159e155ee8e0c8234 --- /dev/null +++ b/trainers/__init__.py @@ -0,0 +1,4 @@ +from .ddpm_trainer import DDPMTrainer + + +__all__ = ['DDPMTrainer'] \ No newline at end of file diff --git a/trainers/ddpm_trainer.py b/trainers/ddpm_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..3e0791564021e09e98912439768b710b3f9d4704 --- /dev/null +++ b/trainers/ddpm_trainer.py @@ -0,0 +1,245 @@ +import torch +import time +import torch.optim as optim +from collections import OrderedDict +from utils.utils import print_current_loss +from os.path import join as pjoin + +from diffusers import DDPMScheduler +from torch.utils.tensorboard import SummaryWriter +import time +import pdb +import sys +import os +from torch.optim.lr_scheduler import ExponentialLR + + +class DDPMTrainer(object): + + def __init__(self, args, model, accelerator, model_ema=None): + self.opt = args + self.accelerator = accelerator + self.device = self.accelerator.device + self.model = model + self.diffusion_steps = args.diffusion_steps + self.noise_scheduler = DDPMScheduler( + num_train_timesteps=self.diffusion_steps, + beta_schedule=args.beta_schedule, + variance_type="fixed_small", + prediction_type=args.prediction_type, + clip_sample=False, + ) + self.model_ema = model_ema + if args.is_train: + self.mse_criterion = torch.nn.MSELoss(reduction="none") + + accelerator.print("Diffusion_config:\n", self.noise_scheduler.config) + + if self.accelerator.is_main_process: + starttime = time.strftime("%Y-%m-%d_%H:%M:%S") + print("Start experiment:", starttime) + self.writer = SummaryWriter( + log_dir=pjoin(args.save_root, "logs_") + starttime[:16], + comment=starttime[:16], + flush_secs=60, + ) + self.accelerator.wait_for_everyone() + + self.optimizer = optim.AdamW( + self.model.parameters(), lr=self.opt.lr, weight_decay=self.opt.weight_decay + ) + self.scheduler = ( + ExponentialLR(self.optimizer, gamma=args.decay_rate) + if args.decay_rate > 0 + else None + ) + + @staticmethod + def zero_grad(opt_list): + for opt in opt_list: + opt.zero_grad() + + def clip_norm(self, network_list): + for network in network_list: + self.accelerator.clip_grad_norm_( + network.parameters(), self.opt.clip_grad_norm + ) # 0.5 -> 1 + + @staticmethod + def step(opt_list): + for opt in opt_list: + opt.step() + + def forward(self, batch_data): + caption, motions, m_lens = batch_data + motions = motions.detach().float() + + x_start = motions + B, T = x_start.shape[:2] + cur_len = torch.LongTensor([min(T, m_len) for m_len in m_lens]).to(self.device) + self.src_mask = self.generate_src_mask(T, cur_len).to(x_start.device) + + # 1. Sample noise that we'll add to the motion + real_noise = torch.randn_like(x_start) + + # 2. Sample a random timestep for each motion + t = torch.randint(0, self.diffusion_steps, (B,), device=self.device) + self.timesteps = t + + # 3. Add noise to the motion according to the noise magnitude at each timestep + # (this is the forward diffusion process) + x_t = self.noise_scheduler.add_noise(x_start, real_noise, t) + + # 4. network prediction + self.prediction = self.model(x_t, t, text=caption) + + if self.opt.prediction_type == "sample": + self.target = x_start + elif self.opt.prediction_type == "epsilon": + self.target = real_noise + elif self.opt.prediction_type == "v_prediction": + self.target = self.noise_scheduler.get_velocity(x_start, real_noise, t) + + def masked_l2(self, a, b, mask, weights): + + loss = self.mse_criterion(a, b).mean(dim=-1) # (bath_size, motion_length) + + loss = (loss * mask).sum(-1) / mask.sum(-1) # (batch_size, ) + + loss = (loss * weights).mean() + + return loss + + def backward_G(self): + loss_logs = OrderedDict({}) + mse_loss_weights = torch.ones_like(self.timesteps) + loss_logs["loss_mot_rec"] = self.masked_l2( + self.prediction, self.target, self.src_mask, mse_loss_weights + ) + + self.loss = loss_logs["loss_mot_rec"] + + return loss_logs + + def update(self): + self.zero_grad([self.optimizer]) + loss_logs = self.backward_G() + self.accelerator.backward(self.loss) + self.clip_norm([self.model]) + self.step([self.optimizer]) + + return loss_logs + + def generate_src_mask(self, T, length): + B = len(length) + src_mask = torch.ones(B, T) + for i in range(B): + for j in range(length[i], T): + src_mask[i, j] = 0 + return src_mask + + def train_mode(self): + self.model.train() + if self.model_ema: + self.model_ema.train() + + def eval_mode(self): + self.model.eval() + if self.model_ema: + self.model_ema.eval() + + def save(self, file_name, total_it): + state = { + "opt_encoder": self.optimizer.state_dict(), + "total_it": total_it, + "encoder": self.accelerator.unwrap_model(self.model).state_dict(), + } + if self.model_ema: + state["model_ema"] = self.accelerator.unwrap_model( + self.model_ema + ).module.state_dict() + torch.save(state, file_name) + return + + def load(self, model_dir): + checkpoint = torch.load(model_dir, map_location=self.device) + self.optimizer.load_state_dict(checkpoint["opt_encoder"]) + if self.model_ema: + self.model_ema.load_state_dict(checkpoint["model_ema"], strict=True) + self.model.load_state_dict(checkpoint["encoder"], strict=True) + + return checkpoint.get("total_it", 0) + + def train(self, train_loader): + + it = 0 + if self.opt.is_continue: + model_path = pjoin(self.opt.model_dir, self.opt.continue_ckpt) + it = self.load(model_path) + self.accelerator.print(f"continue train from {it} iters in {model_path}") + start_time = time.time() + + logs = OrderedDict() + self.dataset = train_loader.dataset + self.model, self.mse_criterion, self.optimizer, train_loader, self.model_ema = ( + self.accelerator.prepare( + self.model, + self.mse_criterion, + self.optimizer, + train_loader, + self.model_ema, + ) + ) + + num_epochs = (self.opt.num_train_steps - it) // len(train_loader) + 1 + self.accelerator.print(f"need to train for {num_epochs} epochs....") + + for epoch in range(0, num_epochs): + self.train_mode() + for i, batch_data in enumerate(train_loader): + self.forward(batch_data) + log_dict = self.update() + it += 1 + + if self.model_ema and it % self.opt.model_ema_steps == 0: + self.accelerator.unwrap_model(self.model_ema).update_parameters( + self.model + ) + + # update logger + for k, v in log_dict.items(): + if k not in logs: + logs[k] = v + else: + logs[k] += v + + if it % self.opt.log_every == 0: + mean_loss = OrderedDict({}) + for tag, value in logs.items(): + mean_loss[tag] = value / self.opt.log_every + logs = OrderedDict() + print_current_loss( + self.accelerator, start_time, it, mean_loss, epoch, inner_iter=i + ) + if self.accelerator.is_main_process: + self.writer.add_scalar("loss", mean_loss["loss_mot_rec"], it) + self.accelerator.wait_for_everyone() + + if ( + it % self.opt.save_interval == 0 + and self.accelerator.is_main_process + ): # Save model + self.save(pjoin(self.opt.model_dir, "latest.tar").format(it), it) + self.accelerator.wait_for_everyone() + + if (self.scheduler is not None) and ( + it % self.opt.update_lr_steps == 0 + ): + self.scheduler.step() + + # Save the last checkpoint if it wasn't already saved. + if it % self.opt.save_interval != 0 and self.accelerator.is_main_process: + self.save(pjoin(self.opt.model_dir, "latest.tar"), it) + + self.accelerator.wait_for_everyone() + self.accelerator.print("FINISH") diff --git a/utils/__init__.py b/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/utils/constants.py b/utils/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..6007802cd629e905bbcd819e4e41183e3f01d0ec --- /dev/null +++ b/utils/constants.py @@ -0,0 +1,176 @@ +SMPL_FOOT_R = [8, 11] +SMPL_FOOT_L = [7, 10] +SMPL_FACE_FORWARD_JOINTS = [2, 1, 17, 16] + +# Define a kinematic tree for the skeletal struture +SMPL_BODY_CHAIN = [ + [0, 2, 5, 8, 11], + [0, 1, 4, 7, 10], + [0, 3, 6, 9, 12, 15], + [9, 14, 17, 19, 21], + [9, 13, 16, 18, 20], +] +SMPL_LEFT_HAND_CHAIN = [ + [20, 22, 23, 24], + [20, 34, 35, 36], + [20, 25, 26, 27], + [20, 31, 32, 33], + [20, 28, 29, 30], +] +SMPL_RIGHT_HAND_CHAIN = [ + [21, 43, 44, 45], + [21, 46, 47, 48], + [21, 40, 41, 42], + [21, 37, 38, 39], + [21, 49, 50, 51], +] + +SMPL_BODY_BONES = [ + -0.0018, + -0.2233, + 0.0282, + 0.0695, + -0.0914, + -0.0068, + -0.0677, + -0.0905, + -0.0043, + -0.0025, + 0.1090, + -0.0267, + 0.0343, + -0.3752, + -0.0045, + -0.0383, + -0.3826, + -0.0089, + 0.0055, + 0.1352, + 0.0011, + -0.0136, + -0.3980, + -0.0437, + 0.0158, + -0.3984, + -0.0423, + 0.0015, + 0.0529, + 0.0254, + 0.0264, + -0.0558, + 0.1193, + -0.0254, + -0.0481, + 0.1233, + -0.0028, + 0.2139, + -0.0429, + 0.0788, + 0.1217, + -0.0341, + -0.0818, + 0.1188, + -0.0386, + 0.0052, + 0.0650, + 0.0513, + 0.0910, + 0.0305, + -0.0089, + -0.0960, + 0.0326, + -0.0091, + 0.2596, + -0.0128, + -0.0275, + -0.2537, + -0.0133, + -0.0214, + 0.2492, + 0.0090, + -0.0012, + -0.2553, + 0.0078, + -0.0056, + 0.0840, + -0.0082, + -0.0149, + -0.0846, + -0.0061, + -0.0103, +] + +SMPL_HYBRIK = [ + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 0, + 1, + 1, + 1, + 1, + 0, + 0, + 0, + 0, + 0, + 0, +] + +SMPL_BODY_PARENTS = [ + 0, + 0, + 0, + 0, + 1, + 2, + 3, + 4, + 5, + 6, + 7, + 8, + 9, + 9, + 9, + 12, + 13, + 14, + 16, + 17, + 18, + 19, +] + +SMPL_BODY_CHILDS = [ + -1, + 4, + 5, + 6, + 7, + 8, + 9, + 10, + 11, + -1, + -2, + -2, + 15, + 16, + 17, + -2, + 18, + 19, + 20, + 21, + -2, + -2, +] diff --git a/utils/ema.py b/utils/ema.py new file mode 100644 index 0000000000000000000000000000000000000000..947c9aef9c2d0086718914a68c064660d0f15b40 --- /dev/null +++ b/utils/ema.py @@ -0,0 +1,13 @@ +import torch + +class ExponentialMovingAverage(torch.optim.swa_utils.AveragedModel): + """Maintains moving averages of model parameters using an exponential decay. + ``ema_avg = decay * avg_model_param + (1 - decay) * model_param`` + `torch.optim.swa_utils.AveragedModel