diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index d0ab6405a8623874dc3bbb367bf6aec8663f1e7b..5949ff012ba817317b6427759dcb7f0cfb55fb4e 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,86 @@ ---- -title: LaVie -emoji: πŸƒ -colorFrom: gray -colorTo: green -sdk: gradio -sdk_version: 4.7.1 -app_file: app.py -pinned: false ---- - -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +# LaVie: High-Quality Video Generation with Cascaded Latent Diffusion Models + +This repository is the official PyTorch implementation of [LaVie](https://arxiv.org/abs/2309.15103). + +**LaVie** is a Text-to-Video (T2V) generation framework, and main part of video generation system [Vchitect](http://vchitect.intern-ai.org.cn/). + +[![arXiv](https://img.shields.io/badge/arXiv-2307.04725-b31b1b.svg)](https://arxiv.org/abs/2309.15103) +[![Project Page](https://img.shields.io/badge/Project-Website-green)](https://vchitect.github.io/LaVie-project/) + + + + +## Installation +``` +conda env create -f environment.yml +conda activate lavie +``` + +## Download Pre-Trained models +Download [pre-trained models](https://huggingface.co/YaohuiW/LaVie/tree/main), [stable diffusion 1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main), [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/tree/main) to `./pretrained_models`. You should be able to see the following: +``` +β”œβ”€β”€ pretrained_models +β”‚ β”œβ”€β”€ lavie_base.pt +β”‚ β”œβ”€β”€ lavie_interpolation.pt +β”‚ β”œβ”€β”€ lavie_vsr.pt +β”‚ β”œβ”€β”€ stable-diffusion-v1-4 +β”‚ β”‚ β”œβ”€β”€ ... +└── └── stable-diffusion-x4-upscaler + β”œβ”€β”€ ... +``` + +## Inference +The inference contains **Base T2V**, **Video Interpolation** and **Video Super-Resolution** three steps. We provide several options to generate videos: +* **Step1**: 320 x 512 resolution, 16 frames +* **Step1+Step2**: 320 x 512 resolution, 61 frames +* **Step1+Step3**: 1280 x 2048 resolution, 16 frames +* **Step1+Step2+Step3**: 1280 x 2048 resolution, 61 frames + +Feel free to try different options:) + + +### Step1. Base T2V +Run following command to generate videos from base T2V model. +``` +cd base +python pipelines/sample.py --config configs/sample.yaml +``` +Edit `text_prompt` in `configs/sample.yaml` to change prompt, results will be saved under `./res/base`. + +### Step2 (optional). Video Interpolation +Run following command to conduct video interpolation. +``` +cd interpolation +python sample.py --config configs/sample.yaml +``` +The default input video path is `./res/base`, results will be saved under `./res/interpolation`. In `configs/sample.yaml`, you could modify default `input_folder` with `YOUR_INPUT_FOLDER` in `configs/sample.yaml`. Input videos should be named as `prompt1.mp4`, `prompt2.mp4`, ... and put under `YOUR_INPUT_FOLDER`. Launching the code will process all the input videos in `input_folder`. + + +### Step3 (optional). Video Super-Resolution +Run following command to conduct video super-resolution. +``` +cd vsr +python sample.py --config configs/sample.yaml +``` +The default input video path is `./res/base` and results will be saved under `./res/vsr`. You could modify default `input_path` with `YOUR_INPUT_FOLDER` in `configs/sample.yaml`. Smiliar to Step2, input videos should be named as `prompt1.mp4`, `prompt2.mp4`, ... and put under `YOUR_INPUT_FOLDER`. Launching the code will process all the input videos in `input_folder`. + + +## BibTex +```bibtex +@article{wang2023lavie, + title={LAVIE: High-Quality Video Generation with Cascaded Latent Diffusion Models}, + author={Wang, Yaohui and Chen, Xinyuan and Ma, Xin and Zhou, Shangchen and Huang, Ziqi and Wang, Yi and Yang, Ceyuan and He, Yinan and Yu, Jiashuo and Yang, Peiqing and others}, + journal={arXiv preprint arXiv:2309.15103}, + year={2023} +} +``` + +## Acknowledgements +The code is buit upon [diffusers](https://github.com/huggingface/diffusers) and [Stable Diffusion](https://github.com/CompVis/stable-diffusion), we thank all the contributors for open-sourcing. + + +## License +The code is licensed under Apache-2.0, model weights are fully open for academic research and also allow **free** commercial usage. To apply for a commercial license, please fill in the [application form](). diff --git a/base/__pycache__/download.cpython-311.pyc b/base/__pycache__/download.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8f9de5e7eb4b7a4bfd816fc575ea4888229239de Binary files /dev/null and b/base/__pycache__/download.cpython-311.pyc differ diff --git a/base/app.py b/base/app.py new file mode 100644 index 0000000000000000000000000000000000000000..14c66df82b058e77c83c9cd1106e5df64aa939c0 --- /dev/null +++ b/base/app.py @@ -0,0 +1,175 @@ +import gradio as gr +from text_to_video import model_t2v_fun,setup_seed +from omegaconf import OmegaConf +import torch +import imageio +import os +import cv2 +import pandas as pd +import torchvision +import random +config_path = "/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml" +args = OmegaConf.load("/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml") +device = "cuda" if torch.cuda.is_available() else "cpu" +# ------- get model --------------- +model_t2V = model_t2v_fun(args) +model_t2V.to(device) +if device == "cuda": + model_t2V.enable_xformers_memory_efficient_attention() + +# model_t2V.enable_xformers_memory_efficient_attention() +css = """ +h1 { + text-align: center; +} +#component-0 { + max-width: 730px; + margin: auto; +} +""" + +def infer(prompt, seed_inp, ddim_steps,cfg): + if seed_inp!=-1: + setup_seed(seed_inp) + else: + seed_inp = random.choice(range(10000000)) + setup_seed(seed_inp) + videos = model_t2V(prompt, video_length=16, height = 320, width= 512, num_inference_steps=ddim_steps, guidance_scale=cfg).video + print(videos[0].shape) + if not os.path.exists(args.output_folder): + os.mkdir(args.output_folder) + torchvision.io.write_video(args.output_folder + prompt[0:30].replace(' ', '_') + '-'+str(seed_inp)+'-'+str(ddim_steps)+'-'+str(cfg)+ '-.mp4', videos[0], fps=8) + # imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8) + # video = cv2.VideoCapture(args.output_folder + prompt.replace(' ', '_') + '.mp4') + # video = imageio.get_reader(args.output_folder + prompt.replace(' ', '_') + '.mp4', 'ffmpeg') + + + # video = model_t2V(prompt, seed_inp, ddim_steps) + + return args.output_folder + prompt[0:30].replace(' ', '_') + '-'+str(seed_inp)+'-'+str(ddim_steps)+'-'+str(cfg)+ '-.mp4' + +print(1) + +# def clean(): +# return gr.Image.update(value=None, visible=False), gr.Video.update(value=None) +def clean(): + return gr.Video.update(value=None) + +title = """ +
+
+

+ InternΒ·Vchitect (Text-to-Video) +

+
+

+ Apply InternΒ·Vchitect to generate a video +

