diff --git a/LICENSE b/LICENSE
new file mode 100644
index 0000000000000000000000000000000000000000..261eeb9e9f8b2b4b0d119366dda99c6fd7d35c64
--- /dev/null
+++ b/LICENSE
@@ -0,0 +1,201 @@
+ Apache License
+ Version 2.0, January 2004
+ http://www.apache.org/licenses/
+
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
+
+ 1. Definitions.
+
+ "License" shall mean the terms and conditions for use, reproduction,
+ and distribution as defined by Sections 1 through 9 of this document.
+
+ "Licensor" shall mean the copyright owner or entity authorized by
+ the copyright owner that is granting the License.
+
+ "Legal Entity" shall mean the union of the acting entity and all
+ other entities that control, are controlled by, or are under common
+ control with that entity. For the purposes of this definition,
+ "control" means (i) the power, direct or indirect, to cause the
+ direction or management of such entity, whether by contract or
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
+ outstanding shares, or (iii) beneficial ownership of such entity.
+
+ "You" (or "Your") shall mean an individual or Legal Entity
+ exercising permissions granted by this License.
+
+ "Source" form shall mean the preferred form for making modifications,
+ including but not limited to software source code, documentation
+ source, and configuration files.
+
+ "Object" form shall mean any form resulting from mechanical
+ transformation or translation of a Source form, including but
+ not limited to compiled object code, generated documentation,
+ and conversions to other media types.
+
+ "Work" shall mean the work of authorship, whether in Source or
+ Object form, made available under the License, as indicated by a
+ copyright notice that is included in or attached to the work
+ (an example is provided in the Appendix below).
+
+ "Derivative Works" shall mean any work, whether in Source or Object
+ form, that is based on (or derived from) the Work and for which the
+ editorial revisions, annotations, elaborations, or other modifications
+ represent, as a whole, an original work of authorship. For the purposes
+ of this License, Derivative Works shall not include works that remain
+ separable from, or merely link (or bind by name) to the interfaces of,
+ the Work and Derivative Works thereof.
+
+ "Contribution" shall mean any work of authorship, including
+ the original version of the Work and any modifications or additions
+ to that Work or Derivative Works thereof, that is intentionally
+ submitted to Licensor for inclusion in the Work by the copyright owner
+ or by an individual or Legal Entity authorized to submit on behalf of
+ the copyright owner. For the purposes of this definition, "submitted"
+ means any form of electronic, verbal, or written communication sent
+ to the Licensor or its representatives, including but not limited to
+ communication on electronic mailing lists, source code control systems,
+ and issue tracking systems that are managed by, or on behalf of, the
+ Licensor for the purpose of discussing and improving the Work, but
+ excluding communication that is conspicuously marked or otherwise
+ designated in writing by the copyright owner as "Not a Contribution."
+
+ "Contributor" shall mean Licensor and any individual or Legal Entity
+ on behalf of whom a Contribution has been received by Licensor and
+ subsequently incorporated within the Work.
+
+ 2. Grant of Copyright License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ copyright license to reproduce, prepare Derivative Works of,
+ publicly display, publicly perform, sublicense, and distribute the
+ Work and such Derivative Works in Source or Object form.
+
+ 3. Grant of Patent License. Subject to the terms and conditions of
+ this License, each Contributor hereby grants to You a perpetual,
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
+ (except as stated in this section) patent license to make, have made,
+ use, offer to sell, sell, import, and otherwise transfer the Work,
+ where such license applies only to those patent claims licensable
+ by such Contributor that are necessarily infringed by their
+ Contribution(s) alone or by combination of their Contribution(s)
+ with the Work to which such Contribution(s) was submitted. If You
+ institute patent litigation against any entity (including a
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
+ or a Contribution incorporated within the Work constitutes direct
+ or contributory patent infringement, then any patent licenses
+ granted to You under this License for that Work shall terminate
+ as of the date such litigation is filed.
+
+ 4. Redistribution. You may reproduce and distribute copies of the
+ Work or Derivative Works thereof in any medium, with or without
+ modifications, and in Source or Object form, provided that You
+ meet the following conditions:
+
+ (a) You must give any other recipients of the Work or
+ Derivative Works a copy of this License; and
+
+ (b) You must cause any modified files to carry prominent notices
+ stating that You changed the files; and
+
+ (c) You must retain, in the Source form of any Derivative Works
+ that You distribute, all copyright, patent, trademark, and
+ attribution notices from the Source form of the Work,
+ excluding those notices that do not pertain to any part of
+ the Derivative Works; and
+
+ (d) If the Work includes a "NOTICE" text file as part of its
+ distribution, then any Derivative Works that You distribute must
+ include a readable copy of the attribution notices contained
+ within such NOTICE file, excluding those notices that do not
+ pertain to any part of the Derivative Works, in at least one
+ of the following places: within a NOTICE text file distributed
+ as part of the Derivative Works; within the Source form or
+ documentation, if provided along with the Derivative Works; or,
+ within a display generated by the Derivative Works, if and
+ wherever such third-party notices normally appear. The contents
+ of the NOTICE file are for informational purposes only and
+ do not modify the License. You may add Your own attribution
+ notices within Derivative Works that You distribute, alongside
+ or as an addendum to the NOTICE text from the Work, provided
+ that such additional attribution notices cannot be construed
+ as modifying the License.
+
+ You may add Your own copyright statement to Your modifications and
+ may provide additional or different license terms and conditions
+ for use, reproduction, or distribution of Your modifications, or
+ for any such Derivative Works as a whole, provided Your use,
+ reproduction, and distribution of the Work otherwise complies with
+ the conditions stated in this License.
+
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
+ any Contribution intentionally submitted for inclusion in the Work
+ by You to the Licensor shall be under the terms and conditions of
+ this License, without any additional terms or conditions.
+ Notwithstanding the above, nothing herein shall supersede or modify
+ the terms of any separate license agreement you may have executed
+ with Licensor regarding such Contributions.
+
+ 6. Trademarks. This License does not grant permission to use the trade
+ names, trademarks, service marks, or product names of the Licensor,
+ except as required for reasonable and customary use in describing the
+ origin of the Work and reproducing the content of the NOTICE file.
+
+ 7. Disclaimer of Warranty. Unless required by applicable law or
+ agreed to in writing, Licensor provides the Work (and each
+ Contributor provides its Contributions) on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
+ implied, including, without limitation, any warranties or conditions
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
+ PARTICULAR PURPOSE. You are solely responsible for determining the
+ appropriateness of using or redistributing the Work and assume any
+ risks associated with Your exercise of permissions under this License.
+
+ 8. Limitation of Liability. In no event and under no legal theory,
+ whether in tort (including negligence), contract, or otherwise,
+ unless required by applicable law (such as deliberate and grossly
+ negligent acts) or agreed to in writing, shall any Contributor be
+ liable to You for damages, including any direct, indirect, special,
+ incidental, or consequential damages of any character arising as a
+ result of this License or out of the use or inability to use the
+ Work (including but not limited to damages for loss of goodwill,
+ work stoppage, computer failure or malfunction, or any and all
+ other commercial damages or losses), even if such Contributor
+ has been advised of the possibility of such damages.
+
+ 9. Accepting Warranty or Additional Liability. While redistributing
+ the Work or Derivative Works thereof, You may choose to offer,
+ and charge a fee for, acceptance of support, warranty, indemnity,
+ or other liability obligations and/or rights consistent with this
+ License. However, in accepting such obligations, You may act only
+ on Your own behalf and on Your sole responsibility, not on behalf
+ of any other Contributor, and only if You agree to indemnify,
+ defend, and hold each Contributor harmless for any liability
+ incurred by, or claims asserted against, such Contributor by reason
+ of your accepting any such warranty or additional liability.
+
+ END OF TERMS AND CONDITIONS
+
+ APPENDIX: How to apply the Apache License to your work.
+
+ To apply the Apache License to your work, attach the following
+ boilerplate notice, with the fields enclosed by brackets "[]"
+ replaced with your own identifying information. (Don't include
+ the brackets!) The text should be enclosed in the appropriate
+ comment syntax for the file format. We also recommend that a
+ file or class name and description of purpose be included on the
+ same "printed page" as the copyright notice for easier
+ identification within third-party archives.
+
+ Copyright [yyyy] [name of copyright owner]
+
+ Licensed under the Apache License, Version 2.0 (the "License");
+ you may not use this file except in compliance with the License.
+ You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+ Unless required by applicable law or agreed to in writing, software
+ distributed under the License is distributed on an "AS IS" BASIS,
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ See the License for the specific language governing permissions and
+ limitations under the License.
diff --git a/README.md b/README.md
index d0ab6405a8623874dc3bbb367bf6aec8663f1e7b..5949ff012ba817317b6427759dcb7f0cfb55fb4e 100644
--- a/README.md
+++ b/README.md
@@ -1,12 +1,86 @@
----
-title: LaVie
-emoji: π
-colorFrom: gray
-colorTo: green
-sdk: gradio
-sdk_version: 4.7.1
-app_file: app.py
-pinned: false
----
-
-Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
+# LaVie: High-Quality Video Generation with Cascaded Latent Diffusion Models
+
+This repository is the official PyTorch implementation of [LaVie](https://arxiv.org/abs/2309.15103).
+
+**LaVie** is a Text-to-Video (T2V) generation framework, and main part of video generation system [Vchitect](http://vchitect.intern-ai.org.cn/).
+
+[![arXiv](https://img.shields.io/badge/arXiv-2307.04725-b31b1b.svg)](https://arxiv.org/abs/2309.15103)
+[![Project Page](https://img.shields.io/badge/Project-Website-green)](https://vchitect.github.io/LaVie-project/)
+
+
+
+
+## Installation
+```
+conda env create -f environment.yml
+conda activate lavie
+```
+
+## Download Pre-Trained models
+Download [pre-trained models](https://huggingface.co/YaohuiW/LaVie/tree/main), [stable diffusion 1.4](https://huggingface.co/CompVis/stable-diffusion-v1-4/tree/main), [stable-diffusion-x4-upscaler](https://huggingface.co/stabilityai/stable-diffusion-x4-upscaler/tree/main) to `./pretrained_models`. You should be able to see the following:
+```
+βββ pretrained_models
+β βββ lavie_base.pt
+β βββ lavie_interpolation.pt
+β βββ lavie_vsr.pt
+β βββ stable-diffusion-v1-4
+β β βββ ...
+βββ βββ stable-diffusion-x4-upscaler
+ βββ ...
+```
+
+## Inference
+The inference contains **Base T2V**, **Video Interpolation** and **Video Super-Resolution** three steps. We provide several options to generate videos:
+* **Step1**: 320 x 512 resolution, 16 frames
+* **Step1+Step2**: 320 x 512 resolution, 61 frames
+* **Step1+Step3**: 1280 x 2048 resolution, 16 frames
+* **Step1+Step2+Step3**: 1280 x 2048 resolution, 61 frames
+
+Feel free to try different options:)
+
+
+### Step1. Base T2V
+Run following command to generate videos from base T2V model.
+```
+cd base
+python pipelines/sample.py --config configs/sample.yaml
+```
+Edit `text_prompt` in `configs/sample.yaml` to change prompt, results will be saved under `./res/base`.
+
+### Step2 (optional). Video Interpolation
+Run following command to conduct video interpolation.
+```
+cd interpolation
+python sample.py --config configs/sample.yaml
+```
+The default input video path is `./res/base`, results will be saved under `./res/interpolation`. In `configs/sample.yaml`, you could modify default `input_folder` with `YOUR_INPUT_FOLDER` in `configs/sample.yaml`. Input videos should be named as `prompt1.mp4`, `prompt2.mp4`, ... and put under `YOUR_INPUT_FOLDER`. Launching the code will process all the input videos in `input_folder`.
+
+
+### Step3 (optional). Video Super-Resolution
+Run following command to conduct video super-resolution.
+```
+cd vsr
+python sample.py --config configs/sample.yaml
+```
+The default input video path is `./res/base` and results will be saved under `./res/vsr`. You could modify default `input_path` with `YOUR_INPUT_FOLDER` in `configs/sample.yaml`. Smiliar to Step2, input videos should be named as `prompt1.mp4`, `prompt2.mp4`, ... and put under `YOUR_INPUT_FOLDER`. Launching the code will process all the input videos in `input_folder`.
+
+
+## BibTex
+```bibtex
+@article{wang2023lavie,
+ title={LAVIE: High-Quality Video Generation with Cascaded Latent Diffusion Models},
+ author={Wang, Yaohui and Chen, Xinyuan and Ma, Xin and Zhou, Shangchen and Huang, Ziqi and Wang, Yi and Yang, Ceyuan and He, Yinan and Yu, Jiashuo and Yang, Peiqing and others},
+ journal={arXiv preprint arXiv:2309.15103},
+ year={2023}
+}
+```
+
+## Acknowledgements
+The code is buit upon [diffusers](https://github.com/huggingface/diffusers) and [Stable Diffusion](https://github.com/CompVis/stable-diffusion), we thank all the contributors for open-sourcing.
+
+
+## License
+The code is licensed under Apache-2.0, model weights are fully open for academic research and also allow **free** commercial usage. To apply for a commercial license, please fill in the [application form]().
diff --git a/base/__pycache__/download.cpython-311.pyc b/base/__pycache__/download.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..8f9de5e7eb4b7a4bfd816fc575ea4888229239de
Binary files /dev/null and b/base/__pycache__/download.cpython-311.pyc differ
diff --git a/base/app.py b/base/app.py
new file mode 100644
index 0000000000000000000000000000000000000000..14c66df82b058e77c83c9cd1106e5df64aa939c0
--- /dev/null
+++ b/base/app.py
@@ -0,0 +1,175 @@
+import gradio as gr
+from text_to_video import model_t2v_fun,setup_seed
+from omegaconf import OmegaConf
+import torch
+import imageio
+import os
+import cv2
+import pandas as pd
+import torchvision
+import random
+config_path = "/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml"
+args = OmegaConf.load("/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml")
+device = "cuda" if torch.cuda.is_available() else "cpu"
+# ------- get model ---------------
+model_t2V = model_t2v_fun(args)
+model_t2V.to(device)
+if device == "cuda":
+ model_t2V.enable_xformers_memory_efficient_attention()
+
+# model_t2V.enable_xformers_memory_efficient_attention()
+css = """
+h1 {
+ text-align: center;
+}
+#component-0 {
+ max-width: 730px;
+ margin: auto;
+}
+"""
+
+def infer(prompt, seed_inp, ddim_steps,cfg):
+ if seed_inp!=-1:
+ setup_seed(seed_inp)
+ else:
+ seed_inp = random.choice(range(10000000))
+ setup_seed(seed_inp)
+ videos = model_t2V(prompt, video_length=16, height = 320, width= 512, num_inference_steps=ddim_steps, guidance_scale=cfg).video
+ print(videos[0].shape)
+ if not os.path.exists(args.output_folder):
+ os.mkdir(args.output_folder)
+ torchvision.io.write_video(args.output_folder + prompt[0:30].replace(' ', '_') + '-'+str(seed_inp)+'-'+str(ddim_steps)+'-'+str(cfg)+ '-.mp4', videos[0], fps=8)
+ # imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8)
+ # video = cv2.VideoCapture(args.output_folder + prompt.replace(' ', '_') + '.mp4')
+ # video = imageio.get_reader(args.output_folder + prompt.replace(' ', '_') + '.mp4', 'ffmpeg')
+
+
+ # video = model_t2V(prompt, seed_inp, ddim_steps)
+
+ return args.output_folder + prompt[0:30].replace(' ', '_') + '-'+str(seed_inp)+'-'+str(ddim_steps)+'-'+str(cfg)+ '-.mp4'
+
+print(1)
+
+# def clean():
+# return gr.Image.update(value=None, visible=False), gr.Video.update(value=None)
+def clean():
+ return gr.Video.update(value=None)
+
+title = """
+
+
+
+ InternΒ·Vchitect (Text-to-Video)
+
+
+
+ Apply InternΒ·Vchitect to generate a video
+
+
+"""
+
+# print(1)
+with gr.Blocks(css='style.css') as demo:
+ gr.Markdown("LaVie: Text-to-Video generation")
+ with gr.Column():
+ with gr.Row(elem_id="col-container"):
+ # inputs = [prompt, seed_inp, ddim_steps]
+ # outputs = [video_out]
+ with gr.Column():
+
+ prompt = gr.Textbox(value="a teddy bear walking on the street", label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in", min_width=200, lines=2)
+
+ ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=50, step=1)
+ seed_inp = gr.Slider(value=-1,label="seed (for random generation, use -1)",show_label=True,minimum=-1,maximum=2147483647)
+ cfg = gr.Number(label="guidance_scale",value=7)
+ # seed_inp = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=400, elem_id="seed-in")
+
+ # with gr.Row():
+ # # control_task = gr.Dropdown(label="Task", choices=["Text-2-video", "Image-2-video"], value="Text-2-video", multiselect=False, elem_id="controltask-in")
+ # ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1)
+ # seed_inp = gr.Slider(label="Seed", minimum=0, maximum=2147483647, step=1, value=123456, elem_id="seed-in")
+
+ # ddim_steps = gr.Slider(label='Steps', minimum=50, maximum=300, value=250, step=1)
+ # ex = gr.Examples(
+ # examples = [['a corgi walking in the park at sunrise, oil painting style',400,50,7],
+ # ['a cut teddy bear reading a book in the park, oil painting style, high quality',700,50,7],
+ # ['an epic tornado attacking above a glowing city at night, the tornado is made of smoke, highly detailed',230,50,7],
+ # ['a jar filled with fire, 4K video, 3D rendered, well-rendered',400,50,7],
+ # ['a teddy bear walking in the park, oil painting style, high quality',400,50,7],
+ # ['a teddy bear walking on the street, 2k, high quality',100,50,7],
+ # ['a panda taking a selfie, 2k, high quality',400,50,7],
+ # ['a polar bear playing drum kit in NYC Times Square, 4k, high resolution',400,50,7],
+ # ['jungle river at sunset, ultra quality',400,50,7],
+ # ['a shark swimming in clear Carribean ocean, 2k, high quality',400,50,7],
+ # ['A steam train moving on a mountainside by Vincent van Gogh',230,50,7],
+ # ['a confused grizzly bear in calculus class',1000,50,7]],
+ # fn = infer,
+ # inputs=[prompt, seed_inp, ddim_steps,cfg],
+ # # outputs=[video_out],
+ # cache_examples=False,
+ # examples_per_page = 6
+ # )
+ # ex.dataset.headers = [""]
+
+ with gr.Column():
+ submit_btn = gr.Button("Generate video")
+ clean_btn = gr.Button("Clean video")
+ # submit_btn = gr.Button("Generate video", size='sm')
+ # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512)
+ video_out = gr.Video(label="Video result", elem_id="video-output")
+ # with gr.Row():
+ # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512)
+ # submit_btn = gr.Button("Generate video", size='sm')
+
+
+ # video_out = gr.Video(label="Video result", elem_id="video-output", height=320, width=512)
+ inputs = [prompt, seed_inp, ddim_steps,cfg]
+ outputs = [video_out]
+ # gr.Examples(
+ # value = [['An astronaut riding a horse',123,50],
+ # ['a panda eating bamboo on a rock',123,50],
+ # ['Spiderman is surfing',123,50]],
+ # label = "example of sampling",
+ # show_label = True,
+ # headers = ['prompt','seed','steps'],
+ # datatype = ['str','number','number'],
+ # row_count=4,
+ # col_count=(3,"fixed")
+ # )
+ ex = gr.Examples(
+ examples = [['a corgi walking in the park at sunrise, oil painting style',400,50,7],
+ ['a cut teddy bear reading a book in the park, oil painting style, high quality',700,50,7],
+ ['an epic tornado attacking above a glowing city at night, the tornado is made of smoke, highly detailed',230,50,7],
+ ['a jar filled with fire, 4K video, 3D rendered, well-rendered',400,50,7],
+ ['a teddy bear walking in the park, oil painting style, high quality',400,50,7],
+ ['a teddy bear walking on the street, 2k, high quality',100,50,7],
+ ['a panda taking a selfie, 2k, high quality',400,50,7],
+ ['a polar bear playing drum kit in NYC Times Square, 4k, high resolution',400,50,7],
+ ['jungle river at sunset, ultra quality',400,50,7],
+ ['a shark swimming in clear Carribean ocean, 2k, high quality',400,50,7],
+ ['A steam train moving on a mountainside by Vincent van Gogh',230,50,7],
+ ['a confused grizzly bear in calculus class',1000,50,7]],
+ fn = infer,
+ inputs=[prompt, seed_inp, ddim_steps,cfg],
+ outputs=[video_out],
+ cache_examples=True,
+ )
+ ex.dataset.headers = [""]
+
+ # control_task.change(change_task_options, inputs=[control_task], outputs=[canny_opt, hough_opt, normal_opt], queue=False)
+ # submit_btn.click(clean, inputs=[], outputs=[video_out], queue=False)
+ clean_btn.click(clean, inputs=[], outputs=[video_out], queue=False)
+ submit_btn.click(infer, inputs, outputs)
+ # share_button.click(None, [], [], _js=share_js)
+
+ print(2)
+demo.queue(max_size=12).launch(server_name="0.0.0.0", server_port=7860)
+
+
diff --git a/base/app.sh b/base/app.sh
new file mode 100644
index 0000000000000000000000000000000000000000..0d9381ee026325b890a58ee74289374f94e36b30
--- /dev/null
+++ b/base/app.sh
@@ -0,0 +1 @@
+srun -p aigc-video --gres=gpu:1 -n1 -N1 python app.py
\ No newline at end of file
diff --git a/base/configs/sample.yaml b/base/configs/sample.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..0fbf8cdda3070e17160bb51491114f837ecf2392
--- /dev/null
+++ b/base/configs/sample.yaml
@@ -0,0 +1,28 @@
+# path:
+output_folder: "/mnt/petrelfs/share_data/zhouyan/gradio/lavie"
+pretrained_path: "/mnt/petrelfs/zhouyan/models"
+
+# model config:
+model: UNet
+video_length: 16
+image_size: [320, 512]
+
+# beta schedule
+beta_start: 0.0001
+beta_end: 0.02
+beta_schedule: "linear"
+
+# model speedup
+use_compile: False
+use_fp16: True
+
+# sample config:
+seed: 3
+run_time: 0
+guidance_scale: 7.0
+sample_method: 'ddpm'
+num_sampling_steps: 250
+text_prompt: [
+ 'a teddy bear walking on the street, high quality, 2k',
+
+ ]
diff --git a/base/download.py b/base/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..26de159d098d2086a35f1504477eb5c01a35f540
--- /dev/null
+++ b/base/download.py
@@ -0,0 +1,18 @@
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import os
+
+
+def find_model(model_name):
+ """
+ Finds a pre-trained model, downloading it if necessary. Alternatively, loads a model from a local path.
+ """
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
+ if "ema" in checkpoint: # supports checkpoints from train.py
+ print('Ema existing!')
+ checkpoint = checkpoint["ema"]
+ return checkpoint
diff --git a/base/gradio_cached_examples/14/Video result/35727d2ebeb816c94d68/laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/35727d2ebeb816c94d68/laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2a75ce05ceb622035785710d5b4e161875326613
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/35727d2ebeb816c94d68/laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/Video result/42b8a418d77480fcc8fc/laviea_panda_taking_a_selfie_2k_h-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/42b8a418d77480fcc8fc/laviea_panda_taking_a_selfie_2k_h-400-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..d21b92f22c837fa2011f50a0e9a71774727f6a22
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/42b8a418d77480fcc8fc/laviea_panda_taking_a_selfie_2k_h-400-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/Video result/4b54d32b7e8f2a3cd333/lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/4b54d32b7e8f2a3cd333/lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..6e29f47c91bc1fcfe35e2e068fafc368a81c01b4
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/4b54d32b7e8f2a3cd333/lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/Video result/625fb799abdbcb60fe2f/laviea_corgi_walking_in_the_park_at-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/625fb799abdbcb60fe2f/laviea_corgi_walking_in_the_park_at-400-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..18bb707b269f4ce942033d55ecacec385fe1da33
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/625fb799abdbcb60fe2f/laviea_corgi_walking_in_the_park_at-400-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/Video result/75b535bd2f78c28d2789/laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/75b535bd2f78c28d2789/laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2bf26ef94098567f37c8bfdbecc17e26eec37a99
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/75b535bd2f78c28d2789/laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/Video result/767a0718deb1983b3d43/laviean_epic_tornado_attacking_abov-230-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/767a0718deb1983b3d43/laviean_epic_tornado_attacking_abov-230-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2fdbde0250b7511f0ec1e2b60bba38cd2bccc673
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/767a0718deb1983b3d43/laviean_epic_tornado_attacking_abov-230-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/Video result/a15ccccc7c42e18cd062/laviea_jar_filled_with_fire_4K_vid-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/a15ccccc7c42e18cd062/laviea_jar_filled_with_fire_4K_vid-400-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..2518cdc98689cbc489213f220c41f873e29d221b
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/a15ccccc7c42e18cd062/laviea_jar_filled_with_fire_4K_vid-400-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/Video result/beedae14fa3a8f24e4ec/laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/beedae14fa3a8f24e4ec/laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..ae6ef5c9c0ff955fc82ca1a1b5c6d3489b93f7d5
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/beedae14fa3a8f24e4ec/laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/Video result/c2e7acb8ce5cb0a52899/laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/c2e7acb8ce5cb0a52899/laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..81575e7cbeb1ede1bd9741bf91b8648721c3273c
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/c2e7acb8ce5cb0a52899/laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/Video result/cecd7ff29690b876a418/laviejungle_river_at_sunset_ultra_-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/cecd7ff29690b876a418/laviejungle_river_at_sunset_ultra_-400-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..d6754f2846e6b66f5109079a4d78c6e560a15d2c
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/cecd7ff29690b876a418/laviejungle_river_at_sunset_ultra_-400-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/Video result/e67b3c12db1c38afd2c4/laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/e67b3c12db1c38afd2c4/laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..88de28546db43a883e9cd2ed10c1ae58d1aaf402
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/e67b3c12db1c38afd2c4/laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/Video result/feeee8981f36b962bfe6/laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4 b/base/gradio_cached_examples/14/Video result/feeee8981f36b962bfe6/laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4
new file mode 100644
index 0000000000000000000000000000000000000000..68ad51852fe928376b0bb47ad00df422caf443a6
Binary files /dev/null and b/base/gradio_cached_examples/14/Video result/feeee8981f36b962bfe6/laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4 differ
diff --git a/base/gradio_cached_examples/14/log.csv b/base/gradio_cached_examples/14/log.csv
new file mode 100644
index 0000000000000000000000000000000000000000..6bb703837fe42ccdc80fb5802c2304c5d790c286
--- /dev/null
+++ b/base/gradio_cached_examples/14/log.csv
@@ -0,0 +1,13 @@
+Video result,flag,username,timestamp
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/625fb799abdbcb60fe2f/laviea_corgi_walking_in_the_park_at-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_corgi_walking_in_the_park_at-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:13.139609
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/feeee8981f36b962bfe6/laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_cut_teddy_bear_reading_a_boo-700-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:23.543257
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/767a0718deb1983b3d43/laviean_epic_tornado_attacking_abov-230-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviean_epic_tornado_attacking_abov-230-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:33.942899
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/a15ccccc7c42e18cd062/laviea_jar_filled_with_fire_4K_vid-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_jar_filled_with_fire,_4K_vid-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:44.348969
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/c2e7acb8ce5cb0a52899/laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_teddy_bear_walking_in_the_pa-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:01:54.765554
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/75b535bd2f78c28d2789/laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_teddy_bear_walking_on_the_st-100-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:05.255612
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/42b8a418d77480fcc8fc/laviea_panda_taking_a_selfie_2k_h-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_panda_taking_a_selfie,_2k,_h-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:15.694357
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/e67b3c12db1c38afd2c4/laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_polar_bear_playing_drum_kit_-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:26.121546
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/cecd7ff29690b876a418/laviejungle_river_at_sunset_ultra_-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviejungle_river_at_sunset,_ultra_-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:36.540682
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/beedae14fa3a8f24e4ec/laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_shark_swimming_in_clear_Carr-400-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:46.992686
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/4b54d32b7e8f2a3cd333/lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""lavieA_steam_train_moving_on_a_moun-230-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:02:57.458758
+"{""video"":{""path"":""gradio_cached_examples/14/Video result/35727d2ebeb816c94d68/laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4"",""url"":null,""size"":null,""orig_name"":""laviea_confused_grizzly_bear_in_cal-1000-50-7-.mp4"",""mime_type"":null},""subtitles"":null}",,,2023-11-27 14:03:07.878403
diff --git a/base/huggingface-t2v/.DS_Store b/base/huggingface-t2v/.DS_Store
new file mode 100644
index 0000000000000000000000000000000000000000..ec6bded75c322a3b51f99901c1eb510901d2bd75
Binary files /dev/null and b/base/huggingface-t2v/.DS_Store differ
diff --git a/base/huggingface-t2v/__init__.py b/base/huggingface-t2v/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/base/huggingface-t2v/requirements.txt b/base/huggingface-t2v/requirements.txt
new file mode 100644
index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391
diff --git a/base/models/__init__.py b/base/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..f93daa1bf5131bd1251a44e99aeeb0d127ea71f1
--- /dev/null
+++ b/base/models/__init__.py
@@ -0,0 +1,33 @@
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+from .unet import UNet3DConditionModel
+from torch.optim.lr_scheduler import LambdaLR
+
+def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
+ from torch.optim.lr_scheduler import LambdaLR
+ def fn(step):
+ if warmup_steps > 0:
+ return min(step / warmup_steps, 1)
+ else:
+ return 1
+ return LambdaLR(optimizer, fn)
+
+
+def get_lr_scheduler(optimizer, name, **kwargs):
+ if name == 'warmup':
+ return customized_lr_scheduler(optimizer, **kwargs)
+ elif name == 'cosine':
+ from torch.optim.lr_scheduler import CosineAnnealingLR
+ return CosineAnnealingLR(optimizer, **kwargs)
+ else:
+ raise NotImplementedError(name)
+
+def get_models(args, sd_path):
+
+ if 'UNet' in args.model:
+ return UNet3DConditionModel.from_pretrained_2d(sd_path, subfolder="unet")
+ else:
+ raise '{} Model Not Supported!'.format(args.model)
+
diff --git a/base/models/__pycache__/__init__.cpython-311.pyc b/base/models/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..208ad63a8430425ee32dad7d7738d750a294fd8b
Binary files /dev/null and b/base/models/__pycache__/__init__.cpython-311.pyc differ
diff --git a/base/models/__pycache__/attention.cpython-311.pyc b/base/models/__pycache__/attention.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..5b4c56a663f3de4d5a2b04fffaf188ce0167437a
Binary files /dev/null and b/base/models/__pycache__/attention.cpython-311.pyc differ
diff --git a/base/models/__pycache__/resnet.cpython-311.pyc b/base/models/__pycache__/resnet.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..80e4daa3a256d39ec31cc468afc5f2b0fbd374ea
Binary files /dev/null and b/base/models/__pycache__/resnet.cpython-311.pyc differ
diff --git a/base/models/__pycache__/unet.cpython-311.pyc b/base/models/__pycache__/unet.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..454d8468602a267fe00425b659445343a76233dc
Binary files /dev/null and b/base/models/__pycache__/unet.cpython-311.pyc differ
diff --git a/base/models/__pycache__/unet_blocks.cpython-311.pyc b/base/models/__pycache__/unet_blocks.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..a47de4825f787c06109b504f0312ca58af82a483
Binary files /dev/null and b/base/models/__pycache__/unet_blocks.cpython-311.pyc differ
diff --git a/base/models/attention.py b/base/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..fb5b0ff4a7fb9a8d667a9b502017208293ab2609
--- /dev/null
+++ b/base/models/attention.py
@@ -0,0 +1,707 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+from dataclasses import dataclass
+from typing import Optional
+
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention import FeedForward, AdaLayerNorm
+from rotary_embedding_torch import RotaryEmbedding
+from typing import Callable, Optional
+from einops import rearrange, repeat
+
+try:
+ from diffusers.models.modeling_utils import ModelMixin
+except:
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
+
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+def exists(x):
+ return x is not None
+
+
+class CrossAttention(nn.Module):
+ r"""
+ copy from diffuser 0.11.1
+ A cross attention layer.
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ use_relative_position: bool = False,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ self.dim_head = dim_head
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = False
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ self.use_relative_position = use_relative_position
+ if self.use_relative_position:
+ self.rotary_emb = RotaryEmbedding(min(32, dim_head))
+
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def reshape_for_scores(self, tensor):
+ # split heads and dims
+ # tensor should be [b (h w)] f (d nd)
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ return tensor
+
+ def same_batch_dim_to_heads(self, tensor):
+ batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
+
+ # print('before reshpape query shape', query.shape)
+ dim = query.shape[-1]
+ if not self.use_relative_position:
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
+ # print('after reshape query shape', query.shape)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ if not self.use_relative_position:
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ # TODO attention_mask
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class Transformer3DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ rotary_emb: bool = None,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # Define input layers
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ # Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if use_linear_projection:
+ self.proj_out = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, use_image_num=None, return_dict: bool = True):
+ # Input
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ video_length=video_length,
+ use_image_num=use_image_num,
+ )
+
+ # Output
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous()
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ rotary_emb: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+ self.use_first_frame = use_first_frame
+
+ # Spatial-Attn
+ self.attn1 = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ upcast_attention=upcast_attention,
+ )
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ # Text Cross-Attn
+ if cross_attention_dim is not None:
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ else:
+ self.attn2 = None
+
+ if cross_attention_dim is not None:
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ else:
+ self.norm2 = None
+
+ # Temp
+ self.attn_temp = TemporalAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ upcast_attention=upcast_attention,
+ rotary_emb=rotary_emb,
+ )
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
+
+
+ # Feed-forward
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.norm3 = nn.LayerNorm(dim)
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, op=None):
+
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ if self.attn2 is not None:
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None, use_image_num=None):
+ # SparseCausal-Attention
+ norm_hidden_states = (
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ )
+
+ if self.only_cross_attention:
+ hidden_states = (
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
+ )
+ else:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num) + hidden_states
+
+ if self.attn2 is not None:
+ # Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ hidden_states = (
+ self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ + hidden_states
+ )
+
+ # Temporal Attention
+ if self.training:
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
+ hidden_states_video = hidden_states[:, :video_length, :]
+ hidden_states_image = hidden_states[:, video_length:, :]
+ norm_hidden_states_video = (
+ self.norm_temp(hidden_states_video, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states_video)
+ )
+ hidden_states_video = self.attn_temp(norm_hidden_states_video) + hidden_states_video
+ hidden_states = torch.cat([hidden_states_video, hidden_states_image], dim=1)
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
+ else:
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length + use_image_num).contiguous()
+ norm_hidden_states = (
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
+ )
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
+
+ # Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ return hidden_states
+
+class TemporalAttention(CrossAttention):
+ def __init__(self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ rotary_emb=None):
+ super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
+ # relative time positional embeddings
+ self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
+ self.rotary_emb = rotary_emb
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
+ dim = query.shape[-1]
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+ def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ # reshape for adding time positional bais
+ query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+ key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+ value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+
+ if exists(self.rotary_emb):
+ query = self.rotary_emb.rotate_queries_or_keys(query)
+ key = self.rotary_emb.rotate_queries_or_keys(key)
+
+ attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
+
+ attention_scores = attention_scores + time_rel_pos_bias
+
+ if attention_mask is not None:
+ # add attention mask
+ attention_scores = attention_scores + attention_mask
+
+ attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
+
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+ # print(attention_probs[0][0])
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
+ hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
+ return hidden_states
+
+class RelativePositionBias(nn.Module):
+ def __init__(
+ self,
+ heads=8,
+ num_buckets=32,
+ max_distance=128,
+ ):
+ super().__init__()
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
+ ret = 0
+ n = -relative_position
+
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).long()
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, n, device):
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
+ values = self.relative_attention_bias(rp_bucket)
+ return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
\ No newline at end of file
diff --git a/base/models/clip.py b/base/models/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..2879f919e9081a373582e4734cded621bab8245d
--- /dev/null
+++ b/base/models/clip.py
@@ -0,0 +1,120 @@
+import numpy
+import torch.nn as nn
+from transformers import CLIPTokenizer, CLIPTextModel
+
+import transformers
+transformers.logging.set_verbosity_error()
+
+"""
+Will encounter following warning:
+- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
+or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
+- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
+that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
+
+https://github.com/CompVis/stable-diffusion/issues/97
+according to this issue, this warning is safe.
+
+This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
+You can safely ignore the warning, it is not an error.
+
+This clip usage is from U-ViT and same with Stable Diffusion.
+"""
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ # def __init__(self, version="openai/clip-vit-huge-patch14", device="cuda", max_length=77):
+ def __init__(self, path, device="cuda", max_length=77):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(path, subfolder="tokenizer")
+ self.transformer = CLIPTextModel.from_pretrained(path, subfolder='text_encoder')
+ self.device = device
+ self.max_length = max_length
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class TextEmbedder(nn.Module):
+ """
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
+ """
+ def __init__(self, path, dropout_prob=0.1):
+ super().__init__()
+ self.text_encodder = FrozenCLIPEmbedder(path=path)
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, text_prompts, force_drop_ids=None):
+ """
+ Drops text to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
+ else:
+ # TODO
+ drop_ids = force_drop_ids == 1
+ labels = list(numpy.where(drop_ids, "", text_prompts))
+ # print(labels)
+ return labels
+
+ def forward(self, text_prompts, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
+ embeddings = self.text_encodder(text_prompts)
+ return embeddings
+
+
+if __name__ == '__main__':
+
+ r"""
+ Returns:
+
+ Examples from CLIPTextModel:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CLIPTextModel
+
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ import torch
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ text_encoder = TextEmbedder(path='/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-2-1-base',
+ dropout_prob=0.00001).to(device)
+
+ text_prompt = [["a photo of a cat", "a photo of a cat"], ["a photo of a dog", "a photo of a cat"], ['a photo of a dog human', "a photo of a cat"]]
+ output = text_encoder(text_prompts=text_prompt, train=False)
+ print(output.shape)
diff --git a/base/models/resnet.py b/base/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1307ad9a0fe26b6b1eca913b024c188df2dd86e
--- /dev/null
+++ b/base/models/resnet.py
@@ -0,0 +1,212 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange
+
+
+class InflatedConv3d(nn.Conv2d):
+ def forward(self, x):
+ video_length = x.shape[2]
+
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ x = super().forward(x)
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+ return x
+
+
+class Upsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ raise NotImplementedError
+ elif use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
+
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, hidden_states, output_size=None):
+ assert hidden_states.shape[1] == self.channels
+
+ if self.use_conv_transpose:
+ raise NotImplementedError
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+
+class Downsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ raise NotImplementedError
+
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states):
+ assert hidden_states.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ raise NotImplementedError
+
+ assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ time_emb_proj_out_channels = out_channels
+ elif self.time_embedding_norm == "scale_shift":
+ time_emb_proj_out_channels = out_channels * 2
+ else:
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+class Mish(torch.nn.Module):
+ def forward(self, hidden_states):
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
\ No newline at end of file
diff --git a/base/models/temporal_attention.py b/base/models/temporal_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..6c417bc9e98409c22297a4878da689e38d188a24
--- /dev/null
+++ b/base/models/temporal_attention.py
@@ -0,0 +1,388 @@
+import torch
+from torch import nn
+from typing import Optional
+from rotary_embedding_torch import RotaryEmbedding
+from dataclasses import dataclass
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+import torch.nn.functional as F
+from einops import rearrange, repeat
+import math
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+def exists(x):
+ return x is not None
+
+class CrossAttention(nn.Module):
+ r"""
+ copy from diffuser 0.11.1
+ A cross attention layer.
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ use_relative_position: bool = False,
+ ):
+ super().__init__()
+ # print('num head', heads)
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ self.dim_head = dim_head
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = False # No use xformers for temporal attention
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def reshape_for_scores(self, tensor):
+ # split heads and dims
+ # tensor should be [b (h w)] f (d nd)
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ return tensor
+
+ def same_batch_dim_to_heads(self, tensor):
+ batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
+
+ # print('before reshpape query shape', query.shape)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
+ # print('after reshape query shape', query.shape)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ hidden_states = self._attention(query, key, value, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ attention_probs = attention_probs.to(value.dtype)
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ # TODO attention_mask
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+class TemporalAttention(CrossAttention):
+ def __init__(self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ rotary_emb=None):
+ super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
+ # relative time positional embeddings
+ self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
+ self.rotary_emb = rotary_emb
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
+ dim = query.shape[-1]
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+
+ def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+ key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+ value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+ if exists(self.rotary_emb):
+ query = self.rotary_emb.rotate_queries_or_keys(query)
+ key = self.rotary_emb.rotate_queries_or_keys(key)
+
+ attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
+ attention_scores = attention_scores + time_rel_pos_bias
+
+ if attention_mask is not None:
+ # add attention mask
+ attention_scores = attention_scores + attention_mask
+
+ attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ attention_probs = attention_probs.to(value.dtype)
+ hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
+ hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
+ return hidden_states
+
+class RelativePositionBias(nn.Module):
+ def __init__(
+ self,
+ heads=8,
+ num_buckets=32,
+ max_distance=128,
+ ):
+ super().__init__()
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
+ ret = 0
+ n = -relative_position
+
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).long()
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, n, device):
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
+ values = self.relative_attention_bias(rp_bucket)
+ return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
diff --git a/base/models/transformer_3d.py b/base/models/transformer_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4b0aba347ea5408955d4646737306c8e945bf16
--- /dev/null
+++ b/base/models/transformer_3d.py
@@ -0,0 +1,367 @@
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Any, Dict, Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.embeddings import ImagePositionalEmbeddings
+from diffusers.utils import BaseOutput, deprecate
+from diffusers.models.embeddings import PatchEmbed
+from diffusers.models.lora import LoRACompatibleConv, LoRACompatibleLinear
+from diffusers.models.modeling_utils import ModelMixin
+from einops import rearrange, repeat
+
+try:
+ from attention import BasicTransformerBlock
+except:
+ from .attention import BasicTransformerBlock
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ """
+ The output of [`Transformer2DModel`].
+
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ The hidden states output conditioned on the `encoder_hidden_states` input. If discrete, returns probability
+ distributions for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+class Transformer3DModel(ModelMixin, ConfigMixin):
+ """
+ A 2D Transformer model for image-like data.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ The number of channels in the input and output (specify if the input is **continuous**).
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of `encoder_hidden_states` dimensions to use.
+ sample_size (`int`, *optional*): The width of the latent images (specify if the input is **discrete**).
+ This is fixed during training since it is used to learn a number of position embeddings.
+ num_vector_embeds (`int`, *optional*):
+ The number of classes of the vector embeddings of the latent pixels (specify if the input is **discrete**).
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to use in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*):
+ The number of diffusion steps used during training. Pass if at least one of the norm_layers is
+ `AdaLayerNorm`. This is fixed during training since it is used to learn a number of embeddings that are
+ added to the hidden states.
+
+ During inference, you can denoise for up to but not more steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the `TransformerBlocks` attention should contain a bias parameter.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ out_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ patch_size: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ norm_type: str = "layer_norm",
+ norm_elementwise_affine: bool = True,
+ rotary_emb=None,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Transformer2DModel can process both standard continuous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = (in_channels is not None) and (patch_size is None)
+ self.is_input_vectorized = num_vector_embeds is not None
+ self.is_input_patches = in_channels is not None and patch_size is not None
+
+ if norm_type == "layer_norm" and num_embeds_ada_norm is not None:
+ deprecation_message = (
+ f"The configuration file of this model: {self.__class__} is outdated. `norm_type` is either not set or"
+ " incorrectly set to `'layer_norm'`.Make sure to set `norm_type` to `'ada_norm'` in the config."
+ " Please make sure to update the config accordingly as leaving `norm_type` might led to incorrect"
+ " results in future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it"
+ " would be very nice if you could open a Pull request for the `transformer/config.json` file"
+ )
+ deprecate("norm_type!=num_embeds_ada_norm", "1.0.0", deprecation_message, standard_warn=False)
+ norm_type = "ada_norm"
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif self.is_input_vectorized and self.is_input_patches:
+ raise ValueError(
+ f"Cannot define both `num_vector_embeds`: {num_vector_embeds} and `patch_size`: {patch_size}. Make"
+ " sure that either `num_vector_embeds` or `num_patches` is None."
+ )
+ elif not self.is_input_continuous and not self.is_input_vectorized and not self.is_input_patches:
+ raise ValueError(
+ f"Has to define `in_channels`: {in_channels}, `num_vector_embeds`: {num_vector_embeds}, or patch_size:"
+ f" {patch_size}. Make sure that `in_channels`, `num_vector_embeds` or `num_patches` is not None."
+ )
+
+ # 2. Define input layers
+ if self.is_input_continuous:
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = LoRACompatibleLinear(in_channels, inner_dim)
+ else:
+ self.proj_in = LoRACompatibleConv(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = sample_size
+ self.width = sample_size
+ self.num_vector_embeds = num_vector_embeds
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
+ )
+ elif self.is_input_patches:
+ assert sample_size is not None, "Transformer2DModel over patched input must provide sample_size"
+
+ self.height = sample_size
+ self.width = sample_size
+
+ self.patch_size = patch_size
+ self.pos_embed = PatchEmbed(
+ height=sample_size,
+ width=sample_size,
+ patch_size=patch_size,
+ in_channels=in_channels,
+ embed_dim=inner_dim,
+ )
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ norm_type=norm_type,
+ norm_elementwise_affine=norm_elementwise_affine,
+ rotary_emb=rotary_emb,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ self.out_channels = in_channels if out_channels is None else out_channels
+ if self.is_input_continuous:
+ # TODO: should use out_channels for continuous projections
+ if use_linear_projection:
+ self.proj_out = LoRACompatibleLinear(inner_dim, in_channels)
+ else:
+ self.proj_out = LoRACompatibleConv(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+ elif self.is_input_patches:
+ self.norm_out = nn.LayerNorm(inner_dim, elementwise_affine=False, eps=1e-6)
+ self.proj_out_1 = nn.Linear(inner_dim, 2 * inner_dim)
+ self.proj_out_2 = nn.Linear(inner_dim, patch_size * patch_size * self.out_channels)
+
+ def forward(
+ self,
+ hidden_states: torch.Tensor,
+ encoder_hidden_states: Optional[torch.Tensor] = None,
+ timestep: Optional[torch.LongTensor] = None,
+ class_labels: Optional[torch.LongTensor] = None,
+ cross_attention_kwargs: Dict[str, Any] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ encoder_attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ use_image_num=None,
+ ):
+ """
+ The [`Transformer2DModel`] forward method.
+
+ Args:
+ hidden_states (`torch.LongTensor` of shape `(batch size, num latent pixels)` if discrete, `torch.FloatTensor` of shape `(batch size, channel, height, width)` if continuous):
+ Input `hidden_states`.
+ encoder_hidden_states ( `torch.FloatTensor` of shape `(batch size, sequence len, embed dims)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.LongTensor`, *optional*):
+ Used to indicate denoising step. Optional timestep to be applied as an embedding in `AdaLayerNorm`.
+ class_labels ( `torch.LongTensor` of shape `(batch size, num classes)`, *optional*):
+ Used to indicate class labels conditioning. Optional class labels to be applied as an embedding in
+ `AdaLayerZeroNorm`.
+ encoder_attention_mask ( `torch.Tensor`, *optional*):
+ Cross-attention mask applied to `encoder_hidden_states`. Two formats supported:
+
+ * Mask `(batch, sequence_length)` True = keep, False = discard.
+ * Bias `(batch, 1, sequence_length)` 0 = keep, -10000 = discard.
+
+ If `ndim == 2`: will be interpreted as a mask, then converted into a bias consistent with the format
+ above. This bias will be added to the cross-attention scores.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain
+ tuple.
+
+ Returns:
+ If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
+ `tuple` where the first element is the sample tensor.
+ """
+ # ensure attention_mask is a bias, and give it a singleton query_tokens dimension.
+ # we may have done this conversion already, e.g. if we came here via UNet2DConditionModel#forward.
+ # we can tell by counting dims; if ndim == 2: it's a mask rather than a bias.
+ # expects mask of shape:
+ # [batch, key_tokens]
+ # adds singleton query_tokens dimension:
+ # [batch, 1, key_tokens]
+ # this helps to broadcast it as a bias over attention scores, which will be in one of the following shapes:
+ # [batch, heads, query_tokens, key_tokens] (e.g. torch sdp attn)
+ # [batch * heads, query_tokens, key_tokens] (e.g. xformers or classic attn)
+ if attention_mask is not None and attention_mask.ndim == 2:
+ # assume that mask is expressed as:
+ # (1 = keep, 0 = discard)
+ # convert mask into a bias that can be added to attention scores:
+ # (keep = +0, discard = -10000.0)
+ attention_mask = (1 - attention_mask.to(hidden_states.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # convert encoder_attention_mask to a bias the same way we do for attention_mask
+ if encoder_attention_mask is not None and encoder_attention_mask.ndim == 2:
+ encoder_attention_mask = (1 - encoder_attention_mask.to(hidden_states.dtype)) * -10000.0
+ encoder_attention_mask = encoder_attention_mask.unsqueeze(1)
+
+ # 1. Input
+ if self.is_input_continuous: # True
+
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ if self.training:
+ video_length = hidden_states.shape[2] - use_image_num
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
+ encoder_hidden_states_length = encoder_hidden_states.shape[1]
+ encoder_hidden_states_video = encoder_hidden_states[:, :encoder_hidden_states_length - use_image_num, ...]
+ encoder_hidden_states_video = repeat(encoder_hidden_states_video, 'b m n c -> b (m f) n c', f=video_length).contiguous()
+ encoder_hidden_states_image = encoder_hidden_states[:, encoder_hidden_states_length - use_image_num:, ...]
+ encoder_hidden_states = torch.cat([encoder_hidden_states_video, encoder_hidden_states_image], dim=1)
+ encoder_hidden_states = rearrange(encoder_hidden_states, 'b m n c -> (b m) n c').contiguous()
+ else:
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w").contiguous()
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length).contiguous()
+
+ batch, _, height, width = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * width, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+ elif self.is_input_vectorized:
+ hidden_states = self.latent_image_embedding(hidden_states)
+ elif self.is_input_patches:
+ hidden_states = self.pos_embed(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ attention_mask=attention_mask,
+ encoder_hidden_states=encoder_hidden_states,
+ encoder_attention_mask=encoder_attention_mask,
+ timestep=timestep,
+ cross_attention_kwargs=cross_attention_kwargs,
+ class_labels=class_labels,
+ video_length=video_length,
+ use_image_num=use_image_num,
+ )
+
+ # 3. Output
+ if self.is_input_continuous:
+ if not self.use_linear_projection:
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = hidden_states.reshape(batch, height, width, inner_dim).permute(0, 3, 1, 2).contiguous()
+
+ output = hidden_states + residual
+ elif self.is_input_vectorized:
+ hidden_states = self.norm_out(hidden_states)
+ logits = self.out(hidden_states)
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+ logits = logits.permute(0, 2, 1)
+
+ # log(p(x_0))
+ output = F.log_softmax(logits.double(), dim=1).float()
+ elif self.is_input_patches:
+ # TODO: cleanup!
+ conditioning = self.transformer_blocks[0].norm1.emb(
+ timestep, class_labels, hidden_dtype=hidden_states.dtype
+ )
+ shift, scale = self.proj_out_1(F.silu(conditioning)).chunk(2, dim=1)
+ hidden_states = self.norm_out(hidden_states) * (1 + scale[:, None]) + shift[:, None]
+ hidden_states = self.proj_out_2(hidden_states)
+
+ # unpatchify
+ height = width = int(hidden_states.shape[1] ** 0.5)
+ hidden_states = hidden_states.reshape(
+ shape=(-1, height, width, self.patch_size, self.patch_size, self.out_channels)
+ )
+ hidden_states = torch.einsum("nhwpqc->nchpwq", hidden_states)
+ output = hidden_states.reshape(
+ shape=(-1, self.out_channels, height * self.patch_size, width * self.patch_size)
+ )
+
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length + use_image_num).contiguous()
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
diff --git a/base/models/unet.py b/base/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..bfb2f473142783e7e6cf8f8fb37c41c56fb2f8de
--- /dev/null
+++ b/base/models/unet.py
@@ -0,0 +1,617 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+import math
+import json
+import torch
+import einops
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+
+try:
+ from diffusers.models.modeling_utils import ModelMixin
+except:
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
+
+try:
+ from .unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+ )
+ from .resnet import InflatedConv3d
+except:
+ from unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+ )
+ from resnet import InflatedConv3d
+
+from rotary_embedding_torch import RotaryEmbedding
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+class RelativePositionBias(nn.Module):
+ def __init__(
+ self,
+ heads=8,
+ num_buckets=32,
+ max_distance=128,
+ ):
+ super().__init__()
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
+ ret = 0
+ n = -relative_position
+
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).long()
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, n, device):
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
+ rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1')
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
+ values = self.relative_attention_bias(rp_bucket)
+ return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
+
+@dataclass
+class UNet3DConditionOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+class UNet3DConditionModel(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None, # 64
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D",
+ ),
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ ):
+ super().__init__()
+
+ # print(use_first_frame)
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ # print(only_cross_attention)
+ # print(type(only_cross_attention))
+ # exit()
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+ # print(only_cross_attention)
+ # exit()
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+ # print(attention_head_dim)
+ # exit()
+
+ rotary_emb = RotaryEmbedding(32)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
+ self.mid_block = UNetMidBlock3DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the videos
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ # relative time positional embeddings
+ self.use_relative_position = use_relative_position
+ if self.use_relative_position:
+ self.time_rel_pos_bias = RelativePositionBias(heads=8, max_distance=32) # realistically will not be able to generate that many frames of video... yet
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_slicable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_slicable_dims(module)
+
+ num_slicable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_slicable_layers * [1]
+
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor = None,
+ class_labels: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ use_image_num: int = 0,
+ return_dict: bool = True,
+ ) -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ # print(emb.shape) # torch.Size([3, 1280])
+ # print(class_emb.shape) # torch.Size([3, 1280])
+ emb = emb + class_emb
+
+ if self.use_relative_position:
+ frame_rel_pos_bias = self.time_rel_pos_bias(sample.shape[2], device=sample.device)
+ else:
+ frame_rel_pos_bias = None
+
+ # pre-process
+ sample = self.conv_in(sample)
+
+ # down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ use_image_num=use_image_num,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # mid
+ sample = self.mid_block(
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask, use_image_num=use_image_num,
+ )
+
+ # up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ use_image_num=use_image_num,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+ # print(sample.shape)
+
+ if not return_dict:
+ return (sample,)
+ sample = UNet3DConditionOutput(sample=sample)
+ return sample
+
+ def forward_with_cfg(self,
+ x,
+ t,
+ encoder_hidden_states = None,
+ class_labels: Optional[torch.Tensor] = None,
+ cfg_scale=4.0,
+ use_fp16=False):
+ """
+ Forward, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[: len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ if use_fp16:
+ combined = combined.to(dtype=torch.float16)
+ model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ eps, rest = model_out[:, :4], model_out[:, 4:]
+ # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+ config["_class_name"] = cls.__name__
+ config["down_block_types"] = [
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D"
+ ]
+ config["up_block_types"] = [
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ]
+
+ config["use_first_frame"] = False
+
+ from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
+
+
+ model = cls.from_config(config)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ state_dict = torch.load(model_file, map_location="cpu")
+ for k, v in model.state_dict().items():
+ # print(k)
+ if '_temp' in k:
+ state_dict.update({k: v})
+ if 'attn_fcross' in k: # conpy parms of attn1 to attn_fcross
+ k = k.replace('attn_fcross', 'attn1')
+ state_dict.update({k: state_dict[k]})
+ if 'norm_fcross' in k:
+ k = k.replace('norm_fcross', 'norm1')
+ state_dict.update({k: state_dict[k]})
+
+ model.load_state_dict(state_dict)
+
+ return model
+
+if __name__ == '__main__':
+ import torch
+ # from xformers.ops import MemoryEfficientAttentionFlashAttentionOp
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ pretrained_model_path = "/mnt/petrelfs/maxin/work/pretrained/stable-diffusion-v1-4/" # p cluster
+ unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device)
+ # unet.enable_xformers_memory_efficient_attention(attention_op=MemoryEfficientAttentionFlashAttentionOp)
+ unet.enable_xformers_memory_efficient_attention()
+ unet.enable_gradient_checkpointing()
+
+ unet.train()
+
+ use_image_num = 5
+ noisy_latents = torch.randn((2, 4, 16 + use_image_num, 32, 32)).to(device)
+ bsz = noisy_latents.shape[0]
+ timesteps = torch.randint(0, 1000, (bsz,)).to(device)
+ timesteps = timesteps.long()
+ encoder_hidden_states = torch.randn((bsz, 1 + use_image_num, 77, 768)).to(device)
+ # class_labels = torch.randn((bsz, )).to(device)
+
+
+ model_pred = unet(sample=noisy_latents, timestep=timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ class_labels=None,
+ use_image_num=use_image_num).sample
+ print(model_pred.shape)
\ No newline at end of file
diff --git a/base/models/unet_blocks.py b/base/models/unet_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..849c10539c7039840c93631c5201069119d3c306
--- /dev/null
+++ b/base/models/unet_blocks.py
@@ -0,0 +1,648 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+import torch
+from torch import nn
+
+try:
+ from .attention import Transformer3DModel
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
+except:
+ from attention import Transformer3DModel
+ from resnet import Downsample3D, ResnetBlock3D, Upsample3D
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=False,
+):
+ # print(down_block_type)
+ # print(use_first_frame)
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "CrossAttnDownBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+ return CrossAttnDownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=False,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "CrossAttnUpBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+ return CrossAttnUpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ )
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class CrossAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ # print(use_first_frame)
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ def create_custom_forward_attn(module, return_dict=None, use_image_num=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict, use_image_num=use_image_num)
+ else:
+ return module(*inputs, use_image_num=use_image_num)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=False
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ attention_mask=None,
+ use_image_num=None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ def create_custom_forward_attn(module, return_dict=None, use_image_num=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict, use_image_num=use_image_num)
+ else:
+ return module(*inputs, use_image_num=use_image_num)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward_attn(attn, return_dict=False, use_image_num=use_image_num),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, use_image_num=use_image_num).sample
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
diff --git a/base/models/utils.py b/base/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..76e205d1709568885e58b56888f62c8805ad4f91
--- /dev/null
+++ b/base/models/utils.py
@@ -0,0 +1,215 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+
+import numpy as np
+import torch.nn as nn
+
+from einops import repeat
+
+
+#################################################################################
+# Unet Utils #
+#################################################################################
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+# class HybridConditioner(nn.Module):
+
+# def __init__(self, c_concat_config, c_crossattn_config):
+# super().__init__()
+# self.concat_conditioner = instantiate_from_config(c_concat_config)
+# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+# def forward(self, c_concat, c_crossattn):
+# c_concat = self.concat_conditioner(c_concat)
+# c_crossattn = self.crossattn_conditioner(c_crossattn)
+# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += torch.DoubleTensor([matmul_ops])
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
\ No newline at end of file
diff --git a/base/pipelines/__pycache__/pipeline_videogen.cpython-311.pyc b/base/pipelines/__pycache__/pipeline_videogen.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..caa940925ab961c63315edfdc6eaffc8faa230c2
Binary files /dev/null and b/base/pipelines/__pycache__/pipeline_videogen.cpython-311.pyc differ
diff --git a/base/pipelines/pipeline_videogen.py b/base/pipelines/pipeline_videogen.py
new file mode 100644
index 0000000000000000000000000000000000000000..97031fc4d25b714b2c666c6d17bf3e16895376f0
--- /dev/null
+++ b/base/pipelines/pipeline_videogen.py
@@ -0,0 +1,677 @@
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, Dict, List, Optional, Union
+import einops
+import torch
+from packaging import version
+from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.configuration_utils import FrozenDict
+from diffusers.models import AutoencoderKL
+from diffusers.schedulers import KarrasDiffusionSchedulers
+from diffusers.utils import (
+ deprecate,
+ is_accelerate_available,
+ is_accelerate_version,
+ logging,
+ #randn_tensor,
+ replace_example_docstring,
+ BaseOutput,
+)
+
+try:
+ from diffusers.utils import randn_tensor
+except:
+ from diffusers.utils.torch_utils import randn_tensor
+
+
+from diffusers.pipeline_utils import DiffusionPipeline
+from dataclasses import dataclass
+
+import os, sys
+sys.path.append(os.path.split(sys.path[0])[0])
+from models.unet import UNet3DConditionModel
+
+import numpy as np
+
+@dataclass
+class StableDiffusionPipelineOutput(BaseOutput):
+ video: torch.Tensor
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+EXAMPLE_DOC_STRING = """
+ Examples:
+ ```py
+ >>> import torch
+ >>> from diffusers import StableDiffusionPipeline
+
+ >>> pipe = StableDiffusionPipeline.from_pretrained("runwayml/stable-diffusion-v1-5", torch_dtype=torch.float16)
+ >>> pipe = pipe.to("cuda")
+
+ >>> prompt = "a photo of an astronaut riding a horse on mars"
+ >>> image = pipe(prompt).images[0]
+ ```
+"""
+
+
+class VideoGenPipeline(DiffusionPipeline):
+ r"""
+ Pipeline for text-to-image generation using Stable Diffusion.
+
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the
+ library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.)
+
+ Args:
+ vae ([`AutoencoderKL`]):
+ Variational Auto-Encoder (VAE) Model to encode and decode images to and from latent representations.
+ text_encoder ([`CLIPTextModel`]):
+ Frozen text-encoder. Stable Diffusion uses the text portion of
+ [CLIP](https://huggingface.co/docs/transformers/model_doc/clip#transformers.CLIPTextModel), specifically
+ the [clip-vit-large-patch14](https://huggingface.co/openai/clip-vit-large-patch14) variant.
+ tokenizer (`CLIPTokenizer`):
+ Tokenizer of class
+ [CLIPTokenizer](https://huggingface.co/docs/transformers/v4.21.0/en/model_doc/clip#transformers.CLIPTokenizer).
+ unet ([`UNet2DConditionModel`]): Conditional U-Net architecture to denoise the encoded image latents.
+ scheduler ([`SchedulerMixin`]):
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents. Can be one of
+ [`DDIMScheduler`], [`LMSDiscreteScheduler`], or [`PNDMScheduler`].
+ safety_checker ([`StableDiffusionSafetyChecker`]):
+ Classification module that estimates whether generated images could be considered offensive or harmful.
+ Please, refer to the [model card](https://huggingface.co/runwayml/stable-diffusion-v1-5) for details.
+ feature_extractor ([`CLIPFeatureExtractor`]):
+ Model that extracts features from generated images to be used as inputs for the `safety_checker`.
+ """
+ _optional_components = ["safety_checker", "feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet3DConditionModel,
+ scheduler: KarrasDiffusionSchedulers,
+ ):
+ super().__init__()
+
+ if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
+ f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure "
+ "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
+ " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
+ " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
+ " file"
+ )
+ deprecate("steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["steps_offset"] = 1
+ scheduler._internal_dict = FrozenDict(new_config)
+
+ if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True:
+ deprecation_message = (
+ f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
+ " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
+ " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
+ " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
+ " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
+ )
+ deprecate("clip_sample not set", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(scheduler.config)
+ new_config["clip_sample"] = False
+ scheduler._internal_dict = FrozenDict(new_config)
+
+
+
+ is_unet_version_less_0_9_0 = hasattr(unet.config, "_diffusers_version") and version.parse(
+ version.parse(unet.config._diffusers_version).base_version
+ ) < version.parse("0.9.0.dev0")
+ is_unet_sample_size_less_64 = hasattr(unet.config, "sample_size") and unet.config.sample_size < 64
+ if is_unet_version_less_0_9_0 and is_unet_sample_size_less_64:
+ deprecation_message = (
+ "The configuration file of the unet has set the default `sample_size` to smaller than"
+ " 64 which seems highly unlikely. If your checkpoint is a fine-tuned version of any of the"
+ " following: \n- CompVis/stable-diffusion-v1-4 \n- CompVis/stable-diffusion-v1-3 \n-"
+ " CompVis/stable-diffusion-v1-2 \n- CompVis/stable-diffusion-v1-1 \n- runwayml/stable-diffusion-v1-5"
+ " \n- runwayml/stable-diffusion-inpainting \n you should change 'sample_size' to 64 in the"
+ " configuration file. Please make sure to update the config accordingly as leaving `sample_size=32`"
+ " in the config might lead to incorrect results in future versions. If you have downloaded this"
+ " checkpoint from the Hugging Face Hub, it would be very nice if you could open a Pull request for"
+ " the `unet/config.json` file"
+ )
+ deprecate("sample_size<64", "1.0.0", deprecation_message, standard_warn=False)
+ new_config = dict(unet.config)
+ new_config["sample_size"] = 64
+ unet._internal_dict = FrozenDict(new_config)
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ scheduler=scheduler,
+ )
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
+ # self.register_to_config(requires_safety_checker=requires_safety_checker)
+
+ def enable_vae_slicing(self):
+ r"""
+ Enable sliced VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
+ steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.vae.enable_slicing()
+
+ def disable_vae_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_slicing()
+
+ def enable_vae_tiling(self):
+ r"""
+ Enable tiled VAE decoding.
+
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
+ several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
+ """
+ self.vae.enable_tiling()
+
+ def disable_vae_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.vae.disable_tiling()
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ Note that offloading happens on a submodule basis. Memory savings are higher than with
+ `enable_model_cpu_offload`, but performance is lower.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ cpu_offload(cpu_offloaded_model, device)
+
+ # if self.safety_checker is not None:
+ # cpu_offload(self.safety_checker, execution_device=device, offload_buffers=True)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ self.final_offload_hook = hook
+
+ @property
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ def decode_latents(self, latents):
+ video_length = latents.shape[2]
+ latents = 1 / 0.18215 * latents
+ latents = einops.rearrange(latents, "b c f h w -> (b f) c h w")
+ video = self.vae.decode(latents).sample
+ video = einops.rearrange(video, "(b f) c h w -> b f h w c", f=video_length)
+ video = ((video / 2 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().contiguous()
+ return video
+
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ def check_inputs(
+ self,
+ prompt,
+ height,
+ width,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if height % 8 != 0 or width % 8 != 0:
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
+ if isinstance(generator, list) and len(generator) != batch_size:
+ raise ValueError(
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
+ )
+
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ height: Optional[int] = None,
+ width: Optional[int] = None,
+ video_length: int = 16,
+ num_inference_steps: int = 50,
+ guidance_scale: float = 7.5,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ output_type: Optional[str] = "pil",
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ cross_attention_kwargs: Optional[Dict[str, Any]] = None,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The height in pixels of the generated image.
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
+ The width in pixels of the generated image.
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
+ Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ output_type (`str`, *optional*, defaults to `"pil"`):
+ The output format of the generate image. Choose between
+ [PIL](https://pillow.readthedocs.io/en/stable/): `PIL.Image.Image` or `np.array`.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+ cross_attention_kwargs (`dict`, *optional*):
+ A kwargs dictionary that if specified is passed along to the `AttnProcessor` as defined under
+ `self.processor` in
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
+
+ Examples:
+
+ Returns:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] or `tuple`:
+ [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
+ When returning a tuple, the first element is a list with the generated images, and the second element is a
+ list of `bool`s denoting whether the corresponding generated image likely represents "not-safe-for-work"
+ (nsfw) content, according to the `safety_checker`.
+ """
+ # 0. Default height and width to unet
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
+
+ # 1. Check inputs. Raise error if not correct
+ self.check_inputs(
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
+ )
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Prepare timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 5. Prepare latent variables
+ num_channels_latents = self.unet.config.in_channels
+ latents = self.prepare_latents(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ video_length,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ )
+
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 7. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+
+ # predict the noise residual
+ noise_pred = self.unet(
+ latent_model_input,
+ t,
+ encoder_hidden_states=prompt_embeds,
+ # cross_attention_kwargs=cross_attention_kwargs,
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+
+ # 8. Post-processing
+ video = self.decode_latents(latents)
+
+ return StableDiffusionPipelineOutput(video=video)
diff --git a/base/pipelines/sample.py b/base/pipelines/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..e2a2a89e311982e6596cece7f767015e9b9ec709
--- /dev/null
+++ b/base/pipelines/sample.py
@@ -0,0 +1,88 @@
+import os
+import torch
+import argparse
+import torchvision
+
+from pipeline_videogen import VideoGenPipeline
+
+from download import find_model
+from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler
+from diffusers.models import AutoencoderKL
+from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
+from omegaconf import OmegaConf
+
+import os, sys
+sys.path.append(os.path.split(sys.path[0])[0])
+from models import get_models
+import imageio
+
+def main(args):
+ #torch.manual_seed(args.seed)
+ torch.set_grad_enabled(False)
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ sd_path = args.pretrained_path + "/stable-diffusion-v1-4"
+ unet = get_models(args, sd_path).to(device, dtype=torch.float16)
+ state_dict = find_model(args.pretrained_path + "/lavie_base.pt")
+ unet.load_state_dict(state_dict)
+
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
+ tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
+ text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge
+
+ # set eval mode
+ unet.eval()
+ vae.eval()
+ text_encoder_one.eval()
+
+ if args.sample_method == 'ddim':
+ scheduler = DDIMScheduler.from_pretrained(sd_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule)
+ elif args.sample_method == 'eulerdiscrete':
+ scheduler = EulerDiscreteScheduler.from_pretrained(sd_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule)
+ elif args.sample_method == 'ddpm':
+ scheduler = DDPMScheduler.from_pretrained(sd_path,
+ subfolder="scheduler",
+ beta_start=args.beta_start,
+ beta_end=args.beta_end,
+ beta_schedule=args.beta_schedule)
+ else:
+ raise NotImplementedError
+
+ videogen_pipeline = VideoGenPipeline(vae=vae,
+ text_encoder=text_encoder_one,
+ tokenizer=tokenizer_one,
+ scheduler=scheduler,
+ unet=unet).to(device)
+ videogen_pipeline.enable_xformers_memory_efficient_attention()
+
+ if not os.path.exists(args.output_folder):
+ os.makedirs(args.output_folder)
+
+ video_grids = []
+ for prompt in args.text_prompt:
+ print('Processing the ({}) prompt'.format(prompt))
+ videos = videogen_pipeline(prompt,
+ video_length=args.video_length,
+ height=args.image_size[0],
+ width=args.image_size[1],
+ num_inference_steps=args.num_sampling_steps,
+ guidance_scale=args.guidance_scale).video
+ imageio.mimwrite(args.output_folder + prompt.replace(' ', '_') + '.mp4', videos[0], fps=8, quality=9) # highest quality is 10, lowest is 0
+
+ print('save path {}'.format(args.output_folder))
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="")
+ args = parser.parse_args()
+
+ main(OmegaConf.load(args.config))
+
diff --git a/base/pipelines/sample.sh b/base/pipelines/sample.sh
new file mode 100644
index 0000000000000000000000000000000000000000..c8dd5fdada1cf921e5c250bd357e765fd9e42b5a
--- /dev/null
+++ b/base/pipelines/sample.sh
@@ -0,0 +1,2 @@
+export CUDA_VISIBLE_DEVICES=6
+python pipelines/sample.py --config configs/sample.yaml
\ No newline at end of file
diff --git a/base/text_to_video/__init__.py b/base/text_to_video/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..b5a194bb33bf22c33de528da45b372f7e2ffaef6
--- /dev/null
+++ b/base/text_to_video/__init__.py
@@ -0,0 +1,44 @@
+import os
+import torch
+import argparse
+import torchvision
+
+from pipelines.pipeline_videogen import VideoGenPipeline
+
+from download import find_model
+from diffusers.schedulers import DDIMScheduler, DDPMScheduler, PNDMScheduler, EulerDiscreteScheduler
+from diffusers.models import AutoencoderKL
+from transformers import CLIPTokenizer, CLIPTextModel, CLIPTextModelWithProjection
+from omegaconf import OmegaConf
+
+import os, sys
+sys.path.append(os.path.split(sys.path[0])[0])
+from models import get_models
+import imageio
+
+config_path = "/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml"
+args = OmegaConf.load("/mnt/petrelfs/zhouyan/project/lavie-release/base/configs/sample.yaml")
+
+device = "cuda" if torch.cuda.is_available() else "cpu"
+
+def model_t2v_fun(args):
+ sd_path = args.pretrained_path + "/stable-diffusion-v1-4"
+ unet = get_models(args, sd_path).to(device, dtype=torch.float16)
+ # state_dict = find_model(args.pretrained_path + "/lavie_base.pt")
+ state_dict = find_model("/mnt/petrelfs/share_data/wangyaohui/lavie/pretrained_models/lavie_base.pt")
+ unet.load_state_dict(state_dict)
+
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae", torch_dtype=torch.float16).to(device)
+ tokenizer_one = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer")
+ text_encoder_one = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder", torch_dtype=torch.float16).to(device) # huge
+ unet.eval()
+ vae.eval()
+ text_encoder_one.eval()
+ scheduler = DDIMScheduler.from_pretrained(sd_path, subfolder="scheduler", beta_start=args.beta_start, beta_end=args.beta_end, beta_schedule=args.beta_schedule)
+ return VideoGenPipeline(vae=vae, text_encoder=text_encoder_one, tokenizer=tokenizer_one, scheduler=scheduler, unet=unet)
+
+def setup_seed(seed):
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+
+
diff --git a/base/text_to_video/__pycache__/__init__.cpython-311.pyc b/base/text_to_video/__pycache__/__init__.cpython-311.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2564eedb859cc053903c2f055f4e2e237d1fa03e
Binary files /dev/null and b/base/text_to_video/__pycache__/__init__.cpython-311.pyc differ
diff --git a/base/try.py b/base/try.py
new file mode 100644
index 0000000000000000000000000000000000000000..5cffd2393f0658f057f30acd2060a21ad1dc8825
--- /dev/null
+++ b/base/try.py
@@ -0,0 +1,5 @@
+import gradio as gr
+
+with gr.Blocks() as demo:
+ prompt = gr.Textbox(label="Prompt", placeholder="enter prompt", show_label=True, elem_id="prompt-in")
+demo.launch(server_name="0.0.0.0")
\ No newline at end of file
diff --git a/environment.yml b/environment.yml
new file mode 100644
index 0000000000000000000000000000000000000000..7af32037d0e2dc010f6ccf4ec0970ad66029ff70
--- /dev/null
+++ b/environment.yml
@@ -0,0 +1,27 @@
+name: lavie
+channels:
+ - pytorch
+ - nvidia
+dependencies:
+ - python=3.11.3
+ - pytorch=2.0.1
+ - pytorch-cuda=11.7
+ - torchvision=0.15.2
+ - pip:
+ - accelerate==0.19.0
+ - av==10.0.0
+ - decord==0.6.0
+ - diffusers[torch]==0.16.0
+ - einops==0.6.1
+ - ffmpeg==1.4
+ - imageio==2.31.1
+ - imageio-ffmpeg==0.4.9
+ - pandas==2.0.1
+ - timm==0.6.13
+ - tqdm==4.65.0
+ - transformers==4.28.1
+ - xformers==0.0.20
+ - omegaconf==2.3.0
+ - natsort==8.4.0
+ - rotary_embedding_torch
+ - gradio==4.3.0
\ No newline at end of file
diff --git a/interpolation/configs/sample.yaml b/interpolation/configs/sample.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..c00d36aaec25efe196ac47d167042c08f3717a57
--- /dev/null
+++ b/interpolation/configs/sample.yaml
@@ -0,0 +1,36 @@
+args:
+ input_folder: "../res/base/"
+ pretrained_path: "../pretrained_models"
+ output_folder: "../res/interpolation/"
+ seed_list:
+ - 3418
+
+ fps_list:
+ - 24
+
+ # model config:
+ model: TSR
+ num_frames: 61
+ image_size: [320, 512]
+ num_sampling_steps: 50
+ vae: mse
+ use_timecross_transformer: False
+ frame_interval: 1
+
+ # sample config:
+ seed: 0
+ cfg_scale: 4.0
+ run_time: 12
+ use_compile: False
+ enable_xformers_memory_efficient_attention: True
+ num_sample: 1
+
+ additional_prompt: ", 4k."
+ negative_prompt: "None"
+ do_classifier_free_guidance: True
+ use_ddim_sample_loop: True
+
+ researve_frame: 3
+ mask_type: "tsr"
+ use_concat: True
+ copy_no_mask: True
diff --git a/interpolation/datasets/__init__.py b/interpolation/datasets/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..1332521f6336d8ce42f026e0aeba66fdd8138a0a
--- /dev/null
+++ b/interpolation/datasets/__init__.py
@@ -0,0 +1 @@
+from datasets import video_transforms
diff --git a/interpolation/datasets/video_transforms.py b/interpolation/datasets/video_transforms.py
new file mode 100644
index 0000000000000000000000000000000000000000..463a10f082288f486a546bbff17b136824723589
--- /dev/null
+++ b/interpolation/datasets/video_transforms.py
@@ -0,0 +1,109 @@
+import torch
+import random
+import numbers
+from torchvision.transforms import RandomCrop, RandomResizedCrop
+
+def _is_tensor_video_clip(clip):
+ if not torch.is_tensor(clip):
+ raise TypeError("clip should be Tensor. Got %s" % type(clip))
+
+ if not clip.ndimension() == 4:
+ raise ValueError("clip should be 4D. Got %dD" % clip.dim())
+
+ return True
+
+
+def to_tensor(clip):
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+ """
+ _is_tensor_video_clip(clip)
+ if not clip.dtype == torch.uint8:
+ raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
+ # return clip.float().permute(3, 0, 1, 2) / 255.0
+ return clip.float() / 255.0
+
+
+def resize(clip, target_size, interpolation_mode):
+ if len(target_size) != 2:
+ raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
+ return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
+
+
+class ToTensorVideo:
+ """
+ Convert tensor data type from uint8 to float, divide value by 255.0 and
+ permute the dimensions of clip tensor
+ """
+
+ def __init__(self):
+ pass
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
+ Return:
+ clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
+ """
+ return to_tensor(clip)
+
+ def __repr__(self) -> str:
+ return self.__class__.__name__
+
+
+class ResizeVideo:
+ '''
+ Resize to the specified size
+ '''
+ def __init__(
+ self,
+ size,
+ interpolation_mode="bilinear",
+ ):
+ if isinstance(size, tuple):
+ if len(size) != 2:
+ raise ValueError(f"size should be tuple (height, width), instead got {size}")
+ self.size = size
+ else:
+ self.size = (size, size)
+
+ self.interpolation_mode = interpolation_mode
+
+
+ def __call__(self, clip):
+ """
+ Args:
+ clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
+ Returns:
+ torch.tensor: scale resized video clip.
+ size is (T, C, h, w)
+ """
+ clip_resize = resize(clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
+ return clip_resize
+
+ def __repr__(self) -> str:
+ return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
+
+
+class TemporalRandomCrop(object):
+ """Temporally crop the given frame indices at a random location.
+
+ Args:
+ size (int): Desired length of frames will be seen in the model.
+ """
+
+ def __init__(self, size):
+ self.size = size
+
+ def __call__(self, total_frames):
+ rand_end = max(0, total_frames - self.size - 1)
+ begin_index = random.randint(0, rand_end)
+ end_index = min(begin_index + self.size, total_frames)
+ return begin_index, end_index
+
diff --git a/interpolation/diffusion/__init__.py b/interpolation/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..c9dbf6cf0bd6b9d1a8f65e0a31e9a84cacc03189
--- /dev/null
+++ b/interpolation/diffusion/__init__.py
@@ -0,0 +1,47 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from . import gaussian_diffusion as gd
+from .respace import SpacedDiffusion, space_timesteps
+
+
+def create_diffusion(
+ timestep_respacing,
+ noise_schedule="linear",
+ use_kl=False,
+ sigma_small=False,
+ predict_xstart=False,
+ # learn_sigma=True,
+ learn_sigma=False, # for unet
+ rescale_learned_sigmas=False,
+ diffusion_steps=1000
+):
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
+ if use_kl:
+ loss_type = gd.LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = gd.LossType.RESCALED_MSE
+ else:
+ loss_type = gd.LossType.MSE
+ if timestep_respacing is None or timestep_respacing == "":
+ timestep_respacing = [diffusion_steps]
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(
+ gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X
+ ),
+ model_var_type=(
+ (
+ gd.ModelVarType.FIXED_LARGE
+ if not sigma_small
+ else gd.ModelVarType.FIXED_SMALL
+ )
+ if not learn_sigma
+ else gd.ModelVarType.LEARNED_RANGE
+ ),
+ loss_type=loss_type
+ # rescale_timesteps=rescale_timesteps,
+ )
diff --git a/interpolation/diffusion/diffusion_utils.py b/interpolation/diffusion/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060
--- /dev/null
+++ b/interpolation/diffusion/diffusion_utils.py
@@ -0,0 +1,88 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+import torch as th
+import numpy as np
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, th.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + th.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def continuous_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a continuous Gaussian distribution.
+ :param x: the targets
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ normalized_x = centered_x * inv_stdv
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
+ return log_probs
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
diff --git a/interpolation/diffusion/gaussian_diffusion.py b/interpolation/diffusion/gaussian_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..854c5c4daceace3826b786d98220c9d1b10611c8
--- /dev/null
+++ b/interpolation/diffusion/gaussian_diffusion.py
@@ -0,0 +1,1000 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+
+import math
+
+import numpy as np
+import torch as th
+import enum
+
+from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
+ return betas
+
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ """
+ This is the deprecated API for creating beta schedules.
+ See get_named_beta_schedule() for the new library of schedules.
+ """
+ if beta_schedule == "quad":
+ betas = (
+ np.linspace(
+ beta_start ** 0.5,
+ beta_end ** 0.5,
+ num_diffusion_timesteps,
+ dtype=np.float64,
+ )
+ ** 2
+ )
+ elif beta_schedule == "linear":
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "warmup10":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
+ elif beta_schedule == "warmup50":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
+ """
+ Get a pre-defined beta schedule for the given name.
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name == "linear":
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ return get_beta_schedule(
+ "linear",
+ beta_start=scale * 0.0001,
+ beta_end=scale * 0.02,
+ num_diffusion_timesteps=num_diffusion_timesteps,
+ )
+ elif schedule_name == "squaredcos_cap_v2":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+ Original ported from this codebase:
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type
+ ):
+
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ ) if len(self.posterior_variance) > 1 else np.array([])
+
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+ In other words, sample from q(x_t | x_0).
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+ q(x_{t-1} | x_t, x_0)
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None,
+ mask=None, x_start=None, use_concat=False,
+ copy_no_mask=False, ):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, F, C = x.shape[:3]
+ assert t.shape == (B,)
+ # model_output = model(x, t, **model_kwargs)
+ if copy_no_mask:
+ if use_concat:
+ try:
+ model_output = model(th.concat([x, x_start], dim=1), t, **model_kwargs).sample
+ except:
+ # print(f'x.shape = {x.shape}, x_start.shape = {x_start.shape}')
+ # )
+ # x.shape = torch.Size([2, 4, 61, 32, 32]), x_start.shape = torch.Size([2, 4, 61, 32, 32]
+ # print(f'x[0,0,:,0,0] = {x[0,0,:,0,0]}, \nx_start[0,0,:,0,0] = {x_start[0,0,:,0,0]}')
+ model_output = model(th.concat([x, x_start], dim=1), t, **model_kwargs)
+ else:
+ try:
+ model_output = model(x, t, **model_kwargs).sample # for tav unet
+ except:
+ model_output = model(x, t, **model_kwargs)
+ else:
+ if use_concat:
+ try:
+ model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs).sample
+ except:
+ model_output = model(th.concat([x, mask, x_start], dim=1), t, **model_kwargs)
+ else:
+ try:
+ model_output = model(x, t, **model_kwargs).sample # for tav unet
+ except:
+ model_output = model(x, t, **model_kwargs)
+ if isinstance(model_output, tuple):
+ model_output, extra = model_output
+ else:
+ extra = None
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, F, C * 2, *x.shape[3:])
+ model_output, model_var_values = th.split(model_output, C, dim=2)
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output, mask=mask, x_start=x_start, use_concat=use_concat)
+ )
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
+
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ return {
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ "extra": extra,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps, mask=None, x_start=None, use_concat=False): # (x_t=x, t=t, eps=model_output)
+ assert x_t.shape == eps.shape
+ if not use_concat:
+ if mask is not None:
+ if x_start is None:
+ return (
+ (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps )* mask + x_t * (1-mask)
+ )
+ else:
+ # breakpoint()
+ if (t == 0).any():
+ print('t=0')
+ x_unknown = _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t \
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ return x_start * (1-mask) + x_unknown * mask
+ else:
+ x_known = self.q_sample(x_start, t-1)
+ x_unknown = _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t \
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ return (
+ x_known * (1-mask) + x_unknown * mask
+ )
+ else:
+ return (
+ (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps )
+ )
+ else:
+ return (
+ (_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps )
+ )
+
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, **model_kwargs)
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+ See condition_mean() for details on cond_fn.
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ mask=None,
+ x_start=None,
+ use_concat=False
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ mask=None,
+ x_start=None,
+ use_concat=False,
+ ):
+ """
+ Generate samples from the model.
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ mask=None,
+ x_start=None,
+ use_concat=False
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad(): # loop
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ mask=None,
+ x_start=None,
+ use_concat=False,
+ copy_no_mask=False,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat,
+ copy_no_mask=copy_no_mask,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ mask=None,
+ x_start=None,
+ use_concat=False,
+ copy_no_mask=False,
+ ):
+ """
+ Generate samples from the model using DDIM.
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat,
+ copy_no_mask=copy_no_mask,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ mask=None,
+ x_start=None,
+ use_concat=False,
+ copy_no_mask=False,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ mask=mask,
+ x_start=x_start,
+ use_concat=use_concat,
+ copy_no_mask=copy_no_mask,
+ )
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ """
+ Get a term for the variational lower-bound.
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+ def training_losses(self, model, x_start, t, model_kwargs=None, noise=None, mask=None, t_head=None, copy_no_mask=False):
+ """
+ Compute training losses for a single timestep.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ # mask could be here
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+ x_t = th.cat([x_t[:, :4], x_start[:, 4:]], dim=1)
+ # mask is used for (0,0,0,1,1,1,...) which means the diffusion model can see the first 3 frames of the input video
+ # print(f'training_losses(): mask = {mask}') # None
+
+ if mask is not None:
+ x_t = x_t*mask + x_start*(1-mask)
+
+ # noise augmentation
+ if copy_no_mask:
+ if t_head is not None:
+ noise_aug = self.q_sample(x_start[:, 4:], t_head) # noise aug on copied_video
+ x_t = th.cat([x_t[:, :4], noise_aug], dim=1)
+ else:
+ if t_head is not None:
+ noise_aug = self.q_sample(x_start[:, 5:], t_head) # b, 4, f, h, w
+ noise_aug = noise_aug * (x_start[:, 4].unsqueeze(1).expand(-1, 4, -1, -1, -1) == 0) # use mask to zero out augmented noises
+ x_t = th.cat([x_t[:, :5], noise_aug], dim=1)
+ terms = {}
+ # for i in [0,1,2,3,4,5,6,7]:
+ # print(f'x_t[0,{i},:,0,0] = {x_t[0,i,:,0,0]}')
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
+ # print(f'self.loss_type = {self.loss_type}') # LossType.MSE
+ # model_output = model(x_t, t, **model_kwargs)
+ try:
+ model_output = model(x_t, t, **model_kwargs).sample # for tav unet
+ except:
+ model_output = model(x_t, t, **model_kwargs)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, F, C = x_t.shape[:3]
+ assert model_output.shape == (B, F, C * 2, *x_t.shape[3:])
+ model_output, model_var_values = th.split(model_output, C, dim=2)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=2)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+
+ # print(f'self.model_mean_type = {self.model_mean_type}') # ModelMeanType.EPSILON
+ target = {
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ # assert model_output.shape == target.shape == x_start.shape
+ # if mask is not None:
+ # nonzero_idx = th.nonzero(1-mask)
+ terms["mse"] = mean_flat((target[:,:4] - model_output) ** 2)
+ # else:
+ # terms["mse"] = mean_flat((target - model_output) ** 2)
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diff --git a/interpolation/diffusion/respace.py b/interpolation/diffusion/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..31c0092380367db477f154a39bf172ff64712fc4
--- /dev/null
+++ b/interpolation/diffusion/respace.py
@@ -0,0 +1,130 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+import torch
+import numpy as np
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ # @torch.compile
+ def training_losses(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(
+ model, self.timestep_map, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ # self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ # if self.rescale_timesteps:
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
diff --git a/interpolation/diffusion/timestep_sampler.py b/interpolation/diffusion/timestep_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f
--- /dev/null
+++ b/interpolation/diffusion/timestep_sampler.py
@@ -0,0 +1,150 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch as th
+import torch.distributed as dist
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ elif name == "loss-second-moment":
+ return LossSecondMomentResampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(ABC):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / np.sum(w)
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
+ indices = th.from_numpy(indices_np).long().to(device)
+ weights_np = 1 / (len(p) * p[indices_np])
+ weights = th.from_numpy(weights_np).float().to(device)
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ self.diffusion = diffusion
+ self._weights = np.ones([diffusion.num_timesteps])
+
+ def weights(self):
+ return self._weights
+
+
+class LossAwareSampler(ScheduleSampler):
+ def update_with_local_losses(self, local_ts, local_losses):
+ """
+ Update the reweighting using losses from a model.
+ Call this method from each rank with a batch of timesteps and the
+ corresponding losses for each of those timesteps.
+ This method will perform synchronization to make sure all of the ranks
+ maintain the exact same reweighting.
+ :param local_ts: an integer Tensor of timesteps.
+ :param local_losses: a 1D Tensor of losses.
+ """
+ batch_sizes = [
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
+ for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(
+ batch_sizes,
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
+ )
+
+ # Pad all_gather batches to be the maximum batch size.
+ batch_sizes = [x.item() for x in batch_sizes]
+ max_bs = max(batch_sizes)
+
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
+ dist.all_gather(timestep_batches, local_ts)
+ dist.all_gather(loss_batches, local_losses)
+ timesteps = [
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
+ ]
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
+ self.update_with_all_losses(timesteps, losses)
+
+ @abstractmethod
+ def update_with_all_losses(self, ts, losses):
+ """
+ Update the reweighting using losses from a model.
+ Sub-classes should override this method to update the reweighting
+ using losses from the model.
+ This method directly updates the reweighting without synchronizing
+ between workers. It is called by update_with_local_losses from all
+ ranks with identical arguments. Thus, it should have deterministic
+ behavior to maintain state across workers.
+ :param ts: a list of int timesteps.
+ :param losses: a list of float losses, one per timestep.
+ """
+
+
+class LossSecondMomentResampler(LossAwareSampler):
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
+ self.diffusion = diffusion
+ self.history_per_term = history_per_term
+ self.uniform_prob = uniform_prob
+ self._loss_history = np.zeros(
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
+ )
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
+
+ def weights(self):
+ if not self._warmed_up():
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
+ weights /= np.sum(weights)
+ weights *= 1 - self.uniform_prob
+ weights += self.uniform_prob / len(weights)
+ return weights
+
+ def update_with_all_losses(self, ts, losses):
+ for t, loss in zip(ts, losses):
+ if self._loss_counts[t] == self.history_per_term:
+ # Shift out the oldest loss term.
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
+ self._loss_history[t, -1] = loss
+ else:
+ self._loss_history[t, self._loss_counts[t]] = loss
+ self._loss_counts[t] += 1
+
+ def _warmed_up(self):
+ return (self._loss_counts == self.history_per_term).all()
diff --git a/interpolation/download.py b/interpolation/download.py
new file mode 100644
index 0000000000000000000000000000000000000000..b7daa0660d0c2eea1e730cbe88585de2e0723dcf
--- /dev/null
+++ b/interpolation/download.py
@@ -0,0 +1,22 @@
+# All rights reserved.
+
+# This source code is licensed under the license found in the
+# LICENSE file in the root directory of this source tree.
+
+import torch
+import os
+
+
+pretrained_models = {''}
+
+
+def find_model(model_name):
+ """
+ Finds a pre-trained model, downloading it if necessary. Alternatively, loads a model from a local path.
+ """
+ assert os.path.isfile(model_name), f'Could not find checkpoint at {model_name}'
+ checkpoint = torch.load(model_name, map_location=lambda storage, loc: storage)
+ if "ema" in checkpoint: # supports checkpoints from train.py
+ checkpoint = checkpoint["ema"]
+ return checkpoint
+
diff --git a/interpolation/models/__init__.py b/interpolation/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..507091b8c274533eca620577bae301541d5a91f6
--- /dev/null
+++ b/interpolation/models/__init__.py
@@ -0,0 +1,33 @@
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+from .unet import UNet3DConditionModel
+from torch.optim.lr_scheduler import LambdaLR
+
+def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
+ from torch.optim.lr_scheduler import LambdaLR
+ def fn(step):
+ if warmup_steps > 0:
+ return min(step / warmup_steps, 1)
+ else:
+ return 1
+ return LambdaLR(optimizer, fn)
+
+
+def get_lr_scheduler(optimizer, name, **kwargs):
+ if name == 'warmup':
+ return customized_lr_scheduler(optimizer, **kwargs)
+ elif name == 'cosine':
+ from torch.optim.lr_scheduler import CosineAnnealingLR
+ return CosineAnnealingLR(optimizer, **kwargs)
+ else:
+ raise NotImplementedError(name)
+
+def get_models(args, ckpt_path):
+
+ if 'TSR' in args.model:
+ return UNet3DConditionModel.from_pretrained_2d(ckpt_path, subfolder="unet", use_concat=args.use_concat, copy_no_mask=args.copy_no_mask)
+ else:
+ raise '{} Model Not Supported!'.format(args.model)
+
diff --git a/interpolation/models/attention.py b/interpolation/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..1c05879560bd827c2096e557dcaa17fadc8c1bb0
--- /dev/null
+++ b/interpolation/models/attention.py
@@ -0,0 +1,665 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+from dataclasses import dataclass
+from typing import Optional
+
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention import FeedForward, AdaLayerNorm
+
+from einops import rearrange, repeat
+
+try:
+ from diffusers.models.modeling_utils import ModelMixin
+except:
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
+
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+class CrossAttention(nn.Module):
+ r"""
+ copy from diffuser 0.11.1
+ A cross attention layer.
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ use_relative_position: bool = False,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ self.dim_head = dim_head
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = False
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ # print(use_relative_position)
+ self.use_relative_position = use_relative_position
+ if self.use_relative_position:
+ # print(dim_head)
+ # print(heads)
+ # adopt https://github.com/huggingface/transformers/blob/8a817e1ecac6a420b1bdc701fcc33535a3b96ff5/src/transformers/models/bert/modeling_bert.py#L265
+ self.max_position_embeddings = 32
+ self.distance_embedding = nn.Embedding(2 * self.max_position_embeddings - 1, dim_head)
+
+ self.dropout = nn.Dropout(dropout)
+
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def reshape_for_scores(self, tensor):
+ # split heads and dims
+ # tensor should be [b (h w)] f (d nd)
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ return tensor
+
+ def same_batch_dim_to_heads(self, tensor):
+ batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
+ # if self.use_relative_position:
+ # print('before attention query shape', query.shape)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
+ # if self.use_relative_position:
+ # print('before attention query shape', query.shape)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ if self.use_relative_position:
+ query = self.reshape_for_scores(self.reshape_batch_dim_to_heads(query))
+ key = self.reshape_for_scores(self.reshape_batch_dim_to_heads(key))
+ value = self.reshape_for_scores(self.reshape_batch_dim_to_heads(value))
+
+ # torch.baddbmm only accepte 3-D tensor
+ # https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm
+ attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2))
+
+ # print('attention_scores shape', attention_scores.shape)
+
+ # print(query.shape) # [b (h w)] nd f d
+ query_length, key_length = query.shape[2], key.shape[2]
+ # print('query shape', query.shape)
+ # print('key shape', key.shape)
+ position_ids_l = torch.arange(query_length, dtype=torch.long, device=query.device).view(-1, 1) # hidden_states.device
+ position_ids_r = torch.arange(key_length, dtype=torch.long, device=key.device).view(1, -1) # hidden_states.device
+ distance = position_ids_l - position_ids_r
+ # print('distance shape', distance.shape)
+ positional_embedding = self.distance_embedding(distance + self.max_position_embeddings - 1)
+ positional_embedding = positional_embedding.to(dtype=query.dtype) # fp16 compatibility
+ # print('positional_embedding shape', positional_embedding.shape)
+ relative_position_scores_query = torch.einsum("bhld, lrd -> bhlr", query, positional_embedding)
+ relative_position_scores_key = torch.einsum("bhrd, lrd -> bhlr", key, positional_embedding)
+ # print('relative_position_scores_key shape', relative_position_scores_key.shape)
+ attention_scores = attention_scores + relative_position_scores_query + relative_position_scores_key
+ # print(attention_scores.shape)
+
+ attention_scores = attention_scores / math.sqrt(self.dim_head)
+
+ # Normalize the attention scores to probabilities.
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ hidden_states = torch.matmul(attention_probs, value)
+ # print(hidden_states.shape)
+ hidden_states = self.same_batch_dim_to_heads(hidden_states)
+ # print(hidden_states.shape)
+ # exit()
+
+ else:
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ # print(attention_probs.shape)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+ # print(attention_probs.shape)
+
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+ # print(hidden_states.shape)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ # print(hidden_states.shape)
+ # exit()
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ # TODO attention_mask
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class Transformer3DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # Define input layers
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ # Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if use_linear_projection:
+ self.proj_out = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
+ # Input
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ video_length=video_length
+ )
+
+ # Output
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ # print(only_cross_attention)
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+ self.use_first_frame = use_first_frame
+
+ # SC-Attn
+ if use_first_frame:
+ self.attn1 = SparseCausalAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+ # print(cross_attention_dim)
+ else:
+ self.attn1 = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ upcast_attention=upcast_attention,
+ )
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ # Cross-Attn
+ if cross_attention_dim is not None:
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ else:
+ self.attn2 = None
+
+ if cross_attention_dim is not None:
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ else:
+ self.norm2 = None
+
+ # Feed-forward
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.norm3 = nn.LayerNorm(dim)
+
+ # Temp-Attn
+ self.attn_temp = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ use_relative_position=use_relative_position
+ )
+ nn.init.zeros_(self.attn_temp.to_out[0].weight.data)
+ self.norm_temp = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op=None):
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ if self.attn2 is not None:
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ # self.attn_temp._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
+ # SparseCausal-Attention
+ norm_hidden_states = (
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ )
+
+ if self.only_cross_attention:
+ hidden_states = (
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
+ )
+ else:
+ if self.use_first_frame:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
+ else:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
+
+ if self.attn2 is not None:
+ # Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ hidden_states = (
+ self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ + hidden_states
+ )
+
+ # Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ # Temporal-Attention
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+ norm_hidden_states = (
+ self.norm_temp(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temp(hidden_states)
+ )
+ hidden_states = self.attn_temp(norm_hidden_states) + hidden_states
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
+
+
+class SparseCausalAttention(CrossAttention):
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ former_frame_index = torch.arange(video_length) - 1
+ former_frame_index[0] = 0
+
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
+ key = rearrange(key, "b f d c -> (b f) d c")
+
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
+ value = rearrange(value, "b f d c -> (b f) d c")
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
diff --git a/interpolation/models/clip.py b/interpolation/models/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..0d77f4d1a725a5d7ed0c4e10a69602e80b91fec1
--- /dev/null
+++ b/interpolation/models/clip.py
@@ -0,0 +1,124 @@
+import numpy
+import torch.nn as nn
+from transformers import CLIPTokenizer, CLIPTextModel
+
+import transformers
+transformers.logging.set_verbosity_error()
+
+"""
+Will encounter following warning:
+- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
+or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
+- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
+that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
+
+https://github.com/CompVis/stable-diffusion/issues/97
+according to this issue, this warning is safe.
+
+This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
+You can safely ignore the warning, it is not an error.
+
+This clip usage is from U-ViT and same with Stable Diffusion.
+"""
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ def __init__(self, sd_path, device="cuda", max_length=77):
+ super().__init__()
+ self.tokenizer = CLIPTokenizer.from_pretrained(sd_path, subfolder="tokenizer", use_fast=False)
+ self.transformer = CLIPTextModel.from_pretrained(sd_path, subfolder="text_encoder")
+ self.device = device
+ self.max_length = max_length
+ self.freeze()
+
+ def freeze(self):
+ self.transformer = self.transformer.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, max_length=self.max_length, return_length=True,
+ return_overflowing_tokens=False, padding="max_length", return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.transformer(input_ids=tokens)
+
+ z = outputs.last_hidden_state
+ return z
+
+ def encode(self, text):
+ return self(text)
+
+
+class TextEmbedder(nn.Module):
+ """
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
+ """
+ def __init__(self, args, dropout_prob=0.1):
+ super().__init__()
+ self.text_encodder = FrozenCLIPEmbedder(args)
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, text_prompts, force_drop_ids=None):
+ """
+ Drops text to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
+ else:
+ # TODO
+ drop_ids = force_drop_ids == 1
+ labels = list(numpy.where(drop_ids, "None", text_prompts))
+ # print(labels)
+ return labels
+
+ def forward(self, text_prompts, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
+ embeddings = self.text_encodder(text_prompts)
+ return embeddings
+
+
+if __name__ == '__main__':
+
+ r"""
+ Returns:
+
+ Examples from CLIPTextModel:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CLIPTextModel
+
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ import torch
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ text_encoder = TextEmbedder(dropout_prob=0.00001).to(device)
+ text_encoder1 = FrozenCLIPEmbedder().to(device)
+
+ text_prompt = ["a photo of a cat", "a photo of a dog", 'a photo of a dog human']
+ # text_prompt = ('None', 'None', 'None')
+ output = text_encoder(text_prompts=text_prompt, train=True)
+ output1 = text_encoder1(text_prompt)
+ # print(output)
+ print(output.shape)
+ print(output1.shape)
+ print((output==output1).all())
diff --git a/interpolation/models/resnet.py b/interpolation/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..c1307ad9a0fe26b6b1eca913b024c188df2dd86e
--- /dev/null
+++ b/interpolation/models/resnet.py
@@ -0,0 +1,212 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange
+
+
+class InflatedConv3d(nn.Conv2d):
+ def forward(self, x):
+ video_length = x.shape[2]
+
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ x = super().forward(x)
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+ return x
+
+
+class Upsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ raise NotImplementedError
+ elif use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
+
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, hidden_states, output_size=None):
+ assert hidden_states.shape[1] == self.channels
+
+ if self.use_conv_transpose:
+ raise NotImplementedError
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+
+class Downsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ raise NotImplementedError
+
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states):
+ assert hidden_states.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ raise NotImplementedError
+
+ assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ time_emb_proj_out_channels = out_channels
+ elif self.time_embedding_norm == "scale_shift":
+ time_emb_proj_out_channels = out_channels * 2
+ else:
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+class Mish(torch.nn.Module):
+ def forward(self, hidden_states):
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
\ No newline at end of file
diff --git a/interpolation/models/unet.py b/interpolation/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..66b161f18f6552a4a9f7838c461c9ed10c26c7bf
--- /dev/null
+++ b/interpolation/models/unet.py
@@ -0,0 +1,576 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+import json
+
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput, logging
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+
+try:
+ from diffusers.models.modeling_utils import ModelMixin
+except:
+ from diffusers.modeling_utils import ModelMixin # 0.11.1
+
+try:
+ from .unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+ )
+ from .resnet import InflatedConv3d
+except:
+ from unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+ )
+ from resnet import InflatedConv3d
+
+
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+@dataclass
+class UNet3DConditionOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+class UNet3DConditionModel(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ sample_size: Optional[int] = None, # 64
+ in_channels: int = 4,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ down_block_types: Tuple[str] = (
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D",
+ ),
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = False,
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1280,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = False,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = None,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ ):
+ super().__init__()
+
+ # print(use_first_frame)
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=(1, 1))
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim)
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+ # print(only_cross_attention)
+ # exit()
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+ # print(attention_head_dim)
+ # exit()
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ )
+ self.down_blocks.append(down_block)
+
+ # mid
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
+ self.mid_block = UNetMidBlock3DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ # count how many layers upsample the videos
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_slicable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_slicable_dims(module)
+
+ num_slicable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_slicable_layers * [1]
+
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ encoder_hidden_states: torch.Tensor = None,
+ class_labels: Optional[torch.Tensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ) -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ # print(emb.shape) # torch.Size([3, 1280])
+ # print(class_emb.shape) # torch.Size([3, 1280])
+ emb = emb + class_emb
+
+ # pre-process
+ sample = self.conv_in(sample)
+
+ # down
+ down_block_res_samples = (sample,)
+ for downsample_block in self.down_blocks:
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # mid
+ sample = self.mid_block(
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+
+ # up
+ for i, upsample_block in enumerate(self.up_blocks):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+ # print(sample.shape)
+
+ if not return_dict:
+ return (sample,)
+ sample = UNet3DConditionOutput(sample=sample)
+ return sample
+
+ def forward_with_cfg(self,
+ x,
+ t,
+ encoder_hidden_states = None,
+ class_labels: Optional[torch.Tensor] = None,
+ cfg_scale=4.0):
+ """
+ Forward, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[: len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ model_out = self.forward(combined, t, encoder_hidden_states, class_labels).sample
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ # eps, rest = model_out[:, :4], model_out[:, 4:]
+ eps, rest = model_out[:, :4], model_out[:, 4:] # b c f h w
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+ @classmethod
+ def from_pretrained_2d(cls, pretrained_model_path, subfolder=None, use_concat=False, copy_no_mask=False):
+ if subfolder is not None:
+ pretrained_model_path = os.path.join(pretrained_model_path, subfolder)
+
+
+ config_file = os.path.join(pretrained_model_path, 'config.json')
+ if not os.path.isfile(config_file):
+ raise RuntimeError(f"{config_file} does not exist")
+ with open(config_file, "r") as f:
+ config = json.load(f)
+ config["_class_name"] = cls.__name__
+ config["down_block_types"] = [
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "DownBlock3D"
+ ]
+ config["up_block_types"] = [
+ "UpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D"
+ ]
+
+ config["use_first_frame"] = True
+
+ if copy_no_mask:
+ config["in_channels"] = 8
+ else:
+ if use_concat:
+ config["in_channels"] = 9
+
+
+ from diffusers.utils import WEIGHTS_NAME # diffusion_pytorch_model.bin
+
+
+ model = cls.from_config(config)
+ model_file = os.path.join(pretrained_model_path, WEIGHTS_NAME)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ state_dict = torch.load(model_file, map_location="cpu")
+
+
+ if use_concat:
+ new_state_dict = {}
+ conv_in_weight = state_dict["conv_in.weight"]
+
+ print(f'from_pretrained_2d copy_no_mask = {copy_no_mask}')
+ if copy_no_mask:
+ new_conv_in_channel = 8
+ new_conv_in_list = [0, 1, 2, 3, 4, 5, 6, 7]
+ else:
+ new_conv_in_channel = 9
+ new_conv_in_list = [0, 1, 2, 3, 4, 5, 6, 7, 8]
+ new_conv_weight = torch.zeros((conv_in_weight.shape[0], new_conv_in_channel, *conv_in_weight.shape[2:]), dtype=conv_in_weight.dtype)
+
+ for i, j in zip([0, 1, 2, 3], new_conv_in_list):
+ new_conv_weight[:, j] = conv_in_weight[:, i]
+ new_state_dict["conv_in.weight"] = new_conv_weight
+ new_state_dict["conv_in.bias"] = state_dict["conv_in.bias"]
+ for k, v in model.state_dict().items():
+ # print(k)
+ if '_temp.' in k:
+ new_state_dict.update({k: v})
+ elif 'conv_in' in k:
+ continue
+ else:
+ new_state_dict[k] = v
+ # # tmp
+ # if 'class_embedding' in k:
+ # state_dict.update({k: v})
+ # breakpoint()
+ model.load_state_dict(new_state_dict)
+ else:
+ for k, v in model.state_dict().items():
+ # print(k)
+ if '_temp.' in k:
+ state_dict.update({k: v})
+ model.load_state_dict(state_dict)
+ return model
+
+if __name__ == '__main__':
+ import torch
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ pretrained_model_path = "/nvme/maxin/work/large-dit-video/pretrained/stable-diffusion-v1-4/" # 43
+ unet = UNet3DConditionModel.from_pretrained_2d(pretrained_model_path, subfolder="unet").to(device)
+
+ noisy_latents = torch.randn((3, 4, 16, 32, 32)).to(device)
+ bsz = noisy_latents.shape[0]
+ timesteps = torch.randint(0, 1000, (bsz,)).to(device)
+ timesteps = timesteps.long()
+ encoder_hidden_states = torch.randn((bsz, 77, 768)).to(device)
+ class_labels = torch.randn((bsz, )).to(device)
+
+ model_pred = unet(sample=noisy_latents, timestep=timesteps,
+ encoder_hidden_states=encoder_hidden_states,
+ class_labels=class_labels).sample
+ print(model_pred.shape)
\ No newline at end of file
diff --git a/interpolation/models/unet_blocks.py b/interpolation/models/unet_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..0bfed25db240043dcc42b7e4459f5ca52d3cd902
--- /dev/null
+++ b/interpolation/models/unet_blocks.py
@@ -0,0 +1,619 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+import torch
+from torch import nn
+
+try:
+ from .attention import Transformer3DModel
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
+except:
+ from attention import Transformer3DModel
+ from resnet import Downsample3D, ResnetBlock3D, Upsample3D
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ use_first_frame=False,
+ use_relative_position=False,
+):
+ # print(down_block_type)
+ # print(use_first_frame)
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "CrossAttnDownBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+ return CrossAttnDownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ use_first_frame=False,
+ use_relative_position=False,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "CrossAttnUpBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+ return CrossAttnUpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ )
+ )
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class CrossAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ # print(use_first_frame)
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ attention_mask=None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
diff --git a/interpolation/models/utils.py b/interpolation/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..b6cfac1ff6c2c99c87920372482251dc2b2fce34
--- /dev/null
+++ b/interpolation/models/utils.py
@@ -0,0 +1,215 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+
+import numpy as np
+import torch.nn as nn
+
+from einops import repeat
+
+
+#################################################################################
+# Unet Utils #
+#################################################################################
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+# class HybridConditioner(nn.Module):
+
+# def __init__(self, c_concat_config, c_crossattn_config):
+# super().__init__()
+# self.concat_conditioner = instantiate_from_config(c_concat_config)
+# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+# def forward(self, c_concat, c_crossattn):
+# c_concat = self.concat_conioner(c_concat)
+# c_crossattn = self.crossattn_conditioner(c_crossattn)
+# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += torch.DoubleTensor([matmul_ops])
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
\ No newline at end of file
diff --git a/interpolation/sample.py b/interpolation/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..f4413e8407338e5ee5fe0df6625f42aef6d9671a
--- /dev/null
+++ b/interpolation/sample.py
@@ -0,0 +1,312 @@
+"""
+we introduce a temporal interpolation network to enhance the smoothness of generated videos and synthesize richer temporal details.
+This network takes a 16-frame base video as input and produces an upsampled output consisting of 61 frames.
+"""
+
+import os
+import sys
+import math
+try:
+ import utils
+
+ from diffusion import create_diffusion
+ from download import find_model
+except:
+ sys.path.append(os.path.split(sys.path[0])[0])
+
+ import utils
+
+ from diffusion import create_diffusion
+ from download import find_model
+
+import torch
+import argparse
+import torchvision
+
+from einops import rearrange
+from models import get_models
+from torchvision.utils import save_image
+from diffusers.models import AutoencoderKL
+from models.clip import TextEmbedder
+from omegaconf import OmegaConf
+from PIL import Image
+import numpy as np
+from torchvision import transforms
+sys.path.append("..")
+from datasets import video_transforms
+from decord import VideoReader
+from utils import mask_generation, mask_generation_before
+from natsort import natsorted
+from diffusers.utils.import_utils import is_xformers_available
+
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+
+
+def get_input(args):
+ input_path = args.input_path
+ transform_video = transforms.Compose([
+ video_transforms.ToTensorVideo(), # TCHW
+ # video_transforms.CenterCropResizeVideo((args.image_h, args.image_w)),
+ video_transforms.ResizeVideo((args.image_h, args.image_w)),
+ transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
+ ])
+ temporal_sample_func = video_transforms.TemporalRandomCrop(args.num_frames * args.frame_interval)
+ if input_path is not None:
+ print(f'loading video from {input_path}')
+ if os.path.isdir(input_path):
+ file_list = os.listdir(input_path)
+ video_frames = []
+ for file in file_list:
+ if file.endswith('jpg') or file.endswith('png'):
+ image = torch.as_tensor(np.array(Image.open(file), dtype=np.uint8, copy=True)).unsqueeze(0)
+ video_frames.append(image)
+ else:
+ continue
+ n = 0
+ video_frames = torch.cat(video_frames, dim=0).permute(0, 3, 1, 2) # f,c,h,w
+ video_frames = transform_video(video_frames)
+ return video_frames, n
+ elif os.path.isfile(input_path):
+ _, full_file_name = os.path.split(input_path)
+ file_name, extention = os.path.splitext(full_file_name)
+ if extention == '.mp4':
+ video_reader = VideoReader(input_path)
+ total_frames = len(video_reader)
+ start_frame_ind, end_frame_ind = temporal_sample_func(total_frames)
+ frame_indice = np.linspace(start_frame_ind, end_frame_ind-1, args.num_frames, dtype=int)
+ video_frames = torch.from_numpy(video_reader.get_batch(frame_indice).asnumpy()).permute(0, 3, 1, 2).contiguous()
+ video_frames = transform_video(video_frames)
+ n = 3
+ del video_reader
+ return video_frames, n
+ else:
+ raise TypeError(f'{extention} is not supported !!')
+ else:
+ raise ValueError('Please check your path input!!')
+ else:
+ print('given video is None, using text to video')
+ video_frames = torch.zeros(16,3,args.latent_h,args.latent_w,dtype=torch.uint8)
+ args.mask_type = 'all'
+ video_frames = transform_video(video_frames)
+ n = 0
+ return video_frames, n
+
+
+def auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,):
+
+
+ b,f,c,h,w=video_input.shape
+ latent_h = args.image_size[0] // 8
+ latent_w = args.image_size[1] // 8
+
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w
+
+ masked_video = rearrange(masked_video, 'b f c h w -> (b f) c h w').contiguous()
+ masked_video = vae.encode(masked_video).latent_dist.sample().mul_(0.18215)
+ masked_video = rearrange(masked_video, '(b f) c h w -> b c f h w', b=b).contiguous()
+ mask = torch.nn.functional.interpolate(mask[:,:,0,:], size=(latent_h, latent_w)).unsqueeze(1)
+
+
+ masked_video = torch.cat([masked_video] * 2) if args.do_classifier_free_guidance else masked_video
+ mask = torch.cat([mask] * 2) if args.do_classifier_free_guidance else mask
+ z = torch.cat([z] * 2) if args.do_classifier_free_guidance else z
+
+ prompt_all = [prompt] + [args.negative_prompt] if args.do_classifier_free_guidance else [prompt]
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
+ model_kwargs = dict(encoder_hidden_states=text_prompt, class_labels=None)
+
+ if args.use_ddim_sample_loop:
+ samples = diffusion.ddim_sample_loop(
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \
+ progress=True, device=device, mask=mask, x_start=masked_video, use_concat=args.use_concat
+ )
+ else:
+ samples = diffusion.p_sample_loop(
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \
+ progress=True, device=device, mask=mask, x_start=masked_video, use_concat=args.use_concat
+ ) # torch.Size([2, 4, 16, 32, 32])
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
+
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
+ return video_clip
+
+
+def auto_inpainting_copy_no_mask(args, video_input, prompt, vae, text_encoder, diffusion, model, device,):
+
+ b,f,c,h,w=video_input.shape
+ latent_h = args.image_size[0] // 8
+ latent_w = args.image_size[1] // 8
+
+ video_input = rearrange(video_input, 'b f c h w -> (b f) c h w').contiguous()
+ video_input = vae.encode(video_input).latent_dist.sample().mul_(0.18215)
+ video_input = rearrange(video_input, '(b f) c h w -> b c f h w', b=b).contiguous()
+
+ lr_indice = torch.IntTensor([i for i in range(0,62,4)]).to(device)
+ copied_video = torch.index_select(video_input, 2, lr_indice)
+ copied_video = torch.repeat_interleave(copied_video, 4, dim=2)
+ copied_video = copied_video[:,:,1:-2,:,:]
+ copied_video = torch.cat([copied_video] * 2) if args.do_classifier_free_guidance else copied_video
+
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ z = torch.randn(1, 4, args.num_frames, args.latent_h, args.latent_w, device=device) # b,c,f,h,w
+ z = torch.cat([z] * 2) if args.do_classifier_free_guidance else z
+
+ prompt_all = [prompt] + [args.negative_prompt] if args.do_classifier_free_guidance else [prompt]
+ text_prompt = text_encoder(text_prompts=prompt_all, train=False)
+ model_kwargs = dict(encoder_hidden_states=text_prompt, class_labels=None)
+
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ if args.use_ddim_sample_loop:
+ samples = diffusion.ddim_sample_loop(
+ model.forward_with_cfg, z.shape, z, clip_denoised=False, model_kwargs=model_kwargs, \
+ progress=True, device=device, mask=None, x_start=copied_video, use_concat=args.use_concat, copy_no_mask=args.copy_no_mask,
+ )
+ else:
+ raise ValueError(f'We only have ddim sampling implementation for now')
+
+ samples, _ = samples.chunk(2, dim=0) # [1, 4, 16, 32, 32]
+
+ video_clip = samples[0].permute(1, 0, 2, 3).contiguous() # [16, 4, 32, 32]
+ video_clip = vae.decode(video_clip / 0.18215).sample # [16, 3, 256, 256]
+ return video_clip
+
+
+
+def main(args):
+
+ for seed in args.seed_list:
+
+ args.seed = seed
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+ # print(f'torch.seed() = {torch.seed()}')
+
+ print('sampling begins')
+ torch.set_grad_enabled(False)
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ # device = "cpu"
+
+ ckpt_path = args.pretrained_path + "/lavie_interpolation.pt"
+ sd_path = args.pretrained_path + "/stable-diffusion-v1-4"
+ for ckpt in [ckpt_path]:
+
+ ckpt_num = str(ckpt_path).zfill(7)
+
+ # Load model:
+ latent_h = args.image_size[0] // 8
+ latent_w = args.image_size[1] // 8
+ args.image_h = args.image_size[0]
+ args.image_w = args.image_size[1]
+ args.latent_h = latent_h
+ args.latent_w = latent_w
+ print(f'args.copy_no_mask = {args.copy_no_mask}')
+ model = get_models(args, sd_path).to(device)
+
+ if args.use_compile:
+ model = torch.compile(model)
+ if args.enable_xformers_memory_efficient_attention:
+ if is_xformers_available():
+ model.enable_xformers_memory_efficient_attention()
+ # model.enable_vae_slicing() # ziqi added
+ else:
+ raise ValueError("xformers is not available. Make sure it is installed correctly")
+
+ # Auto-download a pre-trained model or load a custom checkpoint from train.py:
+ print(f'loading model from {ckpt_path}')
+
+ # load ckpt
+ state_dict = find_model(ckpt_path)
+
+ print(f'state_dict["conv_in.weight"].shape = {state_dict["conv_in.weight"].shape}') # [320, 8, 3, 3]
+ print('loading succeed')
+ # model.load_state_dict(state_dict)
+
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+
+ model.eval() # important!
+ diffusion = create_diffusion(str(args.num_sampling_steps))
+ vae = AutoencoderKL.from_pretrained(sd_path, subfolder="vae").to(device)
+ text_encoder = TextEmbedder(sd_path).to(device)
+
+ video_list = os.listdir(args.input_folder)
+ args.input_path_list = [os.path.join(args.input_folder, video) for video in video_list]
+ for input_path in args.input_path_list:
+
+ args.input_path = input_path
+
+ print(f'=======================================')
+ if not args.input_path.endswith('.mp4'):
+ print(f'Skipping {args.input_path}')
+ continue
+
+ print(f'args.input_path = {args.input_path}')
+
+ torch.manual_seed(args.seed)
+ torch.cuda.manual_seed(args.seed)
+
+ # Labels to condition the model with (feel free to change):
+ video_name = args.input_path.split('/')[-1].split('.mp4')[0]
+ args.prompt = [video_name]
+ print(f'args.prompt = {args.prompt}')
+ prompts = args.prompt
+ class_name = [p + args.additional_prompt for p in prompts]
+
+ if not os.path.exists(os.path.join(args.output_folder)):
+ os.makedirs(os.path.join(args.output_folder))
+ video_input, researve_frames = get_input(args) # f,c,h,w
+ video_input = video_input.to(device).unsqueeze(0) # b,f,c,h,w
+ if args.copy_no_mask:
+ pass
+ else:
+ mask = mask_generation_before(args.mask_type, video_input.shape, video_input.dtype, device) # b,f,c,h,w
+
+ if args.copy_no_mask:
+ pass
+ else:
+ if args.mask_type == 'tsr':
+ masked_video = video_input * (mask == 0)
+ else:
+ masked_video = video_input * (mask == 0)
+
+ all_video = []
+ if researve_frames != 0:
+ all_video.append(video_input)
+ for idx, prompt in enumerate(class_name):
+ if idx == 0:
+ if args.copy_no_mask:
+ video_clip = auto_inpainting_copy_no_mask(args, video_input, prompt, vae, text_encoder, diffusion, model, device,)
+ video_clip_ = video_clip.unsqueeze(0)
+ all_video.append(video_clip_)
+ else:
+ video_clip = auto_inpainting(args, video_input, masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,)
+ video_clip_ = video_clip.unsqueeze(0)
+ all_video.append(video_clip_)
+ else:
+ raise NotImplementedError
+ masked_video = video_input * (mask == 0)
+ video_clip = auto_inpainting_copy_no_mask(args, video_clip.unsqueeze(0), masked_video, mask, prompt, vae, text_encoder, diffusion, model, device,)
+ video_clip_ = video_clip.unsqueeze(0)
+ all_video.append(video_clip_[:, 3:])
+ video_ = ((video_clip * 0.5 + 0.5) * 255).add_(0.5).clamp_(0, 255).to(dtype=torch.uint8).cpu().permute(0, 2, 3, 1)
+ for fps in args.fps_list:
+ save_path = args.output_folder
+ if not os.path.exists(os.path.join(save_path)):
+ os.makedirs(os.path.join(save_path))
+ local_save_path = os.path.join(save_path, f'{video_name}.mp4')
+ print(f'save in {local_save_path}')
+ torchvision.io.write_video(local_save_path, video_, fps=fps)
+
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, required=True)
+ args = parser.parse_args()
+ main(**OmegaConf.load(args.config))
+
+
diff --git a/interpolation/utils.py b/interpolation/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..3219c0af49cb1bd039abe6899aab0ec85a4a4b1e
--- /dev/null
+++ b/interpolation/utils.py
@@ -0,0 +1,371 @@
+import os
+import math
+import torch
+import logging
+import subprocess
+import numpy as np
+import torch.distributed as dist
+
+# from torch._six import inf
+from torch import inf
+from PIL import Image
+from typing import Union, Iterable
+from collections import OrderedDict
+
+
+_tensor_or_tensors = Union[torch.Tensor, Iterable[torch.Tensor]]
+
+#################################################################################
+# Training Helper Functions #
+#################################################################################
+
+#################################################################################
+# Training Clip Gradients #
+#################################################################################
+
+def get_grad_norm(
+ parameters: _tensor_or_tensors, norm_type: float = 2.0) -> torch.Tensor:
+ r"""
+ Copy from torch.nn.utils.clip_grad_norm_
+
+ Clips gradient norm of an iterable of parameters.
+
+ The norm is computed over all gradients together, as if they were
+ concatenated into a single vector. Gradients are modified in-place.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
+ infinity norm.
+ error_if_nonfinite (bool): if True, an error is thrown if the total
+ norm of the gradients from :attr:`parameters` is ``nan``,
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ grads = [p.grad for p in parameters if p.grad is not None]
+ norm_type = float(norm_type)
+ if len(grads) == 0:
+ return torch.tensor(0.)
+ device = grads[0].device
+ if norm_type == inf:
+ norms = [g.detach().abs().max().to(device) for g in grads]
+ total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
+ return total_norm
+
+def clip_grad_norm_(
+ parameters: _tensor_or_tensors, max_norm: float, norm_type: float = 2.0,
+ error_if_nonfinite: bool = False, clip_grad = True) -> torch.Tensor:
+ r"""
+ Copy from torch.nn.utils.clip_grad_norm_
+
+ Clips gradient norm of an iterable of parameters.
+
+ The norm is computed over all gradients together, as if they were
+ concatenated into a single vector. Gradients are modified in-place.
+
+ Args:
+ parameters (Iterable[Tensor] or Tensor): an iterable of Tensors or a
+ single Tensor that will have gradients normalized
+ max_norm (float or int): max norm of the gradients
+ norm_type (float or int): type of the used p-norm. Can be ``'inf'`` for
+ infinity norm.
+ error_if_nonfinite (bool): if True, an error is thrown if the total
+ norm of the gradients from :attr:`parameters` is ``nan``,
+ ``inf``, or ``-inf``. Default: False (will switch to True in the future)
+
+ Returns:
+ Total norm of the parameter gradients (viewed as a single vector).
+ """
+ if isinstance(parameters, torch.Tensor):
+ parameters = [parameters]
+ grads = [p.grad for p in parameters if p.grad is not None]
+ max_norm = float(max_norm)
+ norm_type = float(norm_type)
+ if len(grads) == 0:
+ return torch.tensor(0.)
+ device = grads[0].device
+ if norm_type == inf:
+ norms = [g.detach().abs().max().to(device) for g in grads]
+ total_norm = norms[0] if len(norms) == 1 else torch.max(torch.stack(norms))
+ else:
+ total_norm = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
+ # print(total_norm)
+
+ if clip_grad:
+ if error_if_nonfinite and torch.logical_or(total_norm.isnan(), total_norm.isinf()):
+ raise RuntimeError(
+ f'The total norm of order {norm_type} for gradients from '
+ '`parameters` is non-finite, so it cannot be clipped. To disable '
+ 'this error and scale the gradients by the non-finite norm anyway, '
+ 'set `error_if_nonfinite=False`')
+ clip_coef = max_norm / (total_norm + 1e-6)
+ # Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
+ # avoids a `if clip_coef < 1:` conditional which can require a CPU <=> device synchronization
+ # when the gradients do not reside in CPU memory.
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
+ for g in grads:
+ g.detach().mul_(clip_coef_clamped.to(g.device))
+ # gradient_cliped = torch.norm(torch.stack([torch.norm(g.detach(), norm_type).to(device) for g in grads]), norm_type)
+ # print(gradient_cliped)
+ return total_norm
+
+#################################################################################
+# Training Logger #
+#################################################################################
+
+def create_logger(logging_dir):
+ """
+ Create a logger that writes to a log file and stdout.
+ """
+ if dist.get_rank() == 0: # real logger
+ logging.basicConfig(
+ level=logging.INFO,
+ # format='[\033[34m%(asctime)s\033[0m] %(message)s',
+ format='[%(asctime)s] %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
+ )
+ logger = logging.getLogger(__name__)
+
+ else: # dummy logger (does nothing)
+ logger = logging.getLogger(__name__)
+ logger.addHandler(logging.NullHandler())
+ return logger
+
+def create_accelerate_logger(logging_dir, is_main_process=False):
+ """
+ Create a logger that writes to a log file and stdout.
+ """
+ if is_main_process: # real logger
+ logging.basicConfig(
+ level=logging.INFO,
+ # format='[\033[34m%(asctime)s\033[0m] %(message)s',
+ format='[%(asctime)s] %(message)s',
+ datefmt='%Y-%m-%d %H:%M:%S',
+ handlers=[logging.StreamHandler(), logging.FileHandler(f"{logging_dir}/log.txt")]
+ )
+ logger = logging.getLogger(__name__)
+ else: # dummy logger (does nothing)
+ logger = logging.getLogger(__name__)
+ logger.addHandler(logging.NullHandler())
+ return logger
+
+
+def create_tensorboard(tensorboard_dir):
+ """
+ Create a tensorboard that saves losses.
+ """
+ if dist.get_rank() == 0: # real tensorboard
+ # tensorboard
+ writer = SummaryWriter(tensorboard_dir)
+
+ return writer
+
+def write_tensorboard(writer, *args):
+ '''
+ write the loss information to a tensorboard file.
+ Only for pytorch DDP mode.
+ '''
+ if dist.get_rank() == 0: # real tensorboard
+ writer.add_scalar(args[0], args[1], args[2])
+
+#################################################################################
+# EMA Update/ DDP Training Utils #
+#################################################################################
+
+@torch.no_grad()
+def update_ema(ema_model, model, decay=0.9999):
+ """
+ Step the EMA model towards the current model.
+ """
+ ema_params = OrderedDict(ema_model.named_parameters())
+ model_params = OrderedDict(model.named_parameters())
+
+ for name, param in model_params.items():
+ # TODO: Consider applying only to params that require_grad to avoid small numerical changes of pos_embed
+ ema_params[name].mul_(decay).add_(param.data, alpha=1 - decay)
+
+def requires_grad(model, flag=True):
+ """
+ Set requires_grad flag for all parameters in a model.
+ """
+ for p in model.parameters():
+ p.requires_grad = flag
+
+def cleanup():
+ """
+ End DDP training.
+ """
+ dist.destroy_process_group()
+
+
+def setup_distributed(backend="nccl", port=None):
+ """Initialize distributed training environment.
+ support both slurm and torch.distributed.launch
+ see torch.distributed.init_process_group() for more details
+ """
+ num_gpus = torch.cuda.device_count()
+
+ print(f'Hahahahahaha')
+ if "SLURM_JOB_ID" in os.environ:
+ rank = int(os.environ["SLURM_PROCID"])
+ world_size = int(os.environ["SLURM_NTASKS"])
+ node_list = os.environ["SLURM_NODELIST"]
+ addr = subprocess.getoutput(f"scontrol show hostname {node_list} | head -n1")
+ # specify master port
+ if port is not None:
+ os.environ["MASTER_PORT"] = str(port)
+ elif "MASTER_PORT" not in os.environ:
+ # os.environ["MASTER_PORT"] = "29566"
+ os.environ["MASTER_PORT"] = str(29566 + num_gpus)
+ if "MASTER_ADDR" not in os.environ:
+ os.environ["MASTER_ADDR"] = addr
+ os.environ["WORLD_SIZE"] = str(world_size)
+ os.environ["LOCAL_RANK"] = str(rank % num_gpus)
+ os.environ["RANK"] = str(rank)
+ else:
+ rank = int(os.environ["RANK"])
+ world_size = int(os.environ["WORLD_SIZE"])
+
+ # torch.cuda.set_device(rank % num_gpus)
+
+ print(f'before dist.init_process_group')
+
+ dist.init_process_group(
+ backend=backend,
+ world_size=world_size,
+ rank=rank,
+ )
+ print(f'after dist.init_process_group')
+
+#################################################################################
+# Testing Utils #
+#################################################################################
+
+def save_video_grid(video, nrow=None):
+ b, t, h, w, c = video.shape
+
+ if nrow is None:
+ nrow = math.ceil(math.sqrt(b))
+ ncol = math.ceil(b / nrow)
+ padding = 1
+ video_grid = torch.zeros((t, (padding + h) * nrow + padding,
+ (padding + w) * ncol + padding, c), dtype=torch.uint8)
+
+ print(video_grid.shape)
+ for i in range(b):
+ r = i // ncol
+ c = i % ncol
+ start_r = (padding + h) * r
+ start_c = (padding + w) * c
+ video_grid[:, start_r:start_r + h, start_c:start_c + w] = video[i]
+
+ return video_grid
+
+
+#################################################################################
+# MMCV Utils #
+#################################################################################
+
+
+def collect_env():
+ # Copyright (c) OpenMMLab. All rights reserved.
+ from mmcv.utils import collect_env as collect_base_env
+ from mmcv.utils import get_git_hash
+ """Collect the information of the running environments."""
+
+ env_info = collect_base_env()
+ env_info['MMClassification'] = get_git_hash()[:7]
+
+ for name, val in env_info.items():
+ print(f'{name}: {val}')
+
+ print(torch.cuda.get_arch_list())
+ print(torch.version.cuda)
+
+#################################################################################
+# Long video generation Utils #
+#################################################################################
+
+def mask_generation(mask_type, shape, dtype, device):
+ b, c, f, h, w = shape
+ if mask_type.startswith('random'):
+ num = float(mask_type.split('random')[-1])
+ mask_f = torch.ones(1, 1, f, 1, 1, dtype=dtype, device=device)
+ indices = torch.randperm(f, device=device)[:int(f*num)]
+ mask_f[0, 0, indices, :, :] = 0
+ mask = mask_f.expand(b, c, -1, h, w)
+ elif mask_type.startswith('first'):
+ num = int(mask_type.split('first')[-1])
+ mask_f = torch.cat([torch.zeros(1, 1, num, 1, 1, dtype=dtype, device=device),
+ torch.ones(1, 1, f-num, 1, 1, dtype=dtype, device=device)], dim=2)
+ mask = mask_f.expand(b, c, -1, h, w)
+ else:
+ raise ValueError(f"Invalid mask type: {mask_type}")
+ return mask
+
+
+
+def mask_generation_before(mask_type, shape, dtype, device):
+ b, f, c, h, w = shape
+ if mask_type.startswith('random'):
+ num = float(mask_type.split('random')[-1])
+ mask_f = torch.ones(1, f, 1, 1, 1, dtype=dtype, device=device)
+ indices = torch.randperm(f, device=device)[:int(f*num)]
+ mask_f[0, indices, :, :, :] = 0
+ mask = mask_f.expand(b, -1, c, h, w)
+ elif mask_type.startswith('first'):
+ num = int(mask_type.split('first')[-1])
+ mask_f = torch.cat([torch.zeros(1, num, 1, 1, 1, dtype=dtype, device=device),
+ torch.ones(1, f-num, 1, 1, 1, dtype=dtype, device=device)], dim=1)
+ mask = mask_f.expand(b, -1, c, h, w)
+ elif mask_type.startswith('uniform'):
+ p = float(mask_type.split('uniform')[-1])
+ mask_f = torch.ones(1, f, 1, 1, 1, dtype=dtype, device=device)
+ mask_f[0, torch.rand(f, device=device) < p, :, :, :] = 0
+ print(f'mask_f: = {mask_f}')
+ mask = mask_f.expand(b, -1, c, h, w)
+ print(f'mask.shape: = {mask.shape}, mask: = {mask}')
+ elif mask_type.startswith('all'):
+ mask = torch.ones(b,f,c,h,w,dtype=dtype,device=device)
+ elif mask_type.startswith('onelast'):
+ num = int(mask_type.split('onelast')[-1])
+ mask_one = torch.zeros(1,1,1,1,1, dtype=dtype, device=device)
+ mask_mid = torch.ones(1,f-2*num,1,1,1,dtype=dtype, device=device)
+ mask_last = torch.zeros_like(mask_one)
+ mask = torch.cat([mask_one]*num + [mask_mid] + [mask_last]*num, dim=1)
+ # breakpoint()
+ mask = mask.expand(b, -1, c, h, w)
+ elif mask_type.startswith('interpolate'):
+ mask_f = []
+ for i in range(4):
+ mask_zero = torch.zeros(1,1,1,1,1, dtype=dtype, device=device)
+ mask_f.append(mask_zero)
+ mask_one = torch.ones(1,3,1,1,1, dtype=dtype, device=device)
+ mask_f.append(mask_one)
+ mask = torch.cat(mask_f, dim=1)
+ print(f'mask={mask}')
+ elif mask_type.startswith('tsr'):
+ mask_f = []
+ mask_zero = torch.zeros(1,1,1,1,1, dtype=dtype, device=device)
+ mask_one = torch.ones(1,3,1,1,1, dtype=dtype, device=device)
+ for i in range(15):
+ mask_f.append(mask_zero) # not masked
+ mask_f.append(mask_one) # masked
+ mask_f.append(mask_zero) # not masked
+ mask = torch.cat(mask_f, dim=1)
+ # print(f'before mask.shape = {mask.shape}, mask = {mask}') # [1, 61, 1, 1, 1]
+ mask = mask.expand(b, -1, c, h, w)
+ # print(f'after mask.shape = {mask.shape}, mask = {mask}') # [4, 61, 3, 256, 256]
+ else:
+ raise ValueError(f"Invalid mask type: {mask_type}")
+
+ return mask
diff --git a/pretrained_models/put_pre-trained_weights_here.txt b/pretrained_models/put_pre-trained_weights_here.txt
new file mode 100644
index 0000000000000000000000000000000000000000..d7bf9943847113cf93c998a6fec2e8d6a8324404
--- /dev/null
+++ b/pretrained_models/put_pre-trained_weights_here.txt
@@ -0,0 +1 @@
+put pre-trianed weights here.
\ No newline at end of file
diff --git a/vsr/configs/sample.yaml b/vsr/configs/sample.yaml
new file mode 100644
index 0000000000000000000000000000000000000000..96234f563b8f12e7c816bcdc1d1e52d084b029b7
--- /dev/null
+++ b/vsr/configs/sample.yaml
@@ -0,0 +1,6 @@
+pretrained_path: "../pretrained_models"
+input_path: "../res/base"
+output_path: "../res/vsr"
+noise_level: 50
+guidance_scale: 5
+inference_steps: 50
diff --git a/vsr/configs/unet_3d_config.json b/vsr/configs/unet_3d_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..ccebb1884d60b9f309b0b3fd45f9d284034e1728
--- /dev/null
+++ b/vsr/configs/unet_3d_config.json
@@ -0,0 +1,66 @@
+{
+ "_class_name": "UNet3DVSRModel",
+ "_diffusers_version": "0.9.0.dev0",
+ "_name_or_path": "hf-models/stable-diffusion-x4-upscaler/unet",
+ "act_fn": "silu",
+ "attention_head_dim": 8,
+ "block_out_channels": [
+ 256,
+ 512,
+ 512,
+ 1024
+ ],
+ "center_input_sample": false,
+ "cross_attention_dim": 1024,
+ "down_block_types": [
+ "DownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D"
+ ],
+ "downsample_padding": 1,
+ "dual_cross_attention": false,
+ "flip_sin_to_cos": true,
+ "freq_shift": 0,
+ "in_channels": 7,
+ "layers_per_block": 2,
+ "mid_block_scale_factor": 1,
+ "norm_eps": 1e-05,
+ "norm_num_groups": 32,
+ "num_class_embeds": 1000,
+ "only_cross_attention": [
+ true,
+ true,
+ true,
+ false
+ ],
+ "out_channels": 4,
+ "sample_size": 128,
+ "up_block_types": [
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "UpBlock3D"
+ ],
+ "use_linear_projection": true,
+
+ "down_temporal_idx": [0, 1, 2, 3],
+ "mid_temporal": true,
+ "up_temporal_idx": [0, 1, 2, 3],
+ "temporal_module_config": {
+ "num_attention_layers": 1,
+ "attention_block_types": [
+ "",
+ ""
+ ],
+ "cross_frame_attention_mode": "0_i-1_i",
+ "temporal_shift_fold_div": 2,
+ "temporal_shift_direction": "right",
+ "use_dcn_warpping": false,
+ "use_deformable_conv": true,
+ "attention_dim_div": 2
+ },
+ "use_first_frame": false,
+ "video_condition": false,
+ "freeze_pretrained_2d_upsampler": true
+}
diff --git a/vsr/configs/vae_config.json b/vsr/configs/vae_config.json
new file mode 100644
index 0000000000000000000000000000000000000000..3dca8898737c4bbb92ccd502b080507483ad1728
--- /dev/null
+++ b/vsr/configs/vae_config.json
@@ -0,0 +1,28 @@
+{
+ "_class_name": "AutoencoderKL",
+ "_diffusers_version": "0.9.0.dev0",
+ "_name_or_path": "hf-models/stable-diffusion-x4-upscaler/vae",
+ "act_fn": "silu",
+ "block_out_channels": [
+ 128,
+ 256,
+ 512
+ ],
+ "down_block_types": [
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D",
+ "DownEncoderBlock2D"
+ ],
+ "in_channels": 3,
+ "latent_channels": 4,
+ "layers_per_block": 2,
+ "norm_num_groups": 32,
+ "out_channels": 3,
+ "sample_size": 256,
+ "up_block_types": [
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D",
+ "UpDecoderBlock2D"
+ ],
+ "scaling_factor": 0.08333
+}
diff --git a/vsr/diffusion/__init__.py b/vsr/diffusion/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..d171551ccbeff3f7d38a6e89624815f7ebd1e4db
--- /dev/null
+++ b/vsr/diffusion/__init__.py
@@ -0,0 +1,54 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from . import gaussian_diffusion as gd
+from .respace import SpacedDiffusion, space_timesteps
+
+
+# !important
+def create_diffusion(
+ timestep_respacing="",
+ noise_schedule="linear", # 'linear' for training
+ use_kl=False,
+ rescale_learned_sigmas=False,
+ prediction_type='v_prediction',
+ variance_type='fixed_small',
+ beta_start=0.0001,
+ beta_end=0.02,
+ # beta_start=0.00085,
+ # beta_end=0.012,
+ diffusion_steps=1000
+):
+ betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps, beta_start=beta_start, beta_end=beta_end)
+ if prediction_type == 'epsilon':
+ model_mean_type = gd.ModelMeanType.EPSILON # EPSILON type for stable-diffusion-2-1 512
+ elif prediction_type == 'xstart':
+ model_mean_type = gd.ModelMeanType.START_X
+ elif prediction_type == 'v_prediction':
+ model_mean_type = gd.ModelMeanType.PREVIOUS_V # gd.ModelMeanType.PREVIOUS_V for stable-diffusion-2-1 768/x4-upscaler
+
+ if variance_type == 'fixed_small':
+ model_var_type = gd.ModelVarType.FIXED_SMALL
+ elif variance_type == 'fixed_large':
+ model_var_type = gd.ModelVarType.FIXED_LARGE
+ elif variance_type == 'learned_range':
+ model_var_type = gd.ModelVarType.LEARNED_RANGE
+
+ if use_kl:
+ loss_type = gd.LossType.RESCALED_KL
+ elif rescale_learned_sigmas:
+ loss_type = gd.LossType.RESCALED_MSE
+ else:
+ loss_type = gd.LossType.MSE
+ if timestep_respacing is None or timestep_respacing == "":
+ timestep_respacing = [diffusion_steps]
+ return SpacedDiffusion(
+ use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
+ betas=betas,
+ model_mean_type=(model_mean_type),
+ model_var_type=(model_var_type),
+ loss_type=loss_type
+ # rescale_timesteps=rescale_timesteps,
+ )
diff --git a/vsr/diffusion/diffusion_utils.py b/vsr/diffusion/diffusion_utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..e493a6a3ecb91e553a53cc7eadee5cc0d1753060
--- /dev/null
+++ b/vsr/diffusion/diffusion_utils.py
@@ -0,0 +1,88 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+import torch as th
+import numpy as np
+
+
+def normal_kl(mean1, logvar1, mean2, logvar2):
+ """
+ Compute the KL divergence between two gaussians.
+ Shapes are automatically broadcasted, so batches can be compared to
+ scalars, among other use cases.
+ """
+ tensor = None
+ for obj in (mean1, logvar1, mean2, logvar2):
+ if isinstance(obj, th.Tensor):
+ tensor = obj
+ break
+ assert tensor is not None, "at least one argument must be a Tensor"
+
+ # Force variances to be Tensors. Broadcasting helps convert scalars to
+ # Tensors, but it does not work for th.exp().
+ logvar1, logvar2 = [
+ x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor)
+ for x in (logvar1, logvar2)
+ ]
+
+ return 0.5 * (
+ -1.0
+ + logvar2
+ - logvar1
+ + th.exp(logvar1 - logvar2)
+ + ((mean1 - mean2) ** 2) * th.exp(-logvar2)
+ )
+
+
+def approx_standard_normal_cdf(x):
+ """
+ A fast approximation of the cumulative distribution function of the
+ standard normal.
+ """
+ return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
+
+
+def continuous_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a continuous Gaussian distribution.
+ :param x: the targets
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ normalized_x = centered_x * inv_stdv
+ log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
+ return log_probs
+
+
+def discretized_gaussian_log_likelihood(x, *, means, log_scales):
+ """
+ Compute the log-likelihood of a Gaussian distribution discretizing to a
+ given image.
+ :param x: the target images. It is assumed that this was uint8 values,
+ rescaled to the range [-1, 1].
+ :param means: the Gaussian mean Tensor.
+ :param log_scales: the Gaussian log stddev Tensor.
+ :return: a tensor like x of log probabilities (in nats).
+ """
+ assert x.shape == means.shape == log_scales.shape
+ centered_x = x - means
+ inv_stdv = th.exp(-log_scales)
+ plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
+ cdf_plus = approx_standard_normal_cdf(plus_in)
+ min_in = inv_stdv * (centered_x - 1.0 / 255.0)
+ cdf_min = approx_standard_normal_cdf(min_in)
+ log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
+ log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
+ cdf_delta = cdf_plus - cdf_min
+ log_probs = th.where(
+ x < -0.999,
+ log_cdf_plus,
+ th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
+ )
+ assert log_probs.shape == x.shape
+ return log_probs
diff --git a/vsr/diffusion/gaussian_diffusion.py b/vsr/diffusion/gaussian_diffusion.py
new file mode 100644
index 0000000000000000000000000000000000000000..be197ae8b899ecb5e763226b8c56de335fa123d6
--- /dev/null
+++ b/vsr/diffusion/gaussian_diffusion.py
@@ -0,0 +1,923 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+
+import math
+
+import numpy as np
+import torch as th
+import enum
+
+from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+class ModelMeanType(enum.Enum):
+ """
+ Which type of output the model predicts.
+ """
+
+ PREVIOUS_V = enum.auto() # v-parameterization for VSR; (see section 2.4 https://imagen.research.google/video/paper.pdf)
+ PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
+ START_X = enum.auto() # the model predicts x_0
+ EPSILON = enum.auto() # the model predicts epsilon
+
+
+class ModelVarType(enum.Enum):
+ """
+ What is used as the model's output variance.
+ The LEARNED_RANGE option has been added to allow the model to predict
+ values between FIXED_SMALL and FIXED_LARGE, making its job easier.
+ """
+
+ LEARNED = enum.auto()
+ FIXED_SMALL = enum.auto()
+ FIXED_LARGE = enum.auto()
+ LEARNED_RANGE = enum.auto()
+
+
+class LossType(enum.Enum):
+ MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
+ RESCALED_MSE = (
+ enum.auto()
+ ) # use raw MSE loss (with RESCALED_KL when learning variances)
+ KL = enum.auto() # use the variational lower-bound
+ RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
+
+ def is_vb(self):
+ return self == LossType.KL or self == LossType.RESCALED_KL
+
+
+def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ warmup_time = int(num_diffusion_timesteps * warmup_frac)
+ betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
+ return betas
+
+
+def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
+ """
+ This is the deprecated API for creating beta schedules.
+ See get_named_beta_schedule() for the new library of schedules.
+ """
+ if beta_schedule == "scaled_linear":
+ betas = (np.linspace(beta_start**0.5, beta_end**0.5, num_diffusion_timesteps, dtype=np.float64)** 2)
+ elif beta_schedule == "linear":
+ betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "warmup10":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
+ elif beta_schedule == "warmup50":
+ betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
+ elif beta_schedule == "const":
+ betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
+ elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
+ betas = 1.0 / np.linspace(
+ num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64
+ )
+ else:
+ raise NotImplementedError(beta_schedule)
+ assert betas.shape == (num_diffusion_timesteps,)
+ return betas
+
+
+def get_named_beta_schedule(schedule_name, num_diffusion_timesteps, beta_start, beta_end):
+ """
+ Get a pre-defined beta schedule for the given name.
+ The beta schedule library consists of beta schedules which remain similar
+ in the limit of num_diffusion_timesteps.
+ Beta schedules may be added, but should not be removed or changed once
+ they are committed to maintain backwards compatibility.
+ """
+ if schedule_name in ["linear", "scaled_linear"]:
+ # Linear schedule from Ho et al, extended to work for any number of
+ # diffusion steps.
+ scale = 1000 / num_diffusion_timesteps
+ return get_beta_schedule(
+ schedule_name,
+ beta_start=scale * beta_start,
+ beta_end=scale * beta_end,
+ num_diffusion_timesteps=num_diffusion_timesteps,
+ )
+ elif schedule_name == "squaredcos_cap_v2":
+ return betas_for_alpha_bar(
+ num_diffusion_timesteps,
+ lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
+ )
+ else:
+ raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
+
+
+def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function,
+ which defines the cumulative product of (1-beta) over time from t = [0,1].
+ :param num_diffusion_timesteps: the number of betas to produce.
+ :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
+ produces the cumulative product of (1-beta) up to that
+ part of the diffusion process.
+ :param max_beta: the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+ """
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return np.array(betas)
+
+
+class GaussianDiffusion:
+ """
+ Utilities for training and sampling diffusion models.
+ Original ported from this codebase:
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
+ :param betas: a 1-D numpy array of betas for each diffusion timestep,
+ starting at T and going to 1.
+ """
+
+ def __init__(
+ self,
+ *,
+ betas,
+ model_mean_type,
+ model_var_type,
+ loss_type
+ ):
+
+ self.model_mean_type = model_mean_type
+ self.model_var_type = model_var_type
+ self.loss_type = loss_type
+
+ # Use float64 for accuracy.
+ betas = np.array(betas, dtype=np.float64)
+ self.betas = betas
+ assert len(betas.shape) == 1, "betas must be 1-D"
+ assert (betas > 0).all() and (betas <= 1).all()
+
+ self.num_timesteps = int(betas.shape[0])
+
+ alphas = 1.0 - betas
+ self.alphas_cumprod = np.cumprod(alphas, axis=0)
+ self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
+ self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
+ assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
+ self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
+ self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
+ self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
+ self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
+
+ # calculations for posterior q(x_{t-1} | x_t, x_0)
+ self.posterior_variance = (
+ betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
+ self.posterior_log_variance_clipped = np.log(
+ np.append(self.posterior_variance[1], self.posterior_variance[1:])
+ ) if len(self.posterior_variance) > 1 else np.array([])
+
+ self.posterior_mean_coef1 = (
+ betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
+ )
+ self.posterior_mean_coef2 = (
+ (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
+ )
+
+ def q_mean_variance(self, x_start, t):
+ """
+ Get the distribution q(x_t | x_0).
+ :param x_start: the [N x C x ...] tensor of noiseless inputs.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :return: A tuple (mean, variance, log_variance), all of x_start's shape.
+ """
+ mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
+ log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
+ return mean, variance, log_variance
+
+ def q_sample(self, x_start, t, noise=None):
+ """
+ Diffuse the data for a given number of diffusion steps.
+ In other words, sample from q(x_t | x_0).
+ :param x_start: the initial data batch.
+ :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
+ :param noise: if specified, the split-out normal noise.
+ :return: A noisy version of x_start.
+ """
+ if noise is None:
+ noise = th.randn_like(x_start)
+ assert noise.shape == x_start.shape
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
+ + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
+ )
+
+ def q_posterior_mean_variance(self, x_start, x_t, t):
+ """
+ Compute the mean and variance of the diffusion posterior:
+ q(x_{t-1} | x_t, x_0)
+ """
+ assert x_start.shape == x_t.shape
+ posterior_mean = (
+ _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
+ + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
+ )
+ posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
+ posterior_log_variance_clipped = _extract_into_tensor(
+ self.posterior_log_variance_clipped, t, x_t.shape
+ )
+ assert (
+ posterior_mean.shape[0]
+ == posterior_variance.shape[0]
+ == posterior_log_variance_clipped.shape[0]
+ == x_start.shape[0]
+ )
+ return posterior_mean, posterior_variance, posterior_log_variance_clipped
+
+ def get_v(self, x_start, noise, t):
+ # v-prediction parameterization
+ # training loss type for stable-diffusion-2-1 768/x4-upscaler
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * noise -
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * x_start
+ )
+
+ def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
+ """
+ Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
+ the initial x, x_0.
+ :param model: the model, which takes a signal and a batch of timesteps
+ as input.
+ :param x: the [N x C x ...] tensor at time t.
+ :param t: a 1-D Tensor of timesteps.
+ :param clip_denoised: if True, clip the denoised signal into [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample. Applies before
+ clip_denoised.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict with the following keys:
+ - 'mean': the model mean output.
+ - 'variance': the model variance output.
+ - 'log_variance': the log of 'variance'.
+ - 'pred_xstart': the prediction for x_0.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+
+ B, F, C = x.shape[:3]
+ assert t.shape == (B,)
+ # model_output = model(x, t, **model_kwargs)
+ try:
+ model_output = model(x, t, **model_kwargs).sample # for tav unet
+ # print(model_output.shape)
+ except:
+ model_output = model(x, t, **model_kwargs)
+
+ # for v prediction
+ # if self.model_mean_type == ModelMeanType.PREVIOUS_V:
+ # model_output = self._predict_eps_from_z_and_v(x, t, model_output)
+ if isinstance(model_output, tuple):
+ model_output, extra = model_output
+ else:
+ extra = None
+
+ if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
+ assert model_output.shape == (B, F, C * 2, *x.shape[3:])
+ model_output, model_var_values = th.split(model_output, C, dim=2)
+ min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
+ max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
+ # The model_var_values is [-1, 1] for [min_var, max_var].
+ frac = (model_var_values + 1) / 2
+ model_log_variance = frac * max_log + (1 - frac) * min_log
+ model_variance = th.exp(model_log_variance)
+ else:
+ model_variance, model_log_variance = {
+ # for fixedlarge, we set the initial (log-)variance like so
+ # to get a better decoder log likelihood.
+ ModelVarType.FIXED_LARGE: (
+ np.append(self.posterior_variance[1], self.betas[1:]),
+ np.log(np.append(self.posterior_variance[1], self.betas[1:])),
+ ),
+ ModelVarType.FIXED_SMALL: (
+ self.posterior_variance,
+ self.posterior_log_variance_clipped,
+ ),
+ }[self.model_var_type]
+ model_variance = _extract_into_tensor(model_variance, t, x.shape)
+ model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
+
+ def process_xstart(x):
+ if denoised_fn is not None:
+ x = denoised_fn(x)
+ if clip_denoised:
+ return x.clamp(-1, 1)
+ return x
+
+ if self.model_mean_type == ModelMeanType.START_X:
+ pred_xstart = process_xstart(model_output)
+ else:
+ if self.model_mean_type == ModelMeanType.PREVIOUS_V:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_z_and_v(x_t=x, t=t, v=model_output)
+ )
+ else:
+ pred_xstart = process_xstart(
+ self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output)
+ )
+
+ if self.model_mean_type == ModelMeanType.PREVIOUS_V:
+ eps = self._predict_eps_from_z_and_v(x_t=x, t=t, v=model_output)
+ else:
+ eps = model_output
+
+ model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
+
+ assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
+ return {
+ "eps": eps,
+ "mean": model_mean,
+ "variance": model_variance,
+ "log_variance": model_log_variance,
+ "pred_xstart": pred_xstart,
+ "extra": extra,
+ }
+
+ def _predict_xstart_from_eps(self, x_t, t, eps):
+ assert x_t.shape == eps.shape
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
+ - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
+ )
+
+ def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
+ return (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
+
+ # for v prediction
+ def _predict_xstart_from_z_and_v(self, x_t, t, v):
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * x_t -
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * v
+ )
+
+ # for v prediction
+ def _predict_eps_from_z_and_v(self, x_t, t, v):
+ return (
+ _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_t.shape) * v +
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_t.shape) * x_t
+ )
+
+
+ def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute the mean for the previous step, given a function cond_fn that
+ computes the gradient of a conditional log probability with respect to
+ x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
+ condition on y.
+ This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
+ """
+ gradient = cond_fn(x, t, **model_kwargs)
+ new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
+ return new_mean
+
+ def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
+ """
+ Compute what the p_mean_variance output would have been, should the
+ model's score function be conditioned by cond_fn.
+ See condition_mean() for details on cond_fn.
+ Unlike condition_mean(), this instead uses the conditioning strategy
+ from Song et al (2020).
+ """
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+
+ eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
+ eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
+
+ out = p_mean_var.copy()
+ out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
+ out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
+ return out
+
+ def p_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ ):
+ """
+ Sample x_{t-1} from the model at the given timestep.
+ :param model: the model to sample from.
+ :param x: the current tensor at x_{t-1}.
+ :param t: the value of t, starting at 0 for the first diffusion step.
+ :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - 'sample': a random sample from the model.
+ - 'pred_xstart': a prediction of x_0.
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ noise = th.randn_like(x)
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ if cond_fn is not None:
+ out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def p_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model.
+ :param model: the model module.
+ :param shape: the shape of the samples, (N, C, H, W).
+ :param noise: if specified, the noise from the encoder to sample.
+ Should be of the same shape as `shape`.
+ :param clip_denoised: if True, clip x_start predictions to [-1, 1].
+ :param denoised_fn: if not None, a function which applies to the
+ x_start prediction before it is used to sample.
+ :param cond_fn: if not None, this is a gradient function that acts
+ similarly to the model.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param device: if specified, the device to create the samples on.
+ If not specified, use a model parameter's device.
+ :param progress: if True, show a tqdm progress bar.
+ :return: a non-differentiable batch of samples.
+ """
+ final = None
+ for sample in self.p_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ ):
+ final = sample
+ return final["sample"]
+
+ def p_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ ):
+ """
+ Generate samples from the model and yield intermediate samples from
+ each timestep of diffusion.
+ Arguments are the same as p_sample_loop().
+ Returns a generator over dicts, where each dict is the return value of
+ p_sample().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.p_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ )
+ yield out
+ img = out["sample"]
+
+ def ddim_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t-1} from the model using DDIM.
+ Same usage as p_sample().
+ """
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ if self.model_mean_type == ModelMeanType.PREVIOUS_V:
+ eps = out["eps"]
+ else:
+ eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
+
+ alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
+ alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
+ sigma = (
+ eta
+ * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar))
+ * th.sqrt(1 - alpha_bar / alpha_bar_prev)
+ )
+ # Equation 12.
+ noise = th.randn_like(x)
+ mean_pred = (
+ out["pred_xstart"] * th.sqrt(alpha_bar_prev)
+ + th.sqrt(1 - alpha_bar_prev - sigma ** 2) * eps
+ )
+ nonzero_mask = (
+ (t != 0).float().view(-1, *([1] * (len(x.shape) - 1)))
+ ) # no noise when t == 0
+ sample = mean_pred + nonzero_mask * sigma * noise
+ return {"sample": sample, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_reverse_sample(
+ self,
+ model,
+ x,
+ t,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ eta=0.0,
+ ):
+ """
+ Sample x_{t+1} from the model using DDIM reverse ODE.
+ """
+ assert eta == 0.0, "Reverse ODE only for deterministic path"
+ out = self.p_mean_variance(
+ model,
+ x,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ model_kwargs=model_kwargs,
+ )
+ if cond_fn is not None:
+ out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
+ # Usually our model outputs epsilon, but we re-derive it
+ # in case we used x_start or x_prev prediction.
+ eps = (
+ _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x
+ - out["pred_xstart"]
+ ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
+ alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
+
+ # Equation 12. reversed
+ mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
+
+ return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
+
+ def ddim_sample_loop(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Generate samples from the model using DDIM.
+ Same usage as p_sample_loop().
+ """
+ final = None
+ for sample in self.ddim_sample_loop_progressive(
+ model,
+ shape,
+ noise=noise,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ device=device,
+ progress=progress,
+ eta=eta,
+ ):
+ final = sample
+ return final["sample"]
+
+ def ddim_sample_loop_progressive(
+ self,
+ model,
+ shape,
+ noise=None,
+ clip_denoised=True,
+ denoised_fn=None,
+ cond_fn=None,
+ model_kwargs=None,
+ device=None,
+ progress=False,
+ eta=0.0,
+ ):
+ """
+ Use DDIM to sample from the model and yield intermediate samples from
+ each timestep of DDIM.
+ Same usage as p_sample_loop_progressive().
+ """
+ if device is None:
+ device = next(model.parameters()).device
+ assert isinstance(shape, (tuple, list))
+ if noise is not None:
+ img = noise
+ else:
+ img = th.randn(*shape, device=device)
+ indices = list(range(self.num_timesteps))[::-1]
+
+ if progress:
+ # Lazy import so that we don't depend on tqdm.
+ from tqdm.auto import tqdm
+
+ indices = tqdm(indices)
+
+ for i in indices:
+ t = th.tensor([i] * shape[0], device=device)
+ with th.no_grad():
+ out = self.ddim_sample(
+ model,
+ img,
+ t,
+ clip_denoised=clip_denoised,
+ denoised_fn=denoised_fn,
+ cond_fn=cond_fn,
+ model_kwargs=model_kwargs,
+ eta=eta,
+ )
+ yield out
+ img = out["sample"]
+
+ def _vb_terms_bpd(
+ self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None
+ ):
+ """
+ Get a term for the variational lower-bound.
+ The resulting units are bits (rather than nats, as one might expect).
+ This allows for comparison to other papers.
+ :return: a dict with the following keys:
+ - 'output': a shape [N] tensor of NLLs or KLs.
+ - 'pred_xstart': the x_0 predictions.
+ """
+ true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )
+ out = self.p_mean_variance(
+ model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs
+ )
+ kl = normal_kl(
+ true_mean, true_log_variance_clipped, out["mean"], out["log_variance"]
+ )
+ kl = mean_flat(kl) / np.log(2.0)
+
+ decoder_nll = -discretized_gaussian_log_likelihood(
+ x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
+ )
+ assert decoder_nll.shape == x_start.shape
+ decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
+
+ # At the first timestep return the decoder NLL,
+ # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
+ output = th.where((t == 0), decoder_nll, kl)
+ return {"output": output, "pred_xstart": out["pred_xstart"]}
+
+
+ def training_losses(self, model, x_start, t, loss_mask=None, model_kwargs=None, noise=None):
+ """
+ Compute training losses for a single timestep.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param t: a batch of timestep indices.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :param noise: if specified, the specific Gaussian noise to try to remove.
+ :return: a dict with the key "loss" containing a tensor of shape [N].
+ Some mean or variance settings may also have other keys.
+ """
+ if model_kwargs is None:
+ model_kwargs = {}
+ if noise is None:
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start, t, noise=noise)
+
+ terms = {}
+
+ if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] = self._vb_terms_bpd(
+ model=model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ model_kwargs=model_kwargs,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_KL:
+ terms["loss"] *= self.num_timesteps
+ elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE: # loss_type=LossType.MSE by default
+ try:
+ model_output = model(x_t, t, **model_kwargs).sample # for tav unet
+ except:
+ model_output = model(x_t, t, **model_kwargs)
+
+ if self.model_var_type in [
+ ModelVarType.LEARNED,
+ ModelVarType.LEARNED_RANGE,
+ ]:
+ B, F, C = x_t.shape[:3]
+ assert model_output.shape == (B, F, C * 2, *x_t.shape[3:])
+ model_output, model_var_values = th.split(model_output, C, dim=2)
+ # Learn the variance using the variational bound, but don't let
+ # it affect our mean prediction.
+ frozen_out = th.cat([model_output.detach(), model_var_values], dim=2)
+ terms["vb"] = self._vb_terms_bpd(
+ model=lambda *args, r=frozen_out: r,
+ x_start=x_start,
+ x_t=x_t,
+ t=t,
+ clip_denoised=False,
+ )["output"]
+ if self.loss_type == LossType.RESCALED_MSE:
+ # Divide by 1000 for equivalence with initial implementation.
+ # Without a factor of 1/1000, the VB term hurts the MSE term.
+ terms["vb"] *= self.num_timesteps / 1000.0
+ print("=======")
+
+ target = {
+ ModelMeanType.PREVIOUS_V: self.get_v(x_start=x_start, noise=noise, t=t),
+ ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(
+ x_start=x_start, x_t=x_t, t=t
+ )[0],
+ ModelMeanType.START_X: x_start,
+ ModelMeanType.EPSILON: noise,
+ }[self.model_mean_type]
+ assert model_output.shape == target.shape == x_start.shape
+ if loss_mask is not None:
+ terms["mse"] = mean_flat(((target - model_output) ** 2) * loss_mask)
+ else:
+ terms["mse"] = mean_flat((target - model_output) ** 2)
+
+ if "vb" in terms:
+ terms["loss"] = terms["mse"] + terms["vb"]
+ else:
+ terms["loss"] = terms["mse"]
+ else:
+ raise NotImplementedError(self.loss_type)
+
+ return terms
+
+ def _prior_bpd(self, x_start):
+ """
+ Get the prior KL term for the variational lower-bound, measured in
+ bits-per-dim.
+ This term can't be optimized, as it only depends on the encoder.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :return: a batch of [N] KL values (in bits), one per batch element.
+ """
+ batch_size = x_start.shape[0]
+ t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
+ qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
+ kl_prior = normal_kl(
+ mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0
+ )
+ return mean_flat(kl_prior) / np.log(2.0)
+
+ def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
+ """
+ Compute the entire variational lower-bound, measured in bits-per-dim,
+ as well as other related quantities.
+ :param model: the model to evaluate loss on.
+ :param x_start: the [N x C x ...] tensor of inputs.
+ :param clip_denoised: if True, clip denoised samples.
+ :param model_kwargs: if not None, a dict of extra keyword arguments to
+ pass to the model. This can be used for conditioning.
+ :return: a dict containing the following keys:
+ - total_bpd: the total variational lower-bound, per batch element.
+ - prior_bpd: the prior term in the lower-bound.
+ - vb: an [N x T] tensor of terms in the lower-bound.
+ - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
+ - mse: an [N x T] tensor of epsilon MSEs for each timestep.
+ """
+ device = x_start.device
+ batch_size = x_start.shape[0]
+
+ vb = []
+ xstart_mse = []
+ mse = []
+ for t in list(range(self.num_timesteps))[::-1]:
+ t_batch = th.tensor([t] * batch_size, device=device)
+ noise = th.randn_like(x_start)
+ x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
+ # Calculate VLB term at the current timestep
+ with th.no_grad():
+ out = self._vb_terms_bpd(
+ model,
+ x_start=x_start,
+ x_t=x_t,
+ t=t_batch,
+ clip_denoised=clip_denoised,
+ model_kwargs=model_kwargs,
+ )
+ vb.append(out["output"])
+ xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
+ eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
+ mse.append(mean_flat((eps - noise) ** 2))
+
+ vb = th.stack(vb, dim=1)
+ xstart_mse = th.stack(xstart_mse, dim=1)
+ mse = th.stack(mse, dim=1)
+
+ prior_bpd = self._prior_bpd(x_start)
+ total_bpd = vb.sum(dim=1) + prior_bpd
+ return {
+ "total_bpd": total_bpd,
+ "prior_bpd": prior_bpd,
+ "vb": vb,
+ "xstart_mse": xstart_mse,
+ "mse": mse,
+ }
+
+
+def _extract_into_tensor(arr, timesteps, broadcast_shape):
+ """
+ Extract values from a 1-D numpy array for a batch of indices.
+ :param arr: the 1-D numpy array.
+ :param timesteps: a tensor of indices into the array to extract.
+ :param broadcast_shape: a larger shape of K dimensions with the batch
+ dimension equal to the length of timesteps.
+ :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
+ """
+ res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
+ while len(res.shape) < len(broadcast_shape):
+ res = res[..., None]
+ return res + th.zeros(broadcast_shape, device=timesteps.device)
diff --git a/vsr/diffusion/respace.py b/vsr/diffusion/respace.py
new file mode 100644
index 0000000000000000000000000000000000000000..31c0092380367db477f154a39bf172ff64712fc4
--- /dev/null
+++ b/vsr/diffusion/respace.py
@@ -0,0 +1,130 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+import torch
+import numpy as np
+import torch as th
+
+from .gaussian_diffusion import GaussianDiffusion
+
+
+def space_timesteps(num_timesteps, section_counts):
+ """
+ Create a list of timesteps to use from an original diffusion process,
+ given the number of timesteps we want to take from equally-sized portions
+ of the original process.
+ For example, if there's 300 timesteps and the section counts are [10,15,20]
+ then the first 100 timesteps are strided to be 10 timesteps, the second 100
+ are strided to be 15 timesteps, and the final 100 are strided to be 20.
+ If the stride is a string starting with "ddim", then the fixed striding
+ from the DDIM paper is used, and only one section is allowed.
+ :param num_timesteps: the number of diffusion steps in the original
+ process to divide up.
+ :param section_counts: either a list of numbers, or a string containing
+ comma-separated numbers, indicating the step count
+ per section. As a special case, use "ddimN" where N
+ is a number of steps to use the striding from the
+ DDIM paper.
+ :return: a set of diffusion steps from the original process to use.
+ """
+ if isinstance(section_counts, str):
+ if section_counts.startswith("ddim"):
+ desired_count = int(section_counts[len("ddim") :])
+ for i in range(1, num_timesteps):
+ if len(range(0, num_timesteps, i)) == desired_count:
+ return set(range(0, num_timesteps, i))
+ raise ValueError(
+ f"cannot create exactly {num_timesteps} steps with an integer stride"
+ )
+ section_counts = [int(x) for x in section_counts.split(",")]
+ size_per = num_timesteps // len(section_counts)
+ extra = num_timesteps % len(section_counts)
+ start_idx = 0
+ all_steps = []
+ for i, section_count in enumerate(section_counts):
+ size = size_per + (1 if i < extra else 0)
+ if size < section_count:
+ raise ValueError(
+ f"cannot divide section of {size} steps into {section_count}"
+ )
+ if section_count <= 1:
+ frac_stride = 1
+ else:
+ frac_stride = (size - 1) / (section_count - 1)
+ cur_idx = 0.0
+ taken_steps = []
+ for _ in range(section_count):
+ taken_steps.append(start_idx + round(cur_idx))
+ cur_idx += frac_stride
+ all_steps += taken_steps
+ start_idx += size
+ return set(all_steps)
+
+
+class SpacedDiffusion(GaussianDiffusion):
+ """
+ A diffusion process which can skip steps in a base diffusion process.
+ :param use_timesteps: a collection (sequence or set) of timesteps from the
+ original diffusion process to retain.
+ :param kwargs: the kwargs to create the base diffusion process.
+ """
+
+ def __init__(self, use_timesteps, **kwargs):
+ self.use_timesteps = set(use_timesteps)
+ self.timestep_map = []
+ self.original_num_steps = len(kwargs["betas"])
+
+ base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
+ last_alpha_cumprod = 1.0
+ new_betas = []
+ for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
+ if i in self.use_timesteps:
+ new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
+ last_alpha_cumprod = alpha_cumprod
+ self.timestep_map.append(i)
+ kwargs["betas"] = np.array(new_betas)
+ super().__init__(**kwargs)
+
+ def p_mean_variance(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
+
+ # @torch.compile
+ def training_losses(
+ self, model, *args, **kwargs
+ ): # pylint: disable=signature-differs
+ return super().training_losses(self._wrap_model(model), *args, **kwargs)
+
+ def condition_mean(self, cond_fn, *args, **kwargs):
+ return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def condition_score(self, cond_fn, *args, **kwargs):
+ return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
+
+ def _wrap_model(self, model):
+ if isinstance(model, _WrappedModel):
+ return model
+ return _WrappedModel(
+ model, self.timestep_map, self.original_num_steps
+ )
+
+ def _scale_timesteps(self, t):
+ # Scaling is done by the wrapped model.
+ return t
+
+
+class _WrappedModel:
+ def __init__(self, model, timestep_map, original_num_steps):
+ self.model = model
+ self.timestep_map = timestep_map
+ # self.rescale_timesteps = rescale_timesteps
+ self.original_num_steps = original_num_steps
+
+ def __call__(self, x, ts, **kwargs):
+ map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
+ new_ts = map_tensor[ts]
+ # if self.rescale_timesteps:
+ # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
+ return self.model(x, new_ts, **kwargs)
diff --git a/vsr/diffusion/scheduling_ddim.py b/vsr/diffusion/scheduling_ddim.py
new file mode 100644
index 0000000000000000000000000000000000000000..8564006cac2d477a645178c01bb8ac18b871eacd
--- /dev/null
+++ b/vsr/diffusion/scheduling_ddim.py
@@ -0,0 +1,462 @@
+# Copyright 2023 Stanford University Team and The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# DISCLAIMER: This code is strongly influenced by https://github.com/pesser/pytorch_diffusion
+# and https://github.com/hojonathanho/diffusion
+
+import math
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import numpy as np
+import torch
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput
+
+try:
+ from diffusers.utils import randn_tensor
+except:
+ from diffusers.utils.torch_utils import randn_tensor
+
+from diffusers.schedulers.scheduling_utils import KarrasDiffusionSchedulers, SchedulerMixin
+
+
+@dataclass
+# Copied from diffusers.schedulers.scheduling_ddpm.DDPMSchedulerOutput with DDPM->DDIM
+class DDIMSchedulerOutput(BaseOutput):
+ """
+ Output class for the scheduler's step function output.
+
+ Args:
+ prev_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ Computed sample (x_{t-1}) of previous timestep. `prev_sample` should be used as next model input in the
+ denoising loop.
+ pred_original_sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` for images):
+ The predicted denoised sample (x_{0}) based on the model output from the current timestep.
+ `pred_original_sample` can be used to preview progress or for guidance.
+ """
+
+ prev_sample: torch.FloatTensor
+ pred_original_sample: Optional[torch.FloatTensor] = None
+
+
+# Copied from diffusers.schedulers.scheduling_ddpm.betas_for_alpha_bar
+def betas_for_alpha_bar(num_diffusion_timesteps, max_beta=0.999) -> torch.Tensor:
+ """
+ Create a beta schedule that discretizes the given alpha_t_bar function, which defines the cumulative product of
+ (1-beta) over time from t = [0,1].
+
+ Contains a function alpha_bar that takes an argument t and transforms it to the cumulative product of (1-beta) up
+ to that part of the diffusion process.
+
+
+ Args:
+ num_diffusion_timesteps (`int`): the number of betas to produce.
+ max_beta (`float`): the maximum beta to use; use values lower than 1 to
+ prevent singularities.
+
+ Returns:
+ betas (`np.ndarray`): the betas used by the scheduler to step the model outputs
+ """
+
+ def alpha_bar(time_step):
+ return math.cos((time_step + 0.008) / 1.008 * math.pi / 2) ** 2
+
+ betas = []
+ for i in range(num_diffusion_timesteps):
+ t1 = i / num_diffusion_timesteps
+ t2 = (i + 1) / num_diffusion_timesteps
+ betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
+ return torch.tensor(betas, dtype=torch.float32)
+
+
+class DDIMScheduler(SchedulerMixin, ConfigMixin):
+ """
+ Denoising diffusion implicit models is a scheduler that extends the denoising procedure introduced in denoising
+ diffusion probabilistic models (DDPMs) with non-Markovian guidance.
+
+ [`~ConfigMixin`] takes care of storing all config attributes that are passed in the scheduler's `__init__`
+ function, such as `num_train_timesteps`. They can be accessed via `scheduler.config.num_train_timesteps`.
+ [`SchedulerMixin`] provides general loading and saving functionality via the [`SchedulerMixin.save_pretrained`] and
+ [`~SchedulerMixin.from_pretrained`] functions.
+
+ For more details, see the original paper: https://arxiv.org/abs/2010.02502
+
+ Args:
+ num_train_timesteps (`int`): number of diffusion steps used to train the model.
+ beta_start (`float`): the starting `beta` value of inference.
+ beta_end (`float`): the final `beta` value.
+ beta_schedule (`str`):
+ the beta schedule, a mapping from a beta range to a sequence of betas for stepping the model. Choose from
+ `linear`, `scaled_linear`, or `squaredcos_cap_v2`.
+ trained_betas (`np.ndarray`, optional):
+ option to pass an array of betas directly to the constructor to bypass `beta_start`, `beta_end` etc.
+ clip_sample (`bool`, default `True`):
+ option to clip predicted sample for numerical stability.
+ clip_sample_range (`float`, default `1.0`):
+ the maximum magnitude for sample clipping. Valid only when `clip_sample=True`.
+ set_alpha_to_one (`bool`, default `True`):
+ each diffusion step uses the value of alphas product at that step and at the previous one. For the final
+ step there is no previous alpha. When this option is `True` the previous alpha product is fixed to `1`,
+ otherwise it uses the value of alpha at step 0.
+ steps_offset (`int`, default `0`):
+ an offset added to the inference steps. You can use a combination of `offset=1` and
+ `set_alpha_to_one=False`, to make the last step use step 0 for the previous alpha product, as done in
+ stable diffusion.
+ prediction_type (`str`, default `epsilon`, optional):
+ prediction type of the scheduler function, one of `epsilon` (predicting the noise of the diffusion
+ process), `sample` (directly predicting the noisy sample`) or `v_prediction` (see section 2.4
+ https://imagen.research.google/video/paper.pdf)
+ thresholding (`bool`, default `False`):
+ whether to use the "dynamic thresholding" method (introduced by Imagen, https://arxiv.org/abs/2205.11487).
+ Note that the thresholding method is unsuitable for latent-space diffusion models (such as
+ stable-diffusion).
+ dynamic_thresholding_ratio (`float`, default `0.995`):
+ the ratio for the dynamic thresholding method. Default is `0.995`, the same as Imagen
+ (https://arxiv.org/abs/2205.11487). Valid only when `thresholding=True`.
+ sample_max_value (`float`, default `1.0`):
+ the threshold value for dynamic thresholding. Valid only when `thresholding=True`.
+ """
+
+ _compatibles = [e.name for e in KarrasDiffusionSchedulers]
+ order = 1
+
+ @register_to_config
+ def __init__(
+ self,
+ num_train_timesteps: int = 1000,
+ beta_start: float = 0.0001,
+ beta_end: float = 0.02,
+ beta_schedule: str = "linear",
+ trained_betas: Optional[Union[np.ndarray, List[float]]] = None,
+ clip_sample: bool = True,
+ set_alpha_to_one: bool = True,
+ steps_offset: int = 0,
+ prediction_type: str = "epsilon",
+ thresholding: bool = False,
+ dynamic_thresholding_ratio: float = 0.995,
+ clip_sample_range: float = 1.0,
+ sample_max_value: float = 1.0,
+ ):
+ if trained_betas is not None:
+ self.betas = torch.tensor(trained_betas, dtype=torch.float32)
+ elif beta_schedule == "linear":
+ self.betas = torch.linspace(beta_start, beta_end, num_train_timesteps, dtype=torch.float32)
+ elif beta_schedule == "scaled_linear":
+ # this schedule is very specific to the latent diffusion model.
+ self.betas = (
+ torch.linspace(beta_start**0.5, beta_end**0.5, num_train_timesteps, dtype=torch.float32) ** 2
+ )
+ elif beta_schedule == "squaredcos_cap_v2":
+ # Glide cosine schedule
+ self.betas = betas_for_alpha_bar(num_train_timesteps)
+ else:
+ raise NotImplementedError(f"{beta_schedule} does is not implemented for {self.__class__}")
+
+ self.alphas = 1.0 - self.betas
+ self.alphas_cumprod = torch.cumprod(self.alphas, dim=0)
+
+ # At every step in ddim, we are looking into the previous alphas_cumprod
+ # For the final step, there is no previous alphas_cumprod because we are already at 0
+ # `set_alpha_to_one` decides whether we set this parameter simply to one or
+ # whether we use the final alpha of the "non-previous" one.
+ self.final_alpha_cumprod = torch.tensor(1.0) if set_alpha_to_one else self.alphas_cumprod[0]
+
+ # standard deviation of the initial noise distribution
+ self.init_noise_sigma = 1.0
+
+ # setable values
+ self.num_inference_steps = None
+ self.timesteps = torch.from_numpy(np.arange(0, num_train_timesteps)[::-1].copy().astype(np.int64))
+
+ def scale_model_input(self, sample: torch.FloatTensor, timestep: Optional[int] = None) -> torch.FloatTensor:
+ """
+ Ensures interchangeability with schedulers that need to scale the denoising model input depending on the
+ current timestep.
+
+ Args:
+ sample (`torch.FloatTensor`): input sample
+ timestep (`int`, optional): current timestep
+
+ Returns:
+ `torch.FloatTensor`: scaled input sample
+ """
+ return sample
+
+ def _get_variance(self, timestep, prev_timestep):
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+ beta_prod_t = 1 - alpha_prod_t
+ beta_prod_t_prev = 1 - alpha_prod_t_prev
+
+ variance = (beta_prod_t_prev / beta_prod_t) * (1 - alpha_prod_t / alpha_prod_t_prev)
+
+ return variance
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler._threshold_sample
+ def _threshold_sample(self, sample: torch.FloatTensor) -> torch.FloatTensor:
+ """
+ "Dynamic thresholding: At each sampling step we set s to a certain percentile absolute pixel value in xt0 (the
+ prediction of x_0 at timestep t), and if s > 1, then we threshold xt0 to the range [-s, s] and then divide by
+ s. Dynamic thresholding pushes saturated pixels (those near -1 and 1) inwards, thereby actively preventing
+ pixels from saturation at each step. We find that dynamic thresholding results in significantly better
+ photorealism as well as better image-text alignment, especially when using very large guidance weights."
+
+ https://arxiv.org/abs/2205.11487
+ """
+ dtype = sample.dtype
+ batch_size, channels, height, width = sample.shape
+
+ if dtype not in (torch.float32, torch.float64):
+ sample = sample.float() # upcast for quantile calculation, and clamp not implemented for cpu half
+
+ # Flatten sample for doing quantile calculation along each image
+ sample = sample.reshape(batch_size, channels * height * width)
+
+ abs_sample = sample.abs() # "a certain percentile absolute pixel value"
+
+ s = torch.quantile(abs_sample, self.config.dynamic_thresholding_ratio, dim=1)
+ s = torch.clamp(
+ s, min=1, max=self.config.sample_max_value
+ ) # When clamped to min=1, equivalent to standard clipping to [-1, 1]
+
+ s = s.unsqueeze(1) # (batch_size, 1) because clamp will broadcast along dim=0
+ sample = torch.clamp(sample, -s, s) / s # "we threshold xt0 to the range [-s, s] and then divide by s"
+
+ sample = sample.reshape(batch_size, channels, height, width)
+ sample = sample.to(dtype)
+
+ return sample
+
+ # def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ # """
+ # Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ # Args:
+ # num_inference_steps (`int`):
+ # the number of diffusion steps used when generating samples with a pre-trained model.
+ # """
+
+ # if num_inference_steps > self.config.num_train_timesteps:
+ # raise ValueError(
+ # f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ # f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ # f" maximal {self.config.num_train_timesteps} timesteps."
+ # )
+
+ # self.num_inference_steps = num_inference_steps
+ # step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # # creates integer timesteps by multiplying by ratio
+ # # casting to int to avoid issues when num_inference_step is power of 3
+ # timesteps = (np.arange(0, num_inference_steps) * step_ratio).round()[::-1].copy().astype(np.int64)
+ # self.timesteps = torch.from_numpy(timesteps).to(device)
+ # self.timesteps += self.config.steps_offset
+
+ def set_timesteps(self, num_inference_steps: int, device: Union[str, torch.device] = None):
+ """
+ Sets the discrete timesteps used for the diffusion chain. Supporting function to be run before inference.
+
+ Args:
+ num_inference_steps (`int`):
+ the number of diffusion steps used when generating samples with a pre-trained model.
+ """
+
+ if num_inference_steps > self.config.num_train_timesteps:
+ raise ValueError(
+ f"`num_inference_steps`: {num_inference_steps} cannot be larger than `self.config.train_timesteps`:"
+ f" {self.config.num_train_timesteps} as the unet model trained with this scheduler can only handle"
+ f" maximal {self.config.num_train_timesteps} timesteps."
+ )
+
+ self.num_inference_steps = num_inference_steps
+ step_ratio = self.config.num_train_timesteps // self.num_inference_steps
+ # creates integer timesteps by multiplying by ratio
+ # casting to int to avoid issues when num_inference_step is power of 3
+ timesteps = np.linspace(self.config.steps_offset, self.config.num_train_timesteps, num_inference_steps)
+ timesteps = timesteps.round()[::-1].copy().astype(np.int64)
+ self.timesteps = torch.from_numpy(timesteps).to(device)
+ self.timesteps += self.config.steps_offset
+
+ def step(
+ self,
+ model_output: torch.FloatTensor,
+ timestep: int,
+ sample: torch.FloatTensor,
+ eta: float = 0.0,
+ use_clipped_model_output: bool = False,
+ generator=None,
+ variance_noise: Optional[torch.FloatTensor] = None,
+ return_dict: bool = True,
+ ) -> Union[DDIMSchedulerOutput, Tuple]:
+ """
+ Predict the sample at the previous timestep by reversing the SDE. Core function to propagate the diffusion
+ process from the learned model outputs (most often the predicted noise).
+
+ Args:
+ model_output (`torch.FloatTensor`): direct output from learned diffusion model.
+ timestep (`int`): current discrete timestep in the diffusion chain.
+ sample (`torch.FloatTensor`):
+ current instance of sample being created by diffusion process.
+ eta (`float`): weight of noise for added noise in diffusion step.
+ use_clipped_model_output (`bool`): if `True`, compute "corrected" `model_output` from the clipped
+ predicted original sample. Necessary because predicted original sample is clipped to [-1, 1] when
+ `self.config.clip_sample` is `True`. If no clipping has happened, "corrected" `model_output` would
+ coincide with the one provided as input and `use_clipped_model_output` will have not effect.
+ generator: random number generator.
+ variance_noise (`torch.FloatTensor`): instead of generating noise for the variance using `generator`, we
+ can directly provide the noise for the variance itself. This is useful for methods such as
+ CycleDiffusion. (https://arxiv.org/abs/2210.05559)
+ return_dict (`bool`): option for returning tuple rather than DDIMSchedulerOutput class
+
+ Returns:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] or `tuple`:
+ [`~schedulers.scheduling_utils.DDIMSchedulerOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+
+ """
+ if self.num_inference_steps is None:
+ raise ValueError(
+ "Number of inference steps is 'None', you need to run 'set_timesteps' after creating the scheduler"
+ )
+
+ # See formulas (12) and (16) of DDIM paper https://arxiv.org/pdf/2010.02502.pdf
+ # Ideally, read DDIM paper in-detail understanding
+
+ # Notation ( ->
+ # - pred_noise_t -> e_theta(x_t, t)
+ # - pred_original_sample -> f_theta(x_t, t) or x_0
+ # - std_dev_t -> sigma_t
+ # - eta -> Ξ·
+ # - pred_sample_direction -> "direction pointing to x_t"
+ # - pred_prev_sample -> "x_t-1"
+
+ # 1. get previous step value (=t-1)
+ prev_timestep = timestep - self.config.num_train_timesteps // self.num_inference_steps
+
+ # 2. compute alphas, betas
+ alpha_prod_t = self.alphas_cumprod[timestep]
+ alpha_prod_t_prev = self.alphas_cumprod[prev_timestep] if prev_timestep >= 0 else self.final_alpha_cumprod
+
+ beta_prod_t = 1 - alpha_prod_t
+
+ # 3. compute predicted original sample from predicted noise also called
+ # "predicted x_0" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ # print('===========', self.config.prediction_type)
+ # self.config.prediction_type = "v_prediction"
+ if self.config.prediction_type == "epsilon":
+ pred_original_sample = (sample - beta_prod_t ** (0.5) * model_output) / alpha_prod_t ** (0.5)
+ pred_epsilon = model_output
+ elif self.config.prediction_type == "sample":
+ pred_original_sample = model_output
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+ elif self.config.prediction_type == "v_prediction":
+ pred_original_sample = (alpha_prod_t**0.5) * sample - (beta_prod_t**0.5) * model_output
+ pred_epsilon = (alpha_prod_t**0.5) * model_output + (beta_prod_t**0.5) * sample
+ else:
+ raise ValueError(
+ f"prediction_type given as {self.config.prediction_type} must be one of `epsilon`, `sample`, or"
+ " `v_prediction`"
+ )
+
+ # 4. Clip or threshold "predicted x_0"
+ if self.config.thresholding:
+ pred_original_sample = self._threshold_sample(pred_original_sample)
+ elif self.config.clip_sample:
+ pred_original_sample = pred_original_sample.clamp(
+ -self.config.clip_sample_range, self.config.clip_sample_range
+ )
+
+ # 5. compute variance: "sigma_t(Ξ·)" -> see formula (16)
+ # Ο_t = sqrt((1 β Ξ±_tβ1)/(1 β Ξ±_t)) * sqrt(1 β Ξ±_t/Ξ±_tβ1)
+ variance = self._get_variance(timestep, prev_timestep)
+ std_dev_t = eta * variance ** (0.5)
+
+ if use_clipped_model_output:
+ # the pred_epsilon is always re-derived from the clipped x_0 in Glide
+ pred_epsilon = (sample - alpha_prod_t ** (0.5) * pred_original_sample) / beta_prod_t ** (0.5)
+
+ # 6. compute "direction pointing to x_t" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ pred_sample_direction = (1 - alpha_prod_t_prev - std_dev_t**2) ** (0.5) * pred_epsilon
+
+ # 7. compute x_t without "random noise" of formula (12) from https://arxiv.org/pdf/2010.02502.pdf
+ prev_sample = alpha_prod_t_prev ** (0.5) * pred_original_sample + pred_sample_direction
+
+ if eta > 0:
+ if variance_noise is not None and generator is not None:
+ raise ValueError(
+ "Cannot pass both generator and variance_noise. Please make sure that either `generator` or"
+ " `variance_noise` stays `None`."
+ )
+
+ if variance_noise is None:
+ variance_noise = randn_tensor(
+ model_output.shape, generator=generator, device=model_output.device, dtype=model_output.dtype
+ )
+ variance = std_dev_t * variance_noise
+
+ prev_sample = prev_sample + variance
+
+ if not return_dict:
+ return (prev_sample,)
+
+ return DDIMSchedulerOutput(prev_sample=prev_sample, pred_original_sample=pred_original_sample)
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.add_noise
+ def add_noise(
+ self,
+ original_samples: torch.FloatTensor,
+ noise: torch.FloatTensor,
+ timesteps: torch.IntTensor,
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
+ timesteps = timesteps.to(original_samples.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
+ return noisy_samples
+
+ # Copied from diffusers.schedulers.scheduling_ddpm.DDPMScheduler.get_velocity
+ def get_velocity(
+ self, sample: torch.FloatTensor, noise: torch.FloatTensor, timesteps: torch.IntTensor
+ ) -> torch.FloatTensor:
+ # Make sure alphas_cumprod and timestep have same device and dtype as sample
+ alphas_cumprod = self.alphas_cumprod.to(device=sample.device, dtype=sample.dtype)
+ timesteps = timesteps.to(sample.device)
+
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
+ while len(sqrt_alpha_prod.shape) < len(sample.shape):
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
+
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
+ while len(sqrt_one_minus_alpha_prod.shape) < len(sample.shape):
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
+
+ velocity = sqrt_alpha_prod * noise - sqrt_one_minus_alpha_prod * sample
+ return velocity
+
+ def __len__(self):
+ return self.config.num_train_timesteps
diff --git a/vsr/diffusion/timestep_sampler.py b/vsr/diffusion/timestep_sampler.py
new file mode 100644
index 0000000000000000000000000000000000000000..a3f369847677d8dbaaadb8297691b1be92cf189f
--- /dev/null
+++ b/vsr/diffusion/timestep_sampler.py
@@ -0,0 +1,150 @@
+# Modified from OpenAI's diffusion repos
+# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
+# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
+# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+
+from abc import ABC, abstractmethod
+
+import numpy as np
+import torch as th
+import torch.distributed as dist
+
+
+def create_named_schedule_sampler(name, diffusion):
+ """
+ Create a ScheduleSampler from a library of pre-defined samplers.
+ :param name: the name of the sampler.
+ :param diffusion: the diffusion object to sample for.
+ """
+ if name == "uniform":
+ return UniformSampler(diffusion)
+ elif name == "loss-second-moment":
+ return LossSecondMomentResampler(diffusion)
+ else:
+ raise NotImplementedError(f"unknown schedule sampler: {name}")
+
+
+class ScheduleSampler(ABC):
+ """
+ A distribution over timesteps in the diffusion process, intended to reduce
+ variance of the objective.
+ By default, samplers perform unbiased importance sampling, in which the
+ objective's mean is unchanged.
+ However, subclasses may override sample() to change how the resampled
+ terms are reweighted, allowing for actual changes in the objective.
+ """
+
+ @abstractmethod
+ def weights(self):
+ """
+ Get a numpy array of weights, one per diffusion step.
+ The weights needn't be normalized, but must be positive.
+ """
+
+ def sample(self, batch_size, device):
+ """
+ Importance-sample timesteps for a batch.
+ :param batch_size: the number of timesteps.
+ :param device: the torch device to save to.
+ :return: a tuple (timesteps, weights):
+ - timesteps: a tensor of timestep indices.
+ - weights: a tensor of weights to scale the resulting losses.
+ """
+ w = self.weights()
+ p = w / np.sum(w)
+ indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
+ indices = th.from_numpy(indices_np).long().to(device)
+ weights_np = 1 / (len(p) * p[indices_np])
+ weights = th.from_numpy(weights_np).float().to(device)
+ return indices, weights
+
+
+class UniformSampler(ScheduleSampler):
+ def __init__(self, diffusion):
+ self.diffusion = diffusion
+ self._weights = np.ones([diffusion.num_timesteps])
+
+ def weights(self):
+ return self._weights
+
+
+class LossAwareSampler(ScheduleSampler):
+ def update_with_local_losses(self, local_ts, local_losses):
+ """
+ Update the reweighting using losses from a model.
+ Call this method from each rank with a batch of timesteps and the
+ corresponding losses for each of those timesteps.
+ This method will perform synchronization to make sure all of the ranks
+ maintain the exact same reweighting.
+ :param local_ts: an integer Tensor of timesteps.
+ :param local_losses: a 1D Tensor of losses.
+ """
+ batch_sizes = [
+ th.tensor([0], dtype=th.int32, device=local_ts.device)
+ for _ in range(dist.get_world_size())
+ ]
+ dist.all_gather(
+ batch_sizes,
+ th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
+ )
+
+ # Pad all_gather batches to be the maximum batch size.
+ batch_sizes = [x.item() for x in batch_sizes]
+ max_bs = max(batch_sizes)
+
+ timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
+ loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
+ dist.all_gather(timestep_batches, local_ts)
+ dist.all_gather(loss_batches, local_losses)
+ timesteps = [
+ x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]
+ ]
+ losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
+ self.update_with_all_losses(timesteps, losses)
+
+ @abstractmethod
+ def update_with_all_losses(self, ts, losses):
+ """
+ Update the reweighting using losses from a model.
+ Sub-classes should override this method to update the reweighting
+ using losses from the model.
+ This method directly updates the reweighting without synchronizing
+ between workers. It is called by update_with_local_losses from all
+ ranks with identical arguments. Thus, it should have deterministic
+ behavior to maintain state across workers.
+ :param ts: a list of int timesteps.
+ :param losses: a list of float losses, one per timestep.
+ """
+
+
+class LossSecondMomentResampler(LossAwareSampler):
+ def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
+ self.diffusion = diffusion
+ self.history_per_term = history_per_term
+ self.uniform_prob = uniform_prob
+ self._loss_history = np.zeros(
+ [diffusion.num_timesteps, history_per_term], dtype=np.float64
+ )
+ self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
+
+ def weights(self):
+ if not self._warmed_up():
+ return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
+ weights = np.sqrt(np.mean(self._loss_history ** 2, axis=-1))
+ weights /= np.sum(weights)
+ weights *= 1 - self.uniform_prob
+ weights += self.uniform_prob / len(weights)
+ return weights
+
+ def update_with_all_losses(self, ts, losses):
+ for t, loss in zip(ts, losses):
+ if self._loss_counts[t] == self.history_per_term:
+ # Shift out the oldest loss term.
+ self._loss_history[t, :-1] = self._loss_history[t, 1:]
+ self._loss_history[t, -1] = loss
+ else:
+ self._loss_history[t, self._loss_counts[t]] = loss
+ self._loss_counts[t] += 1
+
+ def _warmed_up(self):
+ return (self._loss_counts == self.history_per_term).all()
diff --git a/vsr/models/__init__.py b/vsr/models/__init__.py
new file mode 100644
index 0000000000000000000000000000000000000000..23fe8e3be56d23ed269e7e9dfd89987eb68e47ea
--- /dev/null
+++ b/vsr/models/__init__.py
@@ -0,0 +1,28 @@
+from .unet import UNet3DVSRModel
+from torch.optim.lr_scheduler import LambdaLR
+
+def customized_lr_scheduler(optimizer, warmup_steps=5000): # 5000 from u-vit
+ from torch.optim.lr_scheduler import LambdaLR
+ def fn(step):
+ if warmup_steps > 0:
+ return min(step / warmup_steps, 1)
+ else:
+ return 1
+ return LambdaLR(optimizer, fn)
+
+
+def get_lr_scheduler(optimizer, name, **kwargs):
+ if name == 'warmup':
+ return customized_lr_scheduler(optimizer, **kwargs)
+ elif name == 'cosine':
+ from torch.optim.lr_scheduler import CosineAnnealingLR
+ return CosineAnnealingLR(optimizer, **kwargs)
+ else:
+ raise NotImplementedError(name)
+
+def get_models():
+ config_path = "./configs/unet_3d_config.json"
+ pretrained_model_path = "./pretrained_models/upscaler4x/unet/diffusion_pytorch_model.bin"
+ return UNet3DVSRModel.from_pretrained_2d(config_path, pretrained_model_path)
+
+
\ No newline at end of file
diff --git a/vsr/models/attention.py b/vsr/models/attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..6b4f4514752d23e916d50c68e18a7a672a3b1fcf
--- /dev/null
+++ b/vsr/models/attention.py
@@ -0,0 +1,826 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/attention.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+from dataclasses import dataclass
+from typing import Optional
+
+import math
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention import FeedForward, AdaLayerNorm
+
+try:
+ from .resnet import ResnetBlock3DCNN
+except:
+ from resnet import ResnetBlock3DCNN
+
+from rotary_embedding_torch import RotaryEmbedding
+from typing import Callable, Optional
+from einops import rearrange, repeat
+
+@dataclass
+class Transformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+def exists(x):
+ return x is not None
+
+
+class CrossAttention(nn.Module):
+ r"""
+ copy from diffuser 0.11.1
+ A cross attention layer.
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ use_relative_position: bool = False,
+ ):
+ super().__init__()
+ # print('num head', heads)
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ self.dim_head = dim_head
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = False
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ # print(use_relative_position)
+ self.use_relative_position = use_relative_position
+ if self.use_relative_position:
+ self.rotary_emb = RotaryEmbedding(min(32, dim_head))
+ # # print(dim_head)
+ # # print(heads)
+ # # adopt https://github.com/huggingface/transformers/blob/8a817e1ecac6a420b1bdc701fcc33535a3b96ff5/src/transformers/models/bert/modeling_bert.py#L265
+ # self.max_position_embeddings = 32
+ # self.distance_embedding = nn.Embedding(2 * self.max_position_embeddings - 1, dim_head)
+
+ # self.dropout = nn.Dropout(dropout)
+
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def reshape_for_scores(self, tensor):
+ # split heads and dims
+ # tensor should be [b (h w)] f (d nd)
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).contiguous()
+ return tensor
+
+ def same_batch_dim_to_heads(self, tensor):
+ batch_size, head_size, seq_len, dim = tensor.shape # [b (h w)] nd f d
+ tensor = tensor.reshape(batch_size, seq_len, dim * head_size)
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, use_image_num=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
+
+ # print('before reshpape query shape', query.shape)
+ dim = query.shape[-1]
+ if not self.use_relative_position:
+ query = self.reshape_heads_to_batch_dim(query) # [b (h w) nd] f d
+ # print('after reshape query shape', query.shape)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ if not self.use_relative_position:
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ # print('query shape', query.shape)
+ # print('key shape', key.shape)
+ # print('value shape', value.shape)
+
+ if attention_mask is not None:
+ # print('attention_mask', attention_mask.shape)
+ # print('attention_scores', attention_scores.shape)
+ # exit()
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+ # print(attention_probs.shape)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+ # print(attention_probs.shape)
+
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+ # print(hidden_states.shape)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ # print(hidden_states.shape)
+ # exit()
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class Transformer3DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ rotary_emb: bool = None,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # Define input layers
+ self.in_channels = in_channels
+
+ # add 3D CNN for VSR
+ # if only_cross_attention == False: # x8 down
+ # self.resblock_temporal = ResnetBlock3DCNN(in_channels=in_channels, kernel=(3,3,3), temb_channels=None)
+ # else:
+ # self.resblock_temporal = ResnetBlock3DCNN(in_channels=in_channels, kernel=(5,1,1), temb_channels=None)
+
+ self.resblock_temporal = ResnetBlock3DCNN(in_channels=in_channels, kernel=(3,1,1), temb_channels=None)
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ # Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if use_linear_projection:
+ self.proj_out = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
+ # Input
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
+
+ batch, channel, height, weight = hidden_states.shape
+
+ # 3D CNN for VSR
+ hidden_states = rearrange(hidden_states, "(b f) c h w -> b c f h w", f=video_length)
+ hidden_states = self.resblock_temporal(hidden_states)
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w", f=video_length)
+
+ residual = hidden_states
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ video_length=video_length
+ )
+
+ # Output
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
+ if not return_dict:
+ return (output,)
+
+ return Transformer3DModelOutput(sample=output)
+
+
+class BasicTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ rotary_emb: bool = None,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ # print(only_cross_attention)
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+ self.use_first_frame = use_first_frame # False for VSR
+
+ # SC-Attn
+ if use_first_frame and only_cross_attention == False:
+ self.attn1 = SparseCausalAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+ # print(cross_attention_dim)
+ else:
+ self.attn1 = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ )
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ # Cross-Attn
+ if cross_attention_dim is not None:
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ )
+ else:
+ self.attn2 = None
+
+ if cross_attention_dim is not None:
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ else:
+ self.norm2 = None
+
+
+ # Temp-Attn for VSR
+ self.attn_temporal = TemporalAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=None,
+ upcast_attention=upcast_attention,
+ rotary_emb=rotary_emb,
+ )
+ nn.init.zeros_(self.attn_temporal.to_out[0].weight.data)
+ self.norm_temporal = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ # Feed-forward
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.norm3 = nn.LayerNorm(dim)
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None):
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ if self.attn2 is not None:
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
+ # SparseCausal-Attention
+ norm_hidden_states = (
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ )
+
+ if self.only_cross_attention: # Cross-Attention
+ hidden_states = (
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
+ )
+ else: # Self-Attention
+ if self.use_first_frame:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
+ else:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
+
+
+ if self.attn2 is not None:
+ # Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ hidden_states = (
+ self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ + hidden_states
+ )
+
+
+ # Temporal-Attention for VSR
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length).contiguous()
+ norm_hidden_states = (
+ self.norm_temporal(hidden_states, timestep) if self.use_ada_layer_norm else self.norm_temporal(hidden_states)
+ )
+ hidden_states = self.attn_temporal(norm_hidden_states) + hidden_states
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d).contiguous()
+
+ # Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+ return hidden_states
+
+
+class SparseCausalAttention(CrossAttention):
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ former_frame_index = torch.arange(video_length) - 1
+ former_frame_index[0] = 0
+
+ key = rearrange(key, "(b f) d c -> b f d c", f=video_length)
+ key = torch.cat([key[:, [0] * video_length], key[:, former_frame_index]], dim=2)
+ key = rearrange(key, "b f d c -> (b f) d c")
+
+ value = rearrange(value, "(b f) d c -> b f d c", f=video_length)
+ value = torch.cat([value[:, [0] * video_length], value[:, former_frame_index]], dim=2)
+ value = rearrange(value, "b f d c -> (b f) d c")
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+
+class TemporalAttention(CrossAttention):
+ def __init__(self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ rotary_emb=None):
+ super().__init__(query_dim, cross_attention_dim, heads, dim_head, dropout, bias, upcast_attention, upcast_softmax, added_kv_proj_dim, norm_num_groups)
+ # relative time positional embeddings
+ self.time_rel_pos_bias = RelativePositionBias(heads=heads, max_distance=32) # realistically will not be able to generate that many frames of video... yet
+ self.rotary_emb = rotary_emb
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ time_rel_pos_bias = self.time_rel_pos_bias(hidden_states.shape[1], device=hidden_states.device)
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states) # [b (h w)] f (nd * d)
+ dim = query.shape[-1]
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask, time_rel_pos_bias)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+
+ def _attention(self, query, key, value, attention_mask=None, time_rel_pos_bias=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ # print('query shape', query.shape)
+ # print('key shape', key.shape)
+ # print('value shape', value.shape)
+ # reshape for adding time positional bais
+ query = self.scale * rearrange(query, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+ key = rearrange(key, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+ value = rearrange(value, 'b f (h d) -> b h f d', h=self.heads) # d: dim_head; n: heads
+ # print('query shape', query.shape)
+ # print('key shape', key.shape)
+ # print('value shape', value.shape)
+
+ # torch.baddbmm only accepte 3-D tensor
+ # https://runebook.dev/zh/docs/pytorch/generated/torch.baddbmm
+ # attention_scores = self.scale * torch.matmul(query, key.transpose(-1, -2))
+ if exists(self.rotary_emb):
+ query = self.rotary_emb.rotate_queries_or_keys(query)
+ key = self.rotary_emb.rotate_queries_or_keys(key)
+
+ attention_scores = torch.einsum('... h i d, ... h j d -> ... h i j', query, key)
+ # print('attention_scores shape', attention_scores.shape)
+ # print('time_rel_pos_bias shape', time_rel_pos_bias.shape)
+ # print('attention_mask shape', attention_mask.shape)
+
+ attention_scores = attention_scores + time_rel_pos_bias
+ # print(attention_scores.shape)
+
+ # bert from huggin face
+ # attention_scores = attention_scores / math.sqrt(self.dim_head)
+
+ # # Normalize the attention scores to probabilities.
+ # attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+
+ if attention_mask is not None:
+ # add attention mask
+ attention_scores = attention_scores + attention_mask
+
+ # vdm
+ attention_scores = attention_scores - attention_scores.amax(dim = -1, keepdim = True).detach()
+
+ attention_probs = nn.functional.softmax(attention_scores, dim=-1)
+ # print(attention_probs[0][0])
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ # hidden_states = torch.matmul(attention_probs, value)
+ hidden_states = torch.einsum('... h i j, ... h j d -> ... h i d', attention_probs, value)
+ # print(hidden_states.shape)
+ # hidden_states = self.same_batch_dim_to_heads(hidden_states)
+ hidden_states = rearrange(hidden_states, 'b h f d -> b f (h d)')
+ # print(hidden_states.shape)
+ # exit()
+ return hidden_states
+
+class RelativePositionBias(nn.Module):
+ def __init__(
+ self,
+ heads=8,
+ num_buckets=32,
+ max_distance=128,
+ ):
+ super().__init__()
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
+ ret = 0
+ n = -relative_position
+
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).long()
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, n, device):
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
+ rel_pos = rearrange(k_pos, 'j -> 1 j') - rearrange(q_pos, 'i -> i 1')
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
+ values = self.relative_attention_bias(rp_bucket)
+ return rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
diff --git a/vsr/models/autoencoder_kl.py b/vsr/models/autoencoder_kl.py
new file mode 100644
index 0000000000000000000000000000000000000000..1983c3f1cefdbccf47659d79e4b4180fb9cc6455
--- /dev/null
+++ b/vsr/models/autoencoder_kl.py
@@ -0,0 +1,334 @@
+# Copyright 2023 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+from dataclasses import dataclass
+from typing import Optional, Tuple, Union
+
+import torch
+import torch.nn as nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.utils import BaseOutput
+
+try:
+ from diffusers.utils import apply_forward_hook
+except:
+ from diffusers.utils.accelerate_utils import apply_forward_hook
+
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.vae import Decoder, DecoderOutput, DiagonalGaussianDistribution, Encoder
+
+
+@dataclass
+class AutoencoderKLOutput(BaseOutput):
+ """
+ Output of AutoencoderKL encoding method.
+
+ Args:
+ latent_dist (`DiagonalGaussianDistribution`):
+ Encoded outputs of `Encoder` represented as the mean and logvar of `DiagonalGaussianDistribution`.
+ `DiagonalGaussianDistribution` allows for sampling latents from the distribution.
+ """
+
+ latent_dist: "DiagonalGaussianDistribution"
+
+
+class AutoencoderKL(ModelMixin, ConfigMixin):
+ r"""Variational Autoencoder (VAE) model with KL loss from the paper Auto-Encoding Variational Bayes by Diederik P. Kingma
+ and Max Welling.
+
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for the generic methods the library
+ implements for all the model (such as downloading or saving, etc.)
+
+ Parameters:
+ in_channels (int, *optional*, defaults to 3): Number of channels in the input image.
+ out_channels (int, *optional*, defaults to 3): Number of channels in the output.
+ down_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("DownEncoderBlock2D",)`): Tuple of downsample block types.
+ up_block_types (`Tuple[str]`, *optional*, defaults to :
+ obj:`("UpDecoderBlock2D",)`): Tuple of upsample block types.
+ block_out_channels (`Tuple[int]`, *optional*, defaults to :
+ obj:`(64,)`): Tuple of block output channels.
+ act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
+ latent_channels (`int`, *optional*, defaults to 4): Number of channels in the latent space.
+ sample_size (`int`, *optional*, defaults to `32`): TODO
+ scaling_factor (`float`, *optional*, defaults to 0.18215):
+ The component-wise standard deviation of the trained latent space computed using the first batch of the
+ training set. This is used to scale the latent space to have unit variance when training the diffusion
+ model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
+ diffusion model. When decoding, the latents are scaled back to the original scale with the formula: `z = 1
+ / scaling_factor * z`. For more details, refer to sections 4.3.2 and D.1 of the [High-Resolution Image
+ Synthesis with Latent Diffusion Models](https://arxiv.org/abs/2112.10752) paper.
+ """
+
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ in_channels: int = 3,
+ out_channels: int = 3,
+ down_block_types: Tuple[str] = ("DownEncoderBlock2D",),
+ up_block_types: Tuple[str] = ("UpDecoderBlock2D",),
+ block_out_channels: Tuple[int] = (64,),
+ layers_per_block: int = 1,
+ act_fn: str = "silu",
+ latent_channels: int = 4,
+ norm_num_groups: int = 32,
+ sample_size: int = 32,
+ scaling_factor: float = 0.18215,
+ ):
+ super().__init__()
+
+ # pass init params to Encoder
+ self.encoder = Encoder(
+ in_channels=in_channels,
+ out_channels=latent_channels,
+ down_block_types=down_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ act_fn=act_fn,
+ norm_num_groups=norm_num_groups,
+ double_z=True,
+ )
+
+ # pass init params to Decoder
+ self.decoder = Decoder(
+ in_channels=latent_channels,
+ out_channels=out_channels,
+ up_block_types=up_block_types,
+ block_out_channels=block_out_channels,
+ layers_per_block=layers_per_block,
+ norm_num_groups=norm_num_groups,
+ act_fn=act_fn,
+ )
+
+ self.quant_conv = nn.Conv2d(2 * latent_channels, 2 * latent_channels, 1)
+ self.post_quant_conv = nn.Conv2d(latent_channels, latent_channels, 1)
+
+ self.use_slicing = False
+ self.use_tiling = False
+
+ # only relevant if vae tiling is enabled
+ self.tile_sample_min_size = self.config.sample_size
+ sample_size = (
+ self.config.sample_size[0]
+ if isinstance(self.config.sample_size, (list, tuple))
+ else self.config.sample_size
+ )
+ self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
+ self.tile_overlap_factor = 0.25
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (Encoder, Decoder)):
+ module.gradient_checkpointing = value
+
+ def enable_tiling(self, use_tiling: bool = True):
+ r"""
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
+ compute decoding and encoding in several steps. This is useful to save a large amount of memory and to allow
+ the processing of larger images.
+ """
+ self.use_tiling = use_tiling
+
+ def disable_tiling(self):
+ r"""
+ Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
+ computing decoding in one step.
+ """
+ self.enable_tiling(False)
+
+ def enable_slicing(self):
+ r"""
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
+ """
+ self.use_slicing = True
+
+ def disable_slicing(self):
+ r"""
+ Disable sliced VAE decoding. If `enable_slicing` was previously invoked, this method will go back to computing
+ decoding in one step.
+ """
+ self.use_slicing = False
+
+ @apply_forward_hook
+ def encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ if self.use_tiling and (x.shape[-1] > self.tile_sample_min_size or x.shape[-2] > self.tile_sample_min_size):
+ return self.tiled_encode(x, return_dict=return_dict)
+
+ h = self.encoder(x)
+ moments = self.quant_conv(h)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def _decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ if self.use_tiling and (z.shape[-1] > self.tile_latent_min_size or z.shape[-2] > self.tile_latent_min_size):
+ return self.tiled_decode(z, return_dict=return_dict)
+
+ z = self.post_quant_conv(z)
+ dec = self.decoder(z)
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ @apply_forward_hook
+ def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ if self.use_slicing and z.shape[0] > 1:
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
+ decoded = torch.cat(decoded_slices)
+ else:
+ decoded = self._decode(z).sample
+
+ if not return_dict:
+ return (decoded,)
+
+ return DecoderOutput(sample=decoded)
+
+ def blend_v(self, a, b, blend_extent):
+ for y in range(min(a.shape[2], b.shape[2], blend_extent)):
+ b[:, :, y, :] = a[:, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, y, :] * (y / blend_extent)
+ return b
+
+ def blend_h(self, a, b, blend_extent):
+ for x in range(min(a.shape[3], b.shape[3], blend_extent)):
+ b[:, :, :, x] = a[:, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, x] * (x / blend_extent)
+ return b
+
+ def tiled_encode(self, x: torch.FloatTensor, return_dict: bool = True) -> AutoencoderKLOutput:
+ r"""Encode a batch of images using a tiled encoder.
+
+ Args:
+ When this option is enabled, the VAE will split the input tensor into tiles to compute encoding in several
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled encoding is:
+ different from non-tiled encoding due to each tile using a different encoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ look of the output, but they should be much less noticeable.
+ x (`torch.FloatTensor`): Input batch of images. return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`AutoencoderKLOutput`] instead of a plain tuple.
+ """
+ overlap_size = int(self.tile_sample_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_latent_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_latent_min_size - blend_extent
+
+ # Split the image into 512x512 tiles and encode them separately.
+ rows = []
+ for i in range(0, x.shape[2], overlap_size):
+ row = []
+ for j in range(0, x.shape[3], overlap_size):
+ tile = x[:, :, i : i + self.tile_sample_min_size, j : j + self.tile_sample_min_size]
+ tile = self.encoder(tile)
+ tile = self.quant_conv(tile)
+ row.append(tile)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ moments = torch.cat(result_rows, dim=2)
+ posterior = DiagonalGaussianDistribution(moments)
+
+ if not return_dict:
+ return (posterior,)
+
+ return AutoencoderKLOutput(latent_dist=posterior)
+
+ def tiled_decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""Decode a batch of images using a tiled decoder.
+
+ Args:
+ When this option is enabled, the VAE will split the input tensor into tiles to compute decoding in several
+ steps. This is useful to keep memory use constant regardless of image size. The end result of tiled decoding is:
+ different from non-tiled decoding due to each tile using a different decoder. To avoid tiling artifacts, the
+ tiles overlap and are blended together to form a smooth output. You may still see tile-sized changes in the
+ look of the output, but they should be much less noticeable.
+ z (`torch.FloatTensor`): Input batch of latent vectors. return_dict (`bool`, *optional*, defaults to
+ `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ overlap_size = int(self.tile_latent_min_size * (1 - self.tile_overlap_factor))
+ blend_extent = int(self.tile_sample_min_size * self.tile_overlap_factor)
+ row_limit = self.tile_sample_min_size - blend_extent
+
+ # Split z into overlapping 64x64 tiles and decode them separately.
+ # The tiles have an overlap to avoid seams between tiles.
+ rows = []
+ for i in range(0, z.shape[2], overlap_size):
+ row = []
+ for j in range(0, z.shape[3], overlap_size):
+ tile = z[:, :, i : i + self.tile_latent_min_size, j : j + self.tile_latent_min_size]
+ tile = self.post_quant_conv(tile)
+ decoded = self.decoder(tile)
+ row.append(decoded)
+ rows.append(row)
+ result_rows = []
+ for i, row in enumerate(rows):
+ result_row = []
+ for j, tile in enumerate(row):
+ # blend the above tile and the left tile
+ # to the current tile and add the current tile to the result row
+ if i > 0:
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent)
+ if j > 0:
+ tile = self.blend_h(row[j - 1], tile, blend_extent)
+ result_row.append(tile[:, :, :row_limit, :row_limit])
+ result_rows.append(torch.cat(result_row, dim=3))
+
+ dec = torch.cat(result_rows, dim=2)
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ sample_posterior: bool = False,
+ return_dict: bool = True,
+ generator: Optional[torch.Generator] = None,
+ ) -> Union[DecoderOutput, torch.FloatTensor]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): Input sample.
+ sample_posterior (`bool`, *optional*, defaults to `False`):
+ Whether to sample from the posterior.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`DecoderOutput`] instead of a plain tuple.
+ """
+ x = sample
+ posterior = self.encode(x).latent_dist
+ if sample_posterior:
+ z = posterior.sample(generator=generator)
+ else:
+ z = posterior.mode()
+ dec = self.decode(z).sample
+
+ if not return_dict:
+ return (dec,)
+
+ return DecoderOutput(sample=dec)
diff --git a/vsr/models/clip.py b/vsr/models/clip.py
new file mode 100644
index 0000000000000000000000000000000000000000..13ec05118f729cc27e6b7b4107a300ecfbdf2d77
--- /dev/null
+++ b/vsr/models/clip.py
@@ -0,0 +1,127 @@
+import numpy
+import torch.nn as nn
+from transformers import CLIPTokenizer, CLIPTextModel
+from diffusers import StableDiffusionUpscalePipeline
+
+"""
+Will encounter following warning:
+- This IS expected if you are initializing CLIPTextModel from the checkpoint of a model trained on another task
+or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
+- This IS NOT expected if you are initializing CLIPTextModel from the checkpoint of a model
+that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
+
+https://github.com/CompVis/stable-diffusion/issues/97
+according to this issue, this warning is safe.
+
+This is expected since the vision backbone of the CLIP model is not needed to run Stable Diffusion.
+You can safely ignore the warning, it is not an error.
+
+This clip usage is from U-ViT and same with Stable Diffusion.
+"""
+
+class AbstractEncoder(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def encode(self, *args, **kwargs):
+ raise NotImplementedError
+
+
+class FrozenCLIPEmbedder(AbstractEncoder):
+ """Uses the CLIP transformer encoder for text (from Hugging Face)"""
+ def __init__(self, device="cuda", max_length=77):
+ super().__init__()
+ # self.tokenizer = CLIPTokenizer.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
+ # self.text_encoder = CLIPTextModel.from_pretrained('laion/CLIP-ViT-H-14-laion2B-s32B-b79K')
+ # TBD: change to https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K/blob/main/config.json
+ # model_id = "stabilityai/stable-diffusion-x4-upscaler" # For VSR
+ # upscale_pipeline = StableDiffusionUpscalePipeline.from_pretrained(model_id)
+ upscale_pipeline = StableDiffusionUpscalePipeline.from_pretrained('./pretrained_models/upscaler4x')
+ self.tokenizer = upscale_pipeline.tokenizer
+ self.text_encoder = upscale_pipeline.text_encoder
+ self.device = device
+ self.max_length = max_length
+ self.freeze()
+
+ def freeze(self):
+ self.text_encoder = self.text_encoder.eval()
+ for param in self.parameters():
+ param.requires_grad = False
+
+ def forward(self, text):
+ batch_encoding = self.tokenizer(text, truncation=True, padding="max_length", max_length=self.max_length,
+ return_tensors="pt")
+ tokens = batch_encoding["input_ids"].to(self.device)
+ outputs = self.text_encoder(input_ids=tokens)
+
+ # return outputs.last_hidden_state
+ return outputs[0]
+
+ def encode(self, text):
+ return self(text)
+
+
+class TextEmbedder(nn.Module):
+ """
+ Embeds text prompt into vector representations. Also handles text dropout for classifier-free guidance.
+ """
+ def __init__(self, dropout_prob=0.1):
+ super().__init__()
+ self.text_encodder = FrozenCLIPEmbedder()
+ self.dropout_prob = dropout_prob
+
+ def token_drop(self, text_prompts, force_drop_ids=None):
+ """
+ Drops text to enable classifier-free guidance.
+ """
+ if force_drop_ids is None:
+ drop_ids = numpy.random.uniform(0, 1, len(text_prompts)) < self.dropout_prob
+ else:
+ drop_ids = force_drop_ids == 1
+ labels = list(numpy.where(drop_ids, "None", text_prompts))
+ # print(labels)
+ return labels
+
+ def forward(self, text_prompts, train, force_drop_ids=None):
+ use_dropout = self.dropout_prob > 0
+ if (train and use_dropout) or (force_drop_ids is not None):
+ text_prompts = self.token_drop(text_prompts, force_drop_ids)
+ embeddings = self.text_encodder(text_prompts)
+ return embeddings
+
+
+if __name__ == '__main__':
+
+ r"""
+ Returns:
+
+ Examples from CLIPTextModel:
+
+ ```python
+ >>> from transformers import AutoTokenizer, CLIPTextModel
+
+ >>> model = CLIPTextModel.from_pretrained("openai/clip-vit-base-patch32")
+ >>> tokenizer = AutoTokenizer.from_pretrained("openai/clip-vit-base-patch32")
+
+ >>> inputs = tokenizer(["a photo of a cat", "a photo of a dog"], padding=True, return_tensors="pt")
+
+ >>> outputs = model(**inputs)
+ >>> last_hidden_state = outputs.last_hidden_state
+ >>> pooled_output = outputs.pooler_output # pooled (EOS token) states
+ ```"""
+
+ import torch
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+
+ text_encoder = TextEmbedder(dropout_prob=0.00001).to(device)
+ text_encoder1 = FrozenCLIPEmbedder().to(device)
+
+ text_prompt = ["a photo of a cat", "a photo of a dog", 'a photo of a dog human']
+ # text_prompt = ('None', 'None', 'None')
+ output = text_encoder(text_prompts=text_prompt, train=True)
+ output1 = text_encoder1(text_prompt)
+ # print(output)
+ print(output.shape)
+ print(output1.shape)
+ print((output==output1).all())
\ No newline at end of file
diff --git a/vsr/models/diffusers_attention.py b/vsr/models/diffusers_attention.py
new file mode 100644
index 0000000000000000000000000000000000000000..c104074eeab835cec0a9ce95f1b5753d05a67566
--- /dev/null
+++ b/vsr/models/diffusers_attention.py
@@ -0,0 +1,983 @@
+# Copyright 2022 The HuggingFace Team. All rights reserved.
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+import math
+from dataclasses import dataclass
+from typing import Optional
+
+import torch
+import torch.nn.functional as F
+from torch import nn
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.embeddings import ImagePositionalEmbeddings
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention import FeedForward, AdaLayerNorm
+from diffusers.models.attention_processor import Attention
+
+
+@dataclass
+class Transformer2DModelOutput(BaseOutput):
+ """
+ Args:
+ sample (`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)` or `(batch size, num_vector_embeds - 1, num_latent_pixels)` if [`Transformer2DModel`] is discrete):
+ Hidden states conditioned on `encoder_hidden_states` input. If discrete, returns probability distributions
+ for the unnoised latent pixels.
+ """
+
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+class Transformer2DModel(ModelMixin, ConfigMixin):
+ """
+ Transformer model for image-like data. Takes either discrete (classes of vector embeddings) or continuous (actual
+ embeddings) inputs.
+
+ When input is continuous: First, project the input (aka embedding) and reshape to b, t, d. Then apply standard
+ transformer action. Finally, reshape to image.
+
+ When input is discrete: First, input (classes of latent pixels) is converted to embeddings and has positional
+ embeddings applied, see `ImagePositionalEmbeddings`. Then apply standard transformer action. Finally, predict
+ classes of unnoised image.
+
+ Note that it is assumed one of the input classes is the masked latent pixel. The predicted classes of the unnoised
+ image do not contain a prediction for the masked pixel as the unnoised image cannot be masked.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ Pass if the input is continuous. The number of channels in the input and output.
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
+ `ImagePositionalEmbeddings`.
+ num_vector_embeds (`int`, *optional*):
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
+ up to but not more than steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
+ """
+
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ use_linear_projection: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # 1. Transformer2DModel can process both standard continous images of shape `(batch_size, num_channels, width, height)` as well as quantized image embeddings of shape `(batch_size, num_image_vectors)`
+ # Define whether input is continuous or discrete depending on configuration
+ self.is_input_continuous = in_channels is not None
+ self.is_input_vectorized = num_vector_embeds is not None
+
+ if self.is_input_continuous and self.is_input_vectorized:
+ raise ValueError(
+ f"Cannot define both `in_channels`: {in_channels} and `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is None."
+ )
+ elif not self.is_input_continuous and not self.is_input_vectorized:
+ raise ValueError(
+ f"Has to define either `in_channels`: {in_channels} or `num_vector_embeds`: {num_vector_embeds}. Make"
+ " sure that either `in_channels` or `num_vector_embeds` is not None."
+ )
+
+ # 2. Define input layers
+ if self.is_input_continuous:
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ assert sample_size is not None, "Transformer2DModel over discrete input must provide sample_size"
+ assert num_vector_embeds is not None, "Transformer2DModel over discrete input must provide num_embed"
+
+ self.height = sample_size
+ self.width = sample_size
+ self.num_vector_embeds = num_vector_embeds
+ self.num_latent_pixels = self.height * self.width
+
+ self.latent_image_embedding = ImagePositionalEmbeddings(
+ num_embed=num_vector_embeds, embed_dim=inner_dim, height=self.height, width=self.width
+ )
+
+ # 3. Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ BasicTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if self.is_input_continuous:
+ if use_linear_projection:
+ self.proj_out = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+ elif self.is_input_vectorized:
+ self.norm_out = nn.LayerNorm(inner_dim)
+ self.out = nn.Linear(inner_dim, self.num_vector_embeds - 1)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
+ """
+ Args:
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
+ When continous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
+ hidden_states
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.long`, *optional*):
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
+ tensor.
+ """
+ # 1. Input
+ if self.is_input_continuous:
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+ elif self.is_input_vectorized:
+ hidden_states = self.latent_image_embedding(hidden_states)
+
+ # 2. Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timestep)
+
+ # 3. Output
+ if self.is_input_continuous:
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+ elif self.is_input_vectorized:
+ hidden_states = self.norm_out(hidden_states)
+ logits = self.out(hidden_states)
+ # (batch, self.num_vector_embeds - 1, self.num_latent_pixels)
+ logits = logits.permute(0, 2, 1)
+
+ # log(p(x_0))
+ output = F.log_softmax(logits.double(), dim=1).float()
+
+ if not return_dict:
+ return (output,)
+
+ return Transformer2DModelOutput(sample=output)
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other. Originally ported from here, but adapted
+ to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ Uses three q, k, v linear layers to compute attention.
+
+ Parameters:
+ channels (`int`): The number of channels in the input and output.
+ num_head_channels (`int`, *optional*):
+ The number of channels in each head. If None, then `num_heads` = 1.
+ norm_num_groups (`int`, *optional*, defaults to 32): The number of groups to use for group norm.
+ rescale_output_factor (`float`, *optional*, defaults to 1.0): The factor to rescale the output by.
+ eps (`float`, *optional*, defaults to 1e-5): The epsilon value to use for group norm.
+ """
+
+ # IMPORTANT;TODO(Patrick, William) - this class will be deprecated soon. Do not use it anymore
+
+ def __init__(
+ self,
+ channels: int,
+ num_head_channels: Optional[int] = None,
+ norm_num_groups: int = 32,
+ rescale_output_factor: float = 1.0,
+ eps: float = 1e-5,
+ ):
+ super().__init__()
+ self.channels = channels
+
+ self.num_heads = channels // num_head_channels if num_head_channels is not None else 1
+ self.num_head_size = num_head_channels
+ self.group_norm = nn.GroupNorm(num_channels=channels, num_groups=norm_num_groups, eps=eps, affine=True)
+
+ # define q,k,v as linear layers
+ self.query = nn.Linear(channels, channels)
+ self.key = nn.Linear(channels, channels)
+ self.value = nn.Linear(channels, channels)
+
+ self.rescale_output_factor = rescale_output_factor
+ self.proj_attn = nn.Linear(channels, channels, 1)
+
+ self._use_memory_efficient_attention_xformers = False
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None):
+ if not is_xformers_available():
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.num_heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.num_heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def forward(self, hidden_states):
+ residual = hidden_states
+ batch, channel, height, width = hidden_states.shape
+
+ # norm
+ hidden_states = self.group_norm(hidden_states)
+
+ hidden_states = hidden_states.view(batch, channel, height * width).transpose(1, 2)
+
+ # proj to q, k, v
+ query_proj = self.query(hidden_states)
+ key_proj = self.key(hidden_states)
+ value_proj = self.value(hidden_states)
+
+ scale = 1 / math.sqrt(self.channels / self.num_heads)
+
+ query_proj = self.reshape_heads_to_batch_dim(query_proj)
+ key_proj = self.reshape_heads_to_batch_dim(key_proj)
+ value_proj = self.reshape_heads_to_batch_dim(value_proj)
+
+ if self._use_memory_efficient_attention_xformers:
+ # Memory efficient attention
+ hidden_states = xformers.ops.memory_efficient_attention(query_proj, key_proj, value_proj, attn_bias=None)
+ hidden_states = hidden_states.to(query_proj.dtype)
+ else:
+ attention_scores = torch.baddbmm(
+ torch.empty(
+ query_proj.shape[0],
+ query_proj.shape[1],
+ key_proj.shape[1],
+ dtype=query_proj.dtype,
+ device=query_proj.device,
+ ),
+ query_proj,
+ key_proj.transpose(-1, -2),
+ beta=0,
+ alpha=scale,
+ )
+ attention_probs = torch.softmax(attention_scores.float(), dim=-1).type(attention_scores.dtype)
+ hidden_states = torch.bmm(attention_probs, value_proj)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+
+ # compute next hidden_states
+ hidden_states = self.proj_attn(hidden_states)
+
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch, channel, height, width)
+
+ # res connect and rescale
+ hidden_states = (hidden_states + residual) / self.rescale_output_factor
+ return hidden_states
+
+
+class BasicTransformerBlock(nn.Module):
+ r"""
+ A basic Transformer block.
+
+ Parameters:
+ dim (`int`): The number of channels in the input and output.
+ num_attention_heads (`int`): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The size of the encoder_hidden_states vector for cross attention.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm (:
+ obj: `int`, *optional*): The number of diffusion steps used during training. See `Transformer2DModel`.
+ attention_bias (:
+ obj: `bool`, *optional*, defaults to `False`): Configure if the attentions should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ num_attention_heads: int,
+ attention_head_dim: int,
+ dropout=0.0,
+ cross_attention_dim: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ attention_bias: bool = False,
+ only_cross_attention: bool = False,
+ upcast_attention: bool = False,
+ ):
+ super().__init__()
+ self.only_cross_attention = only_cross_attention
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+
+ # 1. Self-Attn
+ self.attn1 = CrossAttention(
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ cross_attention_dim=cross_attention_dim if only_cross_attention else None,
+ upcast_attention=upcast_attention,
+ ) # is a self-attention
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+
+ # 2. Cross-Attn
+ if cross_attention_dim is not None:
+ self.attn2 = CrossAttention(
+ query_dim=dim,
+ cross_attention_dim=cross_attention_dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+ ) # is self-attn if encoder_hidden_states is none
+ else:
+ self.attn2 = None
+
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ if cross_attention_dim is not None:
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+ else:
+ self.norm2 = None
+
+ # 3. Feed-forward
+ self.norm3 = nn.LayerNorm(dim)
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None):
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ self.attn1._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+ if self.attn2 is not None:
+ self.attn2._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None):
+ # 1. Self-Attention
+ norm_hidden_states = (
+ self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ )
+
+ if self.only_cross_attention:
+ hidden_states = (
+ self.attn1(norm_hidden_states, encoder_hidden_states, attention_mask=attention_mask) + hidden_states
+ )
+ else:
+ hidden_states = self.attn1(norm_hidden_states, attention_mask=attention_mask) + hidden_states
+
+ if self.attn2 is not None:
+ # 2. Cross-Attention
+ norm_hidden_states = (
+ self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ )
+ hidden_states = (
+ self.attn2(
+ norm_hidden_states, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ + hidden_states
+ )
+
+ # 3. Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ return hidden_states
+
+
+class CrossAttention(nn.Module):
+ r"""
+ A cross attention layer.
+
+ Parameters:
+ query_dim (`int`): The number of channels in the query.
+ cross_attention_dim (`int`, *optional*):
+ The number of channels in the encoder_hidden_states. If not given, defaults to `query_dim`.
+ heads (`int`, *optional*, defaults to 8): The number of heads to use for multi-head attention.
+ dim_head (`int`, *optional*, defaults to 64): The number of channels in each head.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ bias (`bool`, *optional*, defaults to False):
+ Set to `True` for the query, key, and value linear layers to contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ query_dim: int,
+ cross_attention_dim: Optional[int] = None,
+ heads: int = 8,
+ dim_head: int = 64,
+ dropout: float = 0.0,
+ bias=False,
+ upcast_attention: bool = False,
+ upcast_softmax: bool = False,
+ added_kv_proj_dim: Optional[int] = None,
+ norm_num_groups: Optional[int] = None,
+ ):
+ super().__init__()
+ inner_dim = dim_head * heads
+ cross_attention_dim = cross_attention_dim if cross_attention_dim is not None else query_dim
+ self.upcast_attention = upcast_attention
+ self.upcast_softmax = upcast_softmax
+
+ self.scale = dim_head**-0.5
+
+ self.heads = heads
+ # for slice_size > 0 the attention score computation
+ # is split across the batch axis to save memory
+ # You can set slice_size with `set_attention_slice`
+ self.sliceable_head_dim = heads
+ self._slice_size = None
+ self._use_memory_efficient_attention_xformers = False
+ self.added_kv_proj_dim = added_kv_proj_dim
+
+ if norm_num_groups is not None:
+ self.group_norm = nn.GroupNorm(num_channels=inner_dim, num_groups=norm_num_groups, eps=1e-5, affine=True)
+ else:
+ self.group_norm = None
+
+ self.to_q = nn.Linear(query_dim, inner_dim, bias=bias)
+ self.to_k = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+ self.to_v = nn.Linear(cross_attention_dim, inner_dim, bias=bias)
+
+ if self.added_kv_proj_dim is not None:
+ self.add_k_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+ self.add_v_proj = nn.Linear(added_kv_proj_dim, cross_attention_dim)
+
+ self.to_out = nn.ModuleList([])
+ self.to_out.append(nn.Linear(inner_dim, query_dim))
+ self.to_out.append(nn.Dropout(dropout))
+
+ def reshape_heads_to_batch_dim(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size, seq_len, head_size, dim // head_size)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size * head_size, seq_len, dim // head_size)
+ return tensor
+
+ def reshape_batch_dim_to_heads(self, tensor):
+ batch_size, seq_len, dim = tensor.shape
+ head_size = self.heads
+ tensor = tensor.reshape(batch_size // head_size, head_size, seq_len, dim)
+ tensor = tensor.permute(0, 2, 1, 3).reshape(batch_size // head_size, seq_len, dim * head_size)
+ return tensor
+
+ def set_attention_slice(self, slice_size):
+ if slice_size is not None and slice_size > self.sliceable_head_dim:
+ raise ValueError(f"slice_size {slice_size} has to be smaller or equal to {self.sliceable_head_dim}.")
+
+ self._slice_size = slice_size
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None):
+ batch_size, sequence_length, _ = hidden_states.shape
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ key = self.to_k(hidden_states)
+ value = self.to_v(hidden_states)
+ encoder_hidden_states_key_proj = self.add_k_proj(encoder_hidden_states)
+ encoder_hidden_states_value_proj = self.add_v_proj(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+ encoder_hidden_states_key_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_key_proj)
+ encoder_hidden_states_value_proj = self.reshape_heads_to_batch_dim(encoder_hidden_states_value_proj)
+
+ key = torch.concat([encoder_hidden_states_key_proj, key], dim=1)
+ value = torch.concat([encoder_hidden_states_value_proj, value], dim=1)
+ else:
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+ return hidden_states
+
+ def _attention(self, query, key, value, attention_mask=None):
+ if self.upcast_attention:
+ query = query.float()
+ key = key.float()
+
+ attention_scores = torch.baddbmm(
+ torch.empty(query.shape[0], query.shape[1], key.shape[1], dtype=query.dtype, device=query.device),
+ query,
+ key.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attention_scores = attention_scores + attention_mask
+
+ if self.upcast_softmax:
+ attention_scores = attention_scores.float()
+
+ attention_probs = attention_scores.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attention_probs = attention_probs.to(value.dtype)
+
+ # compute attention output
+ hidden_states = torch.bmm(attention_probs, value)
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _sliced_attention(self, query, key, value, sequence_length, dim, attention_mask):
+ batch_size_attention = query.shape[0]
+ hidden_states = torch.zeros(
+ (batch_size_attention, sequence_length, dim // self.heads), device=query.device, dtype=query.dtype
+ )
+ slice_size = self._slice_size if self._slice_size is not None else hidden_states.shape[0]
+ for i in range(hidden_states.shape[0] // slice_size):
+ start_idx = i * slice_size
+ end_idx = (i + 1) * slice_size
+
+ query_slice = query[start_idx:end_idx]
+ key_slice = key[start_idx:end_idx]
+
+ if self.upcast_attention:
+ query_slice = query_slice.float()
+ key_slice = key_slice.float()
+
+ attn_slice = torch.baddbmm(
+ torch.empty(slice_size, query.shape[1], key.shape[1], dtype=query_slice.dtype, device=query.device),
+ query_slice,
+ key_slice.transpose(-1, -2),
+ beta=0,
+ alpha=self.scale,
+ )
+
+ if attention_mask is not None:
+ attn_slice = attn_slice + attention_mask[start_idx:end_idx]
+
+ if self.upcast_softmax:
+ attn_slice = attn_slice.float()
+
+ attn_slice = attn_slice.softmax(dim=-1)
+
+ # cast back to the original dtype
+ attn_slice = attn_slice.to(value.dtype)
+ attn_slice = torch.bmm(attn_slice, value[start_idx:end_idx])
+
+ hidden_states[start_idx:end_idx] = attn_slice
+
+ # reshape hidden_states
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+ def _memory_efficient_attention_xformers(self, query, key, value, attention_mask):
+ query = query.contiguous()
+ key = key.contiguous()
+ value = value.contiguous()
+ hidden_states = xformers.ops.memory_efficient_attention(query, key, value, attn_bias=attention_mask)
+ hidden_states = self.reshape_batch_dim_to_heads(hidden_states)
+ return hidden_states
+
+
+class FeedForward(nn.Module):
+ r"""
+ A feed-forward layer.
+
+ Parameters:
+ dim (`int`): The number of channels in the input.
+ dim_out (`int`, *optional*): The number of channels in the output. If not given, defaults to `dim`.
+ mult (`int`, *optional*, defaults to 4): The multiplier to use for the hidden dimension.
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ """
+
+ def __init__(
+ self,
+ dim: int,
+ dim_out: Optional[int] = None,
+ mult: int = 4,
+ dropout: float = 0.0,
+ activation_fn: str = "geglu",
+ ):
+ super().__init__()
+ inner_dim = int(dim * mult)
+ dim_out = dim_out if dim_out is not None else dim
+
+ if activation_fn == "gelu":
+ act_fn = GELU(dim, inner_dim)
+ elif activation_fn == "geglu":
+ act_fn = GEGLU(dim, inner_dim)
+ elif activation_fn == "geglu-approximate":
+ act_fn = ApproximateGELU(dim, inner_dim)
+
+ self.net = nn.ModuleList([])
+ # project in
+ self.net.append(act_fn)
+ # project dropout
+ self.net.append(nn.Dropout(dropout))
+ # project out
+ self.net.append(nn.Linear(inner_dim, dim_out))
+
+ def forward(self, hidden_states):
+ for module in self.net:
+ hidden_states = module(hidden_states)
+ return hidden_states
+
+
+class GELU(nn.Module):
+ r"""
+ GELU activation function
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+
+ def gelu(self, gate):
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states = self.proj(hidden_states)
+ hidden_states = self.gelu(hidden_states)
+ return hidden_states
+
+
+# feedforward
+class GEGLU(nn.Module):
+ r"""
+ A variant of the gated linear unit activation function from https://arxiv.org/abs/2002.05202.
+
+ Parameters:
+ dim_in (`int`): The number of channels in the input.
+ dim_out (`int`): The number of channels in the output.
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out * 2)
+
+ def gelu(self, gate):
+ if gate.device.type != "mps":
+ return F.gelu(gate)
+ # mps: gelu is not implemented for float16
+ return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype)
+
+ def forward(self, hidden_states):
+ hidden_states, gate = self.proj(hidden_states).chunk(2, dim=-1)
+ return hidden_states * self.gelu(gate)
+
+
+class ApproximateGELU(nn.Module):
+ """
+ The approximate form of Gaussian Error Linear Unit (GELU)
+
+ For more details, see section 2: https://arxiv.org/abs/1606.08415
+ """
+
+ def __init__(self, dim_in: int, dim_out: int):
+ super().__init__()
+ self.proj = nn.Linear(dim_in, dim_out)
+
+ def forward(self, x):
+ x = self.proj(x)
+ return x * torch.sigmoid(1.702 * x)
+
+
+class AdaLayerNorm(nn.Module):
+ """
+ Norm layer modified to incorporate timestep embeddings.
+ """
+
+ def __init__(self, embedding_dim, num_embeddings):
+ super().__init__()
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
+
+ def forward(self, x, timestep):
+ emb = self.linear(self.silu(self.emb(timestep)))
+ scale, shift = torch.chunk(emb, 2)
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
+
+class DualTransformer2DModel(nn.Module):
+ """
+ Dual transformer wrapper that combines two `Transformer2DModel`s for mixed inference.
+
+ Parameters:
+ num_attention_heads (`int`, *optional*, defaults to 16): The number of heads to use for multi-head attention.
+ attention_head_dim (`int`, *optional*, defaults to 88): The number of channels in each head.
+ in_channels (`int`, *optional*):
+ Pass if the input is continuous. The number of channels in the input and output.
+ num_layers (`int`, *optional*, defaults to 1): The number of layers of Transformer blocks to use.
+ dropout (`float`, *optional*, defaults to 0.1): The dropout probability to use.
+ cross_attention_dim (`int`, *optional*): The number of encoder_hidden_states dimensions to use.
+ sample_size (`int`, *optional*): Pass if the input is discrete. The width of the latent images.
+ Note that this is fixed at training time as it is used for learning a number of position embeddings. See
+ `ImagePositionalEmbeddings`.
+ num_vector_embeds (`int`, *optional*):
+ Pass if the input is discrete. The number of classes of the vector embeddings of the latent pixels.
+ Includes the class for the masked latent pixel.
+ activation_fn (`str`, *optional*, defaults to `"geglu"`): Activation function to be used in feed-forward.
+ num_embeds_ada_norm ( `int`, *optional*): Pass if at least one of the norm_layers is `AdaLayerNorm`.
+ The number of diffusion steps used during training. Note that this is fixed at training time as it is used
+ to learn a number of embeddings that are added to the hidden states. During inference, you can denoise for
+ up to but not more than steps than `num_embeds_ada_norm`.
+ attention_bias (`bool`, *optional*):
+ Configure if the TransformerBlocks' attention should contain a bias parameter.
+ """
+
+ def __init__(
+ self,
+ num_attention_heads: int = 16,
+ attention_head_dim: int = 88,
+ in_channels: Optional[int] = None,
+ num_layers: int = 1,
+ dropout: float = 0.0,
+ norm_num_groups: int = 32,
+ cross_attention_dim: Optional[int] = None,
+ attention_bias: bool = False,
+ sample_size: Optional[int] = None,
+ num_vector_embeds: Optional[int] = None,
+ activation_fn: str = "geglu",
+ num_embeds_ada_norm: Optional[int] = None,
+ ):
+ super().__init__()
+ self.transformers = nn.ModuleList(
+ [
+ Transformer2DModel(
+ num_attention_heads=num_attention_heads,
+ attention_head_dim=attention_head_dim,
+ in_channels=in_channels,
+ num_layers=num_layers,
+ dropout=dropout,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_bias=attention_bias,
+ sample_size=sample_size,
+ num_vector_embeds=num_vector_embeds,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ )
+ for _ in range(2)
+ ]
+ )
+
+ # Variables that can be set by a pipeline:
+
+ # The ratio of transformer1 to transformer2's output states to be combined during inference
+ self.mix_ratio = 0.5
+
+ # The shape of `encoder_hidden_states` is expected to be
+ # `(batch_size, condition_lengths[0]+condition_lengths[1], num_features)`
+ self.condition_lengths = [77, 257]
+
+ # Which transformer to use to encode which condition.
+ # E.g. `(1, 0)` means that we'll use `transformers[1](conditions[0])` and `transformers[0](conditions[1])`
+ self.transformer_index_for_condition = [1, 0]
+
+ def forward(
+ self, hidden_states, encoder_hidden_states, timestep=None, attention_mask=None, return_dict: bool = True
+ ):
+ """
+ Args:
+ hidden_states ( When discrete, `torch.LongTensor` of shape `(batch size, num latent pixels)`.
+ When continuous, `torch.FloatTensor` of shape `(batch size, channel, height, width)`): Input
+ hidden_states
+ encoder_hidden_states ( `torch.LongTensor` of shape `(batch size, encoder_hidden_states dim)`, *optional*):
+ Conditional embeddings for cross attention layer. If not given, cross-attention defaults to
+ self-attention.
+ timestep ( `torch.long`, *optional*):
+ Optional timestep to be applied as an embedding in AdaLayerNorm's. Used to indicate denoising step.
+ attention_mask (`torch.FloatTensor`, *optional*):
+ Optional attention mask to be applied in CrossAttention
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.attention.Transformer2DModelOutput`] or `tuple`: [`~models.attention.Transformer2DModelOutput`]
+ if `return_dict` is True, otherwise a `tuple`. When returning a tuple, the first element is the sample
+ tensor.
+ """
+ input_states = hidden_states
+
+ encoded_states = []
+ tokens_start = 0
+ # attention_mask is not used yet
+ for i in range(2):
+ # for each of the two transformers, pass the corresponding condition tokens
+ condition_state = encoder_hidden_states[:, tokens_start : tokens_start + self.condition_lengths[i]]
+ transformer_index = self.transformer_index_for_condition[i]
+ encoded_state = self.transformers[transformer_index](
+ input_states,
+ encoder_hidden_states=condition_state,
+ timestep=timestep,
+ return_dict=False,
+ )[0]
+ encoded_states.append(encoded_state - input_states)
+ tokens_start += self.condition_lengths[i]
+
+ output_states = encoded_states[0] * self.mix_ratio + encoded_states[1] * (1 - self.mix_ratio)
+ output_states = output_states + input_states
+
+ if not return_dict:
+ return (output_states,)
+
+ return Transformer2DModelOutput(sample=output_states)
\ No newline at end of file
diff --git a/vsr/models/pipeline_stable_diffusion_upscale_video_3d.py b/vsr/models/pipeline_stable_diffusion_upscale_video_3d.py
new file mode 100644
index 0000000000000000000000000000000000000000..e83ca6bf4c7cfa75c260e3d1cc088f8bdc5eb240
--- /dev/null
+++ b/vsr/models/pipeline_stable_diffusion_upscale_video_3d.py
@@ -0,0 +1,780 @@
+
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import inspect
+from typing import Any, Callable, List, Optional, Union
+
+import numpy as np
+import math
+import PIL
+import torch
+import torch.nn.functional as F
+from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer
+
+from diffusers.loaders import TextualInversionLoaderMixin
+from diffusers.models import AutoencoderKL, UNet2DConditionModel
+from diffusers.schedulers import DDPMScheduler
+# from diffusers.schedulers import DDIMScheduler
+from diffusion.scheduling_ddim import DDIMScheduler
+
+from diffusers.utils import deprecate, is_accelerate_available, is_accelerate_version, logging
+
+try:
+ from diffusers.utils import randn_tensor
+except:
+ from diffusers.utils.torch_utils import randn_tensor
+
+from diffusers.pipelines.pipeline_utils import DiffusionPipeline
+from diffusers.pipelines.stable_diffusion import StableDiffusionPipelineOutput
+
+from einops import rearrange
+
+# from datasets.data_utils import filter2D
+# from datasets.degradations import random_mixed_kernels, bivariate_Gaussian
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+
+def preprocess(image):
+ if isinstance(image, torch.Tensor):
+ return image
+ elif isinstance(image, PIL.Image.Image):
+ image = [image]
+
+ if isinstance(image[0], PIL.Image.Image):
+ w, h = image[0].size
+ w, h = (x - x % 64 for x in (w, h)) # resize to integer multiple of 64
+
+ image = [np.array(i.resize((w, h)))[None, :] for i in image]
+ image = np.concatenate(image, axis=0)
+ image = np.array(image).astype(np.float32) / 255.0
+ image = image.transpose(0, 3, 1, 2)
+ image = 2.0 * image - 1.0
+ image = torch.from_numpy(image)
+ elif isinstance(image[0], torch.Tensor):
+ image = torch.cat(image, dim=0)
+ return image
+
+
+class StableDiffusionUpscalePipeline(DiffusionPipeline, TextualInversionLoaderMixin):
+ _optional_components = ["feature_extractor"]
+
+ def __init__(
+ self,
+ vae: AutoencoderKL,
+ text_encoder: CLIPTextModel,
+ tokenizer: CLIPTokenizer,
+ unet: UNet2DConditionModel,
+ low_res_scheduler: DDPMScheduler,
+ # scheduler: KarrasDiffusionSchedulers,
+ scheduler: DDIMScheduler,
+ feature_extractor: Optional[CLIPImageProcessor] = None,
+ max_noise_level: int = 350,
+ ):
+ super().__init__()
+
+ if hasattr(
+ vae, "config"
+ ): # check if vae has a config attribute `scaling_factor` and if it is set to 0.08333, else set it to 0.08333 and deprecate
+ is_vae_scaling_factor_set_to_0_08333 = (
+ hasattr(vae.config, "scaling_factor") and vae.config.scaling_factor == 0.08333
+ )
+ if not is_vae_scaling_factor_set_to_0_08333:
+ deprecation_message = (
+ "The configuration file of the vae does not contain `scaling_factor` or it is set to"
+ f" {vae.config.scaling_factor}, which seems highly unlikely. If your checkpoint is a fine-tuned"
+ " version of `stabilityai/stable-diffusion-x4-upscaler` you should change 'scaling_factor' to"
+ " 0.08333 Please make sure to update the config accordingly, as not doing so might lead to"
+ " incorrect results in future versions. If you have downloaded this checkpoint from the Hugging"
+ " Face Hub, it would be very nice if you could open a Pull Request for the `vae/config.json` file"
+ )
+ deprecate("wrong scaling_factor", "1.0.0", deprecation_message, standard_warn=False)
+ vae.register_to_config(scaling_factor=0.08333)
+ # TODO: remove
+ print(f'=============vae.config.scaling_factor: {vae.config.scaling_factor}==================')
+
+ self.register_modules(
+ vae=vae,
+ text_encoder=text_encoder,
+ tokenizer=tokenizer,
+ unet=unet,
+ low_res_scheduler=low_res_scheduler,
+ scheduler=scheduler,
+ feature_extractor=feature_extractor,
+ )
+ self.register_to_config(max_noise_level=max_noise_level)
+
+ def enable_sequential_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
+ text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
+ `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
+ """
+ if is_accelerate_available():
+ from accelerate import cpu_offload
+ else:
+ raise ImportError("Please install accelerate via `pip install accelerate`")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ cpu_offload(cpu_offloaded_model, device)
+
+ def enable_model_cpu_offload(self, gpu_id=0):
+ r"""
+ Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
+ to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
+ method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
+ `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
+ """
+ if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
+ from accelerate import cpu_offload_with_hook
+ else:
+ raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
+
+ device = torch.device(f"cuda:{gpu_id}")
+
+ if self.device.type != "cpu":
+ self.to("cpu", silence_dtype_warnings=True)
+ torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
+
+ hook = None
+ for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
+ if cpu_offloaded_model is not None:
+ _, hook = cpu_offload_with_hook(cpu_offloaded_model, device, prev_module_hook=hook)
+
+ # We'll offload the last model manually.
+ self.final_offload_hook = hook
+
+ @property
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._execution_device
+ def _execution_device(self):
+ r"""
+ Returns the device on which the pipeline's models will be executed. After calling
+ `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
+ hooks.
+ """
+ if not hasattr(self.unet, "_hf_hook"):
+ return self.device
+ for module in self.unet.modules():
+ if (
+ hasattr(module, "_hf_hook")
+ and hasattr(module._hf_hook, "execution_device")
+ and module._hf_hook.execution_device is not None
+ ):
+ return torch.device(module._hf_hook.execution_device)
+ return self.device
+
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline._encode_prompt
+ def _encode_prompt(
+ self,
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt=None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ ):
+ r"""
+ Encodes the prompt into text encoder hidden states.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ prompt to be encoded
+ device: (`torch.device`):
+ torch device
+ num_images_per_prompt (`int`):
+ number of images that should be generated per prompt
+ do_classifier_free_guidance (`bool`):
+ whether to use classifier free guidance or not
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
+ less than `1`).
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ """
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ if prompt_embeds is None:
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ prompt = self.maybe_convert_prompt(prompt, self.tokenizer)
+
+ text_inputs = self.tokenizer(
+ prompt,
+ padding="max_length",
+ max_length=self.tokenizer.model_max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+ text_input_ids = text_inputs.input_ids
+ untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids
+
+ if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
+ text_input_ids, untruncated_ids
+ ):
+ removed_text = self.tokenizer.batch_decode(
+ untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
+ )
+ logger.warning(
+ "The following part of your input was truncated because CLIP can only handle sequences up to"
+ f" {self.tokenizer.model_max_length} tokens: {removed_text}"
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = text_inputs.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ prompt_embeds = self.text_encoder(
+ text_input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ prompt_embeds = prompt_embeds[0]
+
+ prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ bs_embed, seq_len, _ = prompt_embeds.shape
+ # duplicate text embeddings for each generation per prompt, using mps friendly method
+ prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ prompt_embeds = prompt_embeds.view(bs_embed * num_images_per_prompt, seq_len, -1)
+
+ # get unconditional embeddings for classifier free guidance
+ if do_classifier_free_guidance and negative_prompt_embeds is None:
+ uncond_tokens: List[str]
+ if negative_prompt is None:
+ uncond_tokens = [""] * batch_size
+ elif type(prompt) is not type(negative_prompt):
+ raise TypeError(
+ f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
+ f" {type(prompt)}."
+ )
+ elif isinstance(negative_prompt, str):
+ uncond_tokens = [negative_prompt]
+ elif batch_size != len(negative_prompt):
+ raise ValueError(
+ f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
+ f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
+ " the batch size of `prompt`."
+ )
+ else:
+ uncond_tokens = negative_prompt
+
+ # textual inversion: procecss multi-vector tokens if necessary
+ if isinstance(self, TextualInversionLoaderMixin):
+ uncond_tokens = self.maybe_convert_prompt(uncond_tokens, self.tokenizer)
+
+ max_length = prompt_embeds.shape[1]
+ uncond_input = self.tokenizer(
+ uncond_tokens,
+ padding="max_length",
+ max_length=max_length,
+ truncation=True,
+ return_tensors="pt",
+ )
+
+ if hasattr(self.text_encoder.config, "use_attention_mask") and self.text_encoder.config.use_attention_mask:
+ attention_mask = uncond_input.attention_mask.to(device)
+ else:
+ attention_mask = None
+
+ negative_prompt_embeds = self.text_encoder(
+ uncond_input.input_ids.to(device),
+ attention_mask=attention_mask,
+ )
+ negative_prompt_embeds = negative_prompt_embeds[0]
+
+ if do_classifier_free_guidance:
+ # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
+ seq_len = negative_prompt_embeds.shape[1]
+
+ negative_prompt_embeds = negative_prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
+
+ negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_images_per_prompt, 1)
+ negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_images_per_prompt, seq_len, -1)
+
+ # For classifier free guidance, we need to do two forward passes.
+ # Here we concatenate the unconditional and text embeddings into a single batch
+ # to avoid doing two forward passes
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
+
+ return prompt_embeds
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.prepare_extra_step_kwargs
+ def prepare_extra_step_kwargs(self, generator, eta):
+ # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
+ # eta (Ξ·) is only used with the DDIMScheduler, it will be ignored for other schedulers.
+ # eta corresponds to Ξ· in DDIM paper: https://arxiv.org/abs/2010.02502
+ # and should be between [0, 1]
+
+ accepts_eta = "eta" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ extra_step_kwargs = {}
+ if accepts_eta:
+ extra_step_kwargs["eta"] = eta
+
+ # check if the scheduler accepts generator
+ accepts_generator = "generator" in set(inspect.signature(self.scheduler.step).parameters.keys())
+ if accepts_generator:
+ extra_step_kwargs["generator"] = generator
+ return extra_step_kwargs
+
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.StableDiffusionPipeline.decode_latents
+ def decode_latents(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = (image / 2 + 0.5).clamp(0, 1)
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
+ image = image.cpu().permute(0, 2, 3, 1).float().numpy()
+ return image
+
+ def decode_latents_vsr(self, latents):
+ latents = 1 / self.vae.config.scaling_factor * latents
+ image = self.vae.decode(latents).sample
+ image = image.clamp(-1, 1).cpu()
+ return image
+
+ def check_inputs(
+ self,
+ prompt,
+ image,
+ noise_level,
+ callback_steps,
+ negative_prompt=None,
+ prompt_embeds=None,
+ negative_prompt_embeds=None,
+ ):
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ if prompt is not None and prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to"
+ " only forward one of the two."
+ )
+ elif prompt is None and prompt_embeds is None:
+ raise ValueError(
+ "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined."
+ )
+ elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)):
+ raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}")
+
+ if negative_prompt is not None and negative_prompt_embeds is not None:
+ raise ValueError(
+ f"Cannot forward both `negative_prompt`: {negative_prompt} and `negative_prompt_embeds`:"
+ f" {negative_prompt_embeds}. Please make sure to only forward one of the two."
+ )
+
+ if prompt_embeds is not None and negative_prompt_embeds is not None:
+ if prompt_embeds.shape != negative_prompt_embeds.shape:
+ raise ValueError(
+ "`prompt_embeds` and `negative_prompt_embeds` must have the same shape when passed directly, but"
+ f" got: `prompt_embeds` {prompt_embeds.shape} != `negative_prompt_embeds`"
+ f" {negative_prompt_embeds.shape}."
+ )
+
+ if (
+ not isinstance(image, torch.Tensor)
+ and not isinstance(image, PIL.Image.Image)
+ and not isinstance(image, list)
+ ):
+ raise ValueError(
+ f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or `list` but is {type(image)}"
+ )
+
+ # verify batch size of prompt and image are same if image is a list or tensor
+ if isinstance(image, list) or isinstance(image, torch.Tensor):
+ if isinstance(prompt, str):
+ batch_size = 1
+ else:
+ batch_size = len(prompt)
+ if isinstance(image, list):
+ image_batch_size = len(image)
+ else:
+ image_batch_size = image.shape[0]
+ if batch_size != image_batch_size:
+ raise ValueError(
+ f"`prompt` has batch size {batch_size} and `image` has batch size {image_batch_size}."
+ " Please make sure that passed `prompt` matches the batch size of `image`."
+ )
+
+ # check noise level
+ if noise_level > self.config.max_noise_level:
+ raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {noise_level}")
+
+ if (callback_steps is None) or (
+ callback_steps is not None and (not isinstance(callback_steps, int) or callback_steps <= 0)
+ ):
+ raise ValueError(
+ f"`callback_steps` has to be a positive integer but is {callback_steps} of type"
+ f" {type(callback_steps)}."
+ )
+
+ def prepare_latents_3d(self, batch_size, num_channels_latents, seq_len, height, width, dtype, device, generator, latents=None):
+ shape = (batch_size, num_channels_latents, seq_len, height, width)
+ if latents is None:
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ else:
+ if latents.shape != shape:
+ raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
+ latents = latents.to(device)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = latents * self.scheduler.init_noise_sigma
+ return latents
+
+ def get_timesteps(self, num_inference_steps, strength, device):
+ # get the original timestep using init_timestep
+ init_timestep = min(int(num_inference_steps * strength), num_inference_steps)
+ t_start = max(num_inference_steps - init_timestep, 0)
+ timesteps = self.scheduler.timesteps[t_start * self.scheduler.order :]
+
+ return timesteps, num_inference_steps - t_start
+
+ def prepare_latents_inversion(self, image, timestep, batch_size, num_images_per_prompt, dtype, device, generator=None):
+
+ image = image.to(device=device, dtype=dtype)
+ batch_size = batch_size * num_images_per_prompt
+
+ b = image.shape[0]
+ image = rearrange(image, 'b c t h w -> (b t) c h w').contiguous()
+ image = F.interpolate(image, scale_factor=4, mode='bicubic')
+ image = image.to(dtype=torch.float32)
+ init_latents = self.vae.encode(image).latent_dist.sample(generator)
+ torch.cuda.empty_cache()
+ init_latents = rearrange(init_latents, '(b t) c h w -> b c t h w', b=b).contiguous()
+
+ init_latents = self.vae.config.scaling_factor * init_latents
+ init_latents = init_latents.to(dtype=torch.float16)
+
+ # add noise
+ shape = init_latents.shape
+ noise = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
+ # get latents
+ init_latents = self.scheduler.add_noise(init_latents, noise, timestep)
+ # DEBUG
+ # init_latents = noise
+ print('timestep', timestep)
+
+ # scale the initial noise by the standard deviation required by the scheduler
+ latents = init_latents * self.scheduler.init_noise_sigma
+ return latents
+
+ @torch.no_grad()
+ def __call__(
+ self,
+ prompt: Union[str, List[str]] = None,
+ image: Union[torch.FloatTensor, PIL.Image.Image, List[PIL.Image.Image]] = None,
+ num_inference_steps: int = 75,
+ guidance_scale: float = 9.0,
+ noise_level: int = 20,
+ negative_prompt: Optional[Union[str, List[str]]] = None,
+ num_images_per_prompt: Optional[int] = 1,
+ eta: float = 0.0,
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
+ latents: Optional[torch.FloatTensor] = None,
+ prompt_embeds: Optional[torch.FloatTensor] = None,
+ negative_prompt_embeds: Optional[torch.FloatTensor] = None,
+ return_dict: bool = True,
+ callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
+ callback_steps: int = 1,
+ ):
+ r"""
+ Function invoked when calling the pipeline for generation.
+
+ Args:
+ prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`.
+ instead.
+ image (`PIL.Image.Image` or List[`PIL.Image.Image`] or `torch.FloatTensor`):
+ `Image`, or tensor representing an image batch which will be upscaled. *
+ num_inference_steps (`int`, *optional*, defaults to 50):
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
+ expense of slower inference.
+ guidance_scale (`float`, *optional*, defaults to 7.5):
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
+ 1`. Higher guidance scale encourages to generate images that are closely linked to the text `prompt`,
+ usually at the expense of lower image quality.
+ negative_prompt (`str` or `List[str]`, *optional*):
+ The prompt or prompts not to guide the image generation. If not defined, one has to pass
+ `negative_prompt_embeds`. instead. Ignored when not using guidance (i.e., ignored if `guidance_scale`
+ is less than `1`).
+ num_images_per_prompt (`int`, *optional*, defaults to 1):
+ The number of images to generate per prompt.
+ eta (`float`, *optional*, defaults to 0.0):
+ Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
+ [`schedulers.DDIMScheduler`], will be ignored for others.
+ generator (`torch.Generator`, *optional*):
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
+ to make generation deterministic.
+ latents (`torch.FloatTensor`, *optional*):
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for image
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
+ tensor will ge generated by sampling using the supplied random `generator`.
+ prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
+ provided, text embeddings will be generated from `prompt` input argument.
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
+ argument.
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
+ plain tuple.
+ callback (`Callable`, *optional*):
+ A function that will be called every `callback_steps` steps during inference. The function will be
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
+ callback_steps (`int`, *optional*, defaults to 1):
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
+ called at every step.
+
+ Examples:
+ ```py
+ >>> import requests
+ >>> from PIL import Image
+ >>> from io import BytesIO
+ >>> from diffusers import StableDiffusionUpscalePipeline
+ >>> import torch
+
+ >>> # load model and scheduler
+ >>> model_id = "stabilityai/stable-diffusion-x4-upscaler"
+ >>> pipeline = StableDiffusionUpscalePipeline.from_pretrained(
+ ... model_id, revision="fp16", torch_dtype=torch.float16
+ ... )
+ >>> pipeline = pipeline.to("cuda")
+
+ >>> # let's download an image
+ >>> url = "https://huggingface.co/datasets/hf-internal-testing/diffusers-images/resolve/main/sd2-upscale/low_res_cat.png"
+ >>> response = requests.get(url)
+ >>> low_res_img = Image.open(BytesIO(response.content)).convert("RGB")
+ >>> low_res_img = low_res_img.resize((128, 128))
+ >>> prompt = "a white cat"
+
+ >>> upscaled_image = pipeline(prompt=prompt, image=low_res_img).images[0]
+ >>> upscaled_image.save("upsampled_cat.png")
+ ```
+ """
+
+ # 1. Check inputs
+ self.check_inputs(
+ prompt,
+ image,
+ noise_level,
+ callback_steps,
+ negative_prompt,
+ prompt_embeds,
+ negative_prompt_embeds,
+ )
+
+ if image is None:
+ raise ValueError("`image` input cannot be undefined.")
+
+ # 2. Define call parameters
+ if prompt is not None and isinstance(prompt, str):
+ batch_size = 1
+ elif prompt is not None and isinstance(prompt, list):
+ batch_size = len(prompt)
+ else:
+ batch_size = prompt_embeds.shape[0]
+
+ device = self._execution_device
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
+ # corresponds to doing no classifier free guidance.
+ do_classifier_free_guidance = guidance_scale > 1.0
+
+ # 3. Encode input prompt
+ prompt_embeds = self._encode_prompt(
+ prompt,
+ device,
+ num_images_per_prompt,
+ do_classifier_free_guidance,
+ negative_prompt,
+ prompt_embeds=prompt_embeds,
+ negative_prompt_embeds=negative_prompt_embeds,
+ )
+
+ # 4. Preprocess image
+ # image = preprocess(image)
+ image = image.to(dtype=prompt_embeds.dtype, device=device)
+
+ # 5. Add noise to image
+ noise_level = torch.tensor([noise_level], dtype=torch.long, device=device)
+ noise = randn_tensor(image.shape, generator=generator, device=device, dtype=prompt_embeds.dtype)
+ image = self.low_res_scheduler.add_noise(image, noise, noise_level)
+ # image = image.clamp(-1, 1)
+
+ # debug
+ # image = rearrange(image, 'b c t h w -> (b t) c h w').contiguous().cpu()
+ # return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
+
+ batch_multiplier = 2 if do_classifier_free_guidance else 1
+ image = torch.cat([image] * batch_multiplier * num_images_per_prompt)
+ # TODO:
+ # noise_level = noise_level*0
+ noise_level = torch.cat([noise_level] * image.shape[0])
+
+ ####################### Random Noise ########################
+ # 5. set timesteps
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
+ timesteps = self.scheduler.timesteps
+
+ # 6. Prepare latent variables
+ seq_len, height, width = image.shape[2:]
+ # TODO: for downsample_2x
+ # height, width = height//2, width//2
+ num_channels_latents = self.vae.config.latent_channels
+ latents = self.prepare_latents_3d(
+ batch_size * num_images_per_prompt,
+ num_channels_latents,
+ seq_len,
+ height,
+ width,
+ prompt_embeds.dtype,
+ device,
+ generator,
+ latents,
+ ) # b c t h w
+ # print('latents', latents.shape)
+
+ ####################### Random Noise + Latent ########################
+ # # 5. Prepare timesteps
+ # self.scheduler.set_timesteps(num_inference_steps, device=device)
+ # timesteps, num_inference_steps = self.get_timesteps(num_inference_steps, strength=1, device=device)
+ # # DEBUG
+ # # timesteps = self.scheduler.timesteps
+ # latent_timestep = timesteps[:1].repeat(batch_size * num_images_per_prompt)
+
+ # # 6. Prepare latent variables
+ # # b c t h w
+ # b = image.shape[0]
+ # num_channels_latents = self.vae.config.latent_channels
+ # latents = self.prepare_latents_inversion(
+ # image[:b//2],
+ # latent_timestep,
+ # batch_size,
+ # num_images_per_prompt,
+ # prompt_embeds.dtype,
+ # device,
+ # generator,
+ # )
+ # print('latents', latents.shape)
+
+ # 7. Check that sizes of image and latents match
+ num_channels_image = image.shape[1]
+ if num_channels_latents + num_channels_image != self.unet.config.in_channels:
+ raise ValueError(
+ f"Incorrect configuration settings! The config of `pipeline.unet`: {self.unet.config} expects"
+ f" {self.unet.config.in_channels} but received `num_channels_latents`: {num_channels_latents} +"
+ f" `num_channels_image`: {num_channels_image} "
+ f" = {num_channels_latents+num_channels_image}. Please verify the config of"
+ " `pipeline.unet` or your `image` input."
+ )
+
+ # 8. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
+
+ # 9. Denoising loop
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
+ for i, t in enumerate(timesteps):
+ torch.cuda.empty_cache() # delete for VSR
+ # expand the latents if we are doing classifier free guidance
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
+
+ # concat latents, mask, masked_image_latents in the channel dimension
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
+ #latent_model_input = torch.cat([latent_model_input, image], dim=1)
+ # print(f'========== latent_model_input: {latent_model_input.shape} ============')
+ # print(f'========== image: {image.shape} ============')
+ noise_pred = self.unet(
+ latent_model_input, t, image, encoder_hidden_states=prompt_embeds, class_labels=noise_level
+ ).sample
+
+ # perform guidance
+ if do_classifier_free_guidance:
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
+
+ # compute the previous noisy sample x_t -> x_t-1
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
+
+ # call the callback, if provided
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
+ progress_bar.update()
+ if callback is not None and i % callback_steps == 0:
+ callback(i, t, latents)
+
+ del latent_model_input, noise_pred
+
+
+ # 10. Post-processing
+ # make sure the VAE is in float32 mode, as it overflows in float16
+ self.vae.to(dtype=torch.float32)
+
+ # TODO(Patrick, William) - clean up when attention is refactored
+ use_torch_2_0_attn = hasattr(F, "scaled_dot_product_attention")
+ use_xformers = self.vae.decoder.mid_block.attentions[0]._use_memory_efficient_attention_xformers
+ # if xformers or torch_2_0 is used attention block does not need
+ # to be in float32 which can save lots of memory
+ if not use_torch_2_0_attn and not use_xformers:
+ self.vae.post_quant_conv.to(latents.dtype)
+ self.vae.decoder.conv_in.to(latents.dtype)
+ self.vae.decoder.mid_block.to(latents.dtype)
+ else:
+ latents = latents.float()
+
+ # 11. Convert to frames
+ short_seq = 4
+ # b c t h w
+ latents = rearrange(latents, 'b c t h w -> (b t) c h w').contiguous()
+ if latents.shape[0] > short_seq: # for VSR
+ image = []
+ for start_f in range(0, latents.shape[0], short_seq):
+ torch.cuda.empty_cache() # delete for VSR
+ end_f = min(latents.shape[0], start_f + short_seq)
+ image_ = self.decode_latents_vsr(latents[start_f:end_f])
+ image.append(image_)
+ del image_
+ image = torch.cat(image, dim=0)
+ else:
+ image = self.decode_latents_vsr(latents)
+
+ # Offload last model to CPU
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
+ self.final_offload_hook.offload()
+
+ if not return_dict:
+ return (image, None)
+
+ return StableDiffusionPipelineOutput(images=image, nsfw_content_detected=None)
diff --git a/vsr/models/resnet.py b/vsr/models/resnet.py
new file mode 100644
index 0000000000000000000000000000000000000000..490fb323dc43c67acce260a2ff9a7e80f021f76c
--- /dev/null
+++ b/vsr/models/resnet.py
@@ -0,0 +1,316 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/resnet.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+
+from einops import rearrange
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+class Mish(torch.nn.Module):
+ def forward(self, hidden_states):
+ return hidden_states * torch.tanh(torch.nn.functional.softplus(hidden_states))
+
+class InflatedConv3d(nn.Conv2d):
+ def forward(self, x):
+ video_length = x.shape[2]
+
+ x = rearrange(x, "b c f h w -> (b f) c h w")
+ x = super().forward(x)
+ x = rearrange(x, "(b f) c h w -> b c f h w", f=video_length)
+
+ return x
+
+
+class Upsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, use_conv_transpose=False, out_channels=None, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.use_conv_transpose = use_conv_transpose
+ self.name = name
+
+ conv = None
+ if use_conv_transpose:
+ raise NotImplementedError
+ elif use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, padding=1)
+
+ if name == "conv":
+ self.conv = conv
+ else:
+ self.Conv2d_0 = conv
+
+ def forward(self, hidden_states, output_size=None):
+ assert hidden_states.shape[1] == self.channels
+
+ if self.use_conv_transpose:
+ raise NotImplementedError
+
+ # Cast to float32 to as 'upsample_nearest2d_out_frame' op does not support bfloat16
+ dtype = hidden_states.dtype
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(torch.float32)
+
+ # upsample_nearest_nhwc fails with large batch sizes. see https://github.com/huggingface/diffusers/issues/984
+ if hidden_states.shape[0] >= 64:
+ hidden_states = hidden_states.contiguous()
+
+ # if `output_size` is passed we force the interpolation output
+ # size and do not make use of `scale_factor=2`
+ if output_size is None:
+ hidden_states = F.interpolate(hidden_states, scale_factor=[1.0, 2.0, 2.0], mode="nearest")
+ else:
+ hidden_states = F.interpolate(hidden_states, size=output_size, mode="nearest")
+
+ # If the input is bfloat16, we cast back to bfloat16
+ if dtype == torch.bfloat16:
+ hidden_states = hidden_states.to(dtype)
+
+ if self.use_conv:
+ if self.name == "conv":
+ hidden_states = self.conv(hidden_states)
+ else:
+ hidden_states = self.Conv2d_0(hidden_states)
+
+ return hidden_states
+
+
+class Downsample3D(nn.Module):
+ def __init__(self, channels, use_conv=False, out_channels=None, padding=1, name="conv"):
+ super().__init__()
+ self.channels = channels
+ self.out_channels = out_channels or channels
+ self.use_conv = use_conv
+ self.padding = padding
+ stride = 2
+ self.name = name
+
+ if use_conv:
+ conv = InflatedConv3d(self.channels, self.out_channels, 3, stride=stride, padding=padding)
+ else:
+ raise NotImplementedError
+
+ if name == "conv":
+ self.Conv2d_0 = conv
+ self.conv = conv
+ elif name == "Conv2d_0":
+ self.conv = conv
+ else:
+ self.conv = conv
+
+ def forward(self, hidden_states):
+ assert hidden_states.shape[1] == self.channels
+ if self.use_conv and self.padding == 0:
+ raise NotImplementedError
+
+ assert hidden_states.shape[1] == self.channels
+ hidden_states = self.conv(hidden_states)
+
+ return hidden_states
+
+
+class ResnetBlock3D(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ self.conv1 = InflatedConv3d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ time_emb_proj_out_channels = out_channels
+ elif self.time_embedding_norm == "scale_shift":
+ time_emb_proj_out_channels = out_channels * 2
+ else:
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = InflatedConv3d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = InflatedConv3d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, input_tensor, temb):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
+
+
+class ResnetBlock3DCNN(nn.Module):
+ def __init__(
+ self,
+ *,
+ in_channels,
+ out_channels=None,
+ kernel=(3,1,1),
+ conv_shortcut=False,
+ dropout=0.0,
+ temb_channels=512,
+ groups=32,
+ groups_out=None,
+ pre_norm=True,
+ eps=1e-6,
+ non_linearity="swish",
+ time_embedding_norm="default",
+ output_scale_factor=1.0,
+ use_in_shortcut=None,
+ ):
+ super().__init__()
+ self.pre_norm = pre_norm
+ self.pre_norm = True
+ self.in_channels = in_channels
+ out_channels = in_channels if out_channels is None else out_channels
+ self.out_channels = out_channels
+ self.use_conv_shortcut = conv_shortcut
+ self.time_embedding_norm = time_embedding_norm
+ self.output_scale_factor = output_scale_factor
+
+ if groups_out is None:
+ groups_out = groups
+
+ self.norm1 = torch.nn.GroupNorm(num_groups=groups, num_channels=in_channels, eps=eps, affine=True)
+
+ padding = ((kernel[i]-1)//2 for i in range(3))
+ self.conv1 = nn.Conv3d(in_channels, out_channels, kernel_size=kernel, stride=(1,1,1), padding=padding)
+
+ if temb_channels is not None:
+ if self.time_embedding_norm == "default":
+ time_emb_proj_out_channels = out_channels
+ elif self.time_embedding_norm == "scale_shift":
+ time_emb_proj_out_channels = out_channels * 2
+ else:
+ raise ValueError(f"unknown time_embedding_norm : {self.time_embedding_norm} ")
+
+ self.time_emb_proj = torch.nn.Linear(temb_channels, time_emb_proj_out_channels)
+ else:
+ self.time_emb_proj = None
+
+ self.norm2 = torch.nn.GroupNorm(num_groups=groups_out, num_channels=out_channels, eps=eps, affine=True)
+ self.dropout = torch.nn.Dropout(dropout)
+ self.conv2 = nn.Conv3d(out_channels, out_channels, kernel_size=(3,1,1), stride=(1,1,1), padding=(1,0,0))
+
+ if non_linearity == "swish":
+ self.nonlinearity = lambda x: F.silu(x)
+ elif non_linearity == "mish":
+ self.nonlinearity = Mish()
+ elif non_linearity == "silu":
+ self.nonlinearity = nn.SiLU()
+
+ self.use_in_shortcut = self.in_channels != self.out_channels if use_in_shortcut is None else use_in_shortcut
+
+ self.conv_shortcut = None
+ if self.use_in_shortcut:
+ self.conv_shortcut = nn.Conv3d(in_channels, out_channels, kernel_size=(1,1,1), stride=(1,1,1), padding=(0,0,0))
+
+ def forward(self, input_tensor, temb=None):
+ hidden_states = input_tensor
+
+ hidden_states = self.norm1(hidden_states)
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.conv1(hidden_states)
+
+ if temb is not None:
+ temb = self.time_emb_proj(self.nonlinearity(temb))[:, :, None, None, None]
+
+ if temb is not None and self.time_embedding_norm == "default":
+ hidden_states = hidden_states + temb
+
+ hidden_states = self.norm2(hidden_states)
+
+ if temb is not None and self.time_embedding_norm == "scale_shift":
+ scale, shift = torch.chunk(temb, 2, dim=1)
+ hidden_states = hidden_states * (1 + scale) + shift
+
+ hidden_states = self.nonlinearity(hidden_states)
+
+ hidden_states = self.dropout(hidden_states)
+ hidden_states = self.conv2(hidden_states)
+
+ if self.conv_shortcut is not None:
+ input_tensor = self.conv_shortcut(input_tensor)
+
+ output_tensor = (input_tensor + hidden_states) / self.output_scale_factor
+
+ return output_tensor
\ No newline at end of file
diff --git a/vsr/models/temporal_module.py b/vsr/models/temporal_module.py
new file mode 100644
index 0000000000000000000000000000000000000000..6a8ffb7aa8da232507831bf6fb241340b0f8bc9d
--- /dev/null
+++ b/vsr/models/temporal_module.py
@@ -0,0 +1,684 @@
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import torch
+import numpy as np
+import torch.nn.functional as F
+from torch import nn
+import torchvision
+# from torch_utils.ops import grid_sample_gradfix
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.utils import BaseOutput
+from diffusers.utils.import_utils import is_xformers_available
+from diffusers.models.attention import FeedForward
+# from diffusers.models.attention_processor import Attention
+
+try:
+ from .diffusers_attention import CrossAttention
+ from .resnet import Downsample3D, Upsample3D, InflatedConv3d, ResnetBlock3D, ResnetBlock3DCNN
+
+except:
+ from diffusers_attention import CrossAttention
+ from resnet import Downsample3D, Upsample3D, InflatedConv3d, ResnetBlock3D, ResnetBlock3DCNN
+
+from einops import rearrange, repeat
+import math
+
+import pdb
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def grid_sample_align(input, grid):
+ return torch.nn.functional.grid_sample(input=input, grid=grid, mode='bilinear', padding_mode='zeros', align_corners=True)
+
+
+@dataclass
+class TemporalTransformer3DModelOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+if is_xformers_available():
+ import xformers
+ import xformers.ops
+else:
+ xformers = None
+
+
+class EmptyTemporalModule3D(nn.Module):
+ def __init__(self):
+ super().__init__()
+
+ def forward(self, hidden_states, condition_video=None, encoder_hidden_states=None, timesteps=None, temb=None, attention_mask=None):
+ return hidden_states
+
+
+class TemporalModule3D(nn.Module):
+ def __init__(
+ self,
+ in_channels=None,
+ out_channels=None,
+
+ num_attention_layers=None,
+ num_attention_head=8,
+ attention_head_dim=None,
+ cross_attention_dim=768,
+ temb_channels=512,
+
+ dropout=0.,
+ attention_bias=False,
+ activation_fn="geglu",
+ only_cross_attention=False,
+ upcast_attention=False,
+
+ norm_num_groups=8,
+ use_linear_projection=True,
+ use_scale_shift=False, # set True always produce nan loss, I don't know why
+
+ attention_block_types: Tuple[str]=None,
+ cross_frame_attention_mode=None,
+ temporal_shift_fold_div=None,
+ temporal_shift_direction=None,
+
+ use_dcn_warpping=None,
+ use_deformable_conv=None,
+
+ attention_dim_div: int = None,
+ video_condition=False,
+ ):
+ super().__init__()
+ assert len(attention_block_types) == 2
+
+ self.use_scale_shift = use_scale_shift
+ self.video_condition = video_condition
+
+ self.non_linearity = nn.SiLU()
+
+ # 1. 3d cnn
+ if self.video_condition:
+ video_condition_dim = int(out_channels//4)
+ self.v_cond_conv = ResnetBlock3D(in_channels=3, out_channels=video_condition_dim, temb_channels=temb_channels, groups=3, groups_out=32)
+ self.resblocks_3d_t = ResnetBlock3DCNN(in_channels=in_channels+video_condition_dim, out_channels=in_channels, kernel=(5,1,1), temb_channels=temb_channels)
+ else:
+ self.resblocks_3d_t = ResnetBlock3DCNN(in_channels=in_channels, out_channels=in_channels, kernel=(5,1,1), temb_channels=temb_channels)
+
+ self.resblocks_3d_s = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, temb_channels=temb_channels, groups=32, groups_out=32)
+
+ # 2. transformer blocks
+ if not (attention_block_types[0]=='' and attention_block_types[1]==''):
+ attentions = TemporalTransformer3DModel(
+ num_attention_heads=num_attention_head,
+ attention_head_dim=attention_head_dim if attention_head_dim is not None else in_channels // num_attention_head // attention_dim_div,
+
+ in_channels=in_channels,
+ num_layers=num_attention_layers,
+ dropout=dropout,
+ norm_num_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attention_bias=attention_bias,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=1000, # adaptive norm for timestep embedding injection
+ use_linear_projection=use_linear_projection,
+
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+
+ attention_block_types=attention_block_types,
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_shift_fold_div=temporal_shift_fold_div,
+ temporal_shift_direction=temporal_shift_direction,
+
+ use_dcn_warpping=use_dcn_warpping,
+ use_deformable_conv=use_deformable_conv,
+ )
+ self.attentions = nn.ModuleList([attentions])
+
+ if use_scale_shift:
+ self.scale_shift_conv = zero_module(InflatedConv3d(in_channels=in_channels, out_channels=in_channels * 2, kernel_size=1, stride=1, padding=0))
+ else:
+ self.shift_conv = zero_module(InflatedConv3d(in_channels=in_channels, out_channels=in_channels, kernel_size=1, stride=1, padding=0))
+
+
+ def forward(self, hidden_states, condition_video=None, encoder_hidden_states=None, timesteps=None, temb=None, attention_mask=None):
+ input_tensor = hidden_states
+
+ if self.video_condition:
+ # obtain video attention
+ assert condition_video is not None
+ if isinstance(condition_video, dict):
+ condition_video = condition_video[hidden_states.shape[-1]]
+ hidden_condition = self.v_cond_conv(condition_video, temb)
+ hidden_states = torch.cat([hidden_states, hidden_condition], dim=1)
+
+ # 3DCNN
+ hidden_states = self.resblocks_3d_t(hidden_states, temb)
+ hidden_states = self.resblocks_3d_s(hidden_states, temb)
+
+ if hasattr(self, "attentions"):
+ for attn in self.attentions:
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states, timestep=timesteps).sample
+
+ if self.use_scale_shift:
+ hidden_states = self.scale_shift_conv(hidden_states)
+ scale, shift = torch.chunk(hidden_states, chunks=2, dim=1)
+ hidden_states = (1 + scale) * input_tensor + shift
+ else:
+ hidden_states = self.shift_conv(hidden_states)
+ hidden_states = input_tensor + hidden_states
+
+ return hidden_states
+
+
+class TemporalTransformer3DModel(ModelMixin, ConfigMixin):
+ @register_to_config
+ def __init__(
+ self,
+ num_attention_heads=None,
+ attention_head_dim=None,
+ in_channels=None,
+ num_layers=None,
+ dropout=None,
+ norm_num_groups=None,
+ cross_attention_dim=None,
+ attention_bias=None,
+ activation_fn=None,
+ num_embeds_ada_norm=None,
+ use_linear_projection=None,
+ only_cross_attention=None,
+ upcast_attention=None,
+
+ attention_block_types=None,
+ cross_frame_attention_mode=None,
+ temporal_shift_fold_div=None,
+ temporal_shift_direction=None,
+
+ use_dcn_warpping=None,
+ use_deformable_conv=None,
+ ):
+ super().__init__()
+ self.use_linear_projection = use_linear_projection
+ self.num_attention_heads = num_attention_heads
+ self.attention_head_dim = attention_head_dim
+ inner_dim = num_attention_heads * attention_head_dim
+
+ # Define input layers
+ self.in_channels = in_channels
+
+ self.norm = torch.nn.GroupNorm(num_groups=norm_num_groups, num_channels=in_channels, eps=1e-6, affine=True)
+ if use_linear_projection:
+ self.proj_in = nn.Linear(in_channels, inner_dim)
+ else:
+ self.proj_in = nn.Conv2d(in_channels, inner_dim, kernel_size=1, stride=1, padding=0)
+
+ # Define transformers blocks
+ self.transformer_blocks = nn.ModuleList(
+ [
+ TemporalTransformerBlock(
+ inner_dim,
+ num_attention_heads,
+ attention_head_dim,
+ dropout=dropout,
+ cross_attention_dim=cross_attention_dim,
+ activation_fn=activation_fn,
+ num_embeds_ada_norm=num_embeds_ada_norm,
+ attention_bias=attention_bias,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+
+ attention_block_types=attention_block_types,
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_shift_fold_div=temporal_shift_fold_div,
+ temporal_shift_direction=temporal_shift_direction,
+
+ use_dcn_warpping=use_dcn_warpping,
+ use_deformable_conv=use_deformable_conv,
+ )
+ for d in range(num_layers)
+ ]
+ )
+
+ # 4. Define output layers
+ if use_linear_projection:
+ self.proj_out = nn.Linear(inner_dim, in_channels)
+ else:
+ self.proj_out = nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, return_dict: bool = True):
+ # Input
+ assert hidden_states.dim() == 5, f"Expected hidden_states to have ndim=5, but got ndim={hidden_states.dim()}."
+ video_length = hidden_states.shape[2]
+ hidden_states = rearrange(hidden_states, "b c f h w -> (b f) c h w")
+ if encoder_hidden_states is not None:
+ encoder_hidden_states = repeat(encoder_hidden_states, 'b n c -> (b f) n c', f=video_length)
+
+ batch, channel, height, weight = hidden_states.shape
+ residual = hidden_states
+
+ hidden_states = self.norm(hidden_states)
+ if not self.use_linear_projection:
+ hidden_states = self.proj_in(hidden_states)
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ else:
+ inner_dim = hidden_states.shape[1]
+ hidden_states = hidden_states.permute(0, 2, 3, 1).reshape(batch, height * weight, inner_dim)
+ hidden_states = self.proj_in(hidden_states)
+
+ # Blocks
+ for block in self.transformer_blocks:
+ hidden_states = block(
+ hidden_states,
+ encoder_hidden_states=encoder_hidden_states,
+ timestep=timestep,
+ video_length=video_length
+ )
+
+ # Output
+ if not self.use_linear_projection:
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+ hidden_states = self.proj_out(hidden_states)
+ else:
+ hidden_states = self.proj_out(hidden_states)
+ hidden_states = (
+ hidden_states.reshape(batch, height, weight, inner_dim).permute(0, 3, 1, 2).contiguous()
+ )
+
+ output = hidden_states + residual
+
+ output = rearrange(output, "(b f) c h w -> b c f h w", f=video_length)
+ if not return_dict:
+ return (output,)
+
+ return TemporalTransformer3DModelOutput(sample=output)
+
+
+class TemporalTransformerBlock(nn.Module):
+ def __init__(
+ self,
+ dim=None,
+ num_attention_heads=None,
+ attention_head_dim=None,
+ dropout=None,
+ cross_attention_dim=None,
+ activation_fn=None,
+ num_embeds_ada_norm=None,
+ attention_bias=None,
+ only_cross_attention=None,
+ upcast_attention=None,
+
+ attention_block_types=None,
+ cross_frame_attention_mode=None,
+ temporal_shift_fold_div=None,
+ temporal_shift_direction=None,
+
+ use_dcn_warpping=None,
+ use_deformable_conv=None,
+ ):
+ super().__init__()
+ assert len(attention_block_types) == 2
+
+ self.only_cross_attention = only_cross_attention
+ self.use_ada_layer_norm = num_embeds_ada_norm is not None
+ self.use_dcn_warpping = use_dcn_warpping
+
+ # 1. Spatial-Attn (self)
+ if not attention_block_types[0] == '':
+ self.attn_spatial = VersatileSelfAttention(
+ attention_mode=attention_block_types[0],
+
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_shift_fold_div=temporal_shift_fold_div,
+ temporal_shift_direction=temporal_shift_direction,
+ )
+ nn.init.zeros_(self.attn_spatial.to_out[0].weight.data)
+ self.norm1 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ # 2. Temporal-Attn (self)
+ self.attn_temporal = VersatileSelfAttention(
+ attention_mode=attention_block_types[1],
+
+ query_dim=dim,
+ heads=num_attention_heads,
+ dim_head=attention_head_dim,
+ dropout=dropout,
+ bias=attention_bias,
+ upcast_attention=upcast_attention,
+
+ cross_frame_attention_mode=cross_frame_attention_mode,
+ temporal_shift_fold_div=temporal_shift_fold_div,
+ temporal_shift_direction=temporal_shift_direction,
+ )
+ nn.init.zeros_(self.attn_temporal.to_out[0].weight.data)
+ self.norm2 = AdaLayerNorm(dim, num_embeds_ada_norm) if self.use_ada_layer_norm else nn.LayerNorm(dim)
+
+ self.dcn_module = WarpModule(
+ in_channels=dim,
+ use_deformable_conv=use_deformable_conv,
+ ) if use_dcn_warpping else None
+
+ # 3. Feed-forward
+ self.ff = FeedForward(dim, dropout=dropout, activation_fn=activation_fn)
+ self.norm3 = nn.LayerNorm(dim)
+
+
+ def set_use_memory_efficient_attention_xformers(self, use_memory_efficient_attention_xformers: bool, attention_op: None):
+ if not is_xformers_available():
+ print("Here is how to install it")
+ raise ModuleNotFoundError(
+ "Refer to https://github.com/facebookresearch/xformers for more information on how to install"
+ " xformers",
+ name="xformers",
+ )
+ elif not torch.cuda.is_available():
+ raise ValueError(
+ "torch.cuda.is_available() should be True but is False. xformers' memory efficient attention is only"
+ " available for GPU "
+ )
+ else:
+ try:
+ # Make sure we can run the memory efficient attention
+ _ = xformers.ops.memory_efficient_attention(
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ torch.randn((1, 2, 40), device="cuda"),
+ )
+ except Exception as e:
+ raise e
+ if hasattr(self, "attn_spatial"):
+ self.attn_spatial._use_memory_efficient_attention_xformers = use_memory_efficient_attention_xformers
+
+ def forward(self, hidden_states, encoder_hidden_states=None, timestep=None, attention_mask=None, video_length=None):
+ # 1. Spatial-Attention
+ if hasattr(self, "attn_spatial") and hasattr(self, "norm1"):
+ norm_hidden_states = self.norm1(hidden_states, timestep) if self.use_ada_layer_norm else self.norm1(hidden_states)
+ hidden_states = self.attn_spatial(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
+
+ # 2. Temporal-Attention
+ norm_hidden_states = self.norm2(hidden_states, timestep) if self.use_ada_layer_norm else self.norm2(hidden_states)
+ if not self.use_dcn_warpping:
+ hidden_states = self.attn_temporal(norm_hidden_states, attention_mask=attention_mask, video_length=video_length) + hidden_states
+ else:
+ hidden_states = self.dcn_module(
+ hidden_states,
+ offset_hidden_states=self.attn_temporal(norm_hidden_states, attention_mask=attention_mask, video_length=video_length),
+ )
+
+ # 3. Feed-forward
+ hidden_states = self.ff(self.norm3(hidden_states)) + hidden_states
+
+ return hidden_states
+
+
+class VersatileSelfAttention(CrossAttention):
+ def __init__(
+ self,
+ attention_mode=None,
+ cross_frame_attention_mode=None,
+ temporal_shift_fold_div=None,
+ temporal_shift_direction=None,
+ temporal_position_encoding=False,
+ temporal_position_encoding_max_len=24,
+ *args, **kwargs
+ ):
+ super().__init__(*args, **kwargs)
+ assert attention_mode in ("Temporal", "Spatial", "CrossFrame", "SpatialTemporalShift", None)
+ assert cross_frame_attention_mode in ("0_i-1", "i-1_i", "0_i-1_i", "i-1_i_i+1", None)
+
+ self.attention_mode = attention_mode
+
+ self.cross_frame_attention_mode = cross_frame_attention_mode
+
+ self.temporal_shift_fold_div = temporal_shift_fold_div
+ self.temporal_shift_direction = temporal_shift_direction
+
+ self.pos_encoder = PositionalEncoding(
+ kwargs["query_dim"],
+ dropout=0.,
+ max_len=temporal_position_encoding_max_len
+ ) if temporal_position_encoding else None
+
+ def temporal_token_concat(self, tensor, video_length):
+ # print("### temporal token concat")
+ current_frame_index = torch.arange(video_length)
+ former_frame_index = current_frame_index - 1
+ former_frame_index[0] = 0
+
+ later_frame_index = current_frame_index + 1
+ later_frame_index[-1] = -1
+
+ # (b f) d c
+ tensor = rearrange(tensor, "(b f) d c -> b f d c", f=video_length)
+
+ if self.cross_frame_attention_mode == "0_i-1":
+ tensor = torch.cat([tensor[:, [0] * video_length], tensor[:, former_frame_index]], dim=2)
+ elif self.cross_frame_attention_mode == "i-1_i":
+ tensor = torch.cat([tensor[:, former_frame_index], tensor[:, current_frame_index]], dim=2)
+ elif self.cross_frame_attention_mode == "0_i-1_i":
+ tensor = torch.cat([tensor[:, [0] * video_length], tensor[:, former_frame_index], tensor[:, current_frame_index]], dim=2)
+ elif self.cross_frame_attention_mode == "i-1_i_i+1":
+ tensor = torch.cat([tensor[:, former_frame_index], tensor[:, current_frame_index], tensor[:, later_frame_index]], dim=2)
+ else:
+ raise NotImplementedError
+
+ tensor = rearrange(tensor, "b f d c -> (b f) d c")
+ return tensor
+
+ def temporal_shift(self, tensor, video_length):
+ # print("### temporal shift")
+ # (b f) d c
+ tensor = rearrange(tensor, "(b f) d c -> b f d c", f=video_length)
+ n_channels = tensor.shape[-1]
+ fold = n_channels // self.temporal_shift_fold_div
+
+ if self.temporal_shift_direction != "right":
+ raise NotImplementedError
+
+ tensor_out = torch.zeros_like(tensor)
+ tensor_out[:, 1:, :, :fold] = tensor[:, :-1, :, :fold]
+ tensor_out[:, :, :, fold:] = tensor[:, :, :, fold:]
+
+ tensor_out = rearrange(tensor_out, "b f d c -> (b f) d c")
+ return tensor_out
+
+ def forward(self, hidden_states, encoder_hidden_states=None, attention_mask=None, video_length=None):
+ # pdb.set_trace()
+ batch_size, sequence_length, _ = hidden_states.shape
+ assert encoder_hidden_states is None
+
+ # (b f) d c
+ if self.attention_mode == "Temporal":
+ # print("### temporal reshape")
+ d = hidden_states.shape[1]
+ hidden_states = rearrange(hidden_states, "(b f) d c -> (b d) f c", f=video_length)
+
+ if self.pos_encoder is not None:
+ hidden_states = self.pos_encoder(hidden_states)
+
+ encoder_hidden_states = encoder_hidden_states
+
+ if self.group_norm is not None:
+ hidden_states = self.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
+
+ query = self.to_q(hidden_states)
+ dim = query.shape[-1]
+ query = self.reshape_heads_to_batch_dim(query)
+
+ if self.added_kv_proj_dim is not None:
+ raise NotImplementedError
+
+ encoder_hidden_states = encoder_hidden_states if encoder_hidden_states is not None else hidden_states
+ key = self.to_k(encoder_hidden_states)
+ value = self.to_v(encoder_hidden_states)
+
+ if self.attention_mode == "SpatialTemporalShift":
+ key = self.temporal_shift(key, video_length=video_length)
+ value = self.temporal_shift(value, video_length=video_length)
+ elif self.attention_mode == "CrossFrame":
+ key = self.temporal_token_concat(key, video_length=video_length)
+ value = self.temporal_token_concat(value, video_length=video_length)
+
+ key = self.reshape_heads_to_batch_dim(key)
+ value = self.reshape_heads_to_batch_dim(value)
+
+ if attention_mask is not None:
+ if attention_mask.shape[-1] != query.shape[1]:
+ target_length = query.shape[1]
+ attention_mask = F.pad(attention_mask, (0, target_length), value=0.0)
+ attention_mask = attention_mask.repeat_interleave(self.heads, dim=0)
+
+ # attention, what we cannot get enough of
+ if self._use_memory_efficient_attention_xformers:
+ hidden_states = self._memory_efficient_attention_xformers(query, key, value, attention_mask)
+ # Some versions of xformers return output in fp32, cast it back to the dtype of the input
+ hidden_states = hidden_states.to(query.dtype)
+ else:
+ if self._slice_size is None or query.shape[0] // self._slice_size == 1:
+ hidden_states = self._attention(query, key, value, attention_mask)
+ else:
+ hidden_states = self._sliced_attention(query, key, value, sequence_length, dim, attention_mask)
+
+ # linear proj
+ hidden_states = self.to_out[0](hidden_states)
+
+ # dropout
+ hidden_states = self.to_out[1](hidden_states)
+
+ if self.attention_mode == "Temporal":
+ hidden_states = rearrange(hidden_states, "(b d) f c -> (b f) d c", d=d)
+
+ return hidden_states
+
+
+class WarpModule(nn.Module):
+ def __init__(
+ self,
+ in_channels=None,
+ use_deformable_conv=None,
+ ):
+ super().__init__()
+ self.use_deformable_conv = use_deformable_conv
+
+ self.conv = None
+ self.dcn_weight = None
+ if use_deformable_conv:
+ self.conv = nn.Conv2d(in_channels*2, 27, kernel_size=3, stride=1, padding=1)
+ self.dcn_weight = nn.Parameter(torch.randn(in_channels, in_channels, 3, 3) / np.sqrt(in_channels * 3 * 3))
+ self.alpha = nn.Parameter(torch.zeros(1, in_channels, 1, 1))
+ else:
+ self.conv = zero_module(nn.Conv2d(in_channels, 2, kernel_size=3, stride=1, padding=1))
+
+ def forward(self, hidden_states, offset_hidden_states):
+ # (b f) d c
+ spatial_dim = hidden_states.shape[1]
+ size = int(spatial_dim ** 0.5)
+ assert size ** 2 == spatial_dim
+
+ hidden_states = rearrange(hidden_states, "b (h w) c -> b c h w", h=size)
+ offset_hidden_states = rearrange(offset_hidden_states, "b (h w) c -> b c h w", h=size)
+
+ concat_hidden_states = torch.cat([hidden_states, offset_hidden_states], dim=1)
+
+ input_tensor = hidden_states
+ if self.use_deformable_conv:
+ offset_x, offset_y, offsets_mask = torch.chunk(self.conv(concat_hidden_states), chunks=3, dim=1)
+ offsets_mask = offsets_mask.sigmoid() * 2
+ offsets = torch.cat([offset_x, offset_y], dim=1)
+ hidden_states = torchvision.ops.deform_conv2d(
+ hidden_states,
+ offset=offsets,
+ weight=self.dcn_weight,
+ mask=offsets_mask,
+ padding=1,
+ )
+ hidden_states = self.alpha * hidden_states + input_tensor
+
+ else:
+ offsets = self.conv(concat_hidden_states)
+ hidden_states = self.optical_flow_warping(hidden_states, offsets)
+
+ hidden_states = rearrange(hidden_states, "b c h w -> b (h w) c")
+ return hidden_states
+
+ @staticmethod
+ def optical_flow_warping(x, flo):
+ """
+ warp an image/tensor (im2) back to im1, according to the optical flow
+
+ x: [B, C, H, W] (im2)
+ flo: [B, 2, H, W] flow
+ pad_mode (optional): ref to https://pytorch.org/docs/stable/nn.functional.html#grid-sample
+ "zeros": use 0 for out-of-bound grid locations,
+ "border": use border values for out-of-bound grid locations
+ """
+ dtype = x.dtype
+ if dtype != torch.float32:
+ x = x.to(torch.float32)
+ B, C, H, W = x.size()
+ # mesh grid
+ xx = torch.arange(0, W).view(1, -1).repeat(H, 1)
+ yy = torch.arange(0, H).view(-1, 1).repeat(1, W)
+ xx = xx.view(1, 1, H, W).repeat(B, 1, 1, 1)
+ yy = yy.view(1, 1, H, W).repeat(B, 1, 1, 1)
+ grid = torch.cat((xx, yy), 1).float().to(flo.device)
+
+ vgrid = grid + flo
+
+ # scale grid to [-1,1]
+ vgrid[:, 0, :, :] = 2.0 * vgrid[:, 0, :, :].clone() / max(W - 1, 1) - 1.0
+ vgrid[:, 1, :, :] = 2.0 * vgrid[:, 1, :, :].clone() / max(H - 1, 1) - 1.0
+
+ vgrid = vgrid.permute(0, 2, 3, 1)
+ # output = grid_sample_gradfix.grid_sample_align(x, vgrid)
+ output = grid_sample_align(x, vgrid)
+ #output = torch.nn.functional.grid_sample(x, vgrid, padding_mode='zeros', mode='bilinear', align_corners=True)
+
+ mask = torch.ones_like(x)
+ # mask = grid_sample_gradfix.grid_sample_align(mask, vgrid)
+ mask = grid_sample_align(x, vgrid)
+ #mask = torch.nn.functional.grid_sample(mask, vgrid, padding_mode='zeros', mode='bilinear', align_corners=True)
+
+ mask[mask < 0.9999] = 0
+ mask[mask > 0] = 1
+ results = output * mask
+ if dtype != torch.float32:
+ results = results.to(dtype)
+ return results
+
+
+class AdaLayerNorm(nn.Module):
+ """
+ Norm layer modified to incorporate timestep embeddings.
+ """
+ def __init__(self, embedding_dim, num_embeddings):
+ super().__init__()
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
+ self.silu = nn.SiLU()
+ self.linear = nn.Linear(embedding_dim, embedding_dim * 2)
+ self.norm = nn.LayerNorm(embedding_dim, elementwise_affine=False)
+
+ def forward(self, x, timestep):
+ timestep = repeat(timestep, "b -> (b r)", r=x.shape[0] // timestep.shape[0])
+
+ emb = self.linear(self.silu(self.emb(timestep))).unsqueeze(1) # (b f) 1 2d
+ scale, shift = torch.chunk(emb, 2, dim=-1)
+ x = self.norm(x) * (1 + scale) + shift
+ return x
+
diff --git a/vsr/models/unet.py b/vsr/models/unet.py
new file mode 100644
index 0000000000000000000000000000000000000000..f30bc66acd049728a20fc97c2d56c6edea09df02
--- /dev/null
+++ b/vsr/models/unet.py
@@ -0,0 +1,654 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_condition.py
+
+from dataclasses import dataclass
+from typing import List, Optional, Tuple, Union
+
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+import json
+import math
+import torch
+import torch.nn as nn
+import torch.utils.checkpoint
+from torch.nn import functional as F
+import einops
+
+from diffusers.configuration_utils import ConfigMixin, register_to_config
+from diffusers.models.modeling_utils import ModelMixin
+from diffusers.models.embeddings import TimestepEmbedding, Timesteps
+from diffusers.utils import BaseOutput, logging
+
+try:
+ from .unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+ )
+ from .resnet import InflatedConv3d
+ from .temporal_module import TemporalModule3D, EmptyTemporalModule3D
+except:
+ from unet_blocks import (
+ CrossAttnDownBlock3D,
+ CrossAttnUpBlock3D,
+ DownBlock3D,
+ UNetMidBlock3DCrossAttn,
+ UpBlock3D,
+ get_down_block,
+ get_up_block,
+ )
+ from resnet import InflatedConv3d
+ from temporal_module import TemporalModule3D, EmptyTemporalModule3D
+
+from rotary_embedding_torch import RotaryEmbedding
+
+logger = logging.get_logger(__name__) # pylint: disable=invalid-name
+
+def zero_module(module):
+ for p in module.parameters():
+ nn.init.zeros_(p)
+ return module
+
+class RelativePositionBias(nn.Module):
+ def __init__(
+ self,
+ heads=8,
+ num_buckets=32,
+ max_distance=128,
+ ):
+ super().__init__()
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(relative_position, num_buckets=32, max_distance=128):
+ ret = 0
+ n = -relative_position
+
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = max_exact + (
+ torch.log(n.float() / max_exact) / math.log(max_distance / max_exact) * (num_buckets - max_exact)
+ ).long()
+ val_if_large = torch.min(val_if_large, torch.full_like(val_if_large, num_buckets - 1))
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, n, device):
+ q_pos = torch.arange(n, dtype = torch.long, device = device)
+ k_pos = torch.arange(n, dtype = torch.long, device = device)
+ rel_pos = einops.rearrange(k_pos, 'j -> 1 j') - einops.rearrange(q_pos, 'i -> i 1')
+ rp_bucket = self._relative_position_bucket(rel_pos, num_buckets = self.num_buckets, max_distance = self.max_distance)
+ values = self.relative_attention_bias(rp_bucket)
+ return einops.rearrange(values, 'i j h -> h i j') # num_heads, num_frames, num_frames
+
+@dataclass
+class UNet3DConditionOutput(BaseOutput):
+ sample: torch.FloatTensor
+
+
+class UNet3DVSRModel(ModelMixin, ConfigMixin):
+ _supports_gradient_checkpointing = True
+
+ @register_to_config
+ def __init__(
+ self,
+ ### Temporal Module Additional Kwargs ###
+ down_temporal_idx = (0,1,2),
+ mid_temporal = False,
+ up_temporal_idx = (0,1,2),
+ video_condition = True,
+ temporal_module_config = None,
+
+ sample_size: Optional[int] = None, # 80
+ in_channels: int = 7,
+ out_channels: int = 4,
+ center_input_sample: bool = False,
+ max_noise_level: int = 350,
+ flip_sin_to_cos: bool = True,
+ freq_shift: int = 0,
+ attention_head_dim: Union[int, Tuple[int]] = 8,
+ block_out_channels: Tuple[int] = (
+ 256,
+ 512,
+ 512,
+ 1024
+ ),
+ down_block_types: Tuple[str] = (
+ "DownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D",
+ "CrossAttnDownBlock3D"
+ ),
+ mid_block_type: str = "UNetMidBlock3DCrossAttn",
+ up_block_types: Tuple[str] = (
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "CrossAttnUpBlock3D",
+ "UpBlock3D"
+ ),
+ only_cross_attention: Union[bool, Tuple[bool]] = (
+ True,
+ True,
+ True,
+ False
+ ),
+ layers_per_block: int = 2,
+ downsample_padding: int = 1,
+ mid_block_scale_factor: float = 1,
+ act_fn: str = "silu",
+ norm_num_groups: int = 32,
+ norm_eps: float = 1e-5,
+ cross_attention_dim: int = 1024,
+ dual_cross_attention: bool = False,
+ use_linear_projection: bool = True,
+ class_embed_type: Optional[str] = None,
+ num_class_embeds: Optional[int] = 1000,
+ upcast_attention: bool = False,
+ resnet_time_scale_shift: str = "default",
+ use_first_frame: bool = False,
+ use_relative_position: bool = False,
+ ):
+ super().__init__()
+
+ self.sample_size = sample_size
+ time_embed_dim = block_out_channels[0] * 4
+
+ # input
+ self.conv_in = InflatedConv3d(in_channels, block_out_channels[0], kernel_size=3, padding=1)
+
+ # time
+ self.time_proj = Timesteps(block_out_channels[0], flip_sin_to_cos, freq_shift)
+ timestep_input_dim = block_out_channels[0]
+
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+
+ # class embedding
+ if class_embed_type is None and num_class_embeds is not None:
+ self.class_embedding = nn.Embedding(num_class_embeds, time_embed_dim) # VSR for noise level
+ elif class_embed_type == "timestep":
+ self.class_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
+ elif class_embed_type == "identity":
+ self.class_embedding = nn.Identity(time_embed_dim, time_embed_dim)
+ else:
+ self.class_embedding = None
+
+ self.down_blocks = nn.ModuleList([])
+ self.mid_block = None
+ self.up_blocks = nn.ModuleList([])
+ self.video_condition = video_condition
+
+ # Temporal Modules
+ self.down_temporal_blocks = nn.ModuleList([])
+ self.mid_temporal_block = None
+ self.up_temporal_blocks = nn.ModuleList([])
+
+ if isinstance(only_cross_attention, bool):
+ only_cross_attention = [only_cross_attention] * len(down_block_types)
+
+
+ if isinstance(attention_head_dim, int):
+ attention_head_dim = (attention_head_dim,) * len(down_block_types)
+
+
+ self.temporal_rotary_emb = RotaryEmbedding(32)
+
+ # down
+ output_channel = block_out_channels[0]
+ for i, down_block_type in enumerate(down_block_types):
+ input_channel = output_channel
+ output_channel = block_out_channels[i]
+ is_final_block = i == len(block_out_channels) - 1
+
+ down_block = get_down_block(
+ down_block_type,
+ num_layers=layers_per_block,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ add_downsample=not is_final_block,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[i],
+ downsample_padding=downsample_padding,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=self.temporal_rotary_emb,
+ )
+ self.down_blocks.append(down_block)
+
+ # Down Sample Temporal Modules
+ down_temporal_block = TemporalModule3D(
+ in_channels=output_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ video_condition=video_condition,
+ **temporal_module_config,
+ ) if i in down_temporal_idx else EmptyTemporalModule3D()
+ self.down_temporal_blocks.append(down_temporal_block)
+ # mid
+ if mid_block_type == "UNetMidBlock3DCrossAttn":
+ self.mid_block = UNetMidBlock3DCrossAttn(
+ in_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ output_scale_factor=mid_block_scale_factor,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attention_head_dim[-1],
+ resnet_groups=norm_num_groups,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=self.temporal_rotary_emb,
+ )
+ else:
+ raise ValueError(f"unknown mid_block_type : {mid_block_type}")
+
+ self.mid_temporal_block = TemporalModule3D(
+ in_channels=block_out_channels[-1],
+ out_channels=block_out_channels[-1],
+ temb_channels=time_embed_dim,
+ video_condition=video_condition,
+ **temporal_module_config,
+ ) if mid_temporal else EmptyTemporalModule3D()
+
+ # count how many layers upsample the videos
+ self.num_upsamplers = 0
+
+ # up
+ reversed_block_out_channels = list(reversed(block_out_channels))
+ reversed_attention_head_dim = list(reversed(attention_head_dim))
+ only_cross_attention = list(reversed(only_cross_attention))
+ output_channel = reversed_block_out_channels[0]
+ for i, up_block_type in enumerate(up_block_types):
+ is_final_block = i == len(block_out_channels) - 1
+
+ prev_output_channel = output_channel
+ output_channel = reversed_block_out_channels[i]
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
+
+ # add upsample block for all BUT final layer
+ if not is_final_block:
+ add_upsample = True
+ self.num_upsamplers += 1
+ else:
+ add_upsample = False
+
+ up_block = get_up_block(
+ up_block_type,
+ num_layers=layers_per_block + 1,
+ in_channels=input_channel,
+ out_channels=output_channel,
+ prev_output_channel=prev_output_channel,
+ temb_channels=time_embed_dim,
+ add_upsample=add_upsample,
+ resnet_eps=norm_eps,
+ resnet_act_fn=act_fn,
+ resnet_groups=norm_num_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=reversed_attention_head_dim[i],
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention[i],
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=self.temporal_rotary_emb,
+ )
+ self.up_blocks.append(up_block)
+ prev_output_channel = output_channel
+
+ up_temporal_block = TemporalModule3D(
+ in_channels=output_channel,
+ out_channels=output_channel,
+ temb_channels=time_embed_dim,
+ video_condition=video_condition,
+ **temporal_module_config,
+ ) if i in up_temporal_idx else EmptyTemporalModule3D()
+ self.up_temporal_blocks.append(up_temporal_block)
+
+ # out
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=norm_num_groups, eps=norm_eps)
+ self.conv_act = nn.SiLU()
+ self.conv_out = InflatedConv3d(block_out_channels[0], out_channels, kernel_size=3, padding=1)
+
+ def set_attention_slice(self, slice_size):
+ r"""
+ Enable sliced attention computation.
+
+ When this option is enabled, the attention module will split the input tensor in slices, to compute attention
+ in several steps. This is useful to save some memory in exchange for a small speed decrease.
+
+ Args:
+ slice_size (`str` or `int` or `list(int)`, *optional*, defaults to `"auto"`):
+ When `"auto"`, halves the input to the attention heads, so attention will be computed in two steps. If
+ `"max"`, maxium amount of memory will be saved by running only one slice at a time. If a number is
+ provided, uses as many slices as `attention_head_dim // slice_size`. In this case, `attention_head_dim`
+ must be a multiple of `slice_size`.
+ """
+ sliceable_head_dims = []
+
+ def fn_recursive_retrieve_slicable_dims(module: torch.nn.Module):
+ if hasattr(module, "set_attention_slice"):
+ sliceable_head_dims.append(module.sliceable_head_dim)
+
+ for child in module.children():
+ fn_recursive_retrieve_slicable_dims(child)
+
+ # retrieve number of attention layers
+ for module in self.children():
+ fn_recursive_retrieve_slicable_dims(module)
+
+ num_slicable_layers = len(sliceable_head_dims)
+
+ if slice_size == "auto":
+ # half the attention head size is usually a good trade-off between
+ # speed and memory
+ slice_size = [dim // 2 for dim in sliceable_head_dims]
+ elif slice_size == "max":
+ # make smallest slice possible
+ slice_size = num_slicable_layers * [1]
+
+ slice_size = num_slicable_layers * [slice_size] if not isinstance(slice_size, list) else slice_size
+
+ if len(slice_size) != len(sliceable_head_dims):
+ raise ValueError(
+ f"You have provided {len(slice_size)}, but {self.config} has {len(sliceable_head_dims)} different"
+ f" attention layers. Make sure to match `len(slice_size)` to be {len(sliceable_head_dims)}."
+ )
+
+ for i in range(len(slice_size)):
+ size = slice_size[i]
+ dim = sliceable_head_dims[i]
+ if size is not None and size > dim:
+ raise ValueError(f"size {size} has to be smaller or equal to {dim}.")
+
+ # Recursively walk through all the children.
+ # Any children which exposes the set_attention_slice method
+ # gets the message
+ def fn_recursive_set_attention_slice(module: torch.nn.Module, slice_size: List[int]):
+ if hasattr(module, "set_attention_slice"):
+ module.set_attention_slice(slice_size.pop())
+
+ for child in module.children():
+ fn_recursive_set_attention_slice(child, slice_size)
+
+ reversed_slice_size = list(reversed(slice_size))
+ for module in self.children():
+ fn_recursive_set_attention_slice(module, reversed_slice_size)
+
+ def _set_gradient_checkpointing(self, module, value=False):
+ if isinstance(module, (CrossAttnDownBlock3D, DownBlock3D, CrossAttnUpBlock3D, UpBlock3D)):
+ module.gradient_checkpointing = value
+
+ def forward(
+ self,
+ sample: torch.FloatTensor,
+ timestep: Union[torch.Tensor, float, int],
+ low_res: torch.FloatTensor,
+ # encoder_hidden_states: torch.Tensor,
+ encoder_hidden_states = None,
+ class_labels: Optional[torch.Tensor] = 20,
+ low_res_clean: Optional[torch.FloatTensor] = None,
+ attention_mask: Optional[torch.Tensor] = None,
+ return_dict: bool = True,
+ ): # -> Union[UNet3DConditionOutput, Tuple]:
+ r"""
+ Args:
+ sample (`torch.FloatTensor`): (batch, channel, seq_length, height, width) noisy inputs tensor
+ timestep (`torch.FloatTensor` or `float` or `int`): (batch) timesteps
+ encoder_hidden_states (`torch.FloatTensor`): (batch, sequence_length, feature_dim) encoder hidden states
+ class_labels: noise level
+ return_dict (`bool`, *optional*, defaults to `True`):
+ Whether or not to return a [`models.unet_2d_condition.UNet2DConditionOutput`] instead of a plain tuple.
+
+ Returns:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] or `tuple`:
+ [`~models.unet_2d_condition.UNet2DConditionOutput`] if `return_dict` is True, otherwise a `tuple`. When
+ returning a tuple, the first element is the sample tensor.
+ """
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
+ # on the fly if necessary.
+ default_overall_up_factor = 2**self.num_upsamplers
+
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
+ forward_upsample_size = False
+ upsample_size = None
+
+ if self.video_condition:
+ low_res_dict = {}
+ low_res_dict[low_res.shape[-1]] = low_res
+ for s in [1/2., 1/4., 1/8.]:
+ low_res_ds = F.interpolate(low_res, scale_factor=(1, s, s), mode='area')
+ low_res_dict[low_res_ds.shape[-1]] = low_res_ds
+ else:
+ low_res_dict = None
+
+ sample = torch.cat([sample, low_res], dim=1) # concat on C: 4+3=7
+
+ #print(f'==============={sample.shape}================')
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
+ logger.info("Forward upsample size to force interpolation output size.")
+ forward_upsample_size = True
+
+ # prepare attention_mask
+ if attention_mask is not None:
+ attention_mask = (1 - attention_mask.to(sample.dtype)) * -10000.0
+ attention_mask = attention_mask.unsqueeze(1)
+
+ # center input if necessary
+ if self.config.center_input_sample:
+ sample = 2 * sample - 1.0
+
+ # time
+ timesteps = timestep
+ if not torch.is_tensor(timesteps):
+ # This would be a good case for the `match` statement (Python 3.10+)
+ is_mps = sample.device.type == "mps"
+ if isinstance(timestep, float):
+ dtype = torch.float32 if is_mps else torch.float64
+ else:
+ dtype = torch.int32 if is_mps else torch.int64
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
+ elif len(timesteps.shape) == 0:
+ timesteps = timesteps[None].to(sample.device)
+
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
+ timesteps = timesteps.expand(sample.shape[0])
+
+ t_emb = self.time_proj(timesteps)
+
+ # timesteps does not contain any weights and will always return f32 tensors
+ # but time_embedding might actually be running in fp16. so we need to cast here.
+ # there might be better ways to encapsulate this.
+ t_emb = t_emb.to(dtype=self.dtype)
+ emb = self.time_embedding(t_emb)
+
+ if self.class_embedding is not None:
+ if class_labels is None:
+ raise ValueError("class_labels should be provided when num_class_embeds > 0")
+ # check noise level
+ if torch.any(class_labels > self.config.max_noise_level):
+ raise ValueError(f"`noise_level` has to be <= {self.config.max_noise_level} but is {class_labels}")
+
+ if self.config.class_embed_type == "timestep":
+ class_labels = self.time_proj(class_labels)
+
+ class_emb = self.class_embedding(class_labels).to(dtype=self.dtype)
+ emb = emb + class_emb
+
+
+ # pre-process
+ sample = self.conv_in(sample)
+
+ # down
+ down_block_res_samples = (sample,)
+ for downsample_block, down_temporal_block in zip(self.down_blocks, self.down_temporal_blocks):
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
+ sample, res_samples = downsample_block(
+ hidden_states=sample,
+ temb=emb,
+ encoder_hidden_states=encoder_hidden_states,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample, res_samples = downsample_block(hidden_states=sample, temb=emb)
+
+ down_block_res_samples += res_samples
+
+ # 1. temporal modeling during down sample
+ sample = down_temporal_block(
+ hidden_states=sample,
+ condition_video=low_res_dict,
+ encoder_hidden_states=encoder_hidden_states,
+ timesteps=timesteps,
+ temb=emb,
+ )
+ # mid
+ sample = self.mid_block(
+ sample, emb, encoder_hidden_states=encoder_hidden_states, attention_mask=attention_mask
+ )
+ # 2. temporal modeling at mid block
+ sample = self.mid_temporal_block(
+ hidden_states=sample,
+ condition_video=low_res_dict,
+ encoder_hidden_states=encoder_hidden_states,
+ timesteps=timesteps,
+ temb=emb,
+ )
+
+ # up
+ for i, (upsample_block, up_temporal_block) in enumerate(zip(self.up_blocks, self.up_temporal_blocks)):
+ is_final_block = i == len(self.up_blocks) - 1
+
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
+
+ # if we have not reached the final block and need to forward the
+ # upsample size, we do it here
+ if not is_final_block and forward_upsample_size:
+ upsample_size = down_block_res_samples[-1].shape[2:]
+
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
+ sample = upsample_block(
+ hidden_states=sample,
+ temb=emb,
+ res_hidden_states_tuple=res_samples,
+ encoder_hidden_states=encoder_hidden_states,
+ upsample_size=upsample_size,
+ attention_mask=attention_mask,
+ )
+ else:
+ sample = upsample_block(
+ hidden_states=sample, temb=emb, res_hidden_states_tuple=res_samples, upsample_size=upsample_size
+ )
+
+ # 3. temporal modeling during up sample
+ sample = up_temporal_block(
+ hidden_states=sample,
+ condition_video=low_res_dict,
+ encoder_hidden_states=encoder_hidden_states,
+ timesteps=timesteps,
+ temb=emb,
+ )
+
+ # post-process
+ sample = self.conv_norm_out(sample)
+ sample = self.conv_act(sample)
+ sample = self.conv_out(sample)
+ # print(sample.shape)
+
+ if not return_dict:
+ return (sample,)
+
+ return UNet3DConditionOutput(sample=sample)
+
+ def forward_with_cfg(self,
+ x,
+ t,
+ low_res,
+ encoder_hidden_states = None,
+ class_labels: Optional[torch.Tensor] = 20,
+ cfg_scale=4.0,
+ use_fp16=False):
+ """
+ Forward, but also batches the unconditional forward pass for classifier-free guidance.
+ """
+ # https://github.com/openai/glide-text2im/blob/main/notebooks/text2im.ipynb
+ half = x[: len(x) // 2]
+ combined = torch.cat([half, half], dim=0)
+ if use_fp16:
+ combined = combined.to(dtype=torch.float16)
+ model_out = self.forward(combined, t, low_res, encoder_hidden_states, class_labels).sample
+ # For exact reproducibility reasons, we apply classifier-free guidance on only
+ # three channels by default. The standard approach to cfg applies it to all channels.
+ # This can be done by uncommenting the following line and commenting-out the line following that.
+ eps, rest = model_out[:, :4], model_out[:, 4:]
+ # eps, rest = model_out[:, :3], model_out[:, 3:] # b c f h w
+ cond_eps, uncond_eps = torch.split(eps, len(eps) // 2, dim=0)
+ half_eps = uncond_eps + cfg_scale * (cond_eps - uncond_eps)
+ eps = torch.cat([half_eps, half_eps], dim=0)
+ return torch.cat([eps, rest], dim=1)
+
+
+ @classmethod
+ def from_pretrained_2d(cls, config_path, pretrained_model_path):
+ if not os.path.isfile(config_path):
+ raise RuntimeError(f"{config_path} does not exist")
+ with open(config_path, "r") as f:
+ config = json.load(f)
+ config["_class_name"] = cls.__name__
+ freeze_pretrained_2d_upsampler = config["freeze_pretrained_2d_upsampler"]
+
+ model = cls.from_config(config)
+ model_file = os.path.join(pretrained_model_path)
+ if not os.path.isfile(model_file):
+ raise RuntimeError(f"{model_file} does not exist")
+ state_dict = torch.load(model_file, map_location="cpu")
+ for k, v in model.state_dict().items():
+ if 'temporal' in k:
+ print(f'New layers: {k}')
+ state_dict.update({k: v})
+
+ model.load_state_dict(state_dict, strict=True)
+
+ if freeze_pretrained_2d_upsampler:
+ print("Freeze pretrained 2d upsampler!")
+ for k, v in model.named_parameters():
+ if not 'temporal' in k:
+ v.requires_grad = False
+ return model
+
+if __name__ == '__main__':
+ import torch
+
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ config_path = "./configs/unet_3d_config.json"
+ # pretrained_model_path = "./pretrained_models/unet_diffusion_pytorch_model.bin"
+ # unet = UNet3DVSRModel.from_pretrained_2d(config_path, pretrained_model_path).to(device)
diff --git a/vsr/models/unet_blocks.py b/vsr/models/unet_blocks.py
new file mode 100644
index 0000000000000000000000000000000000000000..066637ff78439e61864e5893a122b32f07e58f89
--- /dev/null
+++ b/vsr/models/unet_blocks.py
@@ -0,0 +1,629 @@
+# Adapted from https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/unet_2d_blocks.py
+import os
+import sys
+sys.path.append(os.path.split(sys.path[0])[0])
+
+import torch
+from torch import nn
+
+try:
+ from .attention import Transformer3DModel
+ from .resnet import Downsample3D, ResnetBlock3D, Upsample3D
+except:
+ from attention import Transformer3DModel
+ from resnet import Downsample3D, ResnetBlock3D, Upsample3D
+
+
+def get_down_block(
+ down_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ temb_channels,
+ add_downsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ downsample_padding=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=None,
+):
+ # print(down_block_type)
+ # print(use_first_frame)
+ down_block_type = down_block_type[7:] if down_block_type.startswith("UNetRes") else down_block_type
+ if down_block_type == "DownBlock3D":
+ return DownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif down_block_type == "CrossAttnDownBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnDownBlock3D")
+ return CrossAttnDownBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ add_downsample=add_downsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ downsample_padding=downsample_padding,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ raise ValueError(f"{down_block_type} does not exist.")
+
+
+def get_up_block(
+ up_block_type,
+ num_layers,
+ in_channels,
+ out_channels,
+ prev_output_channel,
+ temb_channels,
+ add_upsample,
+ resnet_eps,
+ resnet_act_fn,
+ attn_num_head_channels,
+ resnet_groups=None,
+ cross_attention_dim=None,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ resnet_time_scale_shift="default",
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=None,
+):
+ up_block_type = up_block_type[7:] if up_block_type.startswith("UNetRes") else up_block_type
+ if up_block_type == "UpBlock3D":
+ return UpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ )
+ elif up_block_type == "CrossAttnUpBlock3D":
+ if cross_attention_dim is None:
+ raise ValueError("cross_attention_dim must be specified for CrossAttnUpBlock3D")
+ return CrossAttnUpBlock3D(
+ num_layers=num_layers,
+ in_channels=in_channels,
+ out_channels=out_channels,
+ prev_output_channel=prev_output_channel,
+ temb_channels=temb_channels,
+ add_upsample=add_upsample,
+ resnet_eps=resnet_eps,
+ resnet_act_fn=resnet_act_fn,
+ resnet_groups=resnet_groups,
+ cross_attention_dim=cross_attention_dim,
+ attn_num_head_channels=attn_num_head_channels,
+ dual_cross_attention=dual_cross_attention,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ resnet_time_scale_shift=resnet_time_scale_shift,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ raise ValueError(f"{up_block_type} does not exist.")
+
+
+class UNetMidBlock3DCrossAttn(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ output_scale_factor=1.0,
+ cross_attention_dim=1280,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=None,
+ ):
+ super().__init__()
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+ resnet_groups = resnet_groups if resnet_groups is not None else min(in_channels // 4, 32)
+
+ # there is always at least one resnet
+ resnets = [
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ ]
+ attentions = []
+
+ for _ in range(num_layers):
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ in_channels // attn_num_head_channels,
+ in_channels=in_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ )
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=in_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
+ hidden_states = self.resnets[0](hidden_states, temb)
+ for attn, resnet in zip(self.attentions, self.resnets[1:]):
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+ hidden_states = resnet(hidden_states, temb)
+
+ return hidden_states
+
+
+class CrossAttnDownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ downsample_padding=1,
+ add_downsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=None,
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ # print(use_first_frame)
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ )
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None, encoder_hidden_states=None, attention_mask=None):
+ output_states = ()
+
+ for resnet, attn in zip(self.resnets, self.attentions):
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class DownBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_downsample=True,
+ downsample_padding=1,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ in_channels = in_channels if i == 0 else out_channels
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=in_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_downsample:
+ self.downsamplers = nn.ModuleList(
+ [
+ Downsample3D(
+ out_channels, use_conv=True, out_channels=out_channels, padding=downsample_padding, name="op"
+ )
+ ]
+ )
+ else:
+ self.downsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, temb=None):
+ output_states = ()
+
+ for resnet in self.resnets:
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ output_states += (hidden_states,)
+
+ if self.downsamplers is not None:
+ for downsampler in self.downsamplers:
+ hidden_states = downsampler(hidden_states)
+
+ output_states += (hidden_states,)
+
+ return hidden_states, output_states
+
+
+class CrossAttnUpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ out_channels: int,
+ prev_output_channel: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ attn_num_head_channels=1,
+ cross_attention_dim=1280,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ dual_cross_attention=False,
+ use_linear_projection=False,
+ only_cross_attention=False,
+ upcast_attention=False,
+ use_first_frame=False,
+ use_relative_position=False,
+ rotary_emb=None
+ ):
+ super().__init__()
+ resnets = []
+ attentions = []
+
+ self.has_cross_attention = True
+ self.attn_num_head_channels = attn_num_head_channels
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+ if dual_cross_attention:
+ raise NotImplementedError
+ attentions.append(
+ Transformer3DModel(
+ attn_num_head_channels,
+ out_channels // attn_num_head_channels,
+ in_channels=out_channels,
+ num_layers=1,
+ cross_attention_dim=cross_attention_dim,
+ norm_num_groups=resnet_groups,
+ use_linear_projection=use_linear_projection,
+ only_cross_attention=only_cross_attention,
+ upcast_attention=upcast_attention,
+ use_first_frame=use_first_frame,
+ use_relative_position=use_relative_position,
+ rotary_emb=rotary_emb,
+ )
+ )
+
+ self.attentions = nn.ModuleList(attentions)
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(
+ self,
+ hidden_states,
+ res_hidden_states_tuple,
+ temb=None,
+ encoder_hidden_states=None,
+ upsample_size=None,
+ attention_mask=None,
+ ):
+ for resnet, attn in zip(self.resnets, self.attentions):
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module, return_dict=None):
+ def custom_forward(*inputs):
+ if return_dict is not None:
+ return module(*inputs, return_dict=return_dict)
+ else:
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ hidden_states = torch.utils.checkpoint.checkpoint(
+ create_custom_forward(attn, return_dict=False),
+ hidden_states,
+ encoder_hidden_states,
+ )[0]
+ else:
+ hidden_states = resnet(hidden_states, temb)
+ hidden_states = attn(hidden_states, encoder_hidden_states=encoder_hidden_states).sample
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
+
+
+class UpBlock3D(nn.Module):
+ def __init__(
+ self,
+ in_channels: int,
+ prev_output_channel: int,
+ out_channels: int,
+ temb_channels: int,
+ dropout: float = 0.0,
+ num_layers: int = 1,
+ resnet_eps: float = 1e-6,
+ resnet_time_scale_shift: str = "default",
+ resnet_act_fn: str = "swish",
+ resnet_groups: int = 32,
+ resnet_pre_norm: bool = True,
+ output_scale_factor=1.0,
+ add_upsample=True,
+ ):
+ super().__init__()
+ resnets = []
+
+ for i in range(num_layers):
+ res_skip_channels = in_channels if (i == num_layers - 1) else out_channels
+ resnet_in_channels = prev_output_channel if i == 0 else out_channels
+
+ resnets.append(
+ ResnetBlock3D(
+ in_channels=resnet_in_channels + res_skip_channels,
+ out_channels=out_channels,
+ temb_channels=temb_channels,
+ eps=resnet_eps,
+ groups=resnet_groups,
+ dropout=dropout,
+ time_embedding_norm=resnet_time_scale_shift,
+ non_linearity=resnet_act_fn,
+ output_scale_factor=output_scale_factor,
+ pre_norm=resnet_pre_norm,
+ )
+ )
+
+ self.resnets = nn.ModuleList(resnets)
+
+ if add_upsample:
+ self.upsamplers = nn.ModuleList([Upsample3D(out_channels, use_conv=True, out_channels=out_channels)])
+ else:
+ self.upsamplers = None
+
+ self.gradient_checkpointing = False
+
+ def forward(self, hidden_states, res_hidden_states_tuple, temb=None, upsample_size=None):
+ for resnet in self.resnets:
+ # pop res hidden states
+ res_hidden_states = res_hidden_states_tuple[-1]
+ res_hidden_states_tuple = res_hidden_states_tuple[:-1]
+ hidden_states = torch.cat([hidden_states, res_hidden_states], dim=1)
+
+ if self.training and self.gradient_checkpointing:
+
+ def create_custom_forward(module):
+ def custom_forward(*inputs):
+ return module(*inputs)
+
+ return custom_forward
+
+ hidden_states = torch.utils.checkpoint.checkpoint(create_custom_forward(resnet), hidden_states, temb)
+ else:
+ hidden_states = resnet(hidden_states, temb)
+
+ if self.upsamplers is not None:
+ for upsampler in self.upsamplers:
+ hidden_states = upsampler(hidden_states, upsample_size)
+
+ return hidden_states
\ No newline at end of file
diff --git a/vsr/models/upscaling.py b/vsr/models/upscaling.py
new file mode 100644
index 0000000000000000000000000000000000000000..f0f62ecd358a839b9a7c6a2d0e403ed9150ed2d3
--- /dev/null
+++ b/vsr/models/upscaling.py
@@ -0,0 +1,95 @@
+import torch
+import torch.nn as nn
+import numpy as np
+from functools import partial
+from inspect import isfunction
+
+def exists(x):
+ return x is not None
+
+def default(val, d):
+ if exists(val):
+ return val
+ return d() if isfunction(d) else d
+
+def extract_into_tensor(a, t, x_shape):
+ b, *_ = t.shape
+ out = a.gather(-1, t)
+ return out.reshape(b, *((1,) * (len(x_shape) - 1)))
+
+def make_beta_schedule(n_timestep, linear_start=1e-4, linear_end=2e-2):
+ betas = (
+ torch.linspace(linear_start ** 0.5, linear_end ** 0.5, n_timestep, dtype=torch.float64) ** 2
+ )
+ return betas.numpy()
+
+class AbstractLowScaleModel(nn.Module):
+ # for concatenating a downsampled image to the latent representation
+ def __init__(self, noise_schedule_config=None):
+ super(AbstractLowScaleModel, self).__init__()
+ if noise_schedule_config is not None:
+ self.register_schedule(**noise_schedule_config)
+
+ def register_schedule(self, timesteps=1000, linear_start=1e-4, linear_end=2e-2):
+ betas = make_beta_schedule(timesteps, linear_start=linear_start, linear_end=linear_end)
+ alphas = 1. - betas
+ alphas_cumprod = np.cumprod(alphas, axis=0)
+ alphas_cumprod_prev = np.append(1., alphas_cumprod[:-1])
+
+ timesteps, = betas.shape
+ self.num_timesteps = int(timesteps)
+ self.linear_start = linear_start
+ self.linear_end = linear_end
+ assert alphas_cumprod.shape[0] == self.num_timesteps, 'alphas have to be defined for each timestep'
+
+ to_torch = partial(torch.tensor, dtype=torch.float32)
+
+ self.register_buffer('betas', to_torch(betas))
+ self.register_buffer('alphas_cumprod', to_torch(alphas_cumprod))
+ self.register_buffer('alphas_cumprod_prev', to_torch(alphas_cumprod_prev))
+
+ # calculations for diffusion q(x_t | x_{t-1}) and others
+ self.register_buffer('sqrt_alphas_cumprod', to_torch(np.sqrt(alphas_cumprod)))
+ self.register_buffer('sqrt_one_minus_alphas_cumprod', to_torch(np.sqrt(1. - alphas_cumprod)))
+ self.register_buffer('log_one_minus_alphas_cumprod', to_torch(np.log(1. - alphas_cumprod)))
+ self.register_buffer('sqrt_recip_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod)))
+ self.register_buffer('sqrt_recipm1_alphas_cumprod', to_torch(np.sqrt(1. / alphas_cumprod - 1)))
+
+ def q_sample(self, x_start, t, noise=None):
+ noise = default(noise, lambda: torch.randn_like(x_start))
+ return (extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start +
+ extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise)
+
+ def forward(self, x):
+ return x, None
+
+ def decode(self, x):
+ return x
+
+
+class SimpleImageConcat(AbstractLowScaleModel):
+ # no noise level conditioning
+ def __init__(self):
+ super(SimpleImageConcat, self).__init__(noise_schedule_config=None)
+ self.max_noise_level = 0
+
+ def forward(self, x):
+ # fix to constant noise level
+ return x, torch.zeros(x.shape[0], device=x.device).long()
+
+
+class ImageConcatWithNoiseAugmentation(AbstractLowScaleModel):
+ def __init__(self, noise_schedule_config, max_noise_level=1000, to_cuda=False):
+ super().__init__(noise_schedule_config=noise_schedule_config)
+ self.max_noise_level = max_noise_level
+
+ def forward(self, x, noise_level=None):
+ if noise_level is None:
+ noise_level = torch.randint(0, self.max_noise_level, (x.shape[0],), device=x.device).long()
+ else:
+ assert isinstance(noise_level, torch.Tensor)
+ z = self.q_sample(x, noise_level)
+ return z, noise_level
+
+
+
diff --git a/vsr/models/utils.py b/vsr/models/utils.py
new file mode 100644
index 0000000000000000000000000000000000000000..76e205d1709568885e58b56888f62c8805ad4f91
--- /dev/null
+++ b/vsr/models/utils.py
@@ -0,0 +1,215 @@
+# adopted from
+# https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
+# and
+# https://github.com/lucidrains/denoising-diffusion-pytorch/blob/7706bdfc6f527f58d33f84b7b522e61e6e3164b3/denoising_diffusion_pytorch/denoising_diffusion_pytorch.py
+# and
+# https://github.com/openai/guided-diffusion/blob/0ba878e517b276c45d1195eb29f6f5f72659a05b/guided_diffusion/nn.py
+#
+# thanks!
+
+
+import os
+import math
+import torch
+
+import numpy as np
+import torch.nn as nn
+
+from einops import repeat
+
+
+#################################################################################
+# Unet Utils #
+#################################################################################
+
+def checkpoint(func, inputs, params, flag):
+ """
+ Evaluate a function without caching intermediate activations, allowing for
+ reduced memory at the expense of extra compute in the backward pass.
+ :param func: the function to evaluate.
+ :param inputs: the argument sequence to pass to `func`.
+ :param params: a sequence of parameters `func` depends on but does not
+ explicitly take as arguments.
+ :param flag: if False, disable gradient checkpointing.
+ """
+ if flag:
+ args = tuple(inputs) + tuple(params)
+ return CheckpointFunction.apply(func, len(inputs), *args)
+ else:
+ return func(*inputs)
+
+
+class CheckpointFunction(torch.autograd.Function):
+ @staticmethod
+ def forward(ctx, run_function, length, *args):
+ ctx.run_function = run_function
+ ctx.input_tensors = list(args[:length])
+ ctx.input_params = list(args[length:])
+
+ with torch.no_grad():
+ output_tensors = ctx.run_function(*ctx.input_tensors)
+ return output_tensors
+
+ @staticmethod
+ def backward(ctx, *output_grads):
+ ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
+ with torch.enable_grad():
+ # Fixes a bug where the first op in run_function modifies the
+ # Tensor storage in place, which is not allowed for detach()'d
+ # Tensors.
+ shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
+ output_tensors = ctx.run_function(*shallow_copies)
+ input_grads = torch.autograd.grad(
+ output_tensors,
+ ctx.input_tensors + ctx.input_params,
+ output_grads,
+ allow_unused=True,
+ )
+ del ctx.input_tensors
+ del ctx.input_params
+ del output_tensors
+ return (None, None) + input_grads
+
+
+def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
+ """
+ Create sinusoidal timestep embeddings.
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ if not repeat_only:
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ else:
+ embedding = repeat(timesteps, 'b -> b d', d=dim).contiguous()
+ return embedding
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+def scale_module(module, scale):
+ """
+ Scale the parameters of a module and return it.
+ """
+ for p in module.parameters():
+ p.detach().mul_(scale)
+ return module
+
+
+def mean_flat(tensor):
+ """
+ Take the mean over all non-batch dimensions.
+ """
+ return tensor.mean(dim=list(range(1, len(tensor.shape))))
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer.
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ return GroupNorm32(32, channels)
+
+
+# PyTorch 1.7 has SiLU, but we support PyTorch 1.5.
+class SiLU(nn.Module):
+ def forward(self, x):
+ return x * torch.sigmoid(x)
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+def conv_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D convolution module.
+ """
+ if dims == 1:
+ return nn.Conv1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.Conv2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.Conv3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+def linear(*args, **kwargs):
+ """
+ Create a linear module.
+ """
+ return nn.Linear(*args, **kwargs)
+
+
+def avg_pool_nd(dims, *args, **kwargs):
+ """
+ Create a 1D, 2D, or 3D average pooling module.
+ """
+ if dims == 1:
+ return nn.AvgPool1d(*args, **kwargs)
+ elif dims == 2:
+ return nn.AvgPool2d(*args, **kwargs)
+ elif dims == 3:
+ return nn.AvgPool3d(*args, **kwargs)
+ raise ValueError(f"unsupported dimensions: {dims}")
+
+
+# class HybridConditioner(nn.Module):
+
+# def __init__(self, c_concat_config, c_crossattn_config):
+# super().__init__()
+# self.concat_conditioner = instantiate_from_config(c_concat_config)
+# self.crossattn_conditioner = instantiate_from_config(c_crossattn_config)
+
+# def forward(self, c_concat, c_crossattn):
+# c_concat = self.concat_conditioner(c_concat)
+# c_crossattn = self.crossattn_conditioner(c_crossattn)
+# return {'c_concat': [c_concat], 'c_crossattn': [c_crossattn]}
+
+
+def noise_like(shape, device, repeat=False):
+ repeat_noise = lambda: torch.randn((1, *shape[1:]), device=device).repeat(shape[0], *((1,) * (len(shape) - 1)))
+ noise = lambda: torch.randn(shape, device=device)
+ return repeat_noise() if repeat else noise()
+
+def count_flops_attn(model, _x, y):
+ """
+ A counter for the `thop` package to count the operations in an
+ attention operation.
+ Meant to be used like:
+ macs, params = thop.profile(
+ model,
+ inputs=(inputs, timestamps),
+ custom_ops={QKVAttention: QKVAttention.count_flops},
+ )
+ """
+ b, c, *spatial = y[0].shape
+ num_spatial = int(np.prod(spatial))
+ # We perform two matmuls with the same number of ops.
+ # The first computes the weight matrix, the second computes
+ # the combination of the value vectors.
+ matmul_ops = 2 * b * (num_spatial ** 2) * c
+ model.total_ops += torch.DoubleTensor([matmul_ops])
+
+def count_params(model, verbose=False):
+ total_params = sum(p.numel() for p in model.parameters())
+ if verbose:
+ print(f"{model.__class__.__name__} has {total_params * 1.e-6:.2f} M params.")
+ return total_params
\ No newline at end of file
diff --git a/vsr/sample.py b/vsr/sample.py
new file mode 100644
index 0000000000000000000000000000000000000000..a0f948b622a7f17a35ec2debbbd55f68e200ab3b
--- /dev/null
+++ b/vsr/sample.py
@@ -0,0 +1,151 @@
+import io
+import os
+import sys
+import argparse
+o_path = os.getcwd()
+sys.path.append(o_path)
+
+import torch
+import time
+import json
+import numpy as np
+import imageio
+import torchvision
+from einops import rearrange
+
+from models.autoencoder_kl import AutoencoderKL
+from models.unet import UNet3DVSRModel
+from models.pipeline_stable_diffusion_upscale_video_3d import StableDiffusionUpscalePipeline
+from diffusers import DDIMScheduler
+from omegaconf import OmegaConf
+
+
+def main(args)
+
+ device = "cuda"
+
+ # ---------------------- load models ----------------------
+ pipeline = StableDiffusionUpscalePipeline.from_pretrained(args.pretrained_path + '/stable-diffusion-x4-upscaler', torch_dtype=torch.float16)
+
+ # vae
+ pipeline.vae = AutoencoderKL.from_config("configs/vae_config.json")
+ pretrained_model = args.pretrained_path + "/stable-diffusion-x4-upscaler/vae/diffusion_pytorch_model.bin"
+ pipeline.vae.load_state_dict(torch.load(pretrained_model, map_location="cpu"))
+
+ # unet
+ config_path = "./configs/unet_3d_config.json"
+ with open(config_path, "r") as f:
+ config = json.load(f)
+ config['video_condition'] = False
+ pipeline.unet = UNet3DVSRModel.from_config(config)
+
+ pretrained_model = args.pretrained_path + "/lavie_vsr.pt"
+ checkpoint = torch.load(pretrained_model, map_location="cpu")['ema']
+
+ pipeline.unet.load_state_dict(checkpoint, True)
+ pipeline.unet = pipeline.unet.half()
+ pipeline.unet.eval() # important!
+
+ # DDIMScheduler
+ with open(args.pretrained_path + '/stable-diffusion-x4-upscaler/scheduler/scheduler_config.json', "r") as f:
+ config = json.load(f)
+ config["beta_schedule"] = "linear"
+ pipeline.scheduler = DDIMScheduler.from_config(config)
+
+ pipeline = pipeline.to("cuda")
+
+ # ---------------------- load user's prompt ----------------------
+ # input
+ video_root = args.input_path
+ video_list = sorted(os.listdir(video_root))
+ print('video num:', len(video_list))
+
+ # output
+ save_root = args.output_path
+ os.makedirs(save_root, exist_ok=True)
+
+ # inference params
+ noise_level = args.noise_level
+ guidance_scale = args.guidance_scale
+ num_inference_steps = args.inference_steps
+
+ # ---------------------- start inferencing ----------------------
+ for i, video_name in enumerate(video_list):
+ video_name = video_name.replace('.mp4', '')
+ print(f'[{i+1}/{len(video_list)}]: ', video_name)
+
+ lr_path = f"{video_root}/{video_name}.mp4"
+ save_path = f"{save_root}/{video_name}.mp4"
+
+ prompt = video_name
+ print('Prompt: ', prompt)
+
+ negative_prompt = "blur, worst quality"
+
+ vframes, aframes, info = torchvision.io.read_video(filename=lr_path, pts_unit='sec', output_format='TCHW') # RGB
+ vframes = vframes / 255.
+ vframes = (vframes - 0.5) * 2 # T C H W [-1, 1]
+ t, _, h, w = vframes.shape
+ vframes = vframes.unsqueeze(dim=0) # 1 T C H W
+ vframes = rearrange(vframes, 'b t c h w -> b c t h w').contiguous() # 1 C T H W
+ print('Input_shape:', vframes.shape, 'Noise_level:', noise_level, 'Guidance_scale:', guidance_scale)
+
+ fps = info['video_fps']
+ generator = torch.Generator(device=device).manual_seed(10)
+
+ torch.cuda.synchronize()
+ start_time = time.time()
+
+ with torch.no_grad():
+ short_seq = 8
+ vframes_seq = vframes.shape[2]
+ if vframes_seq > short_seq: # for VSR
+ upscaled_video_list = []
+ for start_f in range(0, vframes_seq, short_seq):
+ print(f'Processing: [{start_f}-{start_f + short_seq}/{vframes_seq}]')
+ torch.cuda.empty_cache() # delete for VSR
+ end_f = min(vframes_seq, start_f + short_seq)
+
+ upscaled_video_ = pipeline(
+ prompt,
+ image=vframes[:,:,start_f:end_f],
+ generator=generator,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ noise_level=noise_level,
+ negative_prompt=negative_prompt,
+ ).images # T C H W [-1, 1]
+ upscaled_video_list.append(upscaled_video_)
+ upscaled_video = torch.cat(upscaled_video_list, dim=0)
+ else:
+ upscaled_video = pipeline(
+ prompt,
+ image=vframes,
+ generator=generator,
+ num_inference_steps=num_inference_steps,
+ guidance_scale=guidance_scale,
+ noise_level=noise_level,
+ negative_prompt=negative_prompt,
+ ).images # T C H W [-1, 1]
+
+ torch.cuda.synchronize()
+ run_time = time.time() - start_time
+
+ print('Output:', upscaled_video.shape)
+
+ # save video
+ upscaled_video = (upscaled_video / 2 + 0.5).clamp(0, 1) * 255
+ upscaled_video = upscaled_video.permute(0, 2, 3, 1).to(torch.uint8)
+ upscaled_video = upscaled_video.numpy().astype(np.uint8)
+ imageio.mimwrite(save_path, upscaled_video, fps=fps, quality=9) # Highest quality is 10, lowest is 0
+
+ print(f'Save upscaled video "{video_name}" in {save_path}, time (sec): {run_time} \n')
+ print(f'\nAll results are saved in {save_path}')
+
+if __name__ == '__main__':
+
+ parser = argparse.ArgumentParser()
+ parser.add_argument("--config", type=str, default="")
+ args = parser.parse_args()
+
+ main(OmegaConf.load(args.config))