+
+""" + +# print(1) +with gr.Blocks(css='style.css') as demo: + gr.Markdown("
LaVie: Text-to-Video generation
") + with gr.Column(): + with gr.Row(elem_id="col-container"): + # inputs = [prompt, seed_inp, ddim_steps] + # outputs = [video_out] + with gr.Column(): + + prompt = gr.Textbox(value="a teddy bear walking on the street", label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in", min_width=200, lines=2) + + ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=50, step=1) + seed_inp = gr.Slider(value=-1,label="seed (for random generation, use -1)",show_label=True,minimum=-1,maximum=2147483647) + cfg = gr.Number(label="guidance_scale",value=7) + # seed_inp = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=400, elem_id="seed-in") + + # with gr.Row(): + # # control_task = gr.Dropdown(label="Task", choices=["Text-2-video", "Image-2-video"], value="Text-2-video", multiselect=False, elem_id="controltask-in") + # ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1) + # seed_inp = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=123456, elem_id="seed-in") + + # ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1) + # ex = gr.Examples( + # examples = [['a corgi walking in the park at sunrise, oil painting style',400,50,7], + # ['a cut teddy bear reading a book in the park, oil painting style, high quality',700,50,7], + # ['an epic tornado attacking above a glowing city at night, the tornado is made of smoke, highly detailed',230,50,7], + # ['a jar filled with fire, 4K video, 3D rendered, well-rendered',400,50,7], + # ['a teddy bear walking in the park, oil painting style, high quality',400,50,7], + # ['a teddy bear walking on the street, 2k, high quality',100,50,7], + # ['a panda taking a selfie, 2k, high quality',400,50,7], + # ['a polar bear playing drum kit in NYC Times Square, 4k, high resolution',400,50,7], + # ['jungle river at sunset, ultra quality',400,50,7], + # ['a shark swimming in clear Carribean ocean, 2k, high quality',400,50,7], + # ['A steam train moving on a mountainside by Vincent van Gogh',230,50,7], + # ['a confused grizzly bear in calculus class',1000,50,7]], + # fn = infer, + # inputs=[prompt, seed_inp, ddim_steps,cfg], + # # outputs=[video_out], + # cache_examples=False, + # examples_per_page = 6 + # ) + # ex.dataset.headers = [""] + + with gr.Column(): + submit_btn = gr.Button("Generate video") + clean_btn = gr.Button("Clean video") + # submit_btn = gr.Button("Generate video", size='sm') + # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512) + video_out = gr.Video(label="Video result", elem_id="video-output") + # with gr.Row(): + # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512) + # submit_btn = gr.Button("Generate video", size='sm') + + + # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512) + inputs = [prompt, seed_inp, ddim_steps,cfg] + outputs = [video_out] + # gr.Examples( + # value = [['An astronaut riding a horse',123,50], + # ['a panda eating bamboo on a rock',123,50], + # ['Spiderman is surfing',123,50]], + # label = "example of sampling", + # show_label = True, + # headers = ['prompt','seed','steps'], + # datatype = ['str','number','number'], + # row_count=4, + # col_count=(3,"fixed") + # ) + ex = gr.Examples( + examples = [['a corgi walking in the park at sunrise, oil painting style',400,50,7], + ['a cut teddy bear reading a book in the park, oil painting style, high quality',700,50,7], + ['an epic tornado attacking above a glowing city at night, the tornado is made of smoke, highly detailed',230,50,7], + ['a jar filled with fire, 4K video, 3D rendered, well-rendered',400,50,7], + ['a teddy bear walking in the park, oil painting style, high quality',400,50,7], + ['a teddy bear walking on the street, 2k, high quality',100,50,7], + ['a panda taking a selfie, 2k, high quality',400,50,7], + ['a polar bear playing drum kit in NYC Times Square, 4k, high resolution',400,50,7], + ['jungle river at sunset, ultra quality',400,50,7], + ['a shark swimming in clear Carribean ocean, 2k, high quality',400,50,7], + ['A steam train moving on a mountainside by Vincent van Gogh',230,50,7], + ['a confused grizzly bear in calculus class',1000,50,7]], + fn = infer, + inputs=[prompt, seed_inp, ddim_steps,cfg], + outputs=[video_out], + cache_examples=True, + ) + ex.dataset.headers = [""] + + # control_task.change(change_task_options, inputs=[control_task], outputs=[canny_opt, hough_opt, normal_opt], queue=False) + # submit_btn.click(clean, inputs=[], outputs=[video_out], queue=False) + clean_btn.click(clean, inputs=[], outputs=[video_out], queue=False) + submit_btn.click(infer, inputs, outputs) + # share_button.click(None, [], [], _js=share_js) + + print(2) +demo.queue(max_size=12).launch(server_name="0.0.0.0", server_port=7860) + + diff --git a/base/app.sh b/base/app.sh new file mode 100644 index 0000000000000000000000000000000000000000..0d9381ee026325b890a58ee74289374f94e36b30 --- /dev/null +++ b/base/app.sh @@ -0,0 +1 @@ +srun -p aigc-video --gres=gpu:1 -n1 -N1 python app.py \ No newline at end of file diff --git a/base/configs/sample.yaml b/base/configs/sample.yaml new file mode 100644 index 0000000000000000000000000000000000000000..0fbf8cdda3070e17160bb51491114f837ecf2392 --- /dev/null +++ b/base/configs/sample.yaml @@ -0,0 +1,28 @@ +# path: +output_folder: "/mnt/petrelfs/share_data/zhouyan/gradio/lavie" +pretrained_path: "/mnt/petrelfs/zhouyan/models" + +# model config: +model: UNet +video_length: 16 +image_size: [320, 512] + +# beta schedule +beta_start: 0.0001 +beta_end: 0.02 +beta_schedule: "linear" + +# model speedup +use_compile: False +use_fp16: True + +# sample config: +seed: 3 +run_time: 0 +guidance_scale: 7.0 +sample_method: 'ddpm' +num_sampling_steps: 250 +text_prompt: [ + 'a teddy bear walking on the street, high quality, 2k', + + ] diff --git a/base/download.py b/base/download.py new file mode 100644 index 0000000000000000000000000000000000000000..26de159d098d2086a35f1504477eb5c01a35f540 --- /dev/null +++ b/base/download.py @@ -0,0 +1,18 @@ +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import os + + +def find_model(model_name): + """ + Finds a pre-trained model, downloading it if necessary. Alternatively, loads a model from a local path. + """ + checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) + if "ema" in checkpoint: # supports checkpoints from train.py + print('Ema existing!') + checkpoint = checkpoint["ema"] + return checkpoint diff --git a/base/gradio_cached_examples/14/Video result/35727d2ebeb816c94d68/laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/35727d2ebeb816c94d68/laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2a75ce05ceb622035785710d5b4e161875326613 Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/35727d2ebeb816c94d68/laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/Video result/42b8a418d77480fcc8fc/laviea_panda_taking_a_selfie_2k_h-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/42b8a418d77480fcc8fc/laviea_panda_taking_a_selfie_2k_h-400-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d21b92f22c837fa2011f50a0e9a71774727f6a22 Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/42b8a418d77480fcc8fc/laviea_panda_taking_a_selfie_2k_h-400-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/Video result/4b54d32b7e8f2a3cd333/lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/4b54d32b7e8f2a3cd333/lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..6e29f47c91bc1fcfe35e2e068fafc368a81c01b4 Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/4b54d32b7e8f2a3cd333/lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/Video result/625fb799abdbcb60fe2f/laviea_corgi_walking_in_the_park_at-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/625fb799abdbcb60fe2f/laviea_corgi_walking_in_the_park_at-400-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..18bb707b269f4ce942033d55ecacec385fe1da33 Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/625fb799abdbcb60fe2f/laviea_corgi_walking_in_the_park_at-400-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/Video result/75b535bd2f78c28d2789/laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/75b535bd2f78c28d2789/laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2bf26ef94098567f37c8bfdbecc17e26eec37a99 Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/75b535bd2f78c28d2789/laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/Video result/767a0718deb1983b3d43/laviean_epic_tornado_attacking_abov-230-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/767a0718deb1983b3d43/laviean_epic_tornado_attacking_abov-230-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2fdbde0250b7511f0ec1e2b60bba38cd2bccc673 Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/767a0718deb1983b3d43/laviean_epic_tornado_attacking_abov-230-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/Video result/a15ccccc7c42e18cd062/laviea_jar_filled_with_fire_4K_vid-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/a15ccccc7c42e18cd062/laviea_jar_filled_with_fire_4K_vid-400-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..2518cdc98689cbc489213f220c41f873e29d221b Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/a15ccccc7c42e18cd062/laviea_jar_filled_with_fire_4K_vid-400-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/Video result/beedae14fa3a8f24e4ec/laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/beedae14fa3a8f24e4ec/laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..ae6ef5c9c0ff955fc82ca1a1b5c6d3489b93f7d5 Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/beedae14fa3a8f24e4ec/laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/Video result/c2e7acb8ce5cb0a52899/laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/c2e7acb8ce5cb0a52899/laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..81575e7cbeb1ede1bd9741bf91b8648721c3273c Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/c2e7acb8ce5cb0a52899/laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/Video result/cecd7ff29690b876a418/laviejungle_river_at_sunset_ultra_-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/cecd7ff29690b876a418/laviejungle_river_at_sunset_ultra_-400-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..d6754f2846e6b66f5109079a4d78c6e560a15d2c Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/cecd7ff29690b876a418/laviejungle_river_at_sunset_ultra_-400-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/Video result/e67b3c12db1c38afd2c4/laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/e67b3c12db1c38afd2c4/laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..88de28546db43a883e9cd2ed10c1ae58d1aaf402 Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/e67b3c12db1c38afd2c4/laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/Video result/feeee8981f36b962bfe6/laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/feeee8981f36b962bfe6/laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4 new file mode 100644 index 0000000000000000000000000000000000000000..68ad51852fe928376b0bb47ad00df422caf443a6 Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/feeee8981f36b962bfe6/laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4 differ diff --git a/base/gradio_cached_examples/14/log.csv b/base/gradio_cached_examples/14/log.csv new file mode 100644 index 0000000000000000000000000000000000000000..6bb703837fe42ccdc80fb5802c2304c5d790c286 --- /dev/null +++ b/base/gradio_cached_examples/14/log.csv @@ -0,0 +1,13 @@ +Video result,flag,username,timestamp +"{""video"":{""path"":""gradio_cached_examples/14/Video result/625fb799abdbcb60fe2f/laviea_corgi_walking_in_the_park_at-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_corgi_walking_in_the_park_at-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:13.139609 +"{""video"":{""path"":""gradio_cached_examples/14/Video result/feeee8981f36b962bfe6/laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:23.543257 +"{""video"":{""path"":""gradio_cached_examples/14/Video result/767a0718deb1983b3d43/laviean_epic_tornado_attacking_abov-230-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviean_epic_tornado_attacking_abov-230-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:33.942899 +"{""video"":{""path"":""gradio_cached_examples/14/Video result/a15ccccc7c42e18cd062/laviea_jar_filled_with_fire_4K_vid-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_jar_filled_with_fire,_4K_vid-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:44.348969 +"{""video"":{""path"":""gradio_cached_examples/14/Video result/c2e7acb8ce5cb0a52899/laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:54.765554 +"{""video"":{""path"":""gradio_cached_examples/14/Video result/75b535bd2f78c28d2789/laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:05.255612 +"{""video"":{""path"":""gradio_cached_examples/14/Video result/42b8a418d77480fcc8fc/laviea_panda_taking_a_selfie_2k_h-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_panda_taking_a_selfie,_2k,_h-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:15.694357 +"{""video"":{""path"":""gradio_cached_examples/14/Video result/e67b3c12db1c38afd2c4/laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:26.121546 +"{""video"":{""path"":""gradio_cached_examples/14/Video result/cecd7ff29690b876a418/laviejungle_river_at_sunset_ultra_-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviejungle_river_at_sunset,_ultra_-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:36.540682 +"{""video"":{""path"":""gradio_cached_examples/14/Video result/beedae14fa3a8f24e4ec/laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:46.992686 +"{""video"":{""path"":""gradio_cached_examples/14/Video result/4b54d32b7e8f2a3cd333/lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:57.458758 +"{""video"":{""path"":""gradio_cached_examples/14/Video result/35727d2ebeb816c94d68/laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:03:07.878403 diff --git a/base/huggingface-t2v/.DS_Store b/base/huggingface-t2v/.DS_Store new file mode 100644 index 0000000000000000000000000000000000000000..ec6bded75c322a3b51f99901c1eb510901d2bd75 Binary files /dev/null and b/base/huggingface-t2v/.DS_Store differ diff --git a/base/huggingface-t2v/__init__.py b/base/huggingface-t2v/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/base/huggingface-t2v/requirements.txt b/base/huggingface-t2v/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/base/models/__init__.py b/base/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f93daa1bf5131bd1251a44e99aeeb0d127ea71f1 --- /dev/null +++ b/base/models/__init__.py @@ -0,0 +1,33 @@ +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from .unet import UNet3DConditionModel +from torch.optim.lr_scheduler import LambdaLR + +def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit + from torch.optim.lr_scheduler import LambdaLR + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'warmup': + return customized_lr_scheduler(optimizer, **kwargs) + elif name == 'cosine': + from torch.optim.lr_scheduler import CosineAnnealingLR + return CosineAnnealingLR(optimizer, **kwargs) + else: + raise NotImplementedError(name) + +def get_models(args, sd_path): + + if 'UNet' in args.model: + return UNet3DConditionModel.from_pretrained_2d(sd_path, subfolder="unet") + else: + raise '{} Model Not Supported!'.format(args.model) + diff --git a/base/models/__pycache__/__init__.cpython-311.pyc b/base/models/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..208ad63a8430425ee32dad7d7738d750a294fd8b Binary files /dev/null and b/base/models/__pycache__/__init__.cpython-311.pyc differ diff --git a/base/models/__pycache__/attention.cpython-311.pyc b/base/models/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5b4c56a663f3de4d5a2b04fffaf188ce0167437a Binary files /dev/null and b/base/models/__pycache__/attention.cpython-311.pyc differ diff --git a/base/models/__pycache__/resnet.cpython-311.pyc b/base/models/__pycache__/resnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..80e4daa3a256d39ec31cc468afc5f2b0fbd374ea Binary files /dev/null and b/base/models/__pycache__/resnet.cpython-311.pyc differ diff --git a/base/models/__pycache__/unet.cpython-311.pyc b/base/models/__pycache__/unet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..454d8468602a267fe00425b659445343a76233dc Binary files /dev/null and b/base/models/__pycache__/unet.cpython-311.pyc differ diff --git a/base/models/__pycache__/unet_blocks.cpython-311.pyc b/base/models/__pycache__/unet_blocks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a47de4825f787c06109b504f0312ca58af82a483 Binary files /dev/null and b/base/models/__pycache__/unet_blocks.cpython-311.pyc differ diff --git a/base/models/attention.py b/base/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..fb5b0ff4a7fb9a8d667a9b502017208293ab2609 --- /dev/null +++ b/base/models/attention.py @@ -0,0 +1,707 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from dataclasses import dataclass +from typing import Optional + +import math +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm +from rotary_embedding_torch import RotaryEmbedding +from typing import Callable, Optional +from einops import rearrange, repeat + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +def exists(x): + return x is not None + + +class CrossAttention(nn.Module): + r""" + copy from diffuser 0.11.1 + A cross attention layer. + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + use_relative_position: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + self.use_relative_position = use_relative_position + if self.use_relative_position: + self.rotary_emb = RotaryEmbedding(min(32, dim_head)) + + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def reshape_for_scores(self, tensor): + # split heads and dims + # tensor should be [b (h w)] f (d nd) + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).contiguous() + return tensor + + def same_batch_dim_to_heads(self, tensor): + batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d + tensor = tensor.reshape(batch_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + + # print('before reshpape query shape', query.shape) + dim = query.shape[-1] + if not self.use_relative_position: + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # print('after reshape query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if not self.use_relative_position: + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous() + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length, + use_image_num=use_image_num, + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous() + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_first_frame = use_first_frame + + # Spatial-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Text Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Temp + self.attn_temp = TemporalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + rotary_emb=rotary_emb, + ) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None): + + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, use_image_num=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Temporal Attention + if self.training: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous() + hidden_states_video = hidden_states[:, :video_length, :] + hidden_states_image = hidden_states[:, video_length:, :] + norm_hidden_states_video = ( + self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video) + ) + hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video + hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1) + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous() + else: + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous() + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous() + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + +class TemporalAttention(CrossAttention): + def __init__(self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + rotary_emb=None): + super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups) + # relative time positional embeddings + self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet + self.rotary_emb = rotary_emb + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device) + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + dim = query.shape[-1] + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + # reshape for adding time positional bais + query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + + if exists(self.rotary_emb): + query = self.rotary_emb.rotate_queries_or_keys(query) + key = self.rotary_emb.rotate_queries_or_keys(key) + + attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key) + + attention_scores = attention_scores + time_rel_pos_bias + + if attention_mask is not None: + # add attention mask + attention_scores = attention_scores + attention_mask + + attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach() + + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + # print(attention_probs[0][0]) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value) + hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)') + return hidden_states + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames \ No newline at end of file diff --git a/base/models/clip.py b/base/models/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..2879f919e9081a373582e4734cded621bab8245d --- /dev/null +++ b/base/models/clip.py @@ -0,0 +1,120 @@ +import numpy +import torch.nn as nn +from transformers import CLIPTokenizer, CLIPTextModel + +import transformers +transformers.logging.set_verbosity_error() + +""" +Will encounter following warning: +- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task +or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). +- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model +that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). + +https://github.com/CompVis/stable-diffusion/issues/97 +according to this issue, this warning is safe. + +This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion. +You can safely ignore the warning, it is not an error. + +This clip usage is from U-ViT and same with Stable Diffusion. +""" + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77): + def __init__(self, path, device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer") + self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder') + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class TextEmbedder(nn.Module): + """ + Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. + """ + def __init__(self, path, dropout_prob=0.1): + super().__init__() + self.text_encodder = FrozenCLIPEmbedder(path=path) + self.dropout_prob = dropout_prob + + def token_drop(self, text_prompts, force_drop_ids=None): + """ + Drops text to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob + else: + # TODO + drop_ids = force_drop_ids == 1 + labels = list(numpy.where(drop_ids, "", text_prompts)) + # print(labels) + return labels + + def forward(self, text_prompts, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + text_prompts = self.token_drop(text_prompts, force_drop_ids) + embeddings = self.text_encodder(text_prompts) + return embeddings + + +if __name__ == '__main__': + + r""" + Returns: + + Examples from CLIPTextModel: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base', + dropout_prob=0.00001).to(device) + + text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]] + output = text_encoder(text_prompts=text_prompt, train=False) + print(output.shape) diff --git a/base/models/resnet.py b/base/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c1307ad9a0fe26b6b1eca913b024c188df2dd86e --- /dev/null +++ b/base/models/resnet.py @@ -0,0 +1,212 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/base/models/temporal_attention.py b/base/models/temporal_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6c417bc9e98409c22297a4878da689e38d188a24 --- /dev/null +++ b/base/models/temporal_attention.py @@ -0,0 +1,388 @@ +import torch +from torch import nn +from typing import Optional +from rotary_embedding_torch import RotaryEmbedding +from dataclasses import dataclass +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +import torch.nn.functional as F +from einops import rearrange, repeat +import math + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +def exists(x): + return x is not None + +class CrossAttention(nn.Module): + r""" + copy from diffuser 0.11.1 + A cross attention layer. + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + use_relative_position: bool = False, + ): + super().__init__() + # print('num head', heads) + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False # No use xformers for temporal attention + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def reshape_for_scores(self, tensor): + # split heads and dims + # tensor should be [b (h w)] f (d nd) + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).contiguous() + return tensor + + def same_batch_dim_to_heads(self, tensor): + batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d + tensor = tensor.reshape(batch_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + + # print('before reshpape query shape', query.shape) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # print('after reshape query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + hidden_states = self._attention(query, key, value, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + attention_probs = attention_probs.to(value.dtype) + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + +class TemporalAttention(CrossAttention): + def __init__(self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + rotary_emb=None): + super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups) + # relative time positional embeddings + self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet + self.rotary_emb = rotary_emb + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device) + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + dim = query.shape[-1] + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + if exists(self.rotary_emb): + query = self.rotary_emb.rotate_queries_or_keys(query) + key = self.rotary_emb.rotate_queries_or_keys(key) + + attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key) + attention_scores = attention_scores + time_rel_pos_bias + + if attention_mask is not None: + # add attention mask + attention_scores = attention_scores + attention_mask + + attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach() + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + attention_probs = attention_probs.to(value.dtype) + hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value) + hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)') + return hidden_states + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames diff --git a/base/models/transformer_3d.py b/base/models/transformer_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..f4b0aba347ea5408955d4646737306c8e945bf16 --- /dev/null +++ b/base/models/transformer_3d.py @@ -0,0 +1,367 @@ + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Any, Dict, Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput, deprecate +from diffusers.models.embeddings import PatchEmbed +from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear +from diffusers.models.modeling_utils import ModelMixin +from einops import rearrange, repeat + +try: + from attention import BasicTransformerBlock +except: + from .attention import BasicTransformerBlock + +@dataclass +class Transformer3DModelOutput(BaseOutput): + """ + The output of [`Transformer2DModel`]. + + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability + distributions for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +class Transformer3DModel(ModelMixin, ConfigMixin): + """ + A 2D Transformer model for image-like data. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + The number of channels in the input and output (specify if the input is **continuous**). + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use. + sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**). + This is fixed during training since it is used to learn a number of position embeddings. + num_vector_embeds (`int`, *optional*): + The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**). + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): + The number of diffusion steps used during training. Pass if at least one of the norm_layers is + `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are + added to the hidden states. + + During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the `TransformerBlocks` attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + out_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + patch_size: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + norm_type: str = "layer_norm", + norm_elementwise_affine: bool = True, + rotary_emb=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = (in_channels is not None) and (patch_size is None) + self.is_input_vectorized = num_vector_embeds is not None + self.is_input_patches = in_channels is not None and patch_size is not None + + if norm_type == "layer_norm" and num_embeds_ada_norm is not None: + deprecation_message = ( + f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or" + " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config." + " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect" + " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it" + " would be very nice if you could open a Pull request for the `transformer/config.json` file" + ) + deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False) + norm_type = "ada_norm" + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif self.is_input_vectorized and self.is_input_patches: + raise ValueError( + f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make" + " sure that either `num_vector_embeds` or `num_patches` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches: + raise ValueError( + f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:" + f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = LoRACompatibleLinear(in_channels, inner_dim) + else: + self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + elif self.is_input_patches: + assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size" + + self.height = sample_size + self.width = sample_size + + self.patch_size = patch_size + self.pos_embed = PatchEmbed( + height=sample_size, + width=sample_size, + patch_size=patch_size, + in_channels=in_channels, + embed_dim=inner_dim, + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + norm_type=norm_type, + norm_elementwise_affine=norm_elementwise_affine, + rotary_emb=rotary_emb, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + self.out_channels = in_channels if out_channels is None else out_channels + if self.is_input_continuous: + # TODO: should use out_channels for continuous projections + if use_linear_projection: + self.proj_out = LoRACompatibleLinear(inner_dim, in_channels) + else: + self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + elif self.is_input_patches: + self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6) + self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim) + self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels) + + def forward( + self, + hidden_states: torch.Tensor, + encoder_hidden_states: Optional[torch.Tensor] = None, + timestep: Optional[torch.LongTensor] = None, + class_labels: Optional[torch.LongTensor] = None, + cross_attention_kwargs: Dict[str, Any] = None, + attention_mask: Optional[torch.Tensor] = None, + encoder_attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + use_image_num=None, + ): + """ + The [`Transformer2DModel`] forward method. + + Args: + hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous): + Input `hidden_states`. + encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.LongTensor`, *optional*): + Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`. + class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*): + Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in + `AdaLayerZeroNorm`. + encoder_attention_mask ( `torch.Tensor`, *optional*): + Cross-attention mask applied to `encoder_hidden_states`. Two formats supported: + + * Mask `(batch, sequence_length)` True = keep, False = discard. + * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard. + + If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format + above. This bias will be added to the cross-attention scores. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain + tuple. + + Returns: + If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a + `tuple` where the first element is the sample tensor. + """ + # ensure attention_mask is a bias, and give it a singleton query_tokens dimension. + # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward. + # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias. + # expects mask of shape: + # [batch, key_tokens] + # adds singleton query_tokens dimension: + # [batch, 1, key_tokens] + # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes: + # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn) + # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn) + if attention_mask is not None and attention_mask.ndim == 2: + # assume that mask is expressed as: + # (1 = keep, 0 = discard) + # convert mask into a bias that can be added to attention scores: + # (keep = +0, discard = -10000.0) + attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # convert encoder_attention_mask to a bias the same way we do for attention_mask + if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2: + encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0 + encoder_attention_mask = encoder_attention_mask.unsqueeze(1) + + # 1. Input + if self.is_input_continuous: # True + + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + if self.training: + video_length = hidden_states.shape[2] - use_image_num + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() + encoder_hidden_states_length = encoder_hidden_states.shape[1] + encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...] + encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous() + encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...] + encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1) + encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous() + else: + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous() + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous() + + batch, _, height, width = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + elif self.is_input_patches: + hidden_states = self.pos_embed(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + attention_mask=attention_mask, + encoder_hidden_states=encoder_hidden_states, + encoder_attention_mask=encoder_attention_mask, + timestep=timestep, + cross_attention_kwargs=cross_attention_kwargs, + class_labels=class_labels, + video_length=video_length, + use_image_num=use_image_num, + ) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous() + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + elif self.is_input_patches: + # TODO: cleanup! + conditioning = self.transformer_blocks[0].norm1.emb( + timestep, class_labels, hidden_dtype=hidden_states.dtype + ) + shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1) + hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None] + hidden_states = self.proj_out_2(hidden_states) + + # unpatchify + height = width = int(hidden_states.shape[1] ** 0.5) + hidden_states = hidden_states.reshape( + shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels) + ) + hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states) + output = hidden_states.reshape( + shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size) + ) + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous() + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) diff --git a/base/models/unet.py b/base/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..bfb2f473142783e7e6cf8f8fb37c41c56fb2f8de --- /dev/null +++ b/base/models/unet.py @@ -0,0 +1,617 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import math +import json +import torch +import einops +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + +try: + from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from .resnet import InflatedConv3d +except: + from unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from resnet import InflatedConv3d + +from rotary_embedding_torch import RotaryEmbedding + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, # 64 + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + + # print(use_first_frame) + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + # print(only_cross_attention) + # print(type(only_cross_attention)) + # exit() + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + # print(only_cross_attention) + # exit() + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + # print(attention_head_dim) + # exit() + + rotary_emb = RotaryEmbedding(32) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + # relative time positional embeddings + self.use_relative_position = use_relative_position + if self.use_relative_position: + self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor = None, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + use_image_num: int = 0, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + # print(emb.shape) # torch.Size([3, 1280]) + # print(class_emb.shape) # torch.Size([3, 1280]) + emb = emb + class_emb + + if self.use_relative_position: + frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device) + else: + frame_rel_pos_bias = None + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + use_image_num=use_image_num, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num, + ) + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + use_image_num=use_image_num, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + # print(sample.shape) + + if not return_dict: + return (sample,) + sample = UNet3DConditionOutput(sample=sample) + return sample + + def forward_with_cfg(self, + x, + t, + encoder_hidden_states = None, + class_labels: Optional[torch.Tensor] = None, + cfg_scale=4.0, + use_fp16=False): + """ + Forward, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + if use_fp16: + combined = combined.to(dtype=torch.float16) + model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = model_out[:, :4], model_out[:, 4:] + # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + config["use_first_frame"] = False + + from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin + + + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + for k, v in model.state_dict().items(): + # print(k) + if '_temp' in k: + state_dict.update({k: v}) + if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross + k = k.replace('attn_fcross', 'attn1') + state_dict.update({k: state_dict[k]}) + if 'norm_fcross' in k: + k = k.replace('norm_fcross', 'norm1') + state_dict.update({k: state_dict[k]}) + + model.load_state_dict(state_dict) + + return model + +if __name__ == '__main__': + import torch + # from xformers.ops import MemoryEfficientAttentionFlashAttentionOp + + device = "cuda" if torch.cuda.is_available() else "cpu" + + pretrained_model_path = "/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-v1-4/" # p cluster + unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device) + # unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp) + unet.enable_xformers_memory_efficient_attention() + unet.enable_gradient_checkpointing() + + unet.train() + + use_image_num = 5 + noisy_latents = torch.randn((2, 4, 16 + use_image_num, 32, 32)).to(device) + bsz = noisy_latents.shape[0] + timesteps = torch.randint(0, 1000, (bsz,)).to(device) + timesteps = timesteps.long() + encoder_hidden_states = torch.randn((bsz, 1 + use_image_num, 77, 768)).to(device) + # class_labels = torch.randn((bsz, )).to(device) + + + model_pred = unet(sample=noisy_latents, timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + class_labels=None, + use_image_num=use_image_num).sample + print(model_pred.shape) \ No newline at end of file diff --git a/base/models/unet_blocks.py b/base/models/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..849c10539c7039840c93631c5201069119d3c306 --- /dev/null +++ b/base/models/unet_blocks.py @@ -0,0 +1,648 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +from torch import nn + +try: + from .attention import Transformer3DModel + from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +except: + from attention import Transformer3DModel + from resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, +): + # print(down_block_type) + # print(use_first_frame) + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False, + ): + super().__init__() + resnets = [] + attentions = [] + + # print(use_first_frame) + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def create_custom_forward_attn(module, return_dict=None, use_image_num=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict, use_image_num=use_image_num) + else: + return module(*inputs, use_image_num=use_image_num) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=False + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + use_image_num=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + def create_custom_forward_attn(module, return_dict=None, use_image_num=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict, use_image_num=use_image_num) + else: + return module(*inputs, use_image_num=use_image_num) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/base/models/utils.py b/base/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76e205d1709568885e58b56888f62c8805ad4f91 --- /dev/null +++ b/base/models/utils.py @@ -0,0 +1,215 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch + +import numpy as np +import torch.nn as nn + +from einops import repeat + + +################################################################################# +# Unet Utils # +################################################################################# + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +# class HybridConditioner(nn.Module): + +# def __init__(self, c_concat_config, c_crossattn_config): +# super().__init__() +# self.concat_conditioner = instantiate_from_config(c_concat_config) +# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + +# def forward(self, c_concat, c_crossattn): +# c_concat = self.concat_conditioner(c_concat) +# c_crossattn = self.crossattn_conditioner(c_crossattn) +# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params \ No newline at end of file diff --git a/base/pipelines/__pycache__/pipeline_videogen.cpython-311.pyc b/base/pipelines/__pycache__/pipeline_videogen.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..caa940925ab961c63315edfdc6eaffc8faa230c2 Binary files /dev/null and b/base/pipelines/__pycache__/pipeline_videogen.cpython-311.pyc differ diff --git a/base/pipelines/pipeline_videogen.py b/base/pipelines/pipeline_videogen.py new file mode 100644 index 0000000000000000000000000000000000000000..97031fc4d25b714b2c666c6d17bf3e16895376f0 --- /dev/null +++ b/base/pipelines/pipeline_videogen.py @@ -0,0 +1,677 @@ + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, Dict, List, Optional, Union +import einops +import torch +from packaging import version +from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer + +from diffusers.configuration_utils import FrozenDict +from diffusers.models import AutoencoderKL +from diffusers.schedulers import KarrasDiffusionSchedulers +from diffusers.utils import ( + deprecate, + is_accelerate_available, + is_accelerate_version, + logging, + #randn_tensor, + replace_example_docstring, + BaseOutput, +) + +try: + from diffusers.utils import randn_tensor +except: + from diffusers.utils.torch_utils import randn_tensor + + +from diffusers.pipeline_utils import DiffusionPipeline +from dataclasses import dataclass + +import os, sys +sys.path.append(os.path.split(sys.path[0])[0]) +from models.unet import UNet3DConditionModel + +import numpy as np + +@dataclass +class StableDiffusionPipelineOutput(BaseOutput): + video: torch.Tensor + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +EXAMPLE_DOC_STRING = """ + Examples: + ```py + >>> import torch + >>> from diffusers import StableDiffusionPipeline + + >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16) + >>> pipe = pipe.to("cuda") + + >>> prompt = "a photo of an astronaut riding a horse on mars" + >>> image = pipe(prompt).images[0] + ``` +""" + + +class VideoGenPipeline(DiffusionPipeline): + r""" + Pipeline for text-to-image generation using Stable Diffusion. + + This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the + library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) + + Args: + vae ([`AutoencoderKL`]): + Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations. + text_encoder ([`CLIPTextModel`]): + Frozen text-encoder. Stable Diffusion uses the text portion of + [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically + the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant. + tokenizer (`CLIPTokenizer`): + Tokenizer of class + [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer). + unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents. + scheduler ([`SchedulerMixin`]): + A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of + [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`]. + safety_checker ([`StableDiffusionSafetyChecker`]): + Classification module that estimates whether generated images could be considered offensive or harmful. + Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details. + feature_extractor ([`CLIPFeatureExtractor`]): + Model that extracts features from generated images to be used as inputs for the `safety_checker`. + """ + _optional_components = ["safety_checker", "feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet3DConditionModel, + scheduler: KarrasDiffusionSchedulers, + ): + super().__init__() + + if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`" + f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " + "to update the config accordingly as leaving `steps_offset` might led to incorrect results" + " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub," + " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`" + " file" + ) + deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["steps_offset"] = 1 + scheduler._internal_dict = FrozenDict(new_config) + + if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: + deprecation_message = ( + f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`." + " `clip_sample` should be set to False in the configuration file. Please make sure to update the" + " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in" + " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very" + " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file" + ) + deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(scheduler.config) + new_config["clip_sample"] = False + scheduler._internal_dict = FrozenDict(new_config) + + + + is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse( + version.parse(unet.config._diffusers_version).base_version + ) < version.parse("0.9.0.dev0") + is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64 + if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64: + deprecation_message = ( + "The configuration file of the unet has set the default `sample_size` to smaller than" + " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the" + " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-" + " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5" + " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the" + " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`" + " in the config might lead to incorrect results in future versions. If you have downloaded this" + " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for" + " the `unet/config.json` file" + ) + deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False) + new_config = dict(unet.config) + new_config["sample_size"] = 64 + unet._internal_dict = FrozenDict(new_config) + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + scheduler=scheduler, + ) + self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) + # self.register_to_config(requires_safety_checker=requires_safety_checker) + + def enable_vae_slicing(self): + r""" + Enable sliced VAE decoding. + + When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several + steps. This is useful to save some memory and allow larger batch sizes. + """ + self.vae.enable_slicing() + + def disable_vae_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_slicing() + + def enable_vae_tiling(self): + r""" + Enable tiled VAE decoding. + + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in + several steps. This is useful to save a large amount of memory and to allow the processing of larger images. + """ + self.vae.enable_tiling() + + def disable_vae_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.vae.disable_tiling() + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + Note that offloading happens on a submodule basis. Memory savings are higher than with + `enable_model_cpu_offload`, but performance is lower. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"): + from accelerate import cpu_offload + else: + raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + cpu_offload(cpu_offloaded_model, device) + + # if self.safety_checker is not None: + # cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + self.final_offload_hook = hook + + @property + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + def decode_latents(self, latents): + video_length = latents.shape[2] + latents = 1 / 0.18215 * latents + latents = einops.rearrange(latents, "b c f h w -> (b f) c h w") + video = self.vae.decode(latents).sample + video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length) + video = ((video / 2 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().contiguous() + return video + + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + def check_inputs( + self, + prompt, + height, + width, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if height % 8 != 0 or width % 8 != 0: + raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor) + if isinstance(generator, list) and len(generator) != batch_size: + raise ValueError( + f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" + f" size of {batch_size}. Make sure the batch size matches the length of the generators." + ) + + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + @replace_example_docstring(EXAMPLE_DOC_STRING) + def __call__( + self, + prompt: Union[str, List[str]] = None, + height: Optional[int] = None, + width: Optional[int] = None, + video_length: int = 16, + num_inference_steps: int = 50, + guidance_scale: float = 7.5, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + output_type: Optional[str] = "pil", + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + cross_attention_kwargs: Optional[Dict[str, Any]] = None, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The height in pixels of the generated image. + width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor): + The width in pixels of the generated image. + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead. + Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator` or `List[torch.Generator]`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + output_type (`str`, *optional*, defaults to `"pil"`): + The output format of the generate image. Choose between + [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + cross_attention_kwargs (`dict`, *optional*): + A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under + `self.processor` in + [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py). + + Examples: + + Returns: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`: + [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple. + When returning a tuple, the first element is a list with the generated images, and the second element is a + list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work" + (nsfw) content, according to the `safety_checker`. + """ + # 0. Default height and width to unet + height = height or self.unet.config.sample_size * self.vae_scale_factor + width = width or self.unet.config.sample_size * self.vae_scale_factor + + # 1. Check inputs. Raise error if not correct + self.check_inputs( + prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds + ) + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Prepare timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 5. Prepare latent variables + num_channels_latents = self.unet.config.in_channels + latents = self.prepare_latents( + batch_size * num_images_per_prompt, + num_channels_latents, + video_length, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) + + # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 7. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + + # predict the noise residual + noise_pred = self.unet( + latent_model_input, + t, + encoder_hidden_states=prompt_embeds, + # cross_attention_kwargs=cross_attention_kwargs, + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + + # 8. Post-processing + video = self.decode_latents(latents) + + return StableDiffusionPipelineOutput(video=video) diff --git a/base/pipelines/sample.py b/base/pipelines/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..e2a2a89e311982e6596cece7f767015e9b9ec709 --- /dev/null +++ b/base/pipelines/sample.py @@ -0,0 +1,88 @@ +import os +import torch +import argparse +import torchvision + +from pipeline_videogen import VideoGenPipeline + +from download import find_model +from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler +from diffusers.models import AutoencoderKL +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from omegaconf import OmegaConf + +import os, sys +sys.path.append(os.path.split(sys.path[0])[0]) +from models import get_models +import imageio + +def main(args): + #torch.manual_seed(args.seed) + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + + sd_path = args.pretrained_path + "/stable-diffusion-v1-4" + unet = get_models(args, sd_path).to(device, dtype=torch.float16) + state_dict = find_model(args.pretrained_path + "/lavie_base.pt") + unet.load_state_dict(state_dict) + + vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device) + tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer") + text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge + + # set eval mode + unet.eval() + vae.eval() + text_encoder_one.eval() + + if args.sample_method == 'ddim': + scheduler = DDIMScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + elif args.sample_method == 'eulerdiscrete': + scheduler = EulerDiscreteScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + elif args.sample_method == 'ddpm': + scheduler = DDPMScheduler.from_pretrained(sd_path, + subfolder="scheduler", + beta_start=args.beta_start, + beta_end=args.beta_end, + beta_schedule=args.beta_schedule) + else: + raise NotImplementedError + + videogen_pipeline = VideoGenPipeline(vae=vae, + text_encoder=text_encoder_one, + tokenizer=tokenizer_one, + scheduler=scheduler, + unet=unet).to(device) + videogen_pipeline.enable_xformers_memory_efficient_attention() + + if not os.path.exists(args.output_folder): + os.makedirs(args.output_folder) + + video_grids = [] + for prompt in args.text_prompt: + print('Processing the ({}) prompt'.format(prompt)) + videos = videogen_pipeline(prompt, + video_length=args.video_length, + height=args.image_size[0], + width=args.image_size[1], + num_inference_steps=args.num_sampling_steps, + guidance_scale=args.guidance_scale).video + imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0 + + print('save path {}'.format(args.output_folder)) + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="") + args = parser.parse_args() + + main(OmegaConf.load(args.config)) + diff --git a/base/pipelines/sample.sh b/base/pipelines/sample.sh new file mode 100644 index 0000000000000000000000000000000000000000..c8dd5fdada1cf921e5c250bd357e765fd9e42b5a --- /dev/null +++ b/base/pipelines/sample.sh @@ -0,0 +1,2 @@ +export CUDA_VISIBLE_DEVICES=6 +python pipelines/sample.py --config configs/sample.yaml \ No newline at end of file diff --git a/base/text_to_video/__init__.py b/base/text_to_video/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b5a194bb33bf22c33de528da45b372f7e2ffaef6 --- /dev/null +++ b/base/text_to_video/__init__.py @@ -0,0 +1,44 @@ +import os +import torch +import argparse +import torchvision + +from pipelines.pipeline_videogen import VideoGenPipeline + +from download import find_model +from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler +from diffusers.models import AutoencoderKL +from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection +from omegaconf import OmegaConf + +import os, sys +sys.path.append(os.path.split(sys.path[0])[0]) +from models import get_models +import imageio + +config_path = "/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml" +args = OmegaConf.load("/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml") + +device = "cuda" if torch.cuda.is_available() else "cpu" + +def model_t2v_fun(args): + sd_path = args.pretrained_path + "/stable-diffusion-v1-4" + unet = get_models(args, sd_path).to(device, dtype=torch.float16) + # state_dict = find_model(args.pretrained_path + "/lavie_base.pt") + state_dict = find_model("/mnt/petrelfs/share_data/wangyaohui/lavie/pretrained_models/lavie_base.pt") + unet.load_state_dict(state_dict) + + vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device) + tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer") + text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge + unet.eval() + vae.eval() + text_encoder_one.eval() + scheduler = DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule) + return VideoGenPipeline(vae=vae, text_encoder=text_encoder_one, tokenizer=tokenizer_one, scheduler=scheduler, unet=unet) + +def setup_seed(seed): + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + diff --git a/base/text_to_video/__pycache__/__init__.cpython-311.pyc b/base/text_to_video/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..2564eedb859cc053903c2f055f4e2e237d1fa03e Binary files /dev/null and b/base/text_to_video/__pycache__/__init__.cpython-311.pyc differ diff --git a/base/try.py b/base/try.py new file mode 100644 index 0000000000000000000000000000000000000000..5cffd2393f0658f057f30acd2060a21ad1dc8825 --- /dev/null +++ b/base/try.py @@ -0,0 +1,5 @@ +import gradio as gr + +with gr.Blocks() as demo: + prompt = gr.Textbox(label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in") +demo.launch(server_name="0.0.0.0") \ No newline at end of file diff --git a/environment.yml b/environment.yml new file mode 100644 index 0000000000000000000000000000000000000000..7af32037d0e2dc010f6ccf4ec0970ad66029ff70 --- /dev/null +++ b/environment.yml @@ -0,0 +1,27 @@ +name: lavie +channels: + - pytorch + - nvidia +dependencies: + - python=3.11.3 + - pytorch=2.0.1 + - pytorch-cuda=11.7 + - torchvision=0.15.2 + - pip: + - accelerate==0.19.0 + - av==10.0.0 + - decord==0.6.0 + - diffusers[torch]==0.16.0 + - einops==0.6.1 + - ffmpeg==1.4 + - imageio==2.31.1 + - imageio-ffmpeg==0.4.9 + - pandas==2.0.1 + - timm==0.6.13 + - tqdm==4.65.0 + - transformers==4.28.1 + - xformers==0.0.20 + - omegaconf==2.3.0 + - natsort==8.4.0 + - rotary_embedding_torch + - gradio==4.3.0 \ No newline at end of file diff --git a/interpolation/configs/sample.yaml b/interpolation/configs/sample.yaml new file mode 100644 index 0000000000000000000000000000000000000000..c00d36aaec25efe196ac47d167042c08f3717a57 --- /dev/null +++ b/interpolation/configs/sample.yaml @@ -0,0 +1,36 @@ +args: + input_folder: "../res/base/" + pretrained_path: "../pretrained_models" + output_folder: "../res/interpolation/" + seed_list: + - 3418 + + fps_list: + - 24 + + # model config: + model: TSR + num_frames: 61 + image_size: [320, 512] + num_sampling_steps: 50 + vae: mse + use_timecross_transformer: False + frame_interval: 1 + + # sample config: + seed: 0 + cfg_scale: 4.0 + run_time: 12 + use_compile: False + enable_xformers_memory_efficient_attention: True + num_sample: 1 + + additional_prompt: ", 4k." + negative_prompt: "None" + do_classifier_free_guidance: True + use_ddim_sample_loop: True + + researve_frame: 3 + mask_type: "tsr" + use_concat: True + copy_no_mask: True diff --git a/interpolation/datasets/__init__.py b/interpolation/datasets/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..1332521f6336d8ce42f026e0aeba66fdd8138a0a --- /dev/null +++ b/interpolation/datasets/__init__.py @@ -0,0 +1 @@ +from datasets import video_transforms diff --git a/interpolation/datasets/video_transforms.py b/interpolation/datasets/video_transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..463a10f082288f486a546bbff17b136824723589 --- /dev/null +++ b/interpolation/datasets/video_transforms.py @@ -0,0 +1,109 @@ +import torch +import random +import numbers +from torchvision.transforms import RandomCrop, RandomResizedCrop + +def _is_tensor_video_clip(clip): + if not torch.is_tensor(clip): + raise TypeError("clip should be Tensor. Got %s" % type(clip)) + + if not clip.ndimension() == 4: + raise ValueError("clip should be 4D. Got %dD" % clip.dim()) + + return True + + +def to_tensor(clip): + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + _is_tensor_video_clip(clip) + if not clip.dtype == torch.uint8: + raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype)) + # return clip.float().permute(3, 0, 1, 2) / 255.0 + return clip.float() / 255.0 + + +def resize(clip, target_size, interpolation_mode): + if len(target_size) != 2: + raise ValueError(f"target size should be tuple (height, width), instead got {target_size}") + return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False) + + +class ToTensorVideo: + """ + Convert tensor data type from uint8 to float, divide value by 255.0 and + permute the dimensions of clip tensor + """ + + def __init__(self): + pass + + def __call__(self, clip): + """ + Args: + clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W) + Return: + clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W) + """ + return to_tensor(clip) + + def __repr__(self) -> str: + return self.__class__.__name__ + + +class ResizeVideo: + ''' + Resize to the specified size + ''' + def __init__( + self, + size, + interpolation_mode="bilinear", + ): + if isinstance(size, tuple): + if len(size) != 2: + raise ValueError(f"size should be tuple (height, width), instead got {size}") + self.size = size + else: + self.size = (size, size) + + self.interpolation_mode = interpolation_mode + + + def __call__(self, clip): + """ + Args: + clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W) + Returns: + torch.tensor: scale resized video clip. + size is (T, C, h, w) + """ + clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode) + return clip_resize + + def __repr__(self) -> str: + return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}" + + +class TemporalRandomCrop(object): + """Temporally crop the given frame indices at a random location. + + Args: + size (int): Desired length of frames will be seen in the model. + """ + + def __init__(self, size): + self.size = size + + def __call__(self, total_frames): + rand_end = max(0, total_frames - self.size - 1) + begin_index = random.randint(0, rand_end) + end_index = min(begin_index + self.size, total_frames) + return begin_index, end_index + diff --git a/interpolation/diffusion/__init__.py b/interpolation/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9dbf6cf0bd6b9d1a8f65e0a31e9a84cacc03189 --- /dev/null +++ b/interpolation/diffusion/__init__.py @@ -0,0 +1,47 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +def create_diffusion( + timestep_respacing, + noise_schedule="linear", + use_kl=False, + sigma_small=False, + predict_xstart=False, + # learn_sigma=True, + learn_sigma=False, # for unet + rescale_learned_sigmas=False, + diffusion_steps=1000 +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps) + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=( + gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X + ), + model_var_type=( + ( + gd.ModelVarType.FIXED_LARGE + if not sigma_small + else gd.ModelVarType.FIXED_SMALL + ) + if not learn_sigma + else gd.ModelVarType.LEARNED_RANGE + ), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/interpolation/diffusion/diffusion_utils.py b/interpolation/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060 --- /dev/null +++ b/interpolation/diffusion/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/interpolation/diffusion/gaussian_diffusion.py b/interpolation/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..854c5c4daceace3826b786d98220c9d1b10611c8 --- /dev/null +++ b/interpolation/diffusion/gaussian_diffusion.py @@ -0,0 +1,1000 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "quad": + betas = ( + np.linspace( + beta_start ** 0.5, + beta_end ** 0.5, + num_diffusion_timesteps, + dtype=np.float64, + ) + ** 2 + ) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name == "linear": + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + "linear", + beta_start=scale * 0.0001, + beta_end=scale * 0.02, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None, + mask=None, x_start=None, use_concat=False, + copy_no_mask=False, ): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, F, C = x.shape[:3] + assert t.shape == (B,) + # model_output = model(x, t, **model_kwargs) + if copy_no_mask: + if use_concat: + try: + model_output = model(th.concat([x, x_start], dim=1), t, **model_kwargs).sample + except: + # print(f'x.shape = {x.shape}, x_start.shape = {x_start.shape}') + # ) + # x.shape = torch.Size([2, 4, 61, 32, 32]), x_start.shape = torch.Size([2, 4, 61, 32, 32] + # print(f'x[0,0,:,0,0] = {x[0,0,:,0,0]}, \nx_start[0,0,:,0,0] = {x_start[0,0,:,0,0]}') + model_output = model(th.concat([x, x_start], dim=1), t, **model_kwargs) + else: + try: + model_output = model(x, t, **model_kwargs).sample # for tav unet + except: + model_output = model(x, t, **model_kwargs) + else: + if use_concat: + try: + model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs).sample + except: + model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs) + else: + try: + model_output = model(x, t, **model_kwargs).sample # for tav unet + except: + model_output = model(x, t, **model_kwargs) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, F, C * 2, *x.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output, mask=mask, x_start=x_start, use_concat=use_concat) + ) + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps, mask=None, x_start=None, use_concat=False): # (x_t=x, t=t, eps=model_output) + assert x_t.shape == eps.shape + if not use_concat: + if mask is not None: + if x_start is None: + return ( + (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps )* mask + x_t * (1-mask) + ) + else: + # breakpoint() + if (t == 0).any(): + print('t=0') + x_unknown = _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t \ + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + return x_start * (1-mask) + x_unknown * mask + else: + x_known = self.q_sample(x_start, t-1) + x_unknown = _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t \ + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + return ( + x_known * (1-mask) + x_unknown * mask + ) + else: + return ( + (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) + ) + else: + return ( + (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps ) + ) + + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + mask=None, + x_start=None, + use_concat=False + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + mask=None, + x_start=None, + use_concat=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + mask=mask, + x_start=x_start, + use_concat=use_concat + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + mask=None, + x_start=None, + use_concat=False + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): # loop + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + mask=None, + x_start=None, + use_concat=False, + copy_no_mask=False, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + mask=mask, + x_start=x_start, + use_concat=use_concat, + copy_no_mask=copy_no_mask, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + mask=None, + x_start=None, + use_concat=False, + copy_no_mask=False, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + mask=mask, + x_start=x_start, + use_concat=use_concat, + copy_no_mask=copy_no_mask, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + mask=None, + x_start=None, + use_concat=False, + copy_no_mask=False, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + mask=mask, + x_start=x_start, + use_concat=use_concat, + copy_no_mask=copy_no_mask, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None, t_head=None, copy_no_mask=False): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + # mask could be here + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1) + # mask is used for (0,0,0,1,1,1,...) which means the diffusion model can see the first 3 frames of the input video + # print(f'training_losses(): mask = {mask}') # None + + if mask is not None: + x_t = x_t*mask + x_start*(1-mask) + + # noise augmentation + if copy_no_mask: + if t_head is not None: + noise_aug = self.q_sample(x_start[:, 4:], t_head) # noise aug on copied_video + x_t = th.cat([x_t[:, :4], noise_aug], dim=1) + else: + if t_head is not None: + noise_aug = self.q_sample(x_start[:, 5:], t_head) # b, 4, f, h, w + noise_aug = noise_aug * (x_start[:, 4].unsqueeze(1).expand(-1, 4, -1, -1, -1) == 0) # use mask to zero out augmented noises + x_t = th.cat([x_t[:, :5], noise_aug], dim=1) + terms = {} + # for i in [0,1,2,3,4,5,6,7]: + # print(f'x_t[0,{i},:,0,0] = {x_t[0,i,:,0,0]}') + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: + # print(f'self.loss_type = {self.loss_type}') # LossType.MSE + # model_output = model(x_t, t, **model_kwargs) + try: + model_output = model(x_t, t, **model_kwargs).sample # for tav unet + except: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, F, C = x_t.shape[:3] + assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + + # print(f'self.model_mean_type = {self.model_mean_type}') # ModelMeanType.EPSILON + target = { + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + # assert model_output.shape == target.shape == x_start.shape + # if mask is not None: + # nonzero_idx = th.nonzero(1-mask) + terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2) + # else: + # terms["mse"] = mean_flat((target - model_output) ** 2) + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/interpolation/diffusion/respace.py b/interpolation/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..31c0092380367db477f154a39bf172ff64712fc4 --- /dev/null +++ b/interpolation/diffusion/respace.py @@ -0,0 +1,130 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import torch +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + # @torch.compile + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/interpolation/diffusion/timestep_sampler.py b/interpolation/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f --- /dev/null +++ b/interpolation/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/interpolation/download.py b/interpolation/download.py new file mode 100644 index 0000000000000000000000000000000000000000..b7daa0660d0c2eea1e730cbe88585de2e0723dcf --- /dev/null +++ b/interpolation/download.py @@ -0,0 +1,22 @@ +# All rights reserved. + +# This source code is licensed under the license found in the +# LICENSE file in the root directory of this source tree. + +import torch +import os + + +pretrained_models = {''} + + +def find_model(model_name): + """ + Finds a pre-trained model, downloading it if necessary. Alternatively, loads a model from a local path. + """ + assert os.path.isfile(model_name), f'Could not find checkpoint at {model_name}' + checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage) + if "ema" in checkpoint: # supports checkpoints from train.py + checkpoint = checkpoint["ema"] + return checkpoint + diff --git a/interpolation/models/__init__.py b/interpolation/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..507091b8c274533eca620577bae301541d5a91f6 --- /dev/null +++ b/interpolation/models/__init__.py @@ -0,0 +1,33 @@ +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from .unet import UNet3DConditionModel +from torch.optim.lr_scheduler import LambdaLR + +def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit + from torch.optim.lr_scheduler import LambdaLR + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'warmup': + return customized_lr_scheduler(optimizer, **kwargs) + elif name == 'cosine': + from torch.optim.lr_scheduler import CosineAnnealingLR + return CosineAnnealingLR(optimizer, **kwargs) + else: + raise NotImplementedError(name) + +def get_models(args, ckpt_path): + + if 'TSR' in args.model: + return UNet3DConditionModel.from_pretrained_2d(ckpt_path, subfolder="unet", use_concat=args.use_concat, copy_no_mask=args.copy_no_mask) + else: + raise '{} Model Not Supported!'.format(args.model) + diff --git a/interpolation/models/attention.py b/interpolation/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..1c05879560bd827c2096e557dcaa17fadc8c1bb0 --- /dev/null +++ b/interpolation/models/attention.py @@ -0,0 +1,665 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from dataclasses import dataclass +from typing import Optional + +import math +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm + +from einops import rearrange, repeat + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class CrossAttention(nn.Module): + r""" + copy from diffuser 0.11.1 + A cross attention layer. + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + use_relative_position: bool = False, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + # print(use_relative_position) + self.use_relative_position = use_relative_position + if self.use_relative_position: + # print(dim_head) + # print(heads) + # adopt https://github.com/huggingface/transformers/blob/8a817e1ecac6a420b1bdc701fcc33535a3b96ff5/src/transformers/models/bert/modeling_bert.py#L265 + self.max_position_embeddings = 32 + self.distance_embedding = nn.Embedding(2 * self.max_position_embeddings - 1, dim_head) + + self.dropout = nn.Dropout(dropout) + + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def reshape_for_scores(self, tensor): + # split heads and dims + # tensor should be [b (h w)] f (d nd) + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).contiguous() + return tensor + + def same_batch_dim_to_heads(self, tensor): + batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d + tensor = tensor.reshape(batch_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + # if self.use_relative_position: + # print('before attention query shape', query.shape) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # if self.use_relative_position: + # print('before attention query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + if self.use_relative_position: + query = self.reshape_for_scores(self.reshape_batch_dim_to_heads(query)) + key = self.reshape_for_scores(self.reshape_batch_dim_to_heads(key)) + value = self.reshape_for_scores(self.reshape_batch_dim_to_heads(value)) + + # torch.baddbmm only accepte 3-D tensor + # https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm + attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2)) + + # print('attention_scores shape', attention_scores.shape) + + # print(query.shape) # [b (h w)] nd f d + query_length, key_length = query.shape[2], key.shape[2] + # print('query shape', query.shape) + # print('key shape', key.shape) + position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) # hidden_states.device + position_ids_r = torch.arange(key_length, dtype=torch.long, device=key.device).view(1, -1) # hidden_states.device + distance = position_ids_l - position_ids_r + # print('distance shape', distance.shape) + positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1) + positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility + # print('positional_embedding shape', positional_embedding.shape) + relative_position_scores_query = torch.einsum("bhld, lrd -> bhlr", query, positional_embedding) + relative_position_scores_key = torch.einsum("bhrd, lrd -> bhlr", key, positional_embedding) + # print('relative_position_scores_key shape', relative_position_scores_key.shape) + attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key + # print(attention_scores.shape) + + attention_scores = attention_scores / math.sqrt(self.dim_head) + + # Normalize the attention scores to probabilities. + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.matmul(attention_probs, value) + # print(hidden_states.shape) + hidden_states = self.same_batch_dim_to_heads(hidden_states) + # print(hidden_states.shape) + # exit() + + else: + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + # print(attention_probs.shape) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + # print(attention_probs.shape) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + # print(hidden_states.shape) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # print(hidden_states.shape) + # exit() + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + # TODO attention_mask + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + # print(only_cross_attention) + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_first_frame = use_first_frame + + # SC-Attn + if use_first_frame: + self.attn1 = SparseCausalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + # print(cross_attention_dim) + else: + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + # Temp-Attn + self.attn_temp = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + use_relative_position=use_relative_position + ) + nn.init.zeros_(self.attn_temp.to_out[0].weight.data) + self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op=None): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + if self.use_first_frame: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + # Temporal-Attention + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + norm_hidden_states = ( + self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states) + ) + hidden_states = self.attn_temp(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + +class SparseCausalAttention(CrossAttention): + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + key = rearrange(key, "(b f) d c -> b f d c", f=video_length) + key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) + key = rearrange(key, "b f d c -> (b f) d c") + + value = rearrange(value, "(b f) d c -> b f d c", f=video_length) + value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) + value = rearrange(value, "b f d c -> (b f) d c") + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states diff --git a/interpolation/models/clip.py b/interpolation/models/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..0d77f4d1a725a5d7ed0c4e10a69602e80b91fec1 --- /dev/null +++ b/interpolation/models/clip.py @@ -0,0 +1,124 @@ +import numpy +import torch.nn as nn +from transformers import CLIPTokenizer, CLIPTextModel + +import transformers +transformers.logging.set_verbosity_error() + +""" +Will encounter following warning: +- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task +or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). +- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model +that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). + +https://github.com/CompVis/stable-diffusion/issues/97 +according to this issue, this warning is safe. + +This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion. +You can safely ignore the warning, it is not an error. + +This clip usage is from U-ViT and same with Stable Diffusion. +""" + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, sd_path, device="cuda", max_length=77): + super().__init__() + self.tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer", use_fast=False) + self.transformer = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder") + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.transformer = self.transformer.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True, + return_overflowing_tokens=False, padding="max_length", return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.transformer(input_ids=tokens) + + z = outputs.last_hidden_state + return z + + def encode(self, text): + return self(text) + + +class TextEmbedder(nn.Module): + """ + Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. + """ + def __init__(self, args, dropout_prob=0.1): + super().__init__() + self.text_encodder = FrozenCLIPEmbedder(args) + self.dropout_prob = dropout_prob + + def token_drop(self, text_prompts, force_drop_ids=None): + """ + Drops text to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob + else: + # TODO + drop_ids = force_drop_ids == 1 + labels = list(numpy.where(drop_ids, "None", text_prompts)) + # print(labels) + return labels + + def forward(self, text_prompts, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + text_prompts = self.token_drop(text_prompts, force_drop_ids) + embeddings = self.text_encodder(text_prompts) + return embeddings + + +if __name__ == '__main__': + + r""" + Returns: + + Examples from CLIPTextModel: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + text_encoder = TextEmbedder(dropout_prob=0.00001).to(device) + text_encoder1 = FrozenCLIPEmbedder().to(device) + + text_prompt = ["a photo of a cat", "a photo of a dog", 'a photo of a dog human'] + # text_prompt = ('None', 'None', 'None') + output = text_encoder(text_prompts=text_prompt, train=True) + output1 = text_encoder1(text_prompt) + # print(output) + print(output.shape) + print(output1.shape) + print((output==output1).all()) diff --git a/interpolation/models/resnet.py b/interpolation/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..c1307ad9a0fe26b6b1eca913b024c188df2dd86e --- /dev/null +++ b/interpolation/models/resnet.py @@ -0,0 +1,212 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) \ No newline at end of file diff --git a/interpolation/models/unet.py b/interpolation/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..66b161f18f6552a4a9f7838c461c9ed10c26c7bf --- /dev/null +++ b/interpolation/models/unet.py @@ -0,0 +1,576 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import json + +import torch +import torch.nn as nn +import torch.utils.checkpoint + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput, logging +from diffusers.models.embeddings import TimestepEmbedding, Timesteps + +try: + from diffusers.models.modeling_utils import ModelMixin +except: + from diffusers.modeling_utils import ModelMixin # 0.11.1 + +try: + from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from .resnet import InflatedConv3d +except: + from unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from resnet import InflatedConv3d + + + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DConditionModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + sample_size: Optional[int] = None, # 64 + in_channels: int = 4, + out_channels: int = 4, + center_input_sample: bool = False, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + down_block_types: Tuple[str] = ( + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D", + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = False, + block_out_channels: Tuple[int] = (320, 640, 1280, 1280), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1280, + attention_head_dim: Union[int, Tuple[int]] = 8, + dual_cross_attention: bool = False, + use_linear_projection: bool = False, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = None, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + + # print(use_first_frame) + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1)) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + # print(only_cross_attention) + # exit() + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + # print(attention_head_dim) + # exit() + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + self.down_blocks.append(down_block) + + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + encoder_hidden_states: torch.Tensor = None, + class_labels: Optional[torch.Tensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ) -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + # print(emb.shape) # torch.Size([3, 1280]) + # print(class_emb.shape) # torch.Size([3, 1280]) + emb = emb + class_emb + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block in self.down_blocks: + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + # up + for i, upsample_block in enumerate(self.up_blocks): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + # print(sample.shape) + + if not return_dict: + return (sample,) + sample = UNet3DConditionOutput(sample=sample) + return sample + + def forward_with_cfg(self, + x, + t, + encoder_hidden_states = None, + class_labels: Optional[torch.Tensor] = None, + cfg_scale=4.0): + """ + Forward, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + # eps, rest = model_out[:, :4], model_out[:, 4:] + eps, rest = model_out[:, :4], model_out[:, 4:] # b c f h w + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + @classmethod + def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_concat=False, copy_no_mask=False): + if subfolder is not None: + pretrained_model_path = os.path.join(pretrained_model_path, subfolder) + + + config_file = os.path.join(pretrained_model_path, 'config.json') + if not os.path.isfile(config_file): + raise RuntimeError(f"{config_file} does not exist") + with open(config_file, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + config["down_block_types"] = [ + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "DownBlock3D" + ] + config["up_block_types"] = [ + "UpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D" + ] + + config["use_first_frame"] = True + + if copy_no_mask: + config["in_channels"] = 8 + else: + if use_concat: + config["in_channels"] = 9 + + + from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin + + + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + + + if use_concat: + new_state_dict = {} + conv_in_weight = state_dict["conv_in.weight"] + + print(f'from_pretrained_2d copy_no_mask = {copy_no_mask}') + if copy_no_mask: + new_conv_in_channel = 8 + new_conv_in_list = [0, 1, 2, 3, 4, 5, 6, 7] + else: + new_conv_in_channel = 9 + new_conv_in_list = [0, 1, 2, 3, 4, 5, 6, 7, 8] + new_conv_weight = torch.zeros((conv_in_weight.shape[0], new_conv_in_channel, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype) + + for i, j in zip([0, 1, 2, 3], new_conv_in_list): + new_conv_weight[:, j] = conv_in_weight[:, i] + new_state_dict["conv_in.weight"] = new_conv_weight + new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"] + for k, v in model.state_dict().items(): + # print(k) + if '_temp.' in k: + new_state_dict.update({k: v}) + elif 'conv_in' in k: + continue + else: + new_state_dict[k] = v + # # tmp + # if 'class_embedding' in k: + # state_dict.update({k: v}) + # breakpoint() + model.load_state_dict(new_state_dict) + else: + for k, v in model.state_dict().items(): + # print(k) + if '_temp.' in k: + state_dict.update({k: v}) + model.load_state_dict(state_dict) + return model + +if __name__ == '__main__': + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + pretrained_model_path = "/nvme/maxin/work/large-dit-video/pretrained/stable-diffusion-v1-4/" # 43 + unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device) + + noisy_latents = torch.randn((3, 4, 16, 32, 32)).to(device) + bsz = noisy_latents.shape[0] + timesteps = torch.randint(0, 1000, (bsz,)).to(device) + timesteps = timesteps.long() + encoder_hidden_states = torch.randn((bsz, 77, 768)).to(device) + class_labels = torch.randn((bsz, )).to(device) + + model_pred = unet(sample=noisy_latents, timestep=timesteps, + encoder_hidden_states=encoder_hidden_states, + class_labels=class_labels).sample + print(model_pred.shape) \ No newline at end of file diff --git a/interpolation/models/unet_blocks.py b/interpolation/models/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..0bfed25db240043dcc42b7e4459f5ca52d3cd902 --- /dev/null +++ b/interpolation/models/unet_blocks.py @@ -0,0 +1,619 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +from torch import nn + +try: + from .attention import Transformer3DModel + from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +except: + from attention import Transformer3DModel + from resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, +): + # print(down_block_type) + # print(use_first_frame) + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + ): + super().__init__() + resnets = [] + attentions = [] + + # print(use_first_frame) + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states diff --git a/interpolation/models/utils.py b/interpolation/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..b6cfac1ff6c2c99c87920372482251dc2b2fce34 --- /dev/null +++ b/interpolation/models/utils.py @@ -0,0 +1,215 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch + +import numpy as np +import torch.nn as nn + +from einops import repeat + + +################################################################################# +# Unet Utils # +################################################################################# + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +# class HybridConditioner(nn.Module): + +# def __init__(self, c_concat_config, c_crossattn_config): +# super().__init__() +# self.concat_conditioner = instantiate_from_config(c_concat_config) +# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + +# def forward(self, c_concat, c_crossattn): +# c_concat = self.concat_conioner(c_concat) +# c_crossattn = self.crossattn_conditioner(c_crossattn) +# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params \ No newline at end of file diff --git a/interpolation/sample.py b/interpolation/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..f4413e8407338e5ee5fe0df6625f42aef6d9671a --- /dev/null +++ b/interpolation/sample.py @@ -0,0 +1,312 @@ +""" +we introduce a temporal interpolation network to enhance the smoothness of generated videos and synthesize richer temporal details. +This network takes a 16-frame base video as input and produces an upsampled output consisting of 61 frames. +""" + +import os +import sys +import math +try: + import utils + + from diffusion import create_diffusion + from download import find_model +except: + sys.path.append(os.path.split(sys.path[0])[0]) + + import utils + + from diffusion import create_diffusion + from download import find_model + +import torch +import argparse +import torchvision + +from einops import rearrange +from models import get_models +from torchvision.utils import save_image +from diffusers.models import AutoencoderKL +from models.clip import TextEmbedder +from omegaconf import OmegaConf +from PIL import Image +import numpy as np +from torchvision import transforms +sys.path.append("..") +from datasets import video_transforms +from decord import VideoReader +from utils import mask_generation, mask_generation_before +from natsort import natsorted +from diffusers.utils.import_utils import is_xformers_available + +torch.backends.cuda.matmul.allow_tf32 = True +torch.backends.cudnn.allow_tf32 = True + + +def get_input(args): + input_path = args.input_path + transform_video = transforms.Compose([ + video_transforms.ToTensorVideo(), # TCHW + # video_transforms.CenterCropResizeVideo((args.image_h, args.image_w)), + video_transforms.ResizeVideo((args.image_h, args.image_w)), + transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True) + ]) + temporal_sample_func = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval) + if input_path is not None: + print(f'loading video from {input_path}') + if os.path.isdir(input_path): + file_list = os.listdir(input_path) + video_frames = [] + for file in file_list: + if file.endswith('jpg') or file.endswith('png'): + image = torch.as_tensor(np.array(Image.open(file), dtype=np.uint8, copy=True)).unsqueeze(0) + video_frames.append(image) + else: + continue + n = 0 + video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w + video_frames = transform_video(video_frames) + return video_frames, n + elif os.path.isfile(input_path): + _, full_file_name = os.path.split(input_path) + file_name, extention = os.path.splitext(full_file_name) + if extention == '.mp4': + video_reader = VideoReader(input_path) + total_frames = len(video_reader) + start_frame_ind, end_frame_ind = temporal_sample_func(total_frames) + frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, args.num_frames, dtype=int) + video_frames = torch.from_numpy(video_reader.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous() + video_frames = transform_video(video_frames) + n = 3 + del video_reader + return video_frames, n + else: + raise TypeError(f'{extention} is not supported !!') + else: + raise ValueError('Please check your path input!!') + else: + print('given video is None, using text to video') + video_frames = torch.zeros(16,3,args.latent_h,args.latent_w,dtype=torch.uint8) + args.mask_type = 'all' + video_frames = transform_video(video_frames) + n = 0 + return video_frames, n + + +def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,): + + + b,f,c,h,w=video_input.shape + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + + z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w + + masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous() + masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215) + masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous() + mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1) + + + masked_video = torch.cat([masked_video] * 2) if args.do_classifier_free_guidance else masked_video + mask = torch.cat([mask] * 2) if args.do_classifier_free_guidance else mask + z = torch.cat([z] * 2) if args.do_classifier_free_guidance else z + + prompt_all = [prompt] + [args.negative_prompt] if args.do_classifier_free_guidance else [prompt] + text_prompt = text_encoder(text_prompts=prompt_all, train=False) + model_kwargs = dict(encoder_hidden_states=text_prompt, class_labels=None) + + if args.use_ddim_sample_loop: + samples = diffusion.ddim_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \ + progress=True, device=device, mask=mask, x_start=masked_video, use_concat=args.use_concat + ) + else: + samples = diffusion.p_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \ + progress=True, device=device, mask=mask, x_start=masked_video, use_concat=args.use_concat + ) # torch.Size([2, 4, 16, 32, 32]) + samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] + + video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] + video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] + return video_clip + + +def auto_inpainting_copy_no_mask(args, video_input, prompt, vae, text_encoder, diffusion, model, device,): + + b,f,c,h,w=video_input.shape + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + + video_input = rearrange(video_input, 'b f c h w -> (b f) c h w').contiguous() + video_input = vae.encode(video_input).latent_dist.sample().mul_(0.18215) + video_input = rearrange(video_input, '(b f) c h w -> b c f h w', b=b).contiguous() + + lr_indice = torch.IntTensor([i for i in range(0,62,4)]).to(device) + copied_video = torch.index_select(video_input, 2, lr_indice) + copied_video = torch.repeat_interleave(copied_video, 4, dim=2) + copied_video = copied_video[:,:,1:-2,:,:] + copied_video = torch.cat([copied_video] * 2) if args.do_classifier_free_guidance else copied_video + + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w + z = torch.cat([z] * 2) if args.do_classifier_free_guidance else z + + prompt_all = [prompt] + [args.negative_prompt] if args.do_classifier_free_guidance else [prompt] + text_prompt = text_encoder(text_prompts=prompt_all, train=False) + model_kwargs = dict(encoder_hidden_states=text_prompt, class_labels=None) + + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + if args.use_ddim_sample_loop: + samples = diffusion.ddim_sample_loop( + model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \ + progress=True, device=device, mask=None, x_start=copied_video, use_concat=args.use_concat, copy_no_mask=args.copy_no_mask, + ) + else: + raise ValueError(f'We only have ddim sampling implementation for now') + + samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32] + + video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32] + video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256] + return video_clip + + + +def main(args): + + for seed in args.seed_list: + + args.seed = seed + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + # print(f'torch.seed() = {torch.seed()}') + + print('sampling begins') + torch.set_grad_enabled(False) + device = "cuda" if torch.cuda.is_available() else "cpu" + # device = "cpu" + + ckpt_path = args.pretrained_path + "/lavie_interpolation.pt" + sd_path = args.pretrained_path + "/stable-diffusion-v1-4" + for ckpt in [ckpt_path]: + + ckpt_num = str(ckpt_path).zfill(7) + + # Load model: + latent_h = args.image_size[0] // 8 + latent_w = args.image_size[1] // 8 + args.image_h = args.image_size[0] + args.image_w = args.image_size[1] + args.latent_h = latent_h + args.latent_w = latent_w + print(f'args.copy_no_mask = {args.copy_no_mask}') + model = get_models(args, sd_path).to(device) + + if args.use_compile: + model = torch.compile(model) + if args.enable_xformers_memory_efficient_attention: + if is_xformers_available(): + model.enable_xformers_memory_efficient_attention() + # model.enable_vae_slicing() # ziqi added + else: + raise ValueError("xformers is not available. Make sure it is installed correctly") + + # Auto-download a pre-trained model or load a custom checkpoint from train.py: + print(f'loading model from {ckpt_path}') + + # load ckpt + state_dict = find_model(ckpt_path) + + print(f'state_dict["conv_in.weight"].shape = {state_dict["conv_in.weight"].shape}') # [320, 8, 3, 3] + print('loading succeed') + # model.load_state_dict(state_dict) + + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + model.eval() # important! + diffusion = create_diffusion(str(args.num_sampling_steps)) + vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae").to(device) + text_encoder = TextEmbedder(sd_path).to(device) + + video_list = os.listdir(args.input_folder) + args.input_path_list = [os.path.join(args.input_folder, video) for video in video_list] + for input_path in args.input_path_list: + + args.input_path = input_path + + print(f'=======================================') + if not args.input_path.endswith('.mp4'): + print(f'Skipping {args.input_path}') + continue + + print(f'args.input_path = {args.input_path}') + + torch.manual_seed(args.seed) + torch.cuda.manual_seed(args.seed) + + # Labels to condition the model with (feel free to change): + video_name = args.input_path.split('/')[-1].split('.mp4')[0] + args.prompt = [video_name] + print(f'args.prompt = {args.prompt}') + prompts = args.prompt + class_name = [p + args.additional_prompt for p in prompts] + + if not os.path.exists(os.path.join(args.output_folder)): + os.makedirs(os.path.join(args.output_folder)) + video_input, researve_frames = get_input(args) # f,c,h,w + video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w + if args.copy_no_mask: + pass + else: + mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w + + if args.copy_no_mask: + pass + else: + if args.mask_type == 'tsr': + masked_video = video_input * (mask == 0) + else: + masked_video = video_input * (mask == 0) + + all_video = [] + if researve_frames != 0: + all_video.append(video_input) + for idx, prompt in enumerate(class_name): + if idx == 0: + if args.copy_no_mask: + video_clip = auto_inpainting_copy_no_mask(args, video_input, prompt, vae, text_encoder, diffusion, model, device,) + video_clip_ = video_clip.unsqueeze(0) + all_video.append(video_clip_) + else: + video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,) + video_clip_ = video_clip.unsqueeze(0) + all_video.append(video_clip_) + else: + raise NotImplementedError + masked_video = video_input * (mask == 0) + video_clip = auto_inpainting_copy_no_mask(args, video_clip.unsqueeze(0), masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,) + video_clip_ = video_clip.unsqueeze(0) + all_video.append(video_clip_[:, 3:]) + video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1) + for fps in args.fps_list: + save_path = args.output_folder + if not os.path.exists(os.path.join(save_path)): + os.makedirs(os.path.join(save_path)) + local_save_path = os.path.join(save_path, f'{video_name}.mp4') + print(f'save in {local_save_path}') + torchvision.io.write_video(local_save_path, video_, fps=fps) + + + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, required=True) + args = parser.parse_args() + main(**OmegaConf.load(args.config)) + + diff --git a/interpolation/utils.py b/interpolation/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..3219c0af49cb1bd039abe6899aab0ec85a4a4b1e --- /dev/null +++ b/interpolation/utils.py @@ -0,0 +1,371 @@ +import os +import math +import torch +import logging +import subprocess +import numpy as np +import torch.distributed as dist + +# from torch._six import inf +from torch import inf +from PIL import Image +from typing import Union, Iterable +from collections import OrderedDict + + +_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]] + +################################################################################# +# Training Helper Functions # +################################################################################# + +################################################################################# +# Training Clip Gradients # +################################################################################# + +def get_grad_norm( + parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor: + r""" + Copy from torch.nn.utils.clip_grad_norm_ + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + device = grads[0].device + if norm_type == inf: + norms = [g.detach().abs().max().to(device) for g in grads] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + return total_norm + +def clip_grad_norm_( + parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0, + error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor: + r""" + Copy from torch.nn.utils.clip_grad_norm_ + + Clips gradient norm of an iterable of parameters. + + The norm is computed over all gradients together, as if they were + concatenated into a single vector. Gradients are modified in-place. + + Args: + parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a + single Tensor that will have gradients normalized + max_norm (float or int): max norm of the gradients + norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for + infinity norm. + error_if_nonfinite (bool): if True, an error is thrown if the total + norm of the gradients from :attr:`parameters` is ``nan``, + ``inf``, or ``-inf``. Default: False (will switch to True in the future) + + Returns: + Total norm of the parameter gradients (viewed as a single vector). + """ + if isinstance(parameters, torch.Tensor): + parameters = [parameters] + grads = [p.grad for p in parameters if p.grad is not None] + max_norm = float(max_norm) + norm_type = float(norm_type) + if len(grads) == 0: + return torch.tensor(0.) + device = grads[0].device + if norm_type == inf: + norms = [g.detach().abs().max().to(device) for g in grads] + total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms)) + else: + total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + # print(total_norm) + + if clip_grad: + if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()): + raise RuntimeError( + f'The total norm of order {norm_type} for gradients from ' + '`parameters` is non-finite, so it cannot be clipped. To disable ' + 'this error and scale the gradients by the non-finite norm anyway, ' + 'set `error_if_nonfinite=False`') + clip_coef = max_norm / (total_norm + 1e-6) + # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so + # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization + # when the gradients do not reside in CPU memory. + clip_coef_clamped = torch.clamp(clip_coef, max=1.0) + for g in grads: + g.detach().mul_(clip_coef_clamped.to(g.device)) + # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type) + # print(gradient_cliped) + return total_norm + +################################################################################# +# Training Logger # +################################################################################# + +def create_logger(logging_dir): + """ + Create a logger that writes to a log file and stdout. + """ + if dist.get_rank() == 0: # real logger + logging.basicConfig( + level=logging.INFO, + # format='[\033[34m%(asctime)s\033[0m] %(message)s', + format='[%(asctime)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + +def create_accelerate_logger(logging_dir, is_main_process=False): + """ + Create a logger that writes to a log file and stdout. + """ + if is_main_process: # real logger + logging.basicConfig( + level=logging.INFO, + # format='[\033[34m%(asctime)s\033[0m] %(message)s', + format='[%(asctime)s] %(message)s', + datefmt='%Y-%m-%d %H:%M:%S', + handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")] + ) + logger = logging.getLogger(__name__) + else: # dummy logger (does nothing) + logger = logging.getLogger(__name__) + logger.addHandler(logging.NullHandler()) + return logger + + +def create_tensorboard(tensorboard_dir): + """ + Create a tensorboard that saves losses. + """ + if dist.get_rank() == 0: # real tensorboard + # tensorboard + writer = SummaryWriter(tensorboard_dir) + + return writer + +def write_tensorboard(writer, *args): + ''' + write the loss information to a tensorboard file. + Only for pytorch DDP mode. + ''' + if dist.get_rank() == 0: # real tensorboard + writer.add_scalar(args[0], args[1], args[2]) + +################################################################################# +# EMA Update/ DDP Training Utils # +################################################################################# + +@torch.no_grad() +def update_ema(ema_model, model, decay=0.9999): + """ + Step the EMA model towards the current model. + """ + ema_params = OrderedDict(ema_model.named_parameters()) + model_params = OrderedDict(model.named_parameters()) + + for name, param in model_params.items(): + # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed + ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay) + +def requires_grad(model, flag=True): + """ + Set requires_grad flag for all parameters in a model. + """ + for p in model.parameters(): + p.requires_grad = flag + +def cleanup(): + """ + End DDP training. + """ + dist.destroy_process_group() + + +def setup_distributed(backend="nccl", port=None): + """Initialize distributed training environment. + support both slurm and torch.distributed.launch + see torch.distributed.init_process_group() for more details + """ + num_gpus = torch.cuda.device_count() + + print(f'Hahahahahaha') + if "SLURM_JOB_ID" in os.environ: + rank = int(os.environ["SLURM_PROCID"]) + world_size = int(os.environ["SLURM_NTASKS"]) + node_list = os.environ["SLURM_NODELIST"] + addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1") + # specify master port + if port is not None: + os.environ["MASTER_PORT"] = str(port) + elif "MASTER_PORT" not in os.environ: + # os.environ["MASTER_PORT"] = "29566" + os.environ["MASTER_PORT"] = str(29566 + num_gpus) + if "MASTER_ADDR" not in os.environ: + os.environ["MASTER_ADDR"] = addr + os.environ["WORLD_SIZE"] = str(world_size) + os.environ["LOCAL_RANK"] = str(rank % num_gpus) + os.environ["RANK"] = str(rank) + else: + rank = int(os.environ["RANK"]) + world_size = int(os.environ["WORLD_SIZE"]) + + # torch.cuda.set_device(rank % num_gpus) + + print(f'before dist.init_process_group') + + dist.init_process_group( + backend=backend, + world_size=world_size, + rank=rank, + ) + print(f'after dist.init_process_group') + +################################################################################# +# Testing Utils # +################################################################################# + +def save_video_grid(video, nrow=None): + b, t, h, w, c = video.shape + + if nrow is None: + nrow = math.ceil(math.sqrt(b)) + ncol = math.ceil(b / nrow) + padding = 1 + video_grid = torch.zeros((t, (padding + h) * nrow + padding, + (padding + w) * ncol + padding, c), dtype=torch.uint8) + + print(video_grid.shape) + for i in range(b): + r = i // ncol + c = i % ncol + start_r = (padding + h) * r + start_c = (padding + w) * c + video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i] + + return video_grid + + +################################################################################# +# MMCV Utils # +################################################################################# + + +def collect_env(): + # Copyright (c) OpenMMLab. All rights reserved. + from mmcv.utils import collect_env as collect_base_env + from mmcv.utils import get_git_hash + """Collect the information of the running environments.""" + + env_info = collect_base_env() + env_info['MMClassification'] = get_git_hash()[:7] + + for name, val in env_info.items(): + print(f'{name}: {val}') + + print(torch.cuda.get_arch_list()) + print(torch.version.cuda) + +################################################################################# +# Long video generation Utils # +################################################################################# + +def mask_generation(mask_type, shape, dtype, device): + b, c, f, h, w = shape + if mask_type.startswith('random'): + num = float(mask_type.split('random')[-1]) + mask_f = torch.ones(1, 1, f, 1, 1, dtype=dtype, device=device) + indices = torch.randperm(f, device=device)[:int(f*num)] + mask_f[0, 0, indices, :, :] = 0 + mask = mask_f.expand(b, c, -1, h, w) + elif mask_type.startswith('first'): + num = int(mask_type.split('first')[-1]) + mask_f = torch.cat([torch.zeros(1, 1, num, 1, 1, dtype=dtype, device=device), + torch.ones(1, 1, f-num, 1, 1, dtype=dtype, device=device)], dim=2) + mask = mask_f.expand(b, c, -1, h, w) + else: + raise ValueError(f"Invalid mask type: {mask_type}") + return mask + + + +def mask_generation_before(mask_type, shape, dtype, device): + b, f, c, h, w = shape + if mask_type.startswith('random'): + num = float(mask_type.split('random')[-1]) + mask_f = torch.ones(1, f, 1, 1, 1, dtype=dtype, device=device) + indices = torch.randperm(f, device=device)[:int(f*num)] + mask_f[0, indices, :, :, :] = 0 + mask = mask_f.expand(b, -1, c, h, w) + elif mask_type.startswith('first'): + num = int(mask_type.split('first')[-1]) + mask_f = torch.cat([torch.zeros(1, num, 1, 1, 1, dtype=dtype, device=device), + torch.ones(1, f-num, 1, 1, 1, dtype=dtype, device=device)], dim=1) + mask = mask_f.expand(b, -1, c, h, w) + elif mask_type.startswith('uniform'): + p = float(mask_type.split('uniform')[-1]) + mask_f = torch.ones(1, f, 1, 1, 1, dtype=dtype, device=device) + mask_f[0, torch.rand(f, device=device) < p, :, :, :] = 0 + print(f'mask_f: = {mask_f}') + mask = mask_f.expand(b, -1, c, h, w) + print(f'mask.shape: = {mask.shape}, mask: = {mask}') + elif mask_type.startswith('all'): + mask = torch.ones(b,f,c,h,w,dtype=dtype,device=device) + elif mask_type.startswith('onelast'): + num = int(mask_type.split('onelast')[-1]) + mask_one = torch.zeros(1,1,1,1,1, dtype=dtype, device=device) + mask_mid = torch.ones(1,f-2*num,1,1,1,dtype=dtype, device=device) + mask_last = torch.zeros_like(mask_one) + mask = torch.cat([mask_one]*num + [mask_mid] + [mask_last]*num, dim=1) + # breakpoint() + mask = mask.expand(b, -1, c, h, w) + elif mask_type.startswith('interpolate'): + mask_f = [] + for i in range(4): + mask_zero = torch.zeros(1,1,1,1,1, dtype=dtype, device=device) + mask_f.append(mask_zero) + mask_one = torch.ones(1,3,1,1,1, dtype=dtype, device=device) + mask_f.append(mask_one) + mask = torch.cat(mask_f, dim=1) + print(f'mask={mask}') + elif mask_type.startswith('tsr'): + mask_f = [] + mask_zero = torch.zeros(1,1,1,1,1, dtype=dtype, device=device) + mask_one = torch.ones(1,3,1,1,1, dtype=dtype, device=device) + for i in range(15): + mask_f.append(mask_zero) # not masked + mask_f.append(mask_one) # masked + mask_f.append(mask_zero) # not masked + mask = torch.cat(mask_f, dim=1) + # print(f'before mask.shape = {mask.shape}, mask = {mask}') # [1, 61, 1, 1, 1] + mask = mask.expand(b, -1, c, h, w) + # print(f'after mask.shape = {mask.shape}, mask = {mask}') # [4, 61, 3, 256, 256] + else: + raise ValueError(f"Invalid mask type: {mask_type}") + + return mask diff --git a/pretrained_models/put_pre-trained_weights_here.txt b/pretrained_models/put_pre-trained_weights_here.txt new file mode 100644 index 0000000000000000000000000000000000000000..d7bf9943847113cf93c998a6fec2e8d6a8324404 --- /dev/null +++ b/pretrained_models/put_pre-trained_weights_here.txt @@ -0,0 +1 @@ +put pre-trianed weights here. \ No newline at end of file diff --git a/vsr/configs/sample.yaml b/vsr/configs/sample.yaml new file mode 100644 index 0000000000000000000000000000000000000000..96234f563b8f12e7c816bcdc1d1e52d084b029b7 --- /dev/null +++ b/vsr/configs/sample.yaml @@ -0,0 +1,6 @@ +pretrained_path: "../pretrained_models" +input_path: "../res/base" +output_path: "../res/vsr" +noise_level: 50 +guidance_scale: 5 +inference_steps: 50 diff --git a/vsr/configs/unet_3d_config.json b/vsr/configs/unet_3d_config.json new file mode 100644 index 0000000000000000000000000000000000000000..ccebb1884d60b9f309b0b3fd45f9d284034e1728 --- /dev/null +++ b/vsr/configs/unet_3d_config.json @@ -0,0 +1,66 @@ +{ + "_class_name": "UNet3DVSRModel", + "_diffusers_version": "0.9.0.dev0", + "_name_or_path": "hf-models/stable-diffusion-x4-upscaler/unet", + "act_fn": "silu", + "attention_head_dim": 8, + "block_out_channels": [ + 256, + 512, + 512, + 1024 + ], + "center_input_sample": false, + "cross_attention_dim": 1024, + "down_block_types": [ + "DownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D" + ], + "downsample_padding": 1, + "dual_cross_attention": false, + "flip_sin_to_cos": true, + "freq_shift": 0, + "in_channels": 7, + "layers_per_block": 2, + "mid_block_scale_factor": 1, + "norm_eps": 1e-05, + "norm_num_groups": 32, + "num_class_embeds": 1000, + "only_cross_attention": [ + true, + true, + true, + false + ], + "out_channels": 4, + "sample_size": 128, + "up_block_types": [ + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "UpBlock3D" + ], + "use_linear_projection": true, + + "down_temporal_idx": [0, 1, 2, 3], + "mid_temporal": true, + "up_temporal_idx": [0, 1, 2, 3], + "temporal_module_config": { + "num_attention_layers": 1, + "attention_block_types": [ + "", + "" + ], + "cross_frame_attention_mode": "0_i-1_i", + "temporal_shift_fold_div": 2, + "temporal_shift_direction": "right", + "use_dcn_warpping": false, + "use_deformable_conv": true, + "attention_dim_div": 2 + }, + "use_first_frame": false, + "video_condition": false, + "freeze_pretrained_2d_upsampler": true +} diff --git a/vsr/configs/vae_config.json b/vsr/configs/vae_config.json new file mode 100644 index 0000000000000000000000000000000000000000..3dca8898737c4bbb92ccd502b080507483ad1728 --- /dev/null +++ b/vsr/configs/vae_config.json @@ -0,0 +1,28 @@ +{ + "_class_name": "AutoencoderKL", + "_diffusers_version": "0.9.0.dev0", + "_name_or_path": "hf-models/stable-diffusion-x4-upscaler/vae", + "act_fn": "silu", + "block_out_channels": [ + 128, + 256, + 512 + ], + "down_block_types": [ + "DownEncoderBlock2D", + "DownEncoderBlock2D", + "DownEncoderBlock2D" + ], + "in_channels": 3, + "latent_channels": 4, + "layers_per_block": 2, + "norm_num_groups": 32, + "out_channels": 3, + "sample_size": 256, + "up_block_types": [ + "UpDecoderBlock2D", + "UpDecoderBlock2D", + "UpDecoderBlock2D" + ], + "scaling_factor": 0.08333 +} diff --git a/vsr/diffusion/__init__.py b/vsr/diffusion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d171551ccbeff3f7d38a6e89624815f7ebd1e4db --- /dev/null +++ b/vsr/diffusion/__init__.py @@ -0,0 +1,54 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from . import gaussian_diffusion as gd +from .respace import SpacedDiffusion, space_timesteps + + +# !important +def create_diffusion( + timestep_respacing="", + noise_schedule="linear", # 'linear' for training + use_kl=False, + rescale_learned_sigmas=False, + prediction_type='v_prediction', + variance_type='fixed_small', + beta_start=0.0001, + beta_end=0.02, + # beta_start=0.00085, + # beta_end=0.012, + diffusion_steps=1000 +): + betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps, beta_start=beta_start, beta_end=beta_end) + if prediction_type == 'epsilon': + model_mean_type = gd.ModelMeanType.EPSILON # EPSILON type for stable-diffusion-2-1 512 + elif prediction_type == 'xstart': + model_mean_type = gd.ModelMeanType.START_X + elif prediction_type == 'v_prediction': + model_mean_type = gd.ModelMeanType.PREVIOUS_V # gd.ModelMeanType.PREVIOUS_V for stable-diffusion-2-1 768/x4-upscaler + + if variance_type == 'fixed_small': + model_var_type = gd.ModelVarType.FIXED_SMALL + elif variance_type == 'fixed_large': + model_var_type = gd.ModelVarType.FIXED_LARGE + elif variance_type == 'learned_range': + model_var_type = gd.ModelVarType.LEARNED_RANGE + + if use_kl: + loss_type = gd.LossType.RESCALED_KL + elif rescale_learned_sigmas: + loss_type = gd.LossType.RESCALED_MSE + else: + loss_type = gd.LossType.MSE + if timestep_respacing is None or timestep_respacing == "": + timestep_respacing = [diffusion_steps] + return SpacedDiffusion( + use_timesteps=space_timesteps(diffusion_steps, timestep_respacing), + betas=betas, + model_mean_type=(model_mean_type), + model_var_type=(model_var_type), + loss_type=loss_type + # rescale_timesteps=rescale_timesteps, + ) diff --git a/vsr/diffusion/diffusion_utils.py b/vsr/diffusion/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060 --- /dev/null +++ b/vsr/diffusion/diffusion_utils.py @@ -0,0 +1,88 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +import torch as th +import numpy as np + + +def normal_kl(mean1, logvar1, mean2, logvar2): + """ + Compute the KL divergence between two gaussians. + Shapes are automatically broadcasted, so batches can be compared to + scalars, among other use cases. + """ + tensor = None + for obj in (mean1, logvar1, mean2, logvar2): + if isinstance(obj, th.Tensor): + tensor = obj + break + assert tensor is not None, "at least one argument must be a Tensor" + + # Force variances to be Tensors. Broadcasting helps convert scalars to + # Tensors, but it does not work for th.exp(). + logvar1, logvar2 = [ + x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) + for x in (logvar1, logvar2) + ] + + return 0.5 * ( + -1.0 + + logvar2 + - logvar1 + + th.exp(logvar1 - logvar2) + + ((mean1 - mean2) ** 2) * th.exp(-logvar2) + ) + + +def approx_standard_normal_cdf(x): + """ + A fast approximation of the cumulative distribution function of the + standard normal. + """ + return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3)))) + + +def continuous_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a continuous Gaussian distribution. + :param x: the targets + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + centered_x = x - means + inv_stdv = th.exp(-log_scales) + normalized_x = centered_x * inv_stdv + log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x) + return log_probs + + +def discretized_gaussian_log_likelihood(x, *, means, log_scales): + """ + Compute the log-likelihood of a Gaussian distribution discretizing to a + given image. + :param x: the target images. It is assumed that this was uint8 values, + rescaled to the range [-1, 1]. + :param means: the Gaussian mean Tensor. + :param log_scales: the Gaussian log stddev Tensor. + :return: a tensor like x of log probabilities (in nats). + """ + assert x.shape == means.shape == log_scales.shape + centered_x = x - means + inv_stdv = th.exp(-log_scales) + plus_in = inv_stdv * (centered_x + 1.0 / 255.0) + cdf_plus = approx_standard_normal_cdf(plus_in) + min_in = inv_stdv * (centered_x - 1.0 / 255.0) + cdf_min = approx_standard_normal_cdf(min_in) + log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12)) + log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12)) + cdf_delta = cdf_plus - cdf_min + log_probs = th.where( + x < -0.999, + log_cdf_plus, + th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))), + ) + assert log_probs.shape == x.shape + return log_probs diff --git a/vsr/diffusion/gaussian_diffusion.py b/vsr/diffusion/gaussian_diffusion.py new file mode 100644 index 0000000000000000000000000000000000000000..be197ae8b899ecb5e763226b8c56de335fa123d6 --- /dev/null +++ b/vsr/diffusion/gaussian_diffusion.py @@ -0,0 +1,923 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + + +import math + +import numpy as np +import torch as th +import enum + +from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +class ModelMeanType(enum.Enum): + """ + Which type of output the model predicts. + """ + + PREVIOUS_V = enum.auto() # v-parameterization for VSR; (see section 2.4 https://imagen.research.google/video/paper.pdf) + PREVIOUS_X = enum.auto() # the model predicts x_{t-1} + START_X = enum.auto() # the model predicts x_0 + EPSILON = enum.auto() # the model predicts epsilon + + +class ModelVarType(enum.Enum): + """ + What is used as the model's output variance. + The LEARNED_RANGE option has been added to allow the model to predict + values between FIXED_SMALL and FIXED_LARGE, making its job easier. + """ + + LEARNED = enum.auto() + FIXED_SMALL = enum.auto() + FIXED_LARGE = enum.auto() + LEARNED_RANGE = enum.auto() + + +class LossType(enum.Enum): + MSE = enum.auto() # use raw MSE loss (and KL when learning variances) + RESCALED_MSE = ( + enum.auto() + ) # use raw MSE loss (with RESCALED_KL when learning variances) + KL = enum.auto() # use the variational lower-bound + RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB + + def is_vb(self): + return self == LossType.KL or self == LossType.RESCALED_KL + + +def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac): + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + warmup_time = int(num_diffusion_timesteps * warmup_frac) + betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64) + return betas + + +def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps): + """ + This is the deprecated API for creating beta schedules. + See get_named_beta_schedule() for the new library of schedules. + """ + if beta_schedule == "scaled_linear": + betas = (np.linspace(beta_start**0.5, beta_end**0.5, num_diffusion_timesteps, dtype=np.float64)** 2) + elif beta_schedule == "linear": + betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "warmup10": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1) + elif beta_schedule == "warmup50": + betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5) + elif beta_schedule == "const": + betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64) + elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1 + betas = 1.0 / np.linspace( + num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64 + ) + else: + raise NotImplementedError(beta_schedule) + assert betas.shape == (num_diffusion_timesteps,) + return betas + + +def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, beta_start, beta_end): + """ + Get a pre-defined beta schedule for the given name. + The beta schedule library consists of beta schedules which remain similar + in the limit of num_diffusion_timesteps. + Beta schedules may be added, but should not be removed or changed once + they are committed to maintain backwards compatibility. + """ + if schedule_name in ["linear", "scaled_linear"]: + # Linear schedule from Ho et al, extended to work for any number of + # diffusion steps. + scale = 1000 / num_diffusion_timesteps + return get_beta_schedule( + schedule_name, + beta_start=scale * beta_start, + beta_end=scale * beta_end, + num_diffusion_timesteps=num_diffusion_timesteps, + ) + elif schedule_name == "squaredcos_cap_v2": + return betas_for_alpha_bar( + num_diffusion_timesteps, + lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2, + ) + else: + raise NotImplementedError(f"unknown beta schedule: {schedule_name}") + + +def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999): + """ + Create a beta schedule that discretizes the given alpha_t_bar function, + which defines the cumulative product of (1-beta) over time from t = [0,1]. + :param num_diffusion_timesteps: the number of betas to produce. + :param alpha_bar: a lambda that takes an argument t from 0 to 1 and + produces the cumulative product of (1-beta) up to that + part of the diffusion process. + :param max_beta: the maximum beta to use; use values lower than 1 to + prevent singularities. + """ + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return np.array(betas) + + +class GaussianDiffusion: + """ + Utilities for training and sampling diffusion models. + Original ported from this codebase: + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42 + :param betas: a 1-D numpy array of betas for each diffusion timestep, + starting at T and going to 1. + """ + + def __init__( + self, + *, + betas, + model_mean_type, + model_var_type, + loss_type + ): + + self.model_mean_type = model_mean_type + self.model_var_type = model_var_type + self.loss_type = loss_type + + # Use float64 for accuracy. + betas = np.array(betas, dtype=np.float64) + self.betas = betas + assert len(betas.shape) == 1, "betas must be 1-D" + assert (betas > 0).all() and (betas <= 1).all() + + self.num_timesteps = int(betas.shape[0]) + + alphas = 1.0 - betas + self.alphas_cumprod = np.cumprod(alphas, axis=0) + self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1]) + self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0) + assert self.alphas_cumprod_prev.shape == (self.num_timesteps,) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod) + self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod) + self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod) + self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod) + self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1) + + # calculations for posterior q(x_{t-1} | x_t, x_0) + self.posterior_variance = ( + betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain + self.posterior_log_variance_clipped = np.log( + np.append(self.posterior_variance[1], self.posterior_variance[1:]) + ) if len(self.posterior_variance) > 1 else np.array([]) + + self.posterior_mean_coef1 = ( + betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod) + ) + self.posterior_mean_coef2 = ( + (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod) + ) + + def q_mean_variance(self, x_start, t): + """ + Get the distribution q(x_t | x_0). + :param x_start: the [N x C x ...] tensor of noiseless inputs. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :return: A tuple (mean, variance, log_variance), all of x_start's shape. + """ + mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape) + log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape) + return mean, variance, log_variance + + def q_sample(self, x_start, t, noise=None): + """ + Diffuse the data for a given number of diffusion steps. + In other words, sample from q(x_t | x_0). + :param x_start: the initial data batch. + :param t: the number of diffusion steps (minus 1). Here, 0 means one step. + :param noise: if specified, the split-out normal noise. + :return: A noisy version of x_start. + """ + if noise is None: + noise = th.randn_like(x_start) + assert noise.shape == x_start.shape + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise + ) + + def q_posterior_mean_variance(self, x_start, x_t, t): + """ + Compute the mean and variance of the diffusion posterior: + q(x_{t-1} | x_t, x_0) + """ + assert x_start.shape == x_t.shape + posterior_mean = ( + _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start + + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t + ) + posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape) + posterior_log_variance_clipped = _extract_into_tensor( + self.posterior_log_variance_clipped, t, x_t.shape + ) + assert ( + posterior_mean.shape[0] + == posterior_variance.shape[0] + == posterior_log_variance_clipped.shape[0] + == x_start.shape[0] + ) + return posterior_mean, posterior_variance, posterior_log_variance_clipped + + def get_v(self, x_start, noise, t): + # v-prediction parameterization + # training loss type for stable-diffusion-2-1 768/x4-upscaler + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * noise - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start + ) + + def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None): + """ + Apply the model to get p(x_{t-1} | x_t), as well as a prediction of + the initial x, x_0. + :param model: the model, which takes a signal and a batch of timesteps + as input. + :param x: the [N x C x ...] tensor at time t. + :param t: a 1-D Tensor of timesteps. + :param clip_denoised: if True, clip the denoised signal into [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. Applies before + clip_denoised. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict with the following keys: + - 'mean': the model mean output. + - 'variance': the model variance output. + - 'log_variance': the log of 'variance'. + - 'pred_xstart': the prediction for x_0. + """ + if model_kwargs is None: + model_kwargs = {} + + B, F, C = x.shape[:3] + assert t.shape == (B,) + # model_output = model(x, t, **model_kwargs) + try: + model_output = model(x, t, **model_kwargs).sample # for tav unet + # print(model_output.shape) + except: + model_output = model(x, t, **model_kwargs) + + # for v prediction + # if self.model_mean_type == ModelMeanType.PREVIOUS_V: + # model_output = self._predict_eps_from_z_and_v(x, t, model_output) + if isinstance(model_output, tuple): + model_output, extra = model_output + else: + extra = None + + if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]: + assert model_output.shape == (B, F, C * 2, *x.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape) + max_log = _extract_into_tensor(np.log(self.betas), t, x.shape) + # The model_var_values is [-1, 1] for [min_var, max_var]. + frac = (model_var_values + 1) / 2 + model_log_variance = frac * max_log + (1 - frac) * min_log + model_variance = th.exp(model_log_variance) + else: + model_variance, model_log_variance = { + # for fixedlarge, we set the initial (log-)variance like so + # to get a better decoder log likelihood. + ModelVarType.FIXED_LARGE: ( + np.append(self.posterior_variance[1], self.betas[1:]), + np.log(np.append(self.posterior_variance[1], self.betas[1:])), + ), + ModelVarType.FIXED_SMALL: ( + self.posterior_variance, + self.posterior_log_variance_clipped, + ), + }[self.model_var_type] + model_variance = _extract_into_tensor(model_variance, t, x.shape) + model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape) + + def process_xstart(x): + if denoised_fn is not None: + x = denoised_fn(x) + if clip_denoised: + return x.clamp(-1, 1) + return x + + if self.model_mean_type == ModelMeanType.START_X: + pred_xstart = process_xstart(model_output) + else: + if self.model_mean_type == ModelMeanType.PREVIOUS_V: + pred_xstart = process_xstart( + self._predict_xstart_from_z_and_v(x_t=x, t=t, v=model_output) + ) + else: + pred_xstart = process_xstart( + self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output) + ) + + if self.model_mean_type == ModelMeanType.PREVIOUS_V: + eps = self._predict_eps_from_z_and_v(x_t=x, t=t, v=model_output) + else: + eps = model_output + + model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t) + + assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape + return { + "eps": eps, + "mean": model_mean, + "variance": model_variance, + "log_variance": model_log_variance, + "pred_xstart": pred_xstart, + "extra": extra, + } + + def _predict_xstart_from_eps(self, x_t, t, eps): + assert x_t.shape == eps.shape + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t + - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps + ) + + def _predict_eps_from_xstart(self, x_t, t, pred_xstart): + return ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) + + # for v prediction + def _predict_xstart_from_z_and_v(self, x_t, t, v): + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t - + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v + ) + + # for v prediction + def _predict_eps_from_z_and_v(self, x_t, t, v): + return ( + _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v + + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t + ) + + + def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute the mean for the previous step, given a function cond_fn that + computes the gradient of a conditional log probability with respect to + x. In particular, cond_fn computes grad(log(p(y|x))), and we want to + condition on y. + This uses the conditioning strategy from Sohl-Dickstein et al. (2015). + """ + gradient = cond_fn(x, t, **model_kwargs) + new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float() + return new_mean + + def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None): + """ + Compute what the p_mean_variance output would have been, should the + model's score function be conditioned by cond_fn. + See condition_mean() for details on cond_fn. + Unlike condition_mean(), this instead uses the conditioning strategy + from Song et al (2020). + """ + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + + eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"]) + eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs) + + out = p_mean_var.copy() + out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps) + out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t) + return out + + def p_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + ): + """ + Sample x_{t-1} from the model at the given timestep. + :param model: the model to sample from. + :param x: the current tensor at x_{t-1}. + :param t: the value of t, starting at 0 for the first diffusion step. + :param clip_denoised: if True, clip the x_start prediction to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - 'sample': a random sample from the model. + - 'pred_xstart': a prediction of x_0. + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + noise = th.randn_like(x) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + if cond_fn is not None: + out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs) + sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def p_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model. + :param model: the model module. + :param shape: the shape of the samples, (N, C, H, W). + :param noise: if specified, the noise from the encoder to sample. + Should be of the same shape as `shape`. + :param clip_denoised: if True, clip x_start predictions to [-1, 1]. + :param denoised_fn: if not None, a function which applies to the + x_start prediction before it is used to sample. + :param cond_fn: if not None, this is a gradient function that acts + similarly to the model. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param device: if specified, the device to create the samples on. + If not specified, use a model parameter's device. + :param progress: if True, show a tqdm progress bar. + :return: a non-differentiable batch of samples. + """ + final = None + for sample in self.p_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + ): + final = sample + return final["sample"] + + def p_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + ): + """ + Generate samples from the model and yield intermediate samples from + each timestep of diffusion. + Arguments are the same as p_sample_loop(). + Returns a generator over dicts, where each dict is the return value of + p_sample(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.p_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + ) + yield out + img = out["sample"] + + def ddim_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t-1} from the model using DDIM. + Same usage as p_sample(). + """ + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + if self.model_mean_type == ModelMeanType.PREVIOUS_V: + eps = out["eps"] + else: + eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"]) + + alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape) + alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape) + sigma = ( + eta + * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) + * th.sqrt(1 - alpha_bar / alpha_bar_prev) + ) + # Equation 12. + noise = th.randn_like(x) + mean_pred = ( + out["pred_xstart"] * th.sqrt(alpha_bar_prev) + + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps + ) + nonzero_mask = ( + (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) + ) # no noise when t == 0 + sample = mean_pred + nonzero_mask * sigma * noise + return {"sample": sample, "pred_xstart": out["pred_xstart"]} + + def ddim_reverse_sample( + self, + model, + x, + t, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + eta=0.0, + ): + """ + Sample x_{t+1} from the model using DDIM reverse ODE. + """ + assert eta == 0.0, "Reverse ODE only for deterministic path" + out = self.p_mean_variance( + model, + x, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + model_kwargs=model_kwargs, + ) + if cond_fn is not None: + out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs) + # Usually our model outputs epsilon, but we re-derive it + # in case we used x_start or x_prev prediction. + eps = ( + _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x + - out["pred_xstart"] + ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape) + alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape) + + # Equation 12. reversed + mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps + + return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]} + + def ddim_sample_loop( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Generate samples from the model using DDIM. + Same usage as p_sample_loop(). + """ + final = None + for sample in self.ddim_sample_loop_progressive( + model, + shape, + noise=noise, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + device=device, + progress=progress, + eta=eta, + ): + final = sample + return final["sample"] + + def ddim_sample_loop_progressive( + self, + model, + shape, + noise=None, + clip_denoised=True, + denoised_fn=None, + cond_fn=None, + model_kwargs=None, + device=None, + progress=False, + eta=0.0, + ): + """ + Use DDIM to sample from the model and yield intermediate samples from + each timestep of DDIM. + Same usage as p_sample_loop_progressive(). + """ + if device is None: + device = next(model.parameters()).device + assert isinstance(shape, (tuple, list)) + if noise is not None: + img = noise + else: + img = th.randn(*shape, device=device) + indices = list(range(self.num_timesteps))[::-1] + + if progress: + # Lazy import so that we don't depend on tqdm. + from tqdm.auto import tqdm + + indices = tqdm(indices) + + for i in indices: + t = th.tensor([i] * shape[0], device=device) + with th.no_grad(): + out = self.ddim_sample( + model, + img, + t, + clip_denoised=clip_denoised, + denoised_fn=denoised_fn, + cond_fn=cond_fn, + model_kwargs=model_kwargs, + eta=eta, + ) + yield out + img = out["sample"] + + def _vb_terms_bpd( + self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None + ): + """ + Get a term for the variational lower-bound. + The resulting units are bits (rather than nats, as one might expect). + This allows for comparison to other papers. + :return: a dict with the following keys: + - 'output': a shape [N] tensor of NLLs or KLs. + - 'pred_xstart': the x_0 predictions. + """ + true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + ) + out = self.p_mean_variance( + model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs + ) + kl = normal_kl( + true_mean, true_log_variance_clipped, out["mean"], out["log_variance"] + ) + kl = mean_flat(kl) / np.log(2.0) + + decoder_nll = -discretized_gaussian_log_likelihood( + x_start, means=out["mean"], log_scales=0.5 * out["log_variance"] + ) + assert decoder_nll.shape == x_start.shape + decoder_nll = mean_flat(decoder_nll) / np.log(2.0) + + # At the first timestep return the decoder NLL, + # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t)) + output = th.where((t == 0), decoder_nll, kl) + return {"output": output, "pred_xstart": out["pred_xstart"]} + + + def training_losses(self, model, x_start, t, loss_mask=None, model_kwargs=None, noise=None): + """ + Compute training losses for a single timestep. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param t: a batch of timestep indices. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :param noise: if specified, the specific Gaussian noise to try to remove. + :return: a dict with the key "loss" containing a tensor of shape [N]. + Some mean or variance settings may also have other keys. + """ + if model_kwargs is None: + model_kwargs = {} + if noise is None: + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start, t, noise=noise) + + terms = {} + + if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL: + terms["loss"] = self._vb_terms_bpd( + model=model, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + model_kwargs=model_kwargs, + )["output"] + if self.loss_type == LossType.RESCALED_KL: + terms["loss"] *= self.num_timesteps + elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: # loss_type=LossType.MSE by default + try: + model_output = model(x_t, t, **model_kwargs).sample # for tav unet + except: + model_output = model(x_t, t, **model_kwargs) + + if self.model_var_type in [ + ModelVarType.LEARNED, + ModelVarType.LEARNED_RANGE, + ]: + B, F, C = x_t.shape[:3] + assert model_output.shape == (B, F, C * 2, *x_t.shape[3:]) + model_output, model_var_values = th.split(model_output, C, dim=2) + # Learn the variance using the variational bound, but don't let + # it affect our mean prediction. + frozen_out = th.cat([model_output.detach(), model_var_values], dim=2) + terms["vb"] = self._vb_terms_bpd( + model=lambda *args, r=frozen_out: r, + x_start=x_start, + x_t=x_t, + t=t, + clip_denoised=False, + )["output"] + if self.loss_type == LossType.RESCALED_MSE: + # Divide by 1000 for equivalence with initial implementation. + # Without a factor of 1/1000, the VB term hurts the MSE term. + terms["vb"] *= self.num_timesteps / 1000.0 + print("=======") + + target = { + ModelMeanType.PREVIOUS_V: self.get_v(x_start=x_start, noise=noise, t=t), + ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance( + x_start=x_start, x_t=x_t, t=t + )[0], + ModelMeanType.START_X: x_start, + ModelMeanType.EPSILON: noise, + }[self.model_mean_type] + assert model_output.shape == target.shape == x_start.shape + if loss_mask is not None: + terms["mse"] = mean_flat(((target - model_output) ** 2) * loss_mask) + else: + terms["mse"] = mean_flat((target - model_output) ** 2) + + if "vb" in terms: + terms["loss"] = terms["mse"] + terms["vb"] + else: + terms["loss"] = terms["mse"] + else: + raise NotImplementedError(self.loss_type) + + return terms + + def _prior_bpd(self, x_start): + """ + Get the prior KL term for the variational lower-bound, measured in + bits-per-dim. + This term can't be optimized, as it only depends on the encoder. + :param x_start: the [N x C x ...] tensor of inputs. + :return: a batch of [N] KL values (in bits), one per batch element. + """ + batch_size = x_start.shape[0] + t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device) + qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t) + kl_prior = normal_kl( + mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0 + ) + return mean_flat(kl_prior) / np.log(2.0) + + def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None): + """ + Compute the entire variational lower-bound, measured in bits-per-dim, + as well as other related quantities. + :param model: the model to evaluate loss on. + :param x_start: the [N x C x ...] tensor of inputs. + :param clip_denoised: if True, clip denoised samples. + :param model_kwargs: if not None, a dict of extra keyword arguments to + pass to the model. This can be used for conditioning. + :return: a dict containing the following keys: + - total_bpd: the total variational lower-bound, per batch element. + - prior_bpd: the prior term in the lower-bound. + - vb: an [N x T] tensor of terms in the lower-bound. + - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep. + - mse: an [N x T] tensor of epsilon MSEs for each timestep. + """ + device = x_start.device + batch_size = x_start.shape[0] + + vb = [] + xstart_mse = [] + mse = [] + for t in list(range(self.num_timesteps))[::-1]: + t_batch = th.tensor([t] * batch_size, device=device) + noise = th.randn_like(x_start) + x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise) + # Calculate VLB term at the current timestep + with th.no_grad(): + out = self._vb_terms_bpd( + model, + x_start=x_start, + x_t=x_t, + t=t_batch, + clip_denoised=clip_denoised, + model_kwargs=model_kwargs, + ) + vb.append(out["output"]) + xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2)) + eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"]) + mse.append(mean_flat((eps - noise) ** 2)) + + vb = th.stack(vb, dim=1) + xstart_mse = th.stack(xstart_mse, dim=1) + mse = th.stack(mse, dim=1) + + prior_bpd = self._prior_bpd(x_start) + total_bpd = vb.sum(dim=1) + prior_bpd + return { + "total_bpd": total_bpd, + "prior_bpd": prior_bpd, + "vb": vb, + "xstart_mse": xstart_mse, + "mse": mse, + } + + +def _extract_into_tensor(arr, timesteps, broadcast_shape): + """ + Extract values from a 1-D numpy array for a batch of indices. + :param arr: the 1-D numpy array. + :param timesteps: a tensor of indices into the array to extract. + :param broadcast_shape: a larger shape of K dimensions with the batch + dimension equal to the length of timesteps. + :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims. + """ + res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float() + while len(res.shape) < len(broadcast_shape): + res = res[..., None] + return res + th.zeros(broadcast_shape, device=timesteps.device) diff --git a/vsr/diffusion/respace.py b/vsr/diffusion/respace.py new file mode 100644 index 0000000000000000000000000000000000000000..31c0092380367db477f154a39bf172ff64712fc4 --- /dev/null +++ b/vsr/diffusion/respace.py @@ -0,0 +1,130 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +import torch +import numpy as np +import torch as th + +from .gaussian_diffusion import GaussianDiffusion + + +def space_timesteps(num_timesteps, section_counts): + """ + Create a list of timesteps to use from an original diffusion process, + given the number of timesteps we want to take from equally-sized portions + of the original process. + For example, if there's 300 timesteps and the section counts are [10,15,20] + then the first 100 timesteps are strided to be 10 timesteps, the second 100 + are strided to be 15 timesteps, and the final 100 are strided to be 20. + If the stride is a string starting with "ddim", then the fixed striding + from the DDIM paper is used, and only one section is allowed. + :param num_timesteps: the number of diffusion steps in the original + process to divide up. + :param section_counts: either a list of numbers, or a string containing + comma-separated numbers, indicating the step count + per section. As a special case, use "ddimN" where N + is a number of steps to use the striding from the + DDIM paper. + :return: a set of diffusion steps from the original process to use. + """ + if isinstance(section_counts, str): + if section_counts.startswith("ddim"): + desired_count = int(section_counts[len("ddim") :]) + for i in range(1, num_timesteps): + if len(range(0, num_timesteps, i)) == desired_count: + return set(range(0, num_timesteps, i)) + raise ValueError( + f"cannot create exactly {num_timesteps} steps with an integer stride" + ) + section_counts = [int(x) for x in section_counts.split(",")] + size_per = num_timesteps // len(section_counts) + extra = num_timesteps % len(section_counts) + start_idx = 0 + all_steps = [] + for i, section_count in enumerate(section_counts): + size = size_per + (1 if i < extra else 0) + if size < section_count: + raise ValueError( + f"cannot divide section of {size} steps into {section_count}" + ) + if section_count <= 1: + frac_stride = 1 + else: + frac_stride = (size - 1) / (section_count - 1) + cur_idx = 0.0 + taken_steps = [] + for _ in range(section_count): + taken_steps.append(start_idx + round(cur_idx)) + cur_idx += frac_stride + all_steps += taken_steps + start_idx += size + return set(all_steps) + + +class SpacedDiffusion(GaussianDiffusion): + """ + A diffusion process which can skip steps in a base diffusion process. + :param use_timesteps: a collection (sequence or set) of timesteps from the + original diffusion process to retain. + :param kwargs: the kwargs to create the base diffusion process. + """ + + def __init__(self, use_timesteps, **kwargs): + self.use_timesteps = set(use_timesteps) + self.timestep_map = [] + self.original_num_steps = len(kwargs["betas"]) + + base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa + last_alpha_cumprod = 1.0 + new_betas = [] + for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod): + if i in self.use_timesteps: + new_betas.append(1 - alpha_cumprod / last_alpha_cumprod) + last_alpha_cumprod = alpha_cumprod + self.timestep_map.append(i) + kwargs["betas"] = np.array(new_betas) + super().__init__(**kwargs) + + def p_mean_variance( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().p_mean_variance(self._wrap_model(model), *args, **kwargs) + + # @torch.compile + def training_losses( + self, model, *args, **kwargs + ): # pylint: disable=signature-differs + return super().training_losses(self._wrap_model(model), *args, **kwargs) + + def condition_mean(self, cond_fn, *args, **kwargs): + return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs) + + def condition_score(self, cond_fn, *args, **kwargs): + return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs) + + def _wrap_model(self, model): + if isinstance(model, _WrappedModel): + return model + return _WrappedModel( + model, self.timestep_map, self.original_num_steps + ) + + def _scale_timesteps(self, t): + # Scaling is done by the wrapped model. + return t + + +class _WrappedModel: + def __init__(self, model, timestep_map, original_num_steps): + self.model = model + self.timestep_map = timestep_map + # self.rescale_timesteps = rescale_timesteps + self.original_num_steps = original_num_steps + + def __call__(self, x, ts, **kwargs): + map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype) + new_ts = map_tensor[ts] + # if self.rescale_timesteps: + # new_ts = new_ts.float() * (1000.0 / self.original_num_steps) + return self.model(x, new_ts, **kwargs) diff --git a/vsr/diffusion/scheduling_ddim.py b/vsr/diffusion/scheduling_ddim.py new file mode 100644 index 0000000000000000000000000000000000000000..8564006cac2d477a645178c01bb8ac18b871eacd --- /dev/null +++ b/vsr/diffusion/scheduling_ddim.py @@ -0,0 +1,462 @@ +# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion +# and https://github.com/hojonathanho/diffusion + +import math +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import numpy as np +import torch + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput + +try: + from diffusers.utils import randn_tensor +except: + from diffusers.utils.torch_utils import randn_tensor + +from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin + + +@dataclass +# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM +class DDIMSchedulerOutput(BaseOutput): + """ + Output class for the scheduler's step function output. + + Args: + prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the + denoising loop. + pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images): + The predicted denoised sample (x_{0}) based on the model output from the current timestep. + `pred_original_sample` can be used to preview progress or for guidance. + """ + + prev_sample: torch.FloatTensor + pred_original_sample: Optional[torch.FloatTensor] = None + + +# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar +def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor: + """ + Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of + (1-beta) over time from t = [0,1]. + + Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up + to that part of the diffusion process. + + + Args: + num_diffusion_timesteps (`int`): the number of betas to produce. + max_beta (`float`): the maximum beta to use; use values lower than 1 to + prevent singularities. + + Returns: + betas (`np.ndarray`): the betas used by the scheduler to step the model outputs + """ + + def alpha_bar(time_step): + return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2 + + betas = [] + for i in range(num_diffusion_timesteps): + t1 = i / num_diffusion_timesteps + t2 = (i + 1) / num_diffusion_timesteps + betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta)) + return torch.tensor(betas, dtype=torch.float32) + + +class DDIMScheduler(SchedulerMixin, ConfigMixin): + """ + Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising + diffusion probabilistic models (DDPMs) with non-Markovian guidance. + + [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__` + function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`. + [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and + [`~SchedulerMixin.from_pretrained`] functions. + + For more details, see the original paper: https://arxiv.org/abs/2010.02502 + + Args: + num_train_timesteps (`int`): number of diffusion steps used to train the model. + beta_start (`float`): the starting `beta` value of inference. + beta_end (`float`): the final `beta` value. + beta_schedule (`str`): + the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from + `linear`, `scaled_linear`, or `squaredcos_cap_v2`. + trained_betas (`np.ndarray`, optional): + option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc. + clip_sample (`bool`, default `True`): + option to clip predicted sample for numerical stability. + clip_sample_range (`float`, default `1.0`): + the maximum magnitude for sample clipping. Valid only when `clip_sample=True`. + set_alpha_to_one (`bool`, default `True`): + each diffusion step uses the value of alphas product at that step and at the previous one. For the final + step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`, + otherwise it uses the value of alpha at step 0. + steps_offset (`int`, default `0`): + an offset added to the inference steps. You can use a combination of `offset=1` and + `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in + stable diffusion. + prediction_type (`str`, default `epsilon`, optional): + prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion + process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4 + https://imagen.research.google/video/paper.pdf) + thresholding (`bool`, default `False`): + whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487). + Note that the thresholding method is unsuitable for latent-space diffusion models (such as + stable-diffusion). + dynamic_thresholding_ratio (`float`, default `0.995`): + the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen + (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`. + sample_max_value (`float`, default `1.0`): + the threshold value for dynamic thresholding. Valid only when `thresholding=True`. + """ + + _compatibles = [e.name for e in KarrasDiffusionSchedulers] + order = 1 + + @register_to_config + def __init__( + self, + num_train_timesteps: int = 1000, + beta_start: float = 0.0001, + beta_end: float = 0.02, + beta_schedule: str = "linear", + trained_betas: Optional[Union[np.ndarray, List[float]]] = None, + clip_sample: bool = True, + set_alpha_to_one: bool = True, + steps_offset: int = 0, + prediction_type: str = "epsilon", + thresholding: bool = False, + dynamic_thresholding_ratio: float = 0.995, + clip_sample_range: float = 1.0, + sample_max_value: float = 1.0, + ): + if trained_betas is not None: + self.betas = torch.tensor(trained_betas, dtype=torch.float32) + elif beta_schedule == "linear": + self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32) + elif beta_schedule == "scaled_linear": + # this schedule is very specific to the latent diffusion model. + self.betas = ( + torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2 + ) + elif beta_schedule == "squaredcos_cap_v2": + # Glide cosine schedule + self.betas = betas_for_alpha_bar(num_train_timesteps) + else: + raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}") + + self.alphas = 1.0 - self.betas + self.alphas_cumprod = torch.cumprod(self.alphas, dim=0) + + # At every step in ddim, we are looking into the previous alphas_cumprod + # For the final step, there is no previous alphas_cumprod because we are already at 0 + # `set_alpha_to_one` decides whether we set this parameter simply to one or + # whether we use the final alpha of the "non-previous" one. + self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0] + + # standard deviation of the initial noise distribution + self.init_noise_sigma = 1.0 + + # setable values + self.num_inference_steps = None + self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64)) + + def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor: + """ + Ensures interchangeability with schedulers that need to scale the denoising model input depending on the + current timestep. + + Args: + sample (`torch.FloatTensor`): input sample + timestep (`int`, optional): current timestep + + Returns: + `torch.FloatTensor`: scaled input sample + """ + return sample + + def _get_variance(self, timestep, prev_timestep): + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + beta_prod_t = 1 - alpha_prod_t + beta_prod_t_prev = 1 - alpha_prod_t_prev + + variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev) + + return variance + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample + def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor: + """ + "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the + prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by + s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing + pixels from saturation at each step. We find that dynamic thresholding results in significantly better + photorealism as well as better image-text alignment, especially when using very large guidance weights." + + https://arxiv.org/abs/2205.11487 + """ + dtype = sample.dtype + batch_size, channels, height, width = sample.shape + + if dtype not in (torch.float32, torch.float64): + sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half + + # Flatten sample for doing quantile calculation along each image + sample = sample.reshape(batch_size, channels * height * width) + + abs_sample = sample.abs() # "a certain percentile absolute pixel value" + + s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1) + s = torch.clamp( + s, min=1, max=self.config.sample_max_value + ) # When clamped to min=1, equivalent to standard clipping to [-1, 1] + + s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0 + sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s" + + sample = sample.reshape(batch_size, channels, height, width) + sample = sample.to(dtype) + + return sample + + # def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + # """ + # Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + # Args: + # num_inference_steps (`int`): + # the number of diffusion steps used when generating samples with a pre-trained model. + # """ + + # if num_inference_steps > self.config.num_train_timesteps: + # raise ValueError( + # f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + # f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + # f" maximal {self.config.num_train_timesteps} timesteps." + # ) + + # self.num_inference_steps = num_inference_steps + # step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # # creates integer timesteps by multiplying by ratio + # # casting to int to avoid issues when num_inference_step is power of 3 + # timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64) + # self.timesteps = torch.from_numpy(timesteps).to(device) + # self.timesteps += self.config.steps_offset + + def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None): + """ + Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference. + + Args: + num_inference_steps (`int`): + the number of diffusion steps used when generating samples with a pre-trained model. + """ + + if num_inference_steps > self.config.num_train_timesteps: + raise ValueError( + f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:" + f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle" + f" maximal {self.config.num_train_timesteps} timesteps." + ) + + self.num_inference_steps = num_inference_steps + step_ratio = self.config.num_train_timesteps // self.num_inference_steps + # creates integer timesteps by multiplying by ratio + # casting to int to avoid issues when num_inference_step is power of 3 + timesteps = np.linspace(self.config.steps_offset, self.config.num_train_timesteps, num_inference_steps) + timesteps = timesteps.round()[::-1].copy().astype(np.int64) + self.timesteps = torch.from_numpy(timesteps).to(device) + self.timesteps += self.config.steps_offset + + def step( + self, + model_output: torch.FloatTensor, + timestep: int, + sample: torch.FloatTensor, + eta: float = 0.0, + use_clipped_model_output: bool = False, + generator=None, + variance_noise: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + ) -> Union[DDIMSchedulerOutput, Tuple]: + """ + Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion + process from the learned model outputs (most often the predicted noise). + + Args: + model_output (`torch.FloatTensor`): direct output from learned diffusion model. + timestep (`int`): current discrete timestep in the diffusion chain. + sample (`torch.FloatTensor`): + current instance of sample being created by diffusion process. + eta (`float`): weight of noise for added noise in diffusion step. + use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped + predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when + `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would + coincide with the one provided as input and `use_clipped_model_output` will have not effect. + generator: random number generator. + variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we + can directly provide the noise for the variance itself. This is useful for methods such as + CycleDiffusion. (https://arxiv.org/abs/2210.05559) + return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class + + Returns: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`: + [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + + """ + if self.num_inference_steps is None: + raise ValueError( + "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler" + ) + + # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf + # Ideally, read DDIM paper in-detail understanding + + # Notation ( -> + # - pred_noise_t -> e_theta(x_t, t) + # - pred_original_sample -> f_theta(x_t, t) or x_0 + # - std_dev_t -> sigma_t + # - eta -> Ξ· + # - pred_sample_direction -> "direction pointing to x_t" + # - pred_prev_sample -> "x_t-1" + + # 1. get previous step value (=t-1) + prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps + + # 2. compute alphas, betas + alpha_prod_t = self.alphas_cumprod[timestep] + alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod + + beta_prod_t = 1 - alpha_prod_t + + # 3. compute predicted original sample from predicted noise also called + # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + # print('===========', self.config.prediction_type) + # self.config.prediction_type = "v_prediction" + if self.config.prediction_type == "epsilon": + pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5) + pred_epsilon = model_output + elif self.config.prediction_type == "sample": + pred_original_sample = model_output + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + elif self.config.prediction_type == "v_prediction": + pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output + pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample + else: + raise ValueError( + f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or" + " `v_prediction`" + ) + + # 4. Clip or threshold "predicted x_0" + if self.config.thresholding: + pred_original_sample = self._threshold_sample(pred_original_sample) + elif self.config.clip_sample: + pred_original_sample = pred_original_sample.clamp( + -self.config.clip_sample_range, self.config.clip_sample_range + ) + + # 5. compute variance: "sigma_t(Ξ·)" -> see formula (16) + # Οƒ_t = sqrt((1 βˆ’ Ξ±_tβˆ’1)/(1 βˆ’ Ξ±_t)) * sqrt(1 βˆ’ Ξ±_t/Ξ±_tβˆ’1) + variance = self._get_variance(timestep, prev_timestep) + std_dev_t = eta * variance ** (0.5) + + if use_clipped_model_output: + # the pred_epsilon is always re-derived from the clipped x_0 in Glide + pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5) + + # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon + + # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf + prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction + + if eta > 0: + if variance_noise is not None and generator is not None: + raise ValueError( + "Cannot pass both generator and variance_noise. Please make sure that either `generator` or" + " `variance_noise` stays `None`." + ) + + if variance_noise is None: + variance_noise = randn_tensor( + model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype + ) + variance = std_dev_t * variance_noise + + prev_sample = prev_sample + variance + + if not return_dict: + return (prev_sample,) + + return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample) + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise + def add_noise( + self, + original_samples: torch.FloatTensor, + noise: torch.FloatTensor, + timesteps: torch.IntTensor, + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as original_samples + alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype) + timesteps = timesteps.to(original_samples.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(original_samples.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise + return noisy_samples + + # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity + def get_velocity( + self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor + ) -> torch.FloatTensor: + # Make sure alphas_cumprod and timestep have same device and dtype as sample + alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype) + timesteps = timesteps.to(sample.device) + + sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5 + sqrt_alpha_prod = sqrt_alpha_prod.flatten() + while len(sqrt_alpha_prod.shape) < len(sample.shape): + sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1) + + sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5 + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten() + while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape): + sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1) + + velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample + return velocity + + def __len__(self): + return self.config.num_train_timesteps diff --git a/vsr/diffusion/timestep_sampler.py b/vsr/diffusion/timestep_sampler.py new file mode 100644 index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f --- /dev/null +++ b/vsr/diffusion/timestep_sampler.py @@ -0,0 +1,150 @@ +# Modified from OpenAI's diffusion repos +# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py +# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion +# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py + +from abc import ABC, abstractmethod + +import numpy as np +import torch as th +import torch.distributed as dist + + +def create_named_schedule_sampler(name, diffusion): + """ + Create a ScheduleSampler from a library of pre-defined samplers. + :param name: the name of the sampler. + :param diffusion: the diffusion object to sample for. + """ + if name == "uniform": + return UniformSampler(diffusion) + elif name == "loss-second-moment": + return LossSecondMomentResampler(diffusion) + else: + raise NotImplementedError(f"unknown schedule sampler: {name}") + + +class ScheduleSampler(ABC): + """ + A distribution over timesteps in the diffusion process, intended to reduce + variance of the objective. + By default, samplers perform unbiased importance sampling, in which the + objective's mean is unchanged. + However, subclasses may override sample() to change how the resampled + terms are reweighted, allowing for actual changes in the objective. + """ + + @abstractmethod + def weights(self): + """ + Get a numpy array of weights, one per diffusion step. + The weights needn't be normalized, but must be positive. + """ + + def sample(self, batch_size, device): + """ + Importance-sample timesteps for a batch. + :param batch_size: the number of timesteps. + :param device: the torch device to save to. + :return: a tuple (timesteps, weights): + - timesteps: a tensor of timestep indices. + - weights: a tensor of weights to scale the resulting losses. + """ + w = self.weights() + p = w / np.sum(w) + indices_np = np.random.choice(len(p), size=(batch_size,), p=p) + indices = th.from_numpy(indices_np).long().to(device) + weights_np = 1 / (len(p) * p[indices_np]) + weights = th.from_numpy(weights_np).float().to(device) + return indices, weights + + +class UniformSampler(ScheduleSampler): + def __init__(self, diffusion): + self.diffusion = diffusion + self._weights = np.ones([diffusion.num_timesteps]) + + def weights(self): + return self._weights + + +class LossAwareSampler(ScheduleSampler): + def update_with_local_losses(self, local_ts, local_losses): + """ + Update the reweighting using losses from a model. + Call this method from each rank with a batch of timesteps and the + corresponding losses for each of those timesteps. + This method will perform synchronization to make sure all of the ranks + maintain the exact same reweighting. + :param local_ts: an integer Tensor of timesteps. + :param local_losses: a 1D Tensor of losses. + """ + batch_sizes = [ + th.tensor([0], dtype=th.int32, device=local_ts.device) + for _ in range(dist.get_world_size()) + ] + dist.all_gather( + batch_sizes, + th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device), + ) + + # Pad all_gather batches to be the maximum batch size. + batch_sizes = [x.item() for x in batch_sizes] + max_bs = max(batch_sizes) + + timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes] + loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes] + dist.all_gather(timestep_batches, local_ts) + dist.all_gather(loss_batches, local_losses) + timesteps = [ + x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs] + ] + losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]] + self.update_with_all_losses(timesteps, losses) + + @abstractmethod + def update_with_all_losses(self, ts, losses): + """ + Update the reweighting using losses from a model. + Sub-classes should override this method to update the reweighting + using losses from the model. + This method directly updates the reweighting without synchronizing + between workers. It is called by update_with_local_losses from all + ranks with identical arguments. Thus, it should have deterministic + behavior to maintain state across workers. + :param ts: a list of int timesteps. + :param losses: a list of float losses, one per timestep. + """ + + +class LossSecondMomentResampler(LossAwareSampler): + def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001): + self.diffusion = diffusion + self.history_per_term = history_per_term + self.uniform_prob = uniform_prob + self._loss_history = np.zeros( + [diffusion.num_timesteps, history_per_term], dtype=np.float64 + ) + self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int) + + def weights(self): + if not self._warmed_up(): + return np.ones([self.diffusion.num_timesteps], dtype=np.float64) + weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1)) + weights /= np.sum(weights) + weights *= 1 - self.uniform_prob + weights += self.uniform_prob / len(weights) + return weights + + def update_with_all_losses(self, ts, losses): + for t, loss in zip(ts, losses): + if self._loss_counts[t] == self.history_per_term: + # Shift out the oldest loss term. + self._loss_history[t, :-1] = self._loss_history[t, 1:] + self._loss_history[t, -1] = loss + else: + self._loss_history[t, self._loss_counts[t]] = loss + self._loss_counts[t] += 1 + + def _warmed_up(self): + return (self._loss_counts == self.history_per_term).all() diff --git a/vsr/models/__init__.py b/vsr/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..23fe8e3be56d23ed269e7e9dfd89987eb68e47ea --- /dev/null +++ b/vsr/models/__init__.py @@ -0,0 +1,28 @@ +from .unet import UNet3DVSRModel +from torch.optim.lr_scheduler import LambdaLR + +def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit + from torch.optim.lr_scheduler import LambdaLR + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'warmup': + return customized_lr_scheduler(optimizer, **kwargs) + elif name == 'cosine': + from torch.optim.lr_scheduler import CosineAnnealingLR + return CosineAnnealingLR(optimizer, **kwargs) + else: + raise NotImplementedError(name) + +def get_models(): + config_path = "./configs/unet_3d_config.json" + pretrained_model_path = "./pretrained_models/upscaler4x/unet/diffusion_pytorch_model.bin" + return UNet3DVSRModel.from_pretrained_2d(config_path, pretrained_model_path) + + \ No newline at end of file diff --git a/vsr/models/attention.py b/vsr/models/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..6b4f4514752d23e916d50c68e18a7a672a3b1fcf --- /dev/null +++ b/vsr/models/attention.py @@ -0,0 +1,826 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +from dataclasses import dataclass +from typing import Optional + +import math +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm + +try: + from .resnet import ResnetBlock3DCNN +except: + from resnet import ResnetBlock3DCNN + +from rotary_embedding_torch import RotaryEmbedding +from typing import Callable, Optional +from einops import rearrange, repeat + +@dataclass +class Transformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + +def exists(x): + return x is not None + + +class CrossAttention(nn.Module): + r""" + copy from diffuser 0.11.1 + A cross attention layer. + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + use_relative_position: bool = False, + ): + super().__init__() + # print('num head', heads) + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + self.dim_head = dim_head + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + # print(use_relative_position) + self.use_relative_position = use_relative_position + if self.use_relative_position: + self.rotary_emb = RotaryEmbedding(min(32, dim_head)) + # # print(dim_head) + # # print(heads) + # # adopt https://github.com/huggingface/transformers/blob/8a817e1ecac6a420b1bdc701fcc33535a3b96ff5/src/transformers/models/bert/modeling_bert.py#L265 + # self.max_position_embeddings = 32 + # self.distance_embedding = nn.Embedding(2 * self.max_position_embeddings - 1, dim_head) + + # self.dropout = nn.Dropout(dropout) + + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def reshape_for_scores(self, tensor): + # split heads and dims + # tensor should be [b (h w)] f (d nd) + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).contiguous() + return tensor + + def same_batch_dim_to_heads(self, tensor): + batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d + tensor = tensor.reshape(batch_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + + # print('before reshpape query shape', query.shape) + dim = query.shape[-1] + if not self.use_relative_position: + query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d + # print('after reshape query shape', query.shape) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if not self.use_relative_position: + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + + if attention_mask is not None: + # print('attention_mask', attention_mask.shape) + # print('attention_scores', attention_scores.shape) + # exit() + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + # print(attention_probs.shape) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + # print(attention_probs.shape) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + # print(hidden_states.shape) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + # print(hidden_states.shape) + # exit() + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class Transformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + # add 3D CNN for VSR + # if only_cross_attention == False: # x8 down + # self.resblock_temporal = ResnetBlock3DCNN(in_channels=in_channels, kernel=(3,3,3), temb_channels=None) + # else: + # self.resblock_temporal = ResnetBlock3DCNN(in_channels=in_channels, kernel=(5,1,1), temb_channels=None) + + self.resblock_temporal = ResnetBlock3DCNN(in_channels=in_channels, kernel=(3,1,1), temb_channels=None) + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + + # 3D CNN for VSR + hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length) + hidden_states = self.resblock_temporal(hidden_states) + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w", f=video_length) + + residual = hidden_states + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return Transformer3DModelOutput(sample=output) + + +class BasicTransformerBlock(nn.Module): + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + use_first_frame: bool = False, + use_relative_position: bool = False, + rotary_emb: bool = None, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + # print(only_cross_attention) + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_first_frame = use_first_frame # False for VSR + + # SC-Attn + if use_first_frame and only_cross_attention == False: + self.attn1 = SparseCausalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + # print(cross_attention_dim) + else: + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) + else: + self.attn2 = None + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + + # Temp-Attn for VSR + self.attn_temporal = TemporalAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=None, + upcast_attention=upcast_attention, + rotary_emb=rotary_emb, + ) + nn.init.zeros_(self.attn_temporal.to_out[0].weight.data) + self.norm_temporal = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + # Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # SparseCausal-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: # Cross-Attention + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: # Self-Attention + if self.use_first_frame: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + + if self.attn2 is not None: + # Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + + # Temporal-Attention for VSR + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length).contiguous() + norm_hidden_states = ( + self.norm_temporal(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temporal(hidden_states) + ) + hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous() + + # Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + return hidden_states + + +class SparseCausalAttention(CrossAttention): + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + former_frame_index = torch.arange(video_length) - 1 + former_frame_index[0] = 0 + + key = rearrange(key, "(b f) d c -> b f d c", f=video_length) + key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2) + key = rearrange(key, "b f d c -> (b f) d c") + + value = rearrange(value, "(b f) d c -> b f d c", f=video_length) + value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2) + value = rearrange(value, "b f d c -> (b f) d c") + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + +class TemporalAttention(CrossAttention): + def __init__(self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + rotary_emb=None): + super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups) + # relative time positional embeddings + self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet + self.rotary_emb = rotary_emb + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device) + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) # [b (h w)] f (nd * d) + dim = query.shape[-1] + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + + def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + # reshape for adding time positional bais + query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads + # print('query shape', query.shape) + # print('key shape', key.shape) + # print('value shape', value.shape) + + # torch.baddbmm only accepte 3-D tensor + # https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm + # attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2)) + if exists(self.rotary_emb): + query = self.rotary_emb.rotate_queries_or_keys(query) + key = self.rotary_emb.rotate_queries_or_keys(key) + + attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key) + # print('attention_scores shape', attention_scores.shape) + # print('time_rel_pos_bias shape', time_rel_pos_bias.shape) + # print('attention_mask shape', attention_mask.shape) + + attention_scores = attention_scores + time_rel_pos_bias + # print(attention_scores.shape) + + # bert from huggin face + # attention_scores = attention_scores / math.sqrt(self.dim_head) + + # # Normalize the attention scores to probabilities. + # attention_probs = nn.functional.softmax(attention_scores, dim=-1) + + if attention_mask is not None: + # add attention mask + attention_scores = attention_scores + attention_mask + + # vdm + attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach() + + attention_probs = nn.functional.softmax(attention_scores, dim=-1) + # print(attention_probs[0][0]) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + # hidden_states = torch.matmul(attention_probs, value) + hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value) + # print(hidden_states.shape) + # hidden_states = self.same_batch_dim_to_heads(hidden_states) + hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)') + # print(hidden_states.shape) + # exit() + return hidden_states + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames diff --git a/vsr/models/autoencoder_kl.py b/vsr/models/autoencoder_kl.py new file mode 100644 index 0000000000000000000000000000000000000000..1983c3f1cefdbccf47659d79e4b4180fb9cc6455 --- /dev/null +++ b/vsr/models/autoencoder_kl.py @@ -0,0 +1,334 @@ +# Copyright 2023 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from dataclasses import dataclass +from typing import Optional, Tuple, Union + +import torch +import torch.nn as nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.utils import BaseOutput + +try: + from diffusers.utils import apply_forward_hook +except: + from diffusers.utils.accelerate_utils import apply_forward_hook + +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder + + +@dataclass +class AutoencoderKLOutput(BaseOutput): + """ + Output of AutoencoderKL encoding method. + + Args: + latent_dist (`DiagonalGaussianDistribution`): + Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`. + `DiagonalGaussianDistribution` allows for sampling latents from the distribution. + """ + + latent_dist: "DiagonalGaussianDistribution" + + +class AutoencoderKL(ModelMixin, ConfigMixin): + r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma + and Max Welling. + + This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library + implements for all the model (such as downloading or saving, etc.) + + Parameters: + in_channels (int, *optional*, defaults to 3): Number of channels in the input image. + out_channels (int, *optional*, defaults to 3): Number of channels in the output. + down_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types. + up_block_types (`Tuple[str]`, *optional*, defaults to : + obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types. + block_out_channels (`Tuple[int]`, *optional*, defaults to : + obj:`(64,)`): Tuple of block output channels. + act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use. + latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space. + sample_size (`int`, *optional*, defaults to `32`): TODO + scaling_factor (`float`, *optional*, defaults to 0.18215): + The component-wise standard deviation of the trained latent space computed using the first batch of the + training set. This is used to scale the latent space to have unit variance when training the diffusion + model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the + diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1 + / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image + Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper. + """ + + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + in_channels: int = 3, + out_channels: int = 3, + down_block_types: Tuple[str] = ("DownEncoderBlock2D",), + up_block_types: Tuple[str] = ("UpDecoderBlock2D",), + block_out_channels: Tuple[int] = (64,), + layers_per_block: int = 1, + act_fn: str = "silu", + latent_channels: int = 4, + norm_num_groups: int = 32, + sample_size: int = 32, + scaling_factor: float = 0.18215, + ): + super().__init__() + + # pass init params to Encoder + self.encoder = Encoder( + in_channels=in_channels, + out_channels=latent_channels, + down_block_types=down_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + act_fn=act_fn, + norm_num_groups=norm_num_groups, + double_z=True, + ) + + # pass init params to Decoder + self.decoder = Decoder( + in_channels=latent_channels, + out_channels=out_channels, + up_block_types=up_block_types, + block_out_channels=block_out_channels, + layers_per_block=layers_per_block, + norm_num_groups=norm_num_groups, + act_fn=act_fn, + ) + + self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1) + self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1) + + self.use_slicing = False + self.use_tiling = False + + # only relevant if vae tiling is enabled + self.tile_sample_min_size = self.config.sample_size + sample_size = ( + self.config.sample_size[0] + if isinstance(self.config.sample_size, (list, tuple)) + else self.config.sample_size + ) + self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1))) + self.tile_overlap_factor = 0.25 + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (Encoder, Decoder)): + module.gradient_checkpointing = value + + def enable_tiling(self, use_tiling: bool = True): + r""" + Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to + compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow + the processing of larger images. + """ + self.use_tiling = use_tiling + + def disable_tiling(self): + r""" + Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to + computing decoding in one step. + """ + self.enable_tiling(False) + + def enable_slicing(self): + r""" + Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to + compute decoding in several steps. This is useful to save some memory and allow larger batch sizes. + """ + self.use_slicing = True + + def disable_slicing(self): + r""" + Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing + decoding in one step. + """ + self.use_slicing = False + + @apply_forward_hook + def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size): + return self.tiled_encode(x, return_dict=return_dict) + + h = self.encoder(x) + moments = self.quant_conv(h) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size): + return self.tiled_decode(z, return_dict=return_dict) + + z = self.post_quant_conv(z) + dec = self.decoder(z) + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + @apply_forward_hook + def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + if self.use_slicing and z.shape[0] > 1: + decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)] + decoded = torch.cat(decoded_slices) + else: + decoded = self._decode(z).sample + + if not return_dict: + return (decoded,) + + return DecoderOutput(sample=decoded) + + def blend_v(self, a, b, blend_extent): + for y in range(min(a.shape[2], b.shape[2], blend_extent)): + b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent) + return b + + def blend_h(self, a, b, blend_extent): + for x in range(min(a.shape[3], b.shape[3], blend_extent)): + b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent) + return b + + def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput: + r"""Encode a batch of images using a tiled encoder. + + Args: + When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is: + different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + look of the output, but they should be much less noticeable. + x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple. + """ + overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor) + row_limit = self.tile_latent_min_size - blend_extent + + # Split the image into 512x512 tiles and encode them separately. + rows = [] + for i in range(0, x.shape[2], overlap_size): + row = [] + for j in range(0, x.shape[3], overlap_size): + tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size] + tile = self.encoder(tile) + tile = self.quant_conv(tile) + row.append(tile) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + moments = torch.cat(result_rows, dim=2) + posterior = DiagonalGaussianDistribution(moments) + + if not return_dict: + return (posterior,) + + return AutoencoderKLOutput(latent_dist=posterior) + + def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]: + r"""Decode a batch of images using a tiled decoder. + + Args: + When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several + steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is: + different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the + tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the + look of the output, but they should be much less noticeable. + z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to + `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor)) + blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor) + row_limit = self.tile_sample_min_size - blend_extent + + # Split z into overlapping 64x64 tiles and decode them separately. + # The tiles have an overlap to avoid seams between tiles. + rows = [] + for i in range(0, z.shape[2], overlap_size): + row = [] + for j in range(0, z.shape[3], overlap_size): + tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size] + tile = self.post_quant_conv(tile) + decoded = self.decoder(tile) + row.append(decoded) + rows.append(row) + result_rows = [] + for i, row in enumerate(rows): + result_row = [] + for j, tile in enumerate(row): + # blend the above tile and the left tile + # to the current tile and add the current tile to the result row + if i > 0: + tile = self.blend_v(rows[i - 1][j], tile, blend_extent) + if j > 0: + tile = self.blend_h(row[j - 1], tile, blend_extent) + result_row.append(tile[:, :, :row_limit, :row_limit]) + result_rows.append(torch.cat(result_row, dim=3)) + + dec = torch.cat(result_rows, dim=2) + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) + + def forward( + self, + sample: torch.FloatTensor, + sample_posterior: bool = False, + return_dict: bool = True, + generator: Optional[torch.Generator] = None, + ) -> Union[DecoderOutput, torch.FloatTensor]: + r""" + Args: + sample (`torch.FloatTensor`): Input sample. + sample_posterior (`bool`, *optional*, defaults to `False`): + Whether to sample from the posterior. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`DecoderOutput`] instead of a plain tuple. + """ + x = sample + posterior = self.encode(x).latent_dist + if sample_posterior: + z = posterior.sample(generator=generator) + else: + z = posterior.mode() + dec = self.decode(z).sample + + if not return_dict: + return (dec,) + + return DecoderOutput(sample=dec) diff --git a/vsr/models/clip.py b/vsr/models/clip.py new file mode 100644 index 0000000000000000000000000000000000000000..13ec05118f729cc27e6b7b4107a300ecfbdf2d77 --- /dev/null +++ b/vsr/models/clip.py @@ -0,0 +1,127 @@ +import numpy +import torch.nn as nn +from transformers import CLIPTokenizer, CLIPTextModel +from diffusers import StableDiffusionUpscalePipeline + +""" +Will encounter following warning: +- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task +or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model). +- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model +that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model). + +https://github.com/CompVis/stable-diffusion/issues/97 +according to this issue, this warning is safe. + +This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion. +You can safely ignore the warning, it is not an error. + +This clip usage is from U-ViT and same with Stable Diffusion. +""" + +class AbstractEncoder(nn.Module): + def __init__(self): + super().__init__() + + def encode(self, *args, **kwargs): + raise NotImplementedError + + +class FrozenCLIPEmbedder(AbstractEncoder): + """Uses the CLIP transformer encoder for text (from Hugging Face)""" + def __init__(self, device="cuda", max_length=77): + super().__init__() + # self.tokenizer = CLIPTokenizer.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K') + # self.text_encoder = CLIPTextModel.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K') + # TBD: change to https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/config.json + # model_id = "stabilityai/stable-diffusion-x4-upscaler" # For VSR + # upscale_pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id) + upscale_pipeline = StableDiffusionUpscalePipeline.from_pretrained('./pretrained_models/upscaler4x') + self.tokenizer = upscale_pipeline.tokenizer + self.text_encoder = upscale_pipeline.text_encoder + self.device = device + self.max_length = max_length + self.freeze() + + def freeze(self): + self.text_encoder = self.text_encoder.eval() + for param in self.parameters(): + param.requires_grad = False + + def forward(self, text): + batch_encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_length, + return_tensors="pt") + tokens = batch_encoding["input_ids"].to(self.device) + outputs = self.text_encoder(input_ids=tokens) + + # return outputs.last_hidden_state + return outputs[0] + + def encode(self, text): + return self(text) + + +class TextEmbedder(nn.Module): + """ + Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance. + """ + def __init__(self, dropout_prob=0.1): + super().__init__() + self.text_encodder = FrozenCLIPEmbedder() + self.dropout_prob = dropout_prob + + def token_drop(self, text_prompts, force_drop_ids=None): + """ + Drops text to enable classifier-free guidance. + """ + if force_drop_ids is None: + drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob + else: + drop_ids = force_drop_ids == 1 + labels = list(numpy.where(drop_ids, "None", text_prompts)) + # print(labels) + return labels + + def forward(self, text_prompts, train, force_drop_ids=None): + use_dropout = self.dropout_prob > 0 + if (train and use_dropout) or (force_drop_ids is not None): + text_prompts = self.token_drop(text_prompts, force_drop_ids) + embeddings = self.text_encodder(text_prompts) + return embeddings + + +if __name__ == '__main__': + + r""" + Returns: + + Examples from CLIPTextModel: + + ```python + >>> from transformers import AutoTokenizer, CLIPTextModel + + >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32") + >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32") + + >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt") + + >>> outputs = model(**inputs) + >>> last_hidden_state = outputs.last_hidden_state + >>> pooled_output = outputs.pooler_output # pooled (EOS token) states + ```""" + + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + + text_encoder = TextEmbedder(dropout_prob=0.00001).to(device) + text_encoder1 = FrozenCLIPEmbedder().to(device) + + text_prompt = ["a photo of a cat", "a photo of a dog", 'a photo of a dog human'] + # text_prompt = ('None', 'None', 'None') + output = text_encoder(text_prompts=text_prompt, train=True) + output1 = text_encoder1(text_prompt) + # print(output) + print(output.shape) + print(output1.shape) + print((output==output1).all()) \ No newline at end of file diff --git a/vsr/models/diffusers_attention.py b/vsr/models/diffusers_attention.py new file mode 100644 index 0000000000000000000000000000000000000000..c104074eeab835cec0a9ce95f1b5753d05a67566 --- /dev/null +++ b/vsr/models/diffusers_attention.py @@ -0,0 +1,983 @@ +# Copyright 2022 The HuggingFace Team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import math +from dataclasses import dataclass +from typing import Optional + +import torch +import torch.nn.functional as F +from torch import nn + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.embeddings import ImagePositionalEmbeddings +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward, AdaLayerNorm +from diffusers.models.attention_processor import Attention + + +@dataclass +class Transformer2DModelOutput(BaseOutput): + """ + Args: + sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete): + Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions + for the unnoised latent pixels. + """ + + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class Transformer2DModel(ModelMixin, ConfigMixin): + """ + Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual + embeddings) inputs. + + When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard + transformer action. Finally, reshape to image. + + When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional + embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict + classes of unnoised image. + + Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised + image do not contain a prediction for the masked pixel as the unnoised image cannot be masked. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + @register_to_config + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + use_linear_projection: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)` + # Define whether input is continuous or discrete depending on configuration + self.is_input_continuous = in_channels is not None + self.is_input_vectorized = num_vector_embeds is not None + + if self.is_input_continuous and self.is_input_vectorized: + raise ValueError( + f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is None." + ) + elif not self.is_input_continuous and not self.is_input_vectorized: + raise ValueError( + f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make" + " sure that either `in_channels` or `num_vector_embeds` is not None." + ) + + # 2. Define input layers + if self.is_input_continuous: + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size" + assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed" + + self.height = sample_size + self.width = sample_size + self.num_vector_embeds = num_vector_embeds + self.num_latent_pixels = self.height * self.width + + self.latent_image_embedding = ImagePositionalEmbeddings( + num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width + ) + + # 3. Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + BasicTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if self.is_input_continuous: + if use_linear_projection: + self.proj_out = nn.Linear(in_channels, inner_dim) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + elif self.is_input_vectorized: + self.norm_out = nn.LayerNorm(inner_dim) + self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + # 1. Input + if self.is_input_continuous: + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + elif self.is_input_vectorized: + hidden_states = self.latent_image_embedding(hidden_states) + + # 2. Blocks + for block in self.transformer_blocks: + hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep) + + # 3. Output + if self.is_input_continuous: + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + elif self.is_input_vectorized: + hidden_states = self.norm_out(hidden_states) + logits = self.out(hidden_states) + # (batch, self.num_vector_embeds - 1, self.num_latent_pixels) + logits = logits.permute(0, 2, 1) + + # log(p(x_0)) + output = F.log_softmax(logits.double(), dim=1).float() + + if not return_dict: + return (output,) + + return Transformer2DModelOutput(sample=output) + + +class AttentionBlock(nn.Module): + """ + An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted + to the N-d case. + https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66. + Uses three q, k, v linear layers to compute attention. + + Parameters: + channels (`int`): The number of channels in the input and output. + num_head_channels (`int`, *optional*): + The number of channels in each head. If None, then `num_heads` = 1. + norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm. + rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by. + eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm. + """ + + # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore + + def __init__( + self, + channels: int, + num_head_channels: Optional[int] = None, + norm_num_groups: int = 32, + rescale_output_factor: float = 1.0, + eps: float = 1e-5, + ): + super().__init__() + self.channels = channels + + self.num_heads = channels // num_head_channels if num_head_channels is not None else 1 + self.num_head_size = num_head_channels + self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True) + + # define q,k,v as linear layers + self.query = nn.Linear(channels, channels) + self.key = nn.Linear(channels, channels) + self.value = nn.Linear(channels, channels) + + self.rescale_output_factor = rescale_output_factor + self.proj_attn = nn.Linear(channels, channels, 1) + + self._use_memory_efficient_attention_xformers = False + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None): + if not is_xformers_available(): + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.num_heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def forward(self, hidden_states): + residual = hidden_states + batch, channel, height, width = hidden_states.shape + + # norm + hidden_states = self.group_norm(hidden_states) + + hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2) + + # proj to q, k, v + query_proj = self.query(hidden_states) + key_proj = self.key(hidden_states) + value_proj = self.value(hidden_states) + + scale = 1 / math.sqrt(self.channels / self.num_heads) + + query_proj = self.reshape_heads_to_batch_dim(query_proj) + key_proj = self.reshape_heads_to_batch_dim(key_proj) + value_proj = self.reshape_heads_to_batch_dim(value_proj) + + if self._use_memory_efficient_attention_xformers: + # Memory efficient attention + hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None) + hidden_states = hidden_states.to(query_proj.dtype) + else: + attention_scores = torch.baddbmm( + torch.empty( + query_proj.shape[0], + query_proj.shape[1], + key_proj.shape[1], + dtype=query_proj.dtype, + device=query_proj.device, + ), + query_proj, + key_proj.transpose(-1, -2), + beta=0, + alpha=scale, + ) + attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype) + hidden_states = torch.bmm(attention_probs, value_proj) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + + # compute next hidden_states + hidden_states = self.proj_attn(hidden_states) + + hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width) + + # res connect and rescale + hidden_states = (hidden_states + residual) / self.rescale_output_factor + return hidden_states + + +class BasicTransformerBlock(nn.Module): + r""" + A basic Transformer block. + + Parameters: + dim (`int`): The number of channels in the input and output. + num_attention_heads (`int`): The number of heads to use for multi-head attention. + attention_head_dim (`int`): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm (: + obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`. + attention_bias (: + obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter. + """ + + def __init__( + self, + dim: int, + num_attention_heads: int, + attention_head_dim: int, + dropout=0.0, + cross_attention_dim: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + attention_bias: bool = False, + only_cross_attention: bool = False, + upcast_attention: bool = False, + ): + super().__init__() + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + + # 1. Self-Attn + self.attn1 = CrossAttention( + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + cross_attention_dim=cross_attention_dim if only_cross_attention else None, + upcast_attention=upcast_attention, + ) # is a self-attention + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + + # 2. Cross-Attn + if cross_attention_dim is not None: + self.attn2 = CrossAttention( + query_dim=dim, + cross_attention_dim=cross_attention_dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + ) # is self-attn if encoder_hidden_states is none + else: + self.attn2 = None + + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + if cross_attention_dim is not None: + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + else: + self.norm2 = None + + # 3. Feed-forward + self.norm3 = nn.LayerNorm(dim) + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + if self.attn2 is not None: + self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None): + # 1. Self-Attention + norm_hidden_states = ( + self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + ) + + if self.only_cross_attention: + hidden_states = ( + self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states + ) + else: + hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states + + if self.attn2 is not None: + # 2. Cross-Attention + norm_hidden_states = ( + self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + ) + hidden_states = ( + self.attn2( + norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + + hidden_states + ) + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + +class CrossAttention(nn.Module): + r""" + A cross attention layer. + + Parameters: + query_dim (`int`): The number of channels in the query. + cross_attention_dim (`int`, *optional*): + The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`. + heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention. + dim_head (`int`, *optional*, defaults to 64): The number of channels in each head. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + bias (`bool`, *optional*, defaults to False): + Set to `True` for the query, key, and value linear layers to contain a bias parameter. + """ + + def __init__( + self, + query_dim: int, + cross_attention_dim: Optional[int] = None, + heads: int = 8, + dim_head: int = 64, + dropout: float = 0.0, + bias=False, + upcast_attention: bool = False, + upcast_softmax: bool = False, + added_kv_proj_dim: Optional[int] = None, + norm_num_groups: Optional[int] = None, + ): + super().__init__() + inner_dim = dim_head * heads + cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim + self.upcast_attention = upcast_attention + self.upcast_softmax = upcast_softmax + + self.scale = dim_head**-0.5 + + self.heads = heads + # for slice_size > 0 the attention score computation + # is split across the batch axis to save memory + # You can set slice_size with `set_attention_slice` + self.sliceable_head_dim = heads + self._slice_size = None + self._use_memory_efficient_attention_xformers = False + self.added_kv_proj_dim = added_kv_proj_dim + + if norm_num_groups is not None: + self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True) + else: + self.group_norm = None + + self.to_q = nn.Linear(query_dim, inner_dim, bias=bias) + self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias) + + if self.added_kv_proj_dim is not None: + self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim) + + self.to_out = nn.ModuleList([]) + self.to_out.append(nn.Linear(inner_dim, query_dim)) + self.to_out.append(nn.Dropout(dropout)) + + def reshape_heads_to_batch_dim(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size) + return tensor + + def reshape_batch_dim_to_heads(self, tensor): + batch_size, seq_len, dim = tensor.shape + head_size = self.heads + tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim) + tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size) + return tensor + + def set_attention_slice(self, slice_size): + if slice_size is not None and slice_size > self.sliceable_head_dim: + raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.") + + self._slice_size = slice_size + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None): + batch_size, sequence_length, _ = hidden_states.shape + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + key = self.to_k(hidden_states) + value = self.to_v(hidden_states) + encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states) + encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj) + encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj) + + key = torch.concat([encoder_hidden_states_key_proj, key], dim=1) + value = torch.concat([encoder_hidden_states_value_proj, value], dim=1) + else: + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + return hidden_states + + def _attention(self, query, key, value, attention_mask=None): + if self.upcast_attention: + query = query.float() + key = key.float() + + attention_scores = torch.baddbmm( + torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device), + query, + key.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attention_scores = attention_scores + attention_mask + + if self.upcast_softmax: + attention_scores = attention_scores.float() + + attention_probs = attention_scores.softmax(dim=-1) + + # cast back to the original dtype + attention_probs = attention_probs.to(value.dtype) + + # compute attention output + hidden_states = torch.bmm(attention_probs, value) + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask): + batch_size_attention = query.shape[0] + hidden_states = torch.zeros( + (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype + ) + slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0] + for i in range(hidden_states.shape[0] // slice_size): + start_idx = i * slice_size + end_idx = (i + 1) * slice_size + + query_slice = query[start_idx:end_idx] + key_slice = key[start_idx:end_idx] + + if self.upcast_attention: + query_slice = query_slice.float() + key_slice = key_slice.float() + + attn_slice = torch.baddbmm( + torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device), + query_slice, + key_slice.transpose(-1, -2), + beta=0, + alpha=self.scale, + ) + + if attention_mask is not None: + attn_slice = attn_slice + attention_mask[start_idx:end_idx] + + if self.upcast_softmax: + attn_slice = attn_slice.float() + + attn_slice = attn_slice.softmax(dim=-1) + + # cast back to the original dtype + attn_slice = attn_slice.to(value.dtype) + attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx]) + + hidden_states[start_idx:end_idx] = attn_slice + + # reshape hidden_states + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + def _memory_efficient_attention_xformers(self, query, key, value, attention_mask): + query = query.contiguous() + key = key.contiguous() + value = value.contiguous() + hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask) + hidden_states = self.reshape_batch_dim_to_heads(hidden_states) + return hidden_states + + +class FeedForward(nn.Module): + r""" + A feed-forward layer. + + Parameters: + dim (`int`): The number of channels in the input. + dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`. + mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension. + dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + """ + + def __init__( + self, + dim: int, + dim_out: Optional[int] = None, + mult: int = 4, + dropout: float = 0.0, + activation_fn: str = "geglu", + ): + super().__init__() + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim) + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out)) + + def forward(self, hidden_states): + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states + + +class GELU(nn.Module): + r""" + GELU activation function + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +# feedforward +class GEGLU(nn.Module): + r""" + A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202. + + Parameters: + dim_in (`int`): The number of channels in the input. + dim_out (`int`): The number of channels in the output. + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2) + + def gelu(self, gate): + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + """ + The approximate form of Gaussian Error Linear Unit (GELU) + + For more details, see section 2: https://arxiv.org/abs/1606.08415 + """ + + def __init__(self, dim_in: int, dim_out: int): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out) + + def forward(self, x): + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x, timestep): + emb = self.linear(self.silu(self.emb(timestep))) + scale, shift = torch.chunk(emb, 2) + x = self.norm(x) * (1 + scale) + shift + return x + + +class DualTransformer2DModel(nn.Module): + """ + Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference. + + Parameters: + num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention. + attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head. + in_channels (`int`, *optional*): + Pass if the input is continuous. The number of channels in the input and output. + num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use. + dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use. + cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use. + sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images. + Note that this is fixed at training time as it is used for learning a number of position embeddings. See + `ImagePositionalEmbeddings`. + num_vector_embeds (`int`, *optional*): + Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels. + Includes the class for the masked latent pixel. + activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward. + num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`. + The number of diffusion steps used during training. Note that this is fixed at training time as it is used + to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for + up to but not more than steps than `num_embeds_ada_norm`. + attention_bias (`bool`, *optional*): + Configure if the TransformerBlocks' attention should contain a bias parameter. + """ + + def __init__( + self, + num_attention_heads: int = 16, + attention_head_dim: int = 88, + in_channels: Optional[int] = None, + num_layers: int = 1, + dropout: float = 0.0, + norm_num_groups: int = 32, + cross_attention_dim: Optional[int] = None, + attention_bias: bool = False, + sample_size: Optional[int] = None, + num_vector_embeds: Optional[int] = None, + activation_fn: str = "geglu", + num_embeds_ada_norm: Optional[int] = None, + ): + super().__init__() + self.transformers = nn.ModuleList( + [ + Transformer2DModel( + num_attention_heads=num_attention_heads, + attention_head_dim=attention_head_dim, + in_channels=in_channels, + num_layers=num_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + sample_size=sample_size, + num_vector_embeds=num_vector_embeds, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + ) + for _ in range(2) + ] + ) + + # Variables that can be set by a pipeline: + + # The ratio of transformer1 to transformer2's output states to be combined during inference + self.mix_ratio = 0.5 + + # The shape of `encoder_hidden_states` is expected to be + # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)` + self.condition_lengths = [77, 257] + + # Which transformer to use to encode which condition. + # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])` + self.transformer_index_for_condition = [1, 0] + + def forward( + self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True + ): + """ + Args: + hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`. + When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input + hidden_states + encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*): + Conditional embeddings for cross attention layer. If not given, cross-attention defaults to + self-attention. + timestep ( `torch.long`, *optional*): + Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step. + attention_mask (`torch.FloatTensor`, *optional*): + Optional attention mask to be applied in CrossAttention + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`] + if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample + tensor. + """ + input_states = hidden_states + + encoded_states = [] + tokens_start = 0 + # attention_mask is not used yet + for i in range(2): + # for each of the two transformers, pass the corresponding condition tokens + condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]] + transformer_index = self.transformer_index_for_condition[i] + encoded_state = self.transformers[transformer_index]( + input_states, + encoder_hidden_states=condition_state, + timestep=timestep, + return_dict=False, + )[0] + encoded_states.append(encoded_state - input_states) + tokens_start += self.condition_lengths[i] + + output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio) + output_states = output_states + input_states + + if not return_dict: + return (output_states,) + + return Transformer2DModelOutput(sample=output_states) \ No newline at end of file diff --git a/vsr/models/pipeline_stable_diffusion_upscale_video_3d.py b/vsr/models/pipeline_stable_diffusion_upscale_video_3d.py new file mode 100644 index 0000000000000000000000000000000000000000..e83ca6bf4c7cfa75c260e3d1cc088f8bdc5eb240 --- /dev/null +++ b/vsr/models/pipeline_stable_diffusion_upscale_video_3d.py @@ -0,0 +1,780 @@ + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from typing import Any, Callable, List, Optional, Union + +import numpy as np +import math +import PIL +import torch +import torch.nn.functional as F +from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer + +from diffusers.loaders import TextualInversionLoaderMixin +from diffusers.models import AutoencoderKL, UNet2DConditionModel +from diffusers.schedulers import DDPMScheduler +# from diffusers.schedulers import DDIMScheduler +from diffusion.scheduling_ddim import DDIMScheduler + +from diffusers.utils import deprecate, is_accelerate_available, is_accelerate_version, logging + +try: + from diffusers.utils import randn_tensor +except: + from diffusers.utils.torch_utils import randn_tensor + +from diffusers.pipelines.pipeline_utils import DiffusionPipeline +from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput + +from einops import rearrange + +# from datasets.data_utils import filter2D +# from datasets.degradations import random_mixed_kernels, bivariate_Gaussian + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + + +def preprocess(image): + if isinstance(image, torch.Tensor): + return image + elif isinstance(image, PIL.Image.Image): + image = [image] + + if isinstance(image[0], PIL.Image.Image): + w, h = image[0].size + w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64 + + image = [np.array(i.resize((w, h)))[None, :] for i in image] + image = np.concatenate(image, axis=0) + image = np.array(image).astype(np.float32) / 255.0 + image = image.transpose(0, 3, 1, 2) + image = 2.0 * image - 1.0 + image = torch.from_numpy(image) + elif isinstance(image[0], torch.Tensor): + image = torch.cat(image, dim=0) + return image + + +class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin): + _optional_components = ["feature_extractor"] + + def __init__( + self, + vae: AutoencoderKL, + text_encoder: CLIPTextModel, + tokenizer: CLIPTokenizer, + unet: UNet2DConditionModel, + low_res_scheduler: DDPMScheduler, + # scheduler: KarrasDiffusionSchedulers, + scheduler: DDIMScheduler, + feature_extractor: Optional[CLIPImageProcessor] = None, + max_noise_level: int = 350, + ): + super().__init__() + + if hasattr( + vae, "config" + ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate + is_vae_scaling_factor_set_to_0_08333 = ( + hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333 + ) + if not is_vae_scaling_factor_set_to_0_08333: + deprecation_message = ( + "The configuration file of the vae does not contain `scaling_factor` or it is set to" + f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned" + " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to" + " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to" + " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging" + " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file" + ) + deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False) + vae.register_to_config(scaling_factor=0.08333) + # TODO: remove + print(f'=============vae.config.scaling_factor: {vae.config.scaling_factor}==================') + + self.register_modules( + vae=vae, + text_encoder=text_encoder, + tokenizer=tokenizer, + unet=unet, + low_res_scheduler=low_res_scheduler, + scheduler=scheduler, + feature_extractor=feature_extractor, + ) + self.register_to_config(max_noise_level=max_noise_level) + + def enable_sequential_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet, + text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a + `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called. + """ + if is_accelerate_available(): + from accelerate import cpu_offload + else: + raise ImportError("Please install accelerate via `pip install accelerate`") + + device = torch.device(f"cuda:{gpu_id}") + + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + cpu_offload(cpu_offloaded_model, device) + + def enable_model_cpu_offload(self, gpu_id=0): + r""" + Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared + to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward` + method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with + `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`. + """ + if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"): + from accelerate import cpu_offload_with_hook + else: + raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.") + + device = torch.device(f"cuda:{gpu_id}") + + if self.device.type != "cpu": + self.to("cpu", silence_dtype_warnings=True) + torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist) + + hook = None + for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]: + if cpu_offloaded_model is not None: + _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook) + + # We'll offload the last model manually. + self.final_offload_hook = hook + + @property + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device + def _execution_device(self): + r""" + Returns the device on which the pipeline's models will be executed. After calling + `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module + hooks. + """ + if not hasattr(self.unet, "_hf_hook"): + return self.device + for module in self.unet.modules(): + if ( + hasattr(module, "_hf_hook") + and hasattr(module._hf_hook, "execution_device") + and module._hf_hook.execution_device is not None + ): + return torch.device(module._hf_hook.execution_device) + return self.device + + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt + def _encode_prompt( + self, + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt=None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + ): + r""" + Encodes the prompt into text encoder hidden states. + + Args: + prompt (`str` or `List[str]`, *optional*): + prompt to be encoded + device: (`torch.device`): + torch device + num_images_per_prompt (`int`): + number of images that should be generated per prompt + do_classifier_free_guidance (`bool`): + whether to use classifier free guidance or not + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is + less than `1`). + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + """ + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + if prompt_embeds is None: + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + prompt = self.maybe_convert_prompt(prompt, self.tokenizer) + + text_inputs = self.tokenizer( + prompt, + padding="max_length", + max_length=self.tokenizer.model_max_length, + truncation=True, + return_tensors="pt", + ) + text_input_ids = text_inputs.input_ids + untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids + + if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal( + text_input_ids, untruncated_ids + ): + removed_text = self.tokenizer.batch_decode( + untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1] + ) + logger.warning( + "The following part of your input was truncated because CLIP can only handle sequences up to" + f" {self.tokenizer.model_max_length} tokens: {removed_text}" + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = text_inputs.attention_mask.to(device) + else: + attention_mask = None + + prompt_embeds = self.text_encoder( + text_input_ids.to(device), + attention_mask=attention_mask, + ) + prompt_embeds = prompt_embeds[0] + + prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + bs_embed, seq_len, _ = prompt_embeds.shape + # duplicate text embeddings for each generation per prompt, using mps friendly method + prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1) + prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1) + + # get unconditional embeddings for classifier free guidance + if do_classifier_free_guidance and negative_prompt_embeds is None: + uncond_tokens: List[str] + if negative_prompt is None: + uncond_tokens = [""] * batch_size + elif type(prompt) is not type(negative_prompt): + raise TypeError( + f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" + f" {type(prompt)}." + ) + elif isinstance(negative_prompt, str): + uncond_tokens = [negative_prompt] + elif batch_size != len(negative_prompt): + raise ValueError( + f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" + f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" + " the batch size of `prompt`." + ) + else: + uncond_tokens = negative_prompt + + # textual inversion: procecss multi-vector tokens if necessary + if isinstance(self, TextualInversionLoaderMixin): + uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer) + + max_length = prompt_embeds.shape[1] + uncond_input = self.tokenizer( + uncond_tokens, + padding="max_length", + max_length=max_length, + truncation=True, + return_tensors="pt", + ) + + if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask: + attention_mask = uncond_input.attention_mask.to(device) + else: + attention_mask = None + + negative_prompt_embeds = self.text_encoder( + uncond_input.input_ids.to(device), + attention_mask=attention_mask, + ) + negative_prompt_embeds = negative_prompt_embeds[0] + + if do_classifier_free_guidance: + # duplicate unconditional embeddings for each generation per prompt, using mps friendly method + seq_len = negative_prompt_embeds.shape[1] + + negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device) + + negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1) + negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1) + + # For classifier free guidance, we need to do two forward passes. + # Here we concatenate the unconditional and text embeddings into a single batch + # to avoid doing two forward passes + prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds]) + + return prompt_embeds + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs + def prepare_extra_step_kwargs(self, generator, eta): + # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature + # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers. + # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502 + # and should be between [0, 1] + + accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys()) + extra_step_kwargs = {} + if accepts_eta: + extra_step_kwargs["eta"] = eta + + # check if the scheduler accepts generator + accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys()) + if accepts_generator: + extra_step_kwargs["generator"] = generator + return extra_step_kwargs + + # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents + def decode_latents(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = (image / 2 + 0.5).clamp(0, 1) + # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16 + image = image.cpu().permute(0, 2, 3, 1).float().numpy() + return image + + def decode_latents_vsr(self, latents): + latents = 1 / self.vae.config.scaling_factor * latents + image = self.vae.decode(latents).sample + image = image.clamp(-1, 1).cpu() + return image + + def check_inputs( + self, + prompt, + image, + noise_level, + callback_steps, + negative_prompt=None, + prompt_embeds=None, + negative_prompt_embeds=None, + ): + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + if prompt is not None and prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" + " only forward one of the two." + ) + elif prompt is None and prompt_embeds is None: + raise ValueError( + "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." + ) + elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): + raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") + + if negative_prompt is not None and negative_prompt_embeds is not None: + raise ValueError( + f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:" + f" {negative_prompt_embeds}. Please make sure to only forward one of the two." + ) + + if prompt_embeds is not None and negative_prompt_embeds is not None: + if prompt_embeds.shape != negative_prompt_embeds.shape: + raise ValueError( + "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but" + f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`" + f" {negative_prompt_embeds.shape}." + ) + + if ( + not isinstance(image, torch.Tensor) + and not isinstance(image, PIL.Image.Image) + and not isinstance(image, list) + ): + raise ValueError( + f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}" + ) + + # verify batch size of prompt and image are same if image is a list or tensor + if isinstance(image, list) or isinstance(image, torch.Tensor): + if isinstance(prompt, str): + batch_size = 1 + else: + batch_size = len(prompt) + if isinstance(image, list): + image_batch_size = len(image) + else: + image_batch_size = image.shape[0] + if batch_size != image_batch_size: + raise ValueError( + f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}." + " Please make sure that passed `prompt` matches the batch size of `image`." + ) + + # check noise level + if noise_level > self.config.max_noise_level: + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}") + + if (callback_steps is None) or ( + callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0) + ): + raise ValueError( + f"`callback_steps` has to be a positive integer but is {callback_steps} of type" + f" {type(callback_steps)}." + ) + + def prepare_latents_3d(self, batch_size, num_channels_latents, seq_len, height, width, dtype, device, generator, latents=None): + shape = (batch_size, num_channels_latents, seq_len, height, width) + if latents is None: + latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + else: + if latents.shape != shape: + raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}") + latents = latents.to(device) + + # scale the initial noise by the standard deviation required by the scheduler + latents = latents * self.scheduler.init_noise_sigma + return latents + + def get_timesteps(self, num_inference_steps, strength, device): + # get the original timestep using init_timestep + init_timestep = min(int(num_inference_steps * strength), num_inference_steps) + t_start = max(num_inference_steps - init_timestep, 0) + timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :] + + return timesteps, num_inference_steps - t_start + + def prepare_latents_inversion(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None): + + image = image.to(device=device, dtype=dtype) + batch_size = batch_size * num_images_per_prompt + + b = image.shape[0] + image = rearrange(image, 'b c t h w -> (b t) c h w').contiguous() + image = F.interpolate(image, scale_factor=4, mode='bicubic') + image = image.to(dtype=torch.float32) + init_latents = self.vae.encode(image).latent_dist.sample(generator) + torch.cuda.empty_cache() + init_latents = rearrange(init_latents, '(b t) c h w -> b c t h w', b=b).contiguous() + + init_latents = self.vae.config.scaling_factor * init_latents + init_latents = init_latents.to(dtype=torch.float16) + + # add noise + shape = init_latents.shape + noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype) + # get latents + init_latents = self.scheduler.add_noise(init_latents, noise, timestep) + # DEBUG + # init_latents = noise + print('timestep', timestep) + + # scale the initial noise by the standard deviation required by the scheduler + latents = init_latents * self.scheduler.init_noise_sigma + return latents + + @torch.no_grad() + def __call__( + self, + prompt: Union[str, List[str]] = None, + image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None, + num_inference_steps: int = 75, + guidance_scale: float = 9.0, + noise_level: int = 20, + negative_prompt: Optional[Union[str, List[str]]] = None, + num_images_per_prompt: Optional[int] = 1, + eta: float = 0.0, + generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, + latents: Optional[torch.FloatTensor] = None, + prompt_embeds: Optional[torch.FloatTensor] = None, + negative_prompt_embeds: Optional[torch.FloatTensor] = None, + return_dict: bool = True, + callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None, + callback_steps: int = 1, + ): + r""" + Function invoked when calling the pipeline for generation. + + Args: + prompt (`str` or `List[str]`, *optional*): + The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. + instead. + image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`): + `Image`, or tensor representing an image batch which will be upscaled. * + num_inference_steps (`int`, *optional*, defaults to 50): + The number of denoising steps. More denoising steps usually lead to a higher quality image at the + expense of slower inference. + guidance_scale (`float`, *optional*, defaults to 7.5): + Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598). + `guidance_scale` is defined as `w` of equation 2. of [Imagen + Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale > + 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`, + usually at the expense of lower image quality. + negative_prompt (`str` or `List[str]`, *optional*): + The prompt or prompts not to guide the image generation. If not defined, one has to pass + `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` + is less than `1`). + num_images_per_prompt (`int`, *optional*, defaults to 1): + The number of images to generate per prompt. + eta (`float`, *optional*, defaults to 0.0): + Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to + [`schedulers.DDIMScheduler`], will be ignored for others. + generator (`torch.Generator`, *optional*): + One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html) + to make generation deterministic. + latents (`torch.FloatTensor`, *optional*): + Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image + generation. Can be used to tweak the same generation with different prompts. If not provided, a latents + tensor will ge generated by sampling using the supplied random `generator`. + prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not + provided, text embeddings will be generated from `prompt` input argument. + negative_prompt_embeds (`torch.FloatTensor`, *optional*): + Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt + weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input + argument. + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a + plain tuple. + callback (`Callable`, *optional*): + A function that will be called every `callback_steps` steps during inference. The function will be + called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`. + callback_steps (`int`, *optional*, defaults to 1): + The frequency at which the `callback` function will be called. If not specified, the callback will be + called at every step. + + Examples: + ```py + >>> import requests + >>> from PIL import Image + >>> from io import BytesIO + >>> from diffusers import StableDiffusionUpscalePipeline + >>> import torch + + >>> # load model and scheduler + >>> model_id = "stabilityai/stable-diffusion-x4-upscaler" + >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained( + ... model_id, revision="fp16", torch_dtype=torch.float16 + ... ) + >>> pipeline = pipeline.to("cuda") + + >>> # let's download an image + >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png" + >>> response = requests.get(url) + >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB") + >>> low_res_img = low_res_img.resize((128, 128)) + >>> prompt = "a white cat" + + >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0] + >>> upscaled_image.save("upsampled_cat.png") + ``` + """ + + # 1. Check inputs + self.check_inputs( + prompt, + image, + noise_level, + callback_steps, + negative_prompt, + prompt_embeds, + negative_prompt_embeds, + ) + + if image is None: + raise ValueError("`image` input cannot be undefined.") + + # 2. Define call parameters + if prompt is not None and isinstance(prompt, str): + batch_size = 1 + elif prompt is not None and isinstance(prompt, list): + batch_size = len(prompt) + else: + batch_size = prompt_embeds.shape[0] + + device = self._execution_device + # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2) + # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1` + # corresponds to doing no classifier free guidance. + do_classifier_free_guidance = guidance_scale > 1.0 + + # 3. Encode input prompt + prompt_embeds = self._encode_prompt( + prompt, + device, + num_images_per_prompt, + do_classifier_free_guidance, + negative_prompt, + prompt_embeds=prompt_embeds, + negative_prompt_embeds=negative_prompt_embeds, + ) + + # 4. Preprocess image + # image = preprocess(image) + image = image.to(dtype=prompt_embeds.dtype, device=device) + + # 5. Add noise to image + noise_level = torch.tensor([noise_level], dtype=torch.long, device=device) + noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype) + image = self.low_res_scheduler.add_noise(image, noise, noise_level) + # image = image.clamp(-1, 1) + + # debug + # image = rearrange(image, 'b c t h w -> (b t) c h w').contiguous().cpu() + # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) + + batch_multiplier = 2 if do_classifier_free_guidance else 1 + image = torch.cat([image] * batch_multiplier * num_images_per_prompt) + # TODO: + # noise_level = noise_level*0 + noise_level = torch.cat([noise_level] * image.shape[0]) + + ####################### Random Noise ######################## + # 5. set timesteps + self.scheduler.set_timesteps(num_inference_steps, device=device) + timesteps = self.scheduler.timesteps + + # 6. Prepare latent variables + seq_len, height, width = image.shape[2:] + # TODO: for downsample_2x + # height, width = height//2, width//2 + num_channels_latents = self.vae.config.latent_channels + latents = self.prepare_latents_3d( + batch_size * num_images_per_prompt, + num_channels_latents, + seq_len, + height, + width, + prompt_embeds.dtype, + device, + generator, + latents, + ) # b c t h w + # print('latents', latents.shape) + + ####################### Random Noise + Latent ######################## + # # 5. Prepare timesteps + # self.scheduler.set_timesteps(num_inference_steps, device=device) + # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength=1, device=device) + # # DEBUG + # # timesteps = self.scheduler.timesteps + # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt) + + # # 6. Prepare latent variables + # # b c t h w + # b = image.shape[0] + # num_channels_latents = self.vae.config.latent_channels + # latents = self.prepare_latents_inversion( + # image[:b//2], + # latent_timestep, + # batch_size, + # num_images_per_prompt, + # prompt_embeds.dtype, + # device, + # generator, + # ) + # print('latents', latents.shape) + + # 7. Check that sizes of image and latents match + num_channels_image = image.shape[1] + if num_channels_latents + num_channels_image != self.unet.config.in_channels: + raise ValueError( + f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects" + f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +" + f" `num_channels_image`: {num_channels_image} " + f" = {num_channels_latents+num_channels_image}. Please verify the config of" + " `pipeline.unet` or your `image` input." + ) + + # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline + extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) + + # 9. Denoising loop + num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order + with self.progress_bar(total=num_inference_steps) as progress_bar: + for i, t in enumerate(timesteps): + torch.cuda.empty_cache() # delete for VSR + # expand the latents if we are doing classifier free guidance + latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents + + # concat latents, mask, masked_image_latents in the channel dimension + latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) + #latent_model_input = torch.cat([latent_model_input, image], dim=1) + # print(f'========== latent_model_input: {latent_model_input.shape} ============') + # print(f'========== image: {image.shape} ============') + noise_pred = self.unet( + latent_model_input, t, image, encoder_hidden_states=prompt_embeds, class_labels=noise_level + ).sample + + # perform guidance + if do_classifier_free_guidance: + noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) + noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond) + + # compute the previous noisy sample x_t -> x_t-1 + latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample + + # call the callback, if provided + if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): + progress_bar.update() + if callback is not None and i % callback_steps == 0: + callback(i, t, latents) + + del latent_model_input, noise_pred + + + # 10. Post-processing + # make sure the VAE is in float32 mode, as it overflows in float16 + self.vae.to(dtype=torch.float32) + + # TODO(Patrick, William) - clean up when attention is refactored + use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention") + use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers + # if xformers or torch_2_0 is used attention block does not need + # to be in float32 which can save lots of memory + if not use_torch_2_0_attn and not use_xformers: + self.vae.post_quant_conv.to(latents.dtype) + self.vae.decoder.conv_in.to(latents.dtype) + self.vae.decoder.mid_block.to(latents.dtype) + else: + latents = latents.float() + + # 11. Convert to frames + short_seq = 4 + # b c t h w + latents = rearrange(latents, 'b c t h w -> (b t) c h w').contiguous() + if latents.shape[0] > short_seq: # for VSR + image = [] + for start_f in range(0, latents.shape[0], short_seq): + torch.cuda.empty_cache() # delete for VSR + end_f = min(latents.shape[0], start_f + short_seq) + image_ = self.decode_latents_vsr(latents[start_f:end_f]) + image.append(image_) + del image_ + image = torch.cat(image, dim=0) + else: + image = self.decode_latents_vsr(latents) + + # Offload last model to CPU + if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: + self.final_offload_hook.offload() + + if not return_dict: + return (image, None) + + return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None) diff --git a/vsr/models/resnet.py b/vsr/models/resnet.py new file mode 100644 index 0000000000000000000000000000000000000000..490fb323dc43c67acce260a2ff9a7e80f021f76c --- /dev/null +++ b/vsr/models/resnet.py @@ -0,0 +1,316 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) +import torch +import torch.nn as nn +import torch.nn.functional as F + +from einops import rearrange + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + +class Mish(torch.nn.Module): + def forward(self, hidden_states): + return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states)) + +class InflatedConv3d(nn.Conv2d): + def forward(self, x): + video_length = x.shape[2] + + x = rearrange(x, "b c f h w -> (b f) c h w") + x = super().forward(x) + x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length) + + return x + + +class Upsample3D(nn.Module): + def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.use_conv_transpose = use_conv_transpose + self.name = name + + conv = None + if use_conv_transpose: + raise NotImplementedError + elif use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1) + + if name == "conv": + self.conv = conv + else: + self.Conv2d_0 = conv + + def forward(self, hidden_states, output_size=None): + assert hidden_states.shape[1] == self.channels + + if self.use_conv_transpose: + raise NotImplementedError + + # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16 + dtype = hidden_states.dtype + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(torch.float32) + + # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984 + if hidden_states.shape[0] >= 64: + hidden_states = hidden_states.contiguous() + + # if `output_size` is passed we force the interpolation output + # size and do not make use of `scale_factor=2` + if output_size is None: + hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest") + else: + hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest") + + # If the input is bfloat16, we cast back to bfloat16 + if dtype == torch.bfloat16: + hidden_states = hidden_states.to(dtype) + + if self.use_conv: + if self.name == "conv": + hidden_states = self.conv(hidden_states) + else: + hidden_states = self.Conv2d_0(hidden_states) + + return hidden_states + + +class Downsample3D(nn.Module): + def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"): + super().__init__() + self.channels = channels + self.out_channels = out_channels or channels + self.use_conv = use_conv + self.padding = padding + stride = 2 + self.name = name + + if use_conv: + conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding) + else: + raise NotImplementedError + + if name == "conv": + self.Conv2d_0 = conv + self.conv = conv + elif name == "Conv2d_0": + self.conv = conv + else: + self.conv = conv + + def forward(self, hidden_states): + assert hidden_states.shape[1] == self.channels + if self.use_conv and self.padding == 0: + raise NotImplementedError + + assert hidden_states.shape[1] == self.channels + hidden_states = self.conv(hidden_states) + + return hidden_states + + +class ResnetBlock3D(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, input_tensor, temb): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor + + +class ResnetBlock3DCNN(nn.Module): + def __init__( + self, + *, + in_channels, + out_channels=None, + kernel=(3,1,1), + conv_shortcut=False, + dropout=0.0, + temb_channels=512, + groups=32, + groups_out=None, + pre_norm=True, + eps=1e-6, + non_linearity="swish", + time_embedding_norm="default", + output_scale_factor=1.0, + use_in_shortcut=None, + ): + super().__init__() + self.pre_norm = pre_norm + self.pre_norm = True + self.in_channels = in_channels + out_channels = in_channels if out_channels is None else out_channels + self.out_channels = out_channels + self.use_conv_shortcut = conv_shortcut + self.time_embedding_norm = time_embedding_norm + self.output_scale_factor = output_scale_factor + + if groups_out is None: + groups_out = groups + + self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True) + + padding = ((kernel[i]-1)//2 for i in range(3)) + self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=kernel, stride=(1,1,1), padding=padding) + + if temb_channels is not None: + if self.time_embedding_norm == "default": + time_emb_proj_out_channels = out_channels + elif self.time_embedding_norm == "scale_shift": + time_emb_proj_out_channels = out_channels * 2 + else: + raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ") + + self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels) + else: + self.time_emb_proj = None + + self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True) + self.dropout = torch.nn.Dropout(dropout) + self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3,1,1), stride=(1,1,1), padding=(1,0,0)) + + if non_linearity == "swish": + self.nonlinearity = lambda x: F.silu(x) + elif non_linearity == "mish": + self.nonlinearity = Mish() + elif non_linearity == "silu": + self.nonlinearity = nn.SiLU() + + self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut + + self.conv_shortcut = None + if self.use_in_shortcut: + self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=(1,1,1), stride=(1,1,1), padding=(0,0,0)) + + def forward(self, input_tensor, temb=None): + hidden_states = input_tensor + + hidden_states = self.norm1(hidden_states) + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.conv1(hidden_states) + + if temb is not None: + temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None] + + if temb is not None and self.time_embedding_norm == "default": + hidden_states = hidden_states + temb + + hidden_states = self.norm2(hidden_states) + + if temb is not None and self.time_embedding_norm == "scale_shift": + scale, shift = torch.chunk(temb, 2, dim=1) + hidden_states = hidden_states * (1 + scale) + shift + + hidden_states = self.nonlinearity(hidden_states) + + hidden_states = self.dropout(hidden_states) + hidden_states = self.conv2(hidden_states) + + if self.conv_shortcut is not None: + input_tensor = self.conv_shortcut(input_tensor) + + output_tensor = (input_tensor + hidden_states) / self.output_scale_factor + + return output_tensor \ No newline at end of file diff --git a/vsr/models/temporal_module.py b/vsr/models/temporal_module.py new file mode 100644 index 0000000000000000000000000000000000000000..6a8ffb7aa8da232507831bf6fb241340b0f8bc9d --- /dev/null +++ b/vsr/models/temporal_module.py @@ -0,0 +1,684 @@ +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import torch +import numpy as np +import torch.nn.functional as F +from torch import nn +import torchvision +# from torch_utils.ops import grid_sample_gradfix + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.utils import BaseOutput +from diffusers.utils.import_utils import is_xformers_available +from diffusers.models.attention import FeedForward +# from diffusers.models.attention_processor import Attention + +try: + from .diffusers_attention import CrossAttention + from .resnet import Downsample3D, Upsample3D, InflatedConv3d, ResnetBlock3D, ResnetBlock3DCNN + +except: + from diffusers_attention import CrossAttention + from resnet import Downsample3D, Upsample3D, InflatedConv3d, ResnetBlock3D, ResnetBlock3DCNN + +from einops import rearrange, repeat +import math + +import pdb + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def grid_sample_align(input, grid): + return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=True) + + +@dataclass +class TemporalTransformer3DModelOutput(BaseOutput): + sample: torch.FloatTensor + + +if is_xformers_available(): + import xformers + import xformers.ops +else: + xformers = None + + +class EmptyTemporalModule3D(nn.Module): + def __init__(self): + super().__init__() + + def forward(self, hidden_states, condition_video=None, encoder_hidden_states=None, timesteps=None, temb=None, attention_mask=None): + return hidden_states + + +class TemporalModule3D(nn.Module): + def __init__( + self, + in_channels=None, + out_channels=None, + + num_attention_layers=None, + num_attention_head=8, + attention_head_dim=None, + cross_attention_dim=768, + temb_channels=512, + + dropout=0., + attention_bias=False, + activation_fn="geglu", + only_cross_attention=False, + upcast_attention=False, + + norm_num_groups=8, + use_linear_projection=True, + use_scale_shift=False, # set True always produce nan loss, I don't know why + + attention_block_types: Tuple[str]=None, + cross_frame_attention_mode=None, + temporal_shift_fold_div=None, + temporal_shift_direction=None, + + use_dcn_warpping=None, + use_deformable_conv=None, + + attention_dim_div: int = None, + video_condition=False, + ): + super().__init__() + assert len(attention_block_types) == 2 + + self.use_scale_shift = use_scale_shift + self.video_condition = video_condition + + self.non_linearity = nn.SiLU() + + # 1. 3d cnn + if self.video_condition: + video_condition_dim = int(out_channels//4) + self.v_cond_conv = ResnetBlock3D(in_channels=3, out_channels=video_condition_dim, temb_channels=temb_channels, groups=3, groups_out=32) + self.resblocks_3d_t = ResnetBlock3DCNN(in_channels=in_channels+video_condition_dim, out_channels=in_channels, kernel=(5,1,1), temb_channels=temb_channels) + else: + self.resblocks_3d_t = ResnetBlock3DCNN(in_channels=in_channels, out_channels=in_channels, kernel=(5,1,1), temb_channels=temb_channels) + + self.resblocks_3d_s = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, groups=32, groups_out=32) + + # 2. transformer blocks + if not (attention_block_types[0]=='' and attention_block_types[1]==''): + attentions = TemporalTransformer3DModel( + num_attention_heads=num_attention_head, + attention_head_dim=attention_head_dim if attention_head_dim is not None else in_channels // num_attention_head // attention_dim_div, + + in_channels=in_channels, + num_layers=num_attention_layers, + dropout=dropout, + norm_num_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attention_bias=attention_bias, + activation_fn=activation_fn, + num_embeds_ada_norm=1000, # adaptive norm for timestep embedding injection + use_linear_projection=use_linear_projection, + + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_shift_fold_div=temporal_shift_fold_div, + temporal_shift_direction=temporal_shift_direction, + + use_dcn_warpping=use_dcn_warpping, + use_deformable_conv=use_deformable_conv, + ) + self.attentions = nn.ModuleList([attentions]) + + if use_scale_shift: + self.scale_shift_conv = zero_module(InflatedConv3d(in_channels=in_channels, out_channels=in_channels * 2, kernel_size=1, stride=1, padding=0)) + else: + self.shift_conv = zero_module(InflatedConv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1, padding=0)) + + + def forward(self, hidden_states, condition_video=None, encoder_hidden_states=None, timesteps=None, temb=None, attention_mask=None): + input_tensor = hidden_states + + if self.video_condition: + # obtain video attention + assert condition_video is not None + if isinstance(condition_video, dict): + condition_video = condition_video[hidden_states.shape[-1]] + hidden_condition = self.v_cond_conv(condition_video, temb) + hidden_states = torch.cat([hidden_states, hidden_condition], dim=1) + + # 3DCNN + hidden_states = self.resblocks_3d_t(hidden_states, temb) + hidden_states = self.resblocks_3d_s(hidden_states, temb) + + if hasattr(self, "attentions"): + for attn in self.attentions: + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timesteps).sample + + if self.use_scale_shift: + hidden_states = self.scale_shift_conv(hidden_states) + scale, shift = torch.chunk(hidden_states, chunks=2, dim=1) + hidden_states = (1 + scale) * input_tensor + shift + else: + hidden_states = self.shift_conv(hidden_states) + hidden_states = input_tensor + hidden_states + + return hidden_states + + +class TemporalTransformer3DModel(ModelMixin, ConfigMixin): + @register_to_config + def __init__( + self, + num_attention_heads=None, + attention_head_dim=None, + in_channels=None, + num_layers=None, + dropout=None, + norm_num_groups=None, + cross_attention_dim=None, + attention_bias=None, + activation_fn=None, + num_embeds_ada_norm=None, + use_linear_projection=None, + only_cross_attention=None, + upcast_attention=None, + + attention_block_types=None, + cross_frame_attention_mode=None, + temporal_shift_fold_div=None, + temporal_shift_direction=None, + + use_dcn_warpping=None, + use_deformable_conv=None, + ): + super().__init__() + self.use_linear_projection = use_linear_projection + self.num_attention_heads = num_attention_heads + self.attention_head_dim = attention_head_dim + inner_dim = num_attention_heads * attention_head_dim + + # Define input layers + self.in_channels = in_channels + + self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True) + if use_linear_projection: + self.proj_in = nn.Linear(in_channels, inner_dim) + else: + self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0) + + # Define transformers blocks + self.transformer_blocks = nn.ModuleList( + [ + TemporalTransformerBlock( + inner_dim, + num_attention_heads, + attention_head_dim, + dropout=dropout, + cross_attention_dim=cross_attention_dim, + activation_fn=activation_fn, + num_embeds_ada_norm=num_embeds_ada_norm, + attention_bias=attention_bias, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + + attention_block_types=attention_block_types, + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_shift_fold_div=temporal_shift_fold_div, + temporal_shift_direction=temporal_shift_direction, + + use_dcn_warpping=use_dcn_warpping, + use_deformable_conv=use_deformable_conv, + ) + for d in range(num_layers) + ] + ) + + # 4. Define output layers + if use_linear_projection: + self.proj_out = nn.Linear(inner_dim, in_channels) + else: + self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0) + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True): + # Input + assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}." + video_length = hidden_states.shape[2] + hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w") + if encoder_hidden_states is not None: + encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length) + + batch, channel, height, weight = hidden_states.shape + residual = hidden_states + + hidden_states = self.norm(hidden_states) + if not self.use_linear_projection: + hidden_states = self.proj_in(hidden_states) + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + else: + inner_dim = hidden_states.shape[1] + hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim) + hidden_states = self.proj_in(hidden_states) + + # Blocks + for block in self.transformer_blocks: + hidden_states = block( + hidden_states, + encoder_hidden_states=encoder_hidden_states, + timestep=timestep, + video_length=video_length + ) + + # Output + if not self.use_linear_projection: + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + hidden_states = self.proj_out(hidden_states) + else: + hidden_states = self.proj_out(hidden_states) + hidden_states = ( + hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous() + ) + + output = hidden_states + residual + + output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length) + if not return_dict: + return (output,) + + return TemporalTransformer3DModelOutput(sample=output) + + +class TemporalTransformerBlock(nn.Module): + def __init__( + self, + dim=None, + num_attention_heads=None, + attention_head_dim=None, + dropout=None, + cross_attention_dim=None, + activation_fn=None, + num_embeds_ada_norm=None, + attention_bias=None, + only_cross_attention=None, + upcast_attention=None, + + attention_block_types=None, + cross_frame_attention_mode=None, + temporal_shift_fold_div=None, + temporal_shift_direction=None, + + use_dcn_warpping=None, + use_deformable_conv=None, + ): + super().__init__() + assert len(attention_block_types) == 2 + + self.only_cross_attention = only_cross_attention + self.use_ada_layer_norm = num_embeds_ada_norm is not None + self.use_dcn_warpping = use_dcn_warpping + + # 1. Spatial-Attn (self) + if not attention_block_types[0] == '': + self.attn_spatial = VersatileSelfAttention( + attention_mode=attention_block_types[0], + + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_shift_fold_div=temporal_shift_fold_div, + temporal_shift_direction=temporal_shift_direction, + ) + nn.init.zeros_(self.attn_spatial.to_out[0].weight.data) + self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + # 2. Temporal-Attn (self) + self.attn_temporal = VersatileSelfAttention( + attention_mode=attention_block_types[1], + + query_dim=dim, + heads=num_attention_heads, + dim_head=attention_head_dim, + dropout=dropout, + bias=attention_bias, + upcast_attention=upcast_attention, + + cross_frame_attention_mode=cross_frame_attention_mode, + temporal_shift_fold_div=temporal_shift_fold_div, + temporal_shift_direction=temporal_shift_direction, + ) + nn.init.zeros_(self.attn_temporal.to_out[0].weight.data) + self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim) + + self.dcn_module = WarpModule( + in_channels=dim, + use_deformable_conv=use_deformable_conv, + ) if use_dcn_warpping else None + + # 3. Feed-forward + self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn) + self.norm3 = nn.LayerNorm(dim) + + + def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None): + if not is_xformers_available(): + print("Here is how to install it") + raise ModuleNotFoundError( + "Refer to https://github.com/facebookresearch/xformers for more information on how to install" + " xformers", + name="xformers", + ) + elif not torch.cuda.is_available(): + raise ValueError( + "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only" + " available for GPU " + ) + else: + try: + # Make sure we can run the memory efficient attention + _ = xformers.ops.memory_efficient_attention( + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + torch.randn((1, 2, 40), device="cuda"), + ) + except Exception as e: + raise e + if hasattr(self, "attn_spatial"): + self.attn_spatial._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers + + def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None): + # 1. Spatial-Attention + if hasattr(self, "attn_spatial") and hasattr(self, "norm1"): + norm_hidden_states = self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states) + hidden_states = self.attn_spatial(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + + # 2. Temporal-Attention + norm_hidden_states = self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states) + if not self.use_dcn_warpping: + hidden_states = self.attn_temporal(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states + else: + hidden_states = self.dcn_module( + hidden_states, + offset_hidden_states=self.attn_temporal(norm_hidden_states, attention_mask=attention_mask, video_length=video_length), + ) + + # 3. Feed-forward + hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states + + return hidden_states + + +class VersatileSelfAttention(CrossAttention): + def __init__( + self, + attention_mode=None, + cross_frame_attention_mode=None, + temporal_shift_fold_div=None, + temporal_shift_direction=None, + temporal_position_encoding=False, + temporal_position_encoding_max_len=24, + *args, **kwargs + ): + super().__init__(*args, **kwargs) + assert attention_mode in ("Temporal", "Spatial", "CrossFrame", "SpatialTemporalShift", None) + assert cross_frame_attention_mode in ("0_i-1", "i-1_i", "0_i-1_i", "i-1_i_i+1", None) + + self.attention_mode = attention_mode + + self.cross_frame_attention_mode = cross_frame_attention_mode + + self.temporal_shift_fold_div = temporal_shift_fold_div + self.temporal_shift_direction = temporal_shift_direction + + self.pos_encoder = PositionalEncoding( + kwargs["query_dim"], + dropout=0., + max_len=temporal_position_encoding_max_len + ) if temporal_position_encoding else None + + def temporal_token_concat(self, tensor, video_length): + # print("### temporal token concat") + current_frame_index = torch.arange(video_length) + former_frame_index = current_frame_index - 1 + former_frame_index[0] = 0 + + later_frame_index = current_frame_index + 1 + later_frame_index[-1] = -1 + + # (b f) d c + tensor = rearrange(tensor, "(b f) d c -> b f d c", f=video_length) + + if self.cross_frame_attention_mode == "0_i-1": + tensor = torch.cat([tensor[:, [0] * video_length], tensor[:, former_frame_index]], dim=2) + elif self.cross_frame_attention_mode == "i-1_i": + tensor = torch.cat([tensor[:, former_frame_index], tensor[:, current_frame_index]], dim=2) + elif self.cross_frame_attention_mode == "0_i-1_i": + tensor = torch.cat([tensor[:, [0] * video_length], tensor[:, former_frame_index], tensor[:, current_frame_index]], dim=2) + elif self.cross_frame_attention_mode == "i-1_i_i+1": + tensor = torch.cat([tensor[:, former_frame_index], tensor[:, current_frame_index], tensor[:, later_frame_index]], dim=2) + else: + raise NotImplementedError + + tensor = rearrange(tensor, "b f d c -> (b f) d c") + return tensor + + def temporal_shift(self, tensor, video_length): + # print("### temporal shift") + # (b f) d c + tensor = rearrange(tensor, "(b f) d c -> b f d c", f=video_length) + n_channels = tensor.shape[-1] + fold = n_channels // self.temporal_shift_fold_div + + if self.temporal_shift_direction != "right": + raise NotImplementedError + + tensor_out = torch.zeros_like(tensor) + tensor_out[:, 1:, :, :fold] = tensor[:, :-1, :, :fold] + tensor_out[:, :, :, fold:] = tensor[:, :, :, fold:] + + tensor_out = rearrange(tensor_out, "b f d c -> (b f) d c") + return tensor_out + + def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None): + # pdb.set_trace() + batch_size, sequence_length, _ = hidden_states.shape + assert encoder_hidden_states is None + + # (b f) d c + if self.attention_mode == "Temporal": + # print("### temporal reshape") + d = hidden_states.shape[1] + hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length) + + if self.pos_encoder is not None: + hidden_states = self.pos_encoder(hidden_states) + + encoder_hidden_states = encoder_hidden_states + + if self.group_norm is not None: + hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) + + query = self.to_q(hidden_states) + dim = query.shape[-1] + query = self.reshape_heads_to_batch_dim(query) + + if self.added_kv_proj_dim is not None: + raise NotImplementedError + + encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states + key = self.to_k(encoder_hidden_states) + value = self.to_v(encoder_hidden_states) + + if self.attention_mode == "SpatialTemporalShift": + key = self.temporal_shift(key, video_length=video_length) + value = self.temporal_shift(value, video_length=video_length) + elif self.attention_mode == "CrossFrame": + key = self.temporal_token_concat(key, video_length=video_length) + value = self.temporal_token_concat(value, video_length=video_length) + + key = self.reshape_heads_to_batch_dim(key) + value = self.reshape_heads_to_batch_dim(value) + + if attention_mask is not None: + if attention_mask.shape[-1] != query.shape[1]: + target_length = query.shape[1] + attention_mask = F.pad(attention_mask, (0, target_length), value=0.0) + attention_mask = attention_mask.repeat_interleave(self.heads, dim=0) + + # attention, what we cannot get enough of + if self._use_memory_efficient_attention_xformers: + hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask) + # Some versions of xformers return output in fp32, cast it back to the dtype of the input + hidden_states = hidden_states.to(query.dtype) + else: + if self._slice_size is None or query.shape[0] // self._slice_size == 1: + hidden_states = self._attention(query, key, value, attention_mask) + else: + hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask) + + # linear proj + hidden_states = self.to_out[0](hidden_states) + + # dropout + hidden_states = self.to_out[1](hidden_states) + + if self.attention_mode == "Temporal": + hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d) + + return hidden_states + + +class WarpModule(nn.Module): + def __init__( + self, + in_channels=None, + use_deformable_conv=None, + ): + super().__init__() + self.use_deformable_conv = use_deformable_conv + + self.conv = None + self.dcn_weight = None + if use_deformable_conv: + self.conv = nn.Conv2d(in_channels*2, 27, kernel_size=3, stride=1, padding=1) + self.dcn_weight = nn.Parameter(torch.randn(in_channels, in_channels, 3, 3) / np.sqrt(in_channels * 3 * 3)) + self.alpha = nn.Parameter(torch.zeros(1, in_channels, 1, 1)) + else: + self.conv = zero_module(nn.Conv2d(in_channels, 2, kernel_size=3, stride=1, padding=1)) + + def forward(self, hidden_states, offset_hidden_states): + # (b f) d c + spatial_dim = hidden_states.shape[1] + size = int(spatial_dim ** 0.5) + assert size ** 2 == spatial_dim + + hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=size) + offset_hidden_states = rearrange(offset_hidden_states, "b (h w) c -> b c h w", h=size) + + concat_hidden_states = torch.cat([hidden_states, offset_hidden_states], dim=1) + + input_tensor = hidden_states + if self.use_deformable_conv: + offset_x, offset_y, offsets_mask = torch.chunk(self.conv(concat_hidden_states), chunks=3, dim=1) + offsets_mask = offsets_mask.sigmoid() * 2 + offsets = torch.cat([offset_x, offset_y], dim=1) + hidden_states = torchvision.ops.deform_conv2d( + hidden_states, + offset=offsets, + weight=self.dcn_weight, + mask=offsets_mask, + padding=1, + ) + hidden_states = self.alpha * hidden_states + input_tensor + + else: + offsets = self.conv(concat_hidden_states) + hidden_states = self.optical_flow_warping(hidden_states, offsets) + + hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c") + return hidden_states + + @staticmethod + def optical_flow_warping(x, flo): + """ + warp an image/tensor (im2) back to im1, according to the optical flow + + x: [B, C, H, W] (im2) + flo: [B, 2, H, W] flow + pad_mode (optional): ref to https://pytorch.org/docs/stable/nn.functional.html#grid-sample + "zeros": use 0 for out-of-bound grid locations, + "border": use border values for out-of-bound grid locations + """ + dtype = x.dtype + if dtype != torch.float32: + x = x.to(torch.float32) + B, C, H, W = x.size() + # mesh grid + xx = torch.arange(0, W).view(1, -1).repeat(H, 1) + yy = torch.arange(0, H).view(-1, 1).repeat(1, W) + xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1) + yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1) + grid = torch.cat((xx, yy), 1).float().to(flo.device) + + vgrid = grid + flo + + # scale grid to [-1,1] + vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0 + vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0 + + vgrid = vgrid.permute(0, 2, 3, 1) + # output = grid_sample_gradfix.grid_sample_align(x, vgrid) + output = grid_sample_align(x, vgrid) + #output = torch.nn.functional.grid_sample(x, vgrid, padding_mode='zeros', mode='bilinear', align_corners=True) + + mask = torch.ones_like(x) + # mask = grid_sample_gradfix.grid_sample_align(mask, vgrid) + mask = grid_sample_align(x, vgrid) + #mask = torch.nn.functional.grid_sample(mask, vgrid, padding_mode='zeros', mode='bilinear', align_corners=True) + + mask[mask < 0.9999] = 0 + mask[mask > 0] = 1 + results = output * mask + if dtype != torch.float32: + results = results.to(dtype) + return results + + +class AdaLayerNorm(nn.Module): + """ + Norm layer modified to incorporate timestep embeddings. + """ + def __init__(self, embedding_dim, num_embeddings): + super().__init__() + self.emb = nn.Embedding(num_embeddings, embedding_dim) + self.silu = nn.SiLU() + self.linear = nn.Linear(embedding_dim, embedding_dim * 2) + self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False) + + def forward(self, x, timestep): + timestep = repeat(timestep, "b -> (b r)", r=x.shape[0] // timestep.shape[0]) + + emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1) # (b f) 1 2d + scale, shift = torch.chunk(emb, 2, dim=-1) + x = self.norm(x) * (1 + scale) + shift + return x + diff --git a/vsr/models/unet.py b/vsr/models/unet.py new file mode 100644 index 0000000000000000000000000000000000000000..f30bc66acd049728a20fc97c2d56c6edea09df02 --- /dev/null +++ b/vsr/models/unet.py @@ -0,0 +1,654 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py + +from dataclasses import dataclass +from typing import List, Optional, Tuple, Union + +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import json +import math +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn import functional as F +import einops + +from diffusers.configuration_utils import ConfigMixin, register_to_config +from diffusers.models.modeling_utils import ModelMixin +from diffusers.models.embeddings import TimestepEmbedding, Timesteps +from diffusers.utils import BaseOutput, logging + +try: + from .unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from .resnet import InflatedConv3d + from .temporal_module import TemporalModule3D, EmptyTemporalModule3D +except: + from unet_blocks import ( + CrossAttnDownBlock3D, + CrossAttnUpBlock3D, + DownBlock3D, + UNetMidBlock3DCrossAttn, + UpBlock3D, + get_down_block, + get_up_block, + ) + from resnet import InflatedConv3d + from temporal_module import TemporalModule3D, EmptyTemporalModule3D + +from rotary_embedding_torch import RotaryEmbedding + +logger = logging.get_logger(__name__) # pylint: disable=invalid-name + +def zero_module(module): + for p in module.parameters(): + nn.init.zeros_(p) + return module + +class RelativePositionBias(nn.Module): + def __init__( + self, + heads=8, + num_buckets=32, + max_distance=128, + ): + super().__init__() + self.num_buckets = num_buckets + self.max_distance = max_distance + self.relative_attention_bias = nn.Embedding(num_buckets, heads) + + @staticmethod + def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128): + ret = 0 + n = -relative_position + + num_buckets //= 2 + ret += (n < 0).long() * num_buckets + n = torch.abs(n) + + max_exact = num_buckets // 2 + is_small = n < max_exact + + val_if_large = max_exact + ( + torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact) + ).long() + val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1)) + + ret += torch.where(is_small, n, val_if_large) + return ret + + def forward(self, n, device): + q_pos = torch.arange(n, dtype = torch.long, device = device) + k_pos = torch.arange(n, dtype = torch.long, device = device) + rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1') + rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance) + values = self.relative_attention_bias(rp_bucket) + return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames + +@dataclass +class UNet3DConditionOutput(BaseOutput): + sample: torch.FloatTensor + + +class UNet3DVSRModel(ModelMixin, ConfigMixin): + _supports_gradient_checkpointing = True + + @register_to_config + def __init__( + self, + ### Temporal Module Additional Kwargs ### + down_temporal_idx = (0,1,2), + mid_temporal = False, + up_temporal_idx = (0,1,2), + video_condition = True, + temporal_module_config = None, + + sample_size: Optional[int] = None, # 80 + in_channels: int = 7, + out_channels: int = 4, + center_input_sample: bool = False, + max_noise_level: int = 350, + flip_sin_to_cos: bool = True, + freq_shift: int = 0, + attention_head_dim: Union[int, Tuple[int]] = 8, + block_out_channels: Tuple[int] = ( + 256, + 512, + 512, + 1024 + ), + down_block_types: Tuple[str] = ( + "DownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D", + "CrossAttnDownBlock3D" + ), + mid_block_type: str = "UNetMidBlock3DCrossAttn", + up_block_types: Tuple[str] = ( + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "CrossAttnUpBlock3D", + "UpBlock3D" + ), + only_cross_attention: Union[bool, Tuple[bool]] = ( + True, + True, + True, + False + ), + layers_per_block: int = 2, + downsample_padding: int = 1, + mid_block_scale_factor: float = 1, + act_fn: str = "silu", + norm_num_groups: int = 32, + norm_eps: float = 1e-5, + cross_attention_dim: int = 1024, + dual_cross_attention: bool = False, + use_linear_projection: bool = True, + class_embed_type: Optional[str] = None, + num_class_embeds: Optional[int] = 1000, + upcast_attention: bool = False, + resnet_time_scale_shift: str = "default", + use_first_frame: bool = False, + use_relative_position: bool = False, + ): + super().__init__() + + self.sample_size = sample_size + time_embed_dim = block_out_channels[0] * 4 + + # input + self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=1) + + # time + self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift) + timestep_input_dim = block_out_channels[0] + + self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + + # class embedding + if class_embed_type is None and num_class_embeds is not None: + self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) # VSR for noise level + elif class_embed_type == "timestep": + self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim) + elif class_embed_type == "identity": + self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim) + else: + self.class_embedding = None + + self.down_blocks = nn.ModuleList([]) + self.mid_block = None + self.up_blocks = nn.ModuleList([]) + self.video_condition = video_condition + + # Temporal Modules + self.down_temporal_blocks = nn.ModuleList([]) + self.mid_temporal_block = None + self.up_temporal_blocks = nn.ModuleList([]) + + if isinstance(only_cross_attention, bool): + only_cross_attention = [only_cross_attention] * len(down_block_types) + + + if isinstance(attention_head_dim, int): + attention_head_dim = (attention_head_dim,) * len(down_block_types) + + + self.temporal_rotary_emb = RotaryEmbedding(32) + + # down + output_channel = block_out_channels[0] + for i, down_block_type in enumerate(down_block_types): + input_channel = output_channel + output_channel = block_out_channels[i] + is_final_block = i == len(block_out_channels) - 1 + + down_block = get_down_block( + down_block_type, + num_layers=layers_per_block, + in_channels=input_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + add_downsample=not is_final_block, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[i], + downsample_padding=downsample_padding, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=self.temporal_rotary_emb, + ) + self.down_blocks.append(down_block) + + # Down Sample Temporal Modules + down_temporal_block = TemporalModule3D( + in_channels=output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + video_condition=video_condition, + **temporal_module_config, + ) if i in down_temporal_idx else EmptyTemporalModule3D() + self.down_temporal_blocks.append(down_temporal_block) + # mid + if mid_block_type == "UNetMidBlock3DCrossAttn": + self.mid_block = UNetMidBlock3DCrossAttn( + in_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + output_scale_factor=mid_block_scale_factor, + resnet_time_scale_shift=resnet_time_scale_shift, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attention_head_dim[-1], + resnet_groups=norm_num_groups, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=self.temporal_rotary_emb, + ) + else: + raise ValueError(f"unknown mid_block_type : {mid_block_type}") + + self.mid_temporal_block = TemporalModule3D( + in_channels=block_out_channels[-1], + out_channels=block_out_channels[-1], + temb_channels=time_embed_dim, + video_condition=video_condition, + **temporal_module_config, + ) if mid_temporal else EmptyTemporalModule3D() + + # count how many layers upsample the videos + self.num_upsamplers = 0 + + # up + reversed_block_out_channels = list(reversed(block_out_channels)) + reversed_attention_head_dim = list(reversed(attention_head_dim)) + only_cross_attention = list(reversed(only_cross_attention)) + output_channel = reversed_block_out_channels[0] + for i, up_block_type in enumerate(up_block_types): + is_final_block = i == len(block_out_channels) - 1 + + prev_output_channel = output_channel + output_channel = reversed_block_out_channels[i] + input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)] + + # add upsample block for all BUT final layer + if not is_final_block: + add_upsample = True + self.num_upsamplers += 1 + else: + add_upsample = False + + up_block = get_up_block( + up_block_type, + num_layers=layers_per_block + 1, + in_channels=input_channel, + out_channels=output_channel, + prev_output_channel=prev_output_channel, + temb_channels=time_embed_dim, + add_upsample=add_upsample, + resnet_eps=norm_eps, + resnet_act_fn=act_fn, + resnet_groups=norm_num_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=reversed_attention_head_dim[i], + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention[i], + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=self.temporal_rotary_emb, + ) + self.up_blocks.append(up_block) + prev_output_channel = output_channel + + up_temporal_block = TemporalModule3D( + in_channels=output_channel, + out_channels=output_channel, + temb_channels=time_embed_dim, + video_condition=video_condition, + **temporal_module_config, + ) if i in up_temporal_idx else EmptyTemporalModule3D() + self.up_temporal_blocks.append(up_temporal_block) + + # out + self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps) + self.conv_act = nn.SiLU() + self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1) + + def set_attention_slice(self, slice_size): + r""" + Enable sliced attention computation. + + When this option is enabled, the attention module will split the input tensor in slices, to compute attention + in several steps. This is useful to save some memory in exchange for a small speed decrease. + + Args: + slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`): + When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If + `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is + provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim` + must be a multiple of `slice_size`. + """ + sliceable_head_dims = [] + + def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module): + if hasattr(module, "set_attention_slice"): + sliceable_head_dims.append(module.sliceable_head_dim) + + for child in module.children(): + fn_recursive_retrieve_slicable_dims(child) + + # retrieve number of attention layers + for module in self.children(): + fn_recursive_retrieve_slicable_dims(module) + + num_slicable_layers = len(sliceable_head_dims) + + if slice_size == "auto": + # half the attention head size is usually a good trade-off between + # speed and memory + slice_size = [dim // 2 for dim in sliceable_head_dims] + elif slice_size == "max": + # make smallest slice possible + slice_size = num_slicable_layers * [1] + + slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size + + if len(slice_size) != len(sliceable_head_dims): + raise ValueError( + f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different" + f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}." + ) + + for i in range(len(slice_size)): + size = slice_size[i] + dim = sliceable_head_dims[i] + if size is not None and size > dim: + raise ValueError(f"size {size} has to be smaller or equal to {dim}.") + + # Recursively walk through all the children. + # Any children which exposes the set_attention_slice method + # gets the message + def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]): + if hasattr(module, "set_attention_slice"): + module.set_attention_slice(slice_size.pop()) + + for child in module.children(): + fn_recursive_set_attention_slice(child, slice_size) + + reversed_slice_size = list(reversed(slice_size)) + for module in self.children(): + fn_recursive_set_attention_slice(module, reversed_slice_size) + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)): + module.gradient_checkpointing = value + + def forward( + self, + sample: torch.FloatTensor, + timestep: Union[torch.Tensor, float, int], + low_res: torch.FloatTensor, + # encoder_hidden_states: torch.Tensor, + encoder_hidden_states = None, + class_labels: Optional[torch.Tensor] = 20, + low_res_clean: Optional[torch.FloatTensor] = None, + attention_mask: Optional[torch.Tensor] = None, + return_dict: bool = True, + ): # -> Union[UNet3DConditionOutput, Tuple]: + r""" + Args: + sample (`torch.FloatTensor`): (batch, channel, seq_length, height, width) noisy inputs tensor + timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps + encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states + class_labels: noise level + return_dict (`bool`, *optional*, defaults to `True`): + Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple. + + Returns: + [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`: + [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When + returning a tuple, the first element is the sample tensor. + """ + # By default samples have to be AT least a multiple of the overall upsampling factor. + # The overall upsampling factor is equal to 2 ** (# num of upsampling layears). + # However, the upsampling interpolation output size can be forced to fit any upsampling size + # on the fly if necessary. + default_overall_up_factor = 2**self.num_upsamplers + + # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor` + forward_upsample_size = False + upsample_size = None + + if self.video_condition: + low_res_dict = {} + low_res_dict[low_res.shape[-1]] = low_res + for s in [1/2., 1/4., 1/8.]: + low_res_ds = F.interpolate(low_res, scale_factor=(1, s, s), mode='area') + low_res_dict[low_res_ds.shape[-1]] = low_res_ds + else: + low_res_dict = None + + sample = torch.cat([sample, low_res], dim=1) # concat on C: 4+3=7 + + #print(f'==============={sample.shape}================') + if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]): + logger.info("Forward upsample size to force interpolation output size.") + forward_upsample_size = True + + # prepare attention_mask + if attention_mask is not None: + attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0 + attention_mask = attention_mask.unsqueeze(1) + + # center input if necessary + if self.config.center_input_sample: + sample = 2 * sample - 1.0 + + # time + timesteps = timestep + if not torch.is_tensor(timesteps): + # This would be a good case for the `match` statement (Python 3.10+) + is_mps = sample.device.type == "mps" + if isinstance(timestep, float): + dtype = torch.float32 if is_mps else torch.float64 + else: + dtype = torch.int32 if is_mps else torch.int64 + timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device) + elif len(timesteps.shape) == 0: + timesteps = timesteps[None].to(sample.device) + + # broadcast to batch dimension in a way that's compatible with ONNX/Core ML + timesteps = timesteps.expand(sample.shape[0]) + + t_emb = self.time_proj(timesteps) + + # timesteps does not contain any weights and will always return f32 tensors + # but time_embedding might actually be running in fp16. so we need to cast here. + # there might be better ways to encapsulate this. + t_emb = t_emb.to(dtype=self.dtype) + emb = self.time_embedding(t_emb) + + if self.class_embedding is not None: + if class_labels is None: + raise ValueError("class_labels should be provided when num_class_embeds > 0") + # check noise level + if torch.any(class_labels > self.config.max_noise_level): + raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {class_labels}") + + if self.config.class_embed_type == "timestep": + class_labels = self.time_proj(class_labels) + + class_emb = self.class_embedding(class_labels).to(dtype=self.dtype) + emb = emb + class_emb + + + # pre-process + sample = self.conv_in(sample) + + # down + down_block_res_samples = (sample,) + for downsample_block, down_temporal_block in zip(self.down_blocks, self.down_temporal_blocks): + if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention: + sample, res_samples = downsample_block( + hidden_states=sample, + temb=emb, + encoder_hidden_states=encoder_hidden_states, + attention_mask=attention_mask, + ) + else: + sample, res_samples = downsample_block(hidden_states=sample, temb=emb) + + down_block_res_samples += res_samples + + # 1. temporal modeling during down sample + sample = down_temporal_block( + hidden_states=sample, + condition_video=low_res_dict, + encoder_hidden_states=encoder_hidden_states, + timesteps=timesteps, + temb=emb, + ) + # mid + sample = self.mid_block( + sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask + ) + # 2. temporal modeling at mid block + sample = self.mid_temporal_block( + hidden_states=sample, + condition_video=low_res_dict, + encoder_hidden_states=encoder_hidden_states, + timesteps=timesteps, + temb=emb, + ) + + # up + for i, (upsample_block, up_temporal_block) in enumerate(zip(self.up_blocks, self.up_temporal_blocks)): + is_final_block = i == len(self.up_blocks) - 1 + + res_samples = down_block_res_samples[-len(upsample_block.resnets) :] + down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)] + + # if we have not reached the final block and need to forward the + # upsample size, we do it here + if not is_final_block and forward_upsample_size: + upsample_size = down_block_res_samples[-1].shape[2:] + + if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention: + sample = upsample_block( + hidden_states=sample, + temb=emb, + res_hidden_states_tuple=res_samples, + encoder_hidden_states=encoder_hidden_states, + upsample_size=upsample_size, + attention_mask=attention_mask, + ) + else: + sample = upsample_block( + hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size + ) + + # 3. temporal modeling during up sample + sample = up_temporal_block( + hidden_states=sample, + condition_video=low_res_dict, + encoder_hidden_states=encoder_hidden_states, + timesteps=timesteps, + temb=emb, + ) + + # post-process + sample = self.conv_norm_out(sample) + sample = self.conv_act(sample) + sample = self.conv_out(sample) + # print(sample.shape) + + if not return_dict: + return (sample,) + + return UNet3DConditionOutput(sample=sample) + + def forward_with_cfg(self, + x, + t, + low_res, + encoder_hidden_states = None, + class_labels: Optional[torch.Tensor] = 20, + cfg_scale=4.0, + use_fp16=False): + """ + Forward, but also batches the unconditional forward pass for classifier-free guidance. + """ + # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb + half = x[: len(x) // 2] + combined = torch.cat([half, half], dim=0) + if use_fp16: + combined = combined.to(dtype=torch.float16) + model_out = self.forward(combined, t, low_res, encoder_hidden_states, class_labels).sample + # For exact reproducibility reasons, we apply classifier-free guidance on only + # three channels by default. The standard approach to cfg applies it to all channels. + # This can be done by uncommenting the following line and commenting-out the line following that. + eps, rest = model_out[:, :4], model_out[:, 4:] + # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w + cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0) + half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps) + eps = torch.cat([half_eps, half_eps], dim=0) + return torch.cat([eps, rest], dim=1) + + + @classmethod + def from_pretrained_2d(cls, config_path, pretrained_model_path): + if not os.path.isfile(config_path): + raise RuntimeError(f"{config_path} does not exist") + with open(config_path, "r") as f: + config = json.load(f) + config["_class_name"] = cls.__name__ + freeze_pretrained_2d_upsampler = config["freeze_pretrained_2d_upsampler"] + + model = cls.from_config(config) + model_file = os.path.join(pretrained_model_path) + if not os.path.isfile(model_file): + raise RuntimeError(f"{model_file} does not exist") + state_dict = torch.load(model_file, map_location="cpu") + for k, v in model.state_dict().items(): + if 'temporal' in k: + print(f'New layers: {k}') + state_dict.update({k: v}) + + model.load_state_dict(state_dict, strict=True) + + if freeze_pretrained_2d_upsampler: + print("Freeze pretrained 2d upsampler!") + for k, v in model.named_parameters(): + if not 'temporal' in k: + v.requires_grad = False + return model + +if __name__ == '__main__': + import torch + + device = "cuda" if torch.cuda.is_available() else "cpu" + config_path = "./configs/unet_3d_config.json" + # pretrained_model_path = "./pretrained_models/unet_diffusion_pytorch_model.bin" + # unet = UNet3DVSRModel.from_pretrained_2d(config_path, pretrained_model_path).to(device) diff --git a/vsr/models/unet_blocks.py b/vsr/models/unet_blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..066637ff78439e61864e5893a122b32f07e58f89 --- /dev/null +++ b/vsr/models/unet_blocks.py @@ -0,0 +1,629 @@ +# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py +import os +import sys +sys.path.append(os.path.split(sys.path[0])[0]) + +import torch +from torch import nn + +try: + from .attention import Transformer3DModel + from .resnet import Downsample3D, ResnetBlock3D, Upsample3D +except: + from attention import Transformer3DModel + from resnet import Downsample3D, ResnetBlock3D, Upsample3D + + +def get_down_block( + down_block_type, + num_layers, + in_channels, + out_channels, + temb_channels, + add_downsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + downsample_padding=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=None, +): + # print(down_block_type) + # print(use_first_frame) + down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type + if down_block_type == "DownBlock3D": + return DownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif down_block_type == "CrossAttnDownBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D") + return CrossAttnDownBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + add_downsample=add_downsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + downsample_padding=downsample_padding, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{down_block_type} does not exist.") + + +def get_up_block( + up_block_type, + num_layers, + in_channels, + out_channels, + prev_output_channel, + temb_channels, + add_upsample, + resnet_eps, + resnet_act_fn, + attn_num_head_channels, + resnet_groups=None, + cross_attention_dim=None, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + resnet_time_scale_shift="default", + use_first_frame=False, + use_relative_position=False, + rotary_emb=None, +): + up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type + if up_block_type == "UpBlock3D": + return UpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + resnet_time_scale_shift=resnet_time_scale_shift, + ) + elif up_block_type == "CrossAttnUpBlock3D": + if cross_attention_dim is None: + raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D") + return CrossAttnUpBlock3D( + num_layers=num_layers, + in_channels=in_channels, + out_channels=out_channels, + prev_output_channel=prev_output_channel, + temb_channels=temb_channels, + add_upsample=add_upsample, + resnet_eps=resnet_eps, + resnet_act_fn=resnet_act_fn, + resnet_groups=resnet_groups, + cross_attention_dim=cross_attention_dim, + attn_num_head_channels=attn_num_head_channels, + dual_cross_attention=dual_cross_attention, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + resnet_time_scale_shift=resnet_time_scale_shift, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + raise ValueError(f"{up_block_type} does not exist.") + + +class UNetMidBlock3DCrossAttn(nn.Module): + def __init__( + self, + in_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + output_scale_factor=1.0, + cross_attention_dim=1280, + dual_cross_attention=False, + use_linear_projection=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=None, + ): + super().__init__() + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32) + + # there is always at least one resnet + resnets = [ + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ] + attentions = [] + + for _ in range(num_layers): + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + in_channels // attn_num_head_channels, + in_channels=in_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=in_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + hidden_states = self.resnets[0](hidden_states, temb) + for attn, resnet in zip(self.attentions, self.resnets[1:]): + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + hidden_states = resnet(hidden_states, temb) + + return hidden_states + + +class CrossAttnDownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + downsample_padding=1, + add_downsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=None, + ): + super().__init__() + resnets = [] + attentions = [] + + # print(use_first_frame) + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None): + output_states = () + + for resnet, attn in zip(self.resnets, self.attentions): + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class DownBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_downsample=True, + downsample_padding=1, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + in_channels = in_channels if i == 0 else out_channels + resnets.append( + ResnetBlock3D( + in_channels=in_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_downsample: + self.downsamplers = nn.ModuleList( + [ + Downsample3D( + out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op" + ) + ] + ) + else: + self.downsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, temb=None): + output_states = () + + for resnet in self.resnets: + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + output_states += (hidden_states,) + + if self.downsamplers is not None: + for downsampler in self.downsamplers: + hidden_states = downsampler(hidden_states) + + output_states += (hidden_states,) + + return hidden_states, output_states + + +class CrossAttnUpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + out_channels: int, + prev_output_channel: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + attn_num_head_channels=1, + cross_attention_dim=1280, + output_scale_factor=1.0, + add_upsample=True, + dual_cross_attention=False, + use_linear_projection=False, + only_cross_attention=False, + upcast_attention=False, + use_first_frame=False, + use_relative_position=False, + rotary_emb=None + ): + super().__init__() + resnets = [] + attentions = [] + + self.has_cross_attention = True + self.attn_num_head_channels = attn_num_head_channels + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + if dual_cross_attention: + raise NotImplementedError + attentions.append( + Transformer3DModel( + attn_num_head_channels, + out_channels // attn_num_head_channels, + in_channels=out_channels, + num_layers=1, + cross_attention_dim=cross_attention_dim, + norm_num_groups=resnet_groups, + use_linear_projection=use_linear_projection, + only_cross_attention=only_cross_attention, + upcast_attention=upcast_attention, + use_first_frame=use_first_frame, + use_relative_position=use_relative_position, + rotary_emb=rotary_emb, + ) + ) + + self.attentions = nn.ModuleList(attentions) + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward( + self, + hidden_states, + res_hidden_states_tuple, + temb=None, + encoder_hidden_states=None, + upsample_size=None, + attention_mask=None, + ): + for resnet, attn in zip(self.resnets, self.attentions): + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module, return_dict=None): + def custom_forward(*inputs): + if return_dict is not None: + return module(*inputs, return_dict=return_dict) + else: + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + hidden_states = torch.utils.checkpoint.checkpoint( + create_custom_forward(attn, return_dict=False), + hidden_states, + encoder_hidden_states, + )[0] + else: + hidden_states = resnet(hidden_states, temb) + hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states + + +class UpBlock3D(nn.Module): + def __init__( + self, + in_channels: int, + prev_output_channel: int, + out_channels: int, + temb_channels: int, + dropout: float = 0.0, + num_layers: int = 1, + resnet_eps: float = 1e-6, + resnet_time_scale_shift: str = "default", + resnet_act_fn: str = "swish", + resnet_groups: int = 32, + resnet_pre_norm: bool = True, + output_scale_factor=1.0, + add_upsample=True, + ): + super().__init__() + resnets = [] + + for i in range(num_layers): + res_skip_channels = in_channels if (i == num_layers - 1) else out_channels + resnet_in_channels = prev_output_channel if i == 0 else out_channels + + resnets.append( + ResnetBlock3D( + in_channels=resnet_in_channels + res_skip_channels, + out_channels=out_channels, + temb_channels=temb_channels, + eps=resnet_eps, + groups=resnet_groups, + dropout=dropout, + time_embedding_norm=resnet_time_scale_shift, + non_linearity=resnet_act_fn, + output_scale_factor=output_scale_factor, + pre_norm=resnet_pre_norm, + ) + ) + + self.resnets = nn.ModuleList(resnets) + + if add_upsample: + self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)]) + else: + self.upsamplers = None + + self.gradient_checkpointing = False + + def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None): + for resnet in self.resnets: + # pop res hidden states + res_hidden_states = res_hidden_states_tuple[-1] + res_hidden_states_tuple = res_hidden_states_tuple[:-1] + hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1) + + if self.training and self.gradient_checkpointing: + + def create_custom_forward(module): + def custom_forward(*inputs): + return module(*inputs) + + return custom_forward + + hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb) + else: + hidden_states = resnet(hidden_states, temb) + + if self.upsamplers is not None: + for upsampler in self.upsamplers: + hidden_states = upsampler(hidden_states, upsample_size) + + return hidden_states \ No newline at end of file diff --git a/vsr/models/upscaling.py b/vsr/models/upscaling.py new file mode 100644 index 0000000000000000000000000000000000000000..f0f62ecd358a839b9a7c6a2d0e403ed9150ed2d3 --- /dev/null +++ b/vsr/models/upscaling.py @@ -0,0 +1,95 @@ +import torch +import torch.nn as nn +import numpy as np +from functools import partial +from inspect import isfunction + +def exists(x): + return x is not None + +def default(val, d): + if exists(val): + return val + return d() if isfunction(d) else d + +def extract_into_tensor(a, t, x_shape): + b, *_ = t.shape + out = a.gather(-1, t) + return out.reshape(b, *((1,) * (len(x_shape) - 1))) + +def make_beta_schedule(n_timestep, linear_start=1e-4, linear_end=2e-2): + betas = ( + torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2 + ) + return betas.numpy() + +class AbstractLowScaleModel(nn.Module): + # for concatenating a downsampled image to the latent representation + def __init__(self, noise_schedule_config=None): + super(AbstractLowScaleModel, self).__init__() + if noise_schedule_config is not None: + self.register_schedule(**noise_schedule_config) + + def register_schedule(self, timesteps=1000, linear_start=1e-4, linear_end=2e-2): + betas = make_beta_schedule(timesteps, linear_start=linear_start, linear_end=linear_end) + alphas = 1. - betas + alphas_cumprod = np.cumprod(alphas, axis=0) + alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1]) + + timesteps, = betas.shape + self.num_timesteps = int(timesteps) + self.linear_start = linear_start + self.linear_end = linear_end + assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep' + + to_torch = partial(torch.tensor, dtype=torch.float32) + + self.register_buffer('betas', to_torch(betas)) + self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod)) + self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev)) + + # calculations for diffusion q(x_t | x_{t-1}) and others + self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod))) + self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod))) + self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod))) + self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod))) + self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1))) + + def q_sample(self, x_start, t, noise=None): + noise = default(noise, lambda: torch.randn_like(x_start)) + return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start + + extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise) + + def forward(self, x): + return x, None + + def decode(self, x): + return x + + +class SimpleImageConcat(AbstractLowScaleModel): + # no noise level conditioning + def __init__(self): + super(SimpleImageConcat, self).__init__(noise_schedule_config=None) + self.max_noise_level = 0 + + def forward(self, x): + # fix to constant noise level + return x, torch.zeros(x.shape[0], device=x.device).long() + + +class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel): + def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False): + super().__init__(noise_schedule_config=noise_schedule_config) + self.max_noise_level = max_noise_level + + def forward(self, x, noise_level=None): + if noise_level is None: + noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long() + else: + assert isinstance(noise_level, torch.Tensor) + z = self.q_sample(x, noise_level) + return z, noise_level + + + diff --git a/vsr/models/utils.py b/vsr/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..76e205d1709568885e58b56888f62c8805ad4f91 --- /dev/null +++ b/vsr/models/utils.py @@ -0,0 +1,215 @@ +# adopted from +# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py +# and +# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py +# and +# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py +# +# thanks! + + +import os +import math +import torch + +import numpy as np +import torch.nn as nn + +from einops import repeat + + +################################################################################# +# Unet Utils # +################################################################################# + +def checkpoint(func, inputs, params, flag): + """ + Evaluate a function without caching intermediate activations, allowing for + reduced memory at the expense of extra compute in the backward pass. + :param func: the function to evaluate. + :param inputs: the argument sequence to pass to `func`. + :param params: a sequence of parameters `func` depends on but does not + explicitly take as arguments. + :param flag: if False, disable gradient checkpointing. + """ + if flag: + args = tuple(inputs) + tuple(params) + return CheckpointFunction.apply(func, len(inputs), *args) + else: + return func(*inputs) + + +class CheckpointFunction(torch.autograd.Function): + @staticmethod + def forward(ctx, run_function, length, *args): + ctx.run_function = run_function + ctx.input_tensors = list(args[:length]) + ctx.input_params = list(args[length:]) + + with torch.no_grad(): + output_tensors = ctx.run_function(*ctx.input_tensors) + return output_tensors + + @staticmethod + def backward(ctx, *output_grads): + ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors] + with torch.enable_grad(): + # Fixes a bug where the first op in run_function modifies the + # Tensor storage in place, which is not allowed for detach()'d + # Tensors. + shallow_copies = [x.view_as(x) for x in ctx.input_tensors] + output_tensors = ctx.run_function(*shallow_copies) + input_grads = torch.autograd.grad( + output_tensors, + ctx.input_tensors + ctx.input_params, + output_grads, + allow_unused=True, + ) + del ctx.input_tensors + del ctx.input_params + del output_tensors + return (None, None) + input_grads + + +def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False): + """ + Create sinusoidal timestep embeddings. + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + if not repeat_only: + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + else: + embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous() + return embedding + + +def zero_module(module): + """ + Zero out the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().zero_() + return module + + +def scale_module(module, scale): + """ + Scale the parameters of a module and return it. + """ + for p in module.parameters(): + p.detach().mul_(scale) + return module + + +def mean_flat(tensor): + """ + Take the mean over all non-batch dimensions. + """ + return tensor.mean(dim=list(range(1, len(tensor.shape)))) + + +def normalization(channels): + """ + Make a standard normalization layer. + :param channels: number of input channels. + :return: an nn.Module for normalization. + """ + return GroupNorm32(32, channels) + + +# PyTorch 1.7 has SiLU, but we support PyTorch 1.5. +class SiLU(nn.Module): + def forward(self, x): + return x * torch.sigmoid(x) + + +class GroupNorm32(nn.GroupNorm): + def forward(self, x): + return super().forward(x.float()).type(x.dtype) + +def conv_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D convolution module. + """ + if dims == 1: + return nn.Conv1d(*args, **kwargs) + elif dims == 2: + return nn.Conv2d(*args, **kwargs) + elif dims == 3: + return nn.Conv3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +def linear(*args, **kwargs): + """ + Create a linear module. + """ + return nn.Linear(*args, **kwargs) + + +def avg_pool_nd(dims, *args, **kwargs): + """ + Create a 1D, 2D, or 3D average pooling module. + """ + if dims == 1: + return nn.AvgPool1d(*args, **kwargs) + elif dims == 2: + return nn.AvgPool2d(*args, **kwargs) + elif dims == 3: + return nn.AvgPool3d(*args, **kwargs) + raise ValueError(f"unsupported dimensions: {dims}") + + +# class HybridConditioner(nn.Module): + +# def __init__(self, c_concat_config, c_crossattn_config): +# super().__init__() +# self.concat_conditioner = instantiate_from_config(c_concat_config) +# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config) + +# def forward(self, c_concat, c_crossattn): +# c_concat = self.concat_conditioner(c_concat) +# c_crossattn = self.crossattn_conditioner(c_crossattn) +# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]} + + +def noise_like(shape, device, repeat=False): + repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1))) + noise = lambda: torch.randn(shape, device=device) + return repeat_noise() if repeat else noise() + +def count_flops_attn(model, _x, y): + """ + A counter for the `thop` package to count the operations in an + attention operation. + Meant to be used like: + macs, params = thop.profile( + model, + inputs=(inputs, timestamps), + custom_ops={QKVAttention: QKVAttention.count_flops}, + ) + """ + b, c, *spatial = y[0].shape + num_spatial = int(np.prod(spatial)) + # We perform two matmuls with the same number of ops. + # The first computes the weight matrix, the second computes + # the combination of the value vectors. + matmul_ops = 2 * b * (num_spatial ** 2) * c + model.total_ops += torch.DoubleTensor([matmul_ops]) + +def count_params(model, verbose=False): + total_params = sum(p.numel() for p in model.parameters()) + if verbose: + print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.") + return total_params \ No newline at end of file diff --git a/vsr/sample.py b/vsr/sample.py new file mode 100644 index 0000000000000000000000000000000000000000..a0f948b622a7f17a35ec2debbbd55f68e200ab3b --- /dev/null +++ b/vsr/sample.py @@ -0,0 +1,151 @@ +import io +import os +import sys +import argparse +o_path = os.getcwd() +sys.path.append(o_path) + +import torch +import time +import json +import numpy as np +import imageio +import torchvision +from einops import rearrange + +from models.autoencoder_kl import AutoencoderKL +from models.unet import UNet3DVSRModel +from models.pipeline_stable_diffusion_upscale_video_3d import StableDiffusionUpscalePipeline +from diffusers import DDIMScheduler +from omegaconf import OmegaConf + + +def main(args) + + device = "cuda" + + # ---------------------- load models ---------------------- + pipeline = StableDiffusionUpscalePipeline.from_pretrained(args.pretrained_path + '/stable-diffusion-x4-upscaler', torch_dtype=torch.float16) + + # vae + pipeline.vae = AutoencoderKL.from_config("configs/vae_config.json") + pretrained_model = args.pretrained_path + "/stable-diffusion-x4-upscaler/vae/diffusion_pytorch_model.bin" + pipeline.vae.load_state_dict(torch.load(pretrained_model, map_location="cpu")) + + # unet + config_path = "./configs/unet_3d_config.json" + with open(config_path, "r") as f: + config = json.load(f) + config['video_condition'] = False + pipeline.unet = UNet3DVSRModel.from_config(config) + + pretrained_model = args.pretrained_path + "/lavie_vsr.pt" + checkpoint = torch.load(pretrained_model, map_location="cpu")['ema'] + + pipeline.unet.load_state_dict(checkpoint, True) + pipeline.unet = pipeline.unet.half() + pipeline.unet.eval() # important! + + # DDIMScheduler + with open(args.pretrained_path + '/stable-diffusion-x4-upscaler/scheduler/scheduler_config.json', "r") as f: + config = json.load(f) + config["beta_schedule"] = "linear" + pipeline.scheduler = DDIMScheduler.from_config(config) + + pipeline = pipeline.to("cuda") + + # ---------------------- load user's prompt ---------------------- + # input + video_root = args.input_path + video_list = sorted(os.listdir(video_root)) + print('video num:', len(video_list)) + + # output + save_root = args.output_path + os.makedirs(save_root, exist_ok=True) + + # inference params + noise_level = args.noise_level + guidance_scale = args.guidance_scale + num_inference_steps = args.inference_steps + + # ---------------------- start inferencing ---------------------- + for i, video_name in enumerate(video_list): + video_name = video_name.replace('.mp4', '') + print(f'[{i+1}/{len(video_list)}]: ', video_name) + + lr_path = f"{video_root}/{video_name}.mp4" + save_path = f"{save_root}/{video_name}.mp4" + + prompt = video_name + print('Prompt: ', prompt) + + negative_prompt = "blur, worst quality" + + vframes, aframes, info = torchvision.io.read_video(filename=lr_path, pts_unit='sec', output_format='TCHW') # RGB + vframes = vframes / 255. + vframes = (vframes - 0.5) * 2 # T C H W [-1, 1] + t, _, h, w = vframes.shape + vframes = vframes.unsqueeze(dim=0) # 1 T C H W + vframes = rearrange(vframes, 'b t c h w -> b c t h w').contiguous() # 1 C T H W + print('Input_shape:', vframes.shape, 'Noise_level:', noise_level, 'Guidance_scale:', guidance_scale) + + fps = info['video_fps'] + generator = torch.Generator(device=device).manual_seed(10) + + torch.cuda.synchronize() + start_time = time.time() + + with torch.no_grad(): + short_seq = 8 + vframes_seq = vframes.shape[2] + if vframes_seq > short_seq: # for VSR + upscaled_video_list = [] + for start_f in range(0, vframes_seq, short_seq): + print(f'Processing: [{start_f}-{start_f + short_seq}/{vframes_seq}]') + torch.cuda.empty_cache() # delete for VSR + end_f = min(vframes_seq, start_f + short_seq) + + upscaled_video_ = pipeline( + prompt, + image=vframes[:,:,start_f:end_f], + generator=generator, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + noise_level=noise_level, + negative_prompt=negative_prompt, + ).images # T C H W [-1, 1] + upscaled_video_list.append(upscaled_video_) + upscaled_video = torch.cat(upscaled_video_list, dim=0) + else: + upscaled_video = pipeline( + prompt, + image=vframes, + generator=generator, + num_inference_steps=num_inference_steps, + guidance_scale=guidance_scale, + noise_level=noise_level, + negative_prompt=negative_prompt, + ).images # T C H W [-1, 1] + + torch.cuda.synchronize() + run_time = time.time() - start_time + + print('Output:', upscaled_video.shape) + + # save video + upscaled_video = (upscaled_video / 2 + 0.5).clamp(0, 1) * 255 + upscaled_video = upscaled_video.permute(0, 2, 3, 1).to(torch.uint8) + upscaled_video = upscaled_video.numpy().astype(np.uint8) + imageio.mimwrite(save_path, upscaled_video, fps=fps, quality=9) # Highest quality is 10, lowest is 0 + + print(f'Save upscaled video "{video_name}" in {save_path}, time (sec): {run_time} \n') + print(f'\nAll results are saved in {save_path}') + +if __name__ == '__main__': + + parser = argparse.ArgumentParser() + parser.add_argument("--config", type=str, default="") + args = parser.parse_args() + + main(OmegaConf.load(args.config))