zxl
commited on
Commit
•
07c6a04
1
Parent(s):
bd6e6ad
first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- CONTRIBUTING.md +37 -0
- LICENSE +0 -0
- README.md +16 -8
- app.py +508 -0
- docs/dsp.md +25 -0
- docs/pab.md +121 -0
- eval/pab/commom_metrics/README.md +6 -0
- eval/pab/commom_metrics/__init__.py +0 -0
- eval/pab/commom_metrics/calculate_lpips.py +97 -0
- eval/pab/commom_metrics/calculate_psnr.py +90 -0
- eval/pab/commom_metrics/calculate_ssim.py +116 -0
- eval/pab/commom_metrics/eval.py +160 -0
- eval/pab/experiments/__init__.py +0 -0
- eval/pab/experiments/attention_ablation.py +60 -0
- eval/pab/experiments/components_ablation.py +46 -0
- eval/pab/experiments/latte.py +57 -0
- eval/pab/experiments/opensora.py +44 -0
- eval/pab/experiments/opensora_plan.py +57 -0
- eval/pab/experiments/utils.py +22 -0
- eval/pab/vbench/VBench_full_info.json +0 -0
- eval/pab/vbench/cal_vbench.py +154 -0
- eval/pab/vbench/run_vbench.py +52 -0
- examples/cogvideo/sample.py +14 -0
- examples/latte/sample.py +24 -0
- examples/open_sora/sample.py +24 -0
- examples/open_sora_plan/sample.py +24 -0
- requirements.txt +25 -0
- setup.py +55 -0
- tests/__init__.py +0 -0
- videosys/__init__.py +19 -0
- videosys/core/__init__.py +0 -0
- videosys/core/comm.py +420 -0
- videosys/core/engine.py +132 -0
- videosys/core/mp_utils.py +270 -0
- videosys/core/pab_mgr.py +364 -0
- videosys/core/parallel_mgr.py +119 -0
- videosys/core/pipeline.py +34 -0
- videosys/core/shardformer/__init__.py +0 -0
- videosys/core/shardformer/t5/__init__.py +0 -0
- videosys/core/shardformer/t5/modeling.py +39 -0
- videosys/core/shardformer/t5/policy.py +68 -0
- videosys/datasets/dataloader.py +94 -0
- videosys/datasets/image_transform.py +42 -0
- videosys/datasets/video_transform.py +441 -0
- videosys/diffusion/__init__.py +41 -0
- videosys/diffusion/diffusion_utils.py +79 -0
- videosys/diffusion/gaussian_diffusion.py +829 -0
- videosys/diffusion/respace.py +119 -0
- videosys/diffusion/timestep_sampler.py +143 -0
- videosys/models/__init__.py +0 -0
CONTRIBUTING.md
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## Coding Standards
|
2 |
+
|
3 |
+
### Unit Tests
|
4 |
+
We use [PyTest](https://docs.pytest.org/en/latest/) to execute tests. You can install pytest by `pip install pytest`. As some of the tests require initialization of the distributed backend, GPUs are needed to execute these tests.
|
5 |
+
|
6 |
+
To set up the environment for unit testing, first change your current directory to the root directory of your local ColossalAI repository, then run
|
7 |
+
```bash
|
8 |
+
pip install -r requirements/requirements-test.txt
|
9 |
+
```
|
10 |
+
If you encounter an error telling "Could not find a version that satisfies the requirement fbgemm-gpu==0.2.0", please downgrade your python version to 3.8 or 3.9 and try again.
|
11 |
+
|
12 |
+
If you only want to run CPU tests, you can run
|
13 |
+
|
14 |
+
```bash
|
15 |
+
pytest -m cpu tests/
|
16 |
+
```
|
17 |
+
|
18 |
+
If you have 8 GPUs on your machine, you can run the full test
|
19 |
+
|
20 |
+
```bash
|
21 |
+
pytest tests/
|
22 |
+
```
|
23 |
+
|
24 |
+
If you do not have 8 GPUs on your machine, do not worry. Unit testing will be automatically conducted when you put up a pull request to the main branch.
|
25 |
+
|
26 |
+
|
27 |
+
### Code Style
|
28 |
+
|
29 |
+
We have some static checks when you commit your code change, please make sure you can pass all the tests and make sure the coding style meets our requirements. We use pre-commit hook to make sure the code is aligned with the writing standard. To set up the code style checking, you need to follow the steps below.
|
30 |
+
|
31 |
+
```shell
|
32 |
+
# these commands are executed under the Colossal-AI directory
|
33 |
+
pip install pre-commit
|
34 |
+
pre-commit install
|
35 |
+
```
|
36 |
+
|
37 |
+
Code format checking will be automatically executed when you commit your changes.
|
LICENSE
ADDED
The diff for this file is too large to render.
See raw diff
|
|
README.md
CHANGED
@@ -1,12 +1,20 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.42.0
|
|
|
|
|
|
|
8 |
app_file: app.py
|
9 |
-
|
10 |
-
|
11 |
-
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
+
title: VideoSys-CogVideoX
|
3 |
+
emoji: 🎥
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: green
|
6 |
sdk: gradio
|
7 |
sdk_version: 4.42.0
|
8 |
+
suggested_hardware: a10g-large
|
9 |
+
suggested_storage: large
|
10 |
+
app_port: 7860
|
11 |
app_file: app.py
|
12 |
+
models:
|
13 |
+
- THUDM/CogVideoX-2b
|
14 |
+
tags:
|
15 |
+
- cogvideox
|
16 |
+
- video-generation
|
17 |
+
- thudm
|
18 |
+
short_description: Text-to-Video
|
19 |
+
disable_embedding: false
|
20 |
+
---
|
app.py
ADDED
@@ -0,0 +1,508 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# # import gradio as gr
|
2 |
+
# # from videosys import CogVideoConfig, VideoSysEngine
|
3 |
+
# # import tempfile
|
4 |
+
# # import os
|
5 |
+
# # import logging
|
6 |
+
# # import uuid
|
7 |
+
|
8 |
+
# # logging.basicConfig(level=logging.INFO)
|
9 |
+
# # logger = logging.getLogger(__name__)
|
10 |
+
|
11 |
+
# # config = CogVideoConfig(world_size=1)
|
12 |
+
# # engine = VideoSysEngine(config)
|
13 |
+
|
14 |
+
# # def generate_video(prompt):
|
15 |
+
# # try:
|
16 |
+
# # video = engine.generate(prompt).video[0]
|
17 |
+
|
18 |
+
# # # 使用临时文件和唯一标识符
|
19 |
+
# # with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
|
20 |
+
# # temp_filename = temp_file.name
|
21 |
+
# # unique_filename = f"{uuid.uuid4().hex}.mp4"
|
22 |
+
# # output_path = os.path.join(tempfile.gettempdir(), unique_filename)
|
23 |
+
|
24 |
+
# # engine.save_video(video, output_path)
|
25 |
+
|
26 |
+
# # return output_path
|
27 |
+
# # except Exception as e:
|
28 |
+
# # logger.error(f"An error occurred: {str(e)}")
|
29 |
+
# # return None # 返回 None 而不是错误消息
|
30 |
+
|
31 |
+
# # iface = gr.Interface(
|
32 |
+
# # fn=generate_video,
|
33 |
+
# # inputs=gr.Textbox(lines=2, placeholder="Enter your prompt here..."),
|
34 |
+
# # outputs=gr.Video(label="Generated Video"),
|
35 |
+
# # title="CogVideoX-2b: Text-to-Video Generation",
|
36 |
+
# # description="Enter a text prompt to generate a video using CogVideoX-2b."
|
37 |
+
# # )
|
38 |
+
|
39 |
+
# # iface.launch()
|
40 |
+
|
41 |
+
|
42 |
+
# from videosys import CogVideoConfig, VideoSysEngine
|
43 |
+
# from videosys.models.cogvideo.pipeline import CogVideoPABConfig
|
44 |
+
# import os
|
45 |
+
|
46 |
+
# import gradio as gr
|
47 |
+
# import numpy as np
|
48 |
+
# import torch
|
49 |
+
# from openai import OpenAI
|
50 |
+
# from time import time
|
51 |
+
# import tempfile
|
52 |
+
# import uuid
|
53 |
+
# import logging
|
54 |
+
|
55 |
+
# logging.basicConfig(level=logging.INFO)
|
56 |
+
# logger = logging.getLogger(__name__)
|
57 |
+
|
58 |
+
# dtype = torch.bfloat16
|
59 |
+
# sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
60 |
+
|
61 |
+
# For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
62 |
+
# There are a few rules to follow:
|
63 |
+
|
64 |
+
# You will only ever output a single video description per user request.
|
65 |
+
|
66 |
+
# When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
67 |
+
# Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
68 |
+
|
69 |
+
# Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
70 |
+
# """
|
71 |
+
|
72 |
+
# def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
73 |
+
# if not os.environ.get("OPENAI_API_KEY"):
|
74 |
+
# return prompt
|
75 |
+
# client = OpenAI()
|
76 |
+
# text = prompt.strip()
|
77 |
+
|
78 |
+
# for i in range(retry_times):
|
79 |
+
# response = client.chat.completions.create(
|
80 |
+
# messages=[
|
81 |
+
# {"role": "system", "content": sys_prompt},
|
82 |
+
# {
|
83 |
+
# "role": "user",
|
84 |
+
# "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
|
85 |
+
# },
|
86 |
+
# {
|
87 |
+
# "role": "assistant",
|
88 |
+
# "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
|
89 |
+
# },
|
90 |
+
# {
|
91 |
+
# "role": "user",
|
92 |
+
# "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
|
93 |
+
# },
|
94 |
+
# {
|
95 |
+
# "role": "assistant",
|
96 |
+
# "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
|
97 |
+
# },
|
98 |
+
# {
|
99 |
+
# "role": "user",
|
100 |
+
# "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
|
101 |
+
# },
|
102 |
+
# {
|
103 |
+
# "role": "assistant",
|
104 |
+
# "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
|
105 |
+
# },
|
106 |
+
# {
|
107 |
+
# "role": "user",
|
108 |
+
# "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
109 |
+
# },
|
110 |
+
# ],
|
111 |
+
# model="glm-4-0520",
|
112 |
+
# temperature=0.01,
|
113 |
+
# top_p=0.7,
|
114 |
+
# stream=False,
|
115 |
+
# max_tokens=250,
|
116 |
+
# )
|
117 |
+
# if response.choices:
|
118 |
+
# return response.choices[0].message.content
|
119 |
+
# return prompt
|
120 |
+
|
121 |
+
# def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_gap=2):
|
122 |
+
# pab_config = CogVideoPABConfig(full_threshold=pab_threshold, full_gap=pab_gap)
|
123 |
+
# config = CogVideoConfig(world_size=1, enable_pab=enable_video_sys, pab_config=pab_config)
|
124 |
+
# engine = VideoSysEngine(config)
|
125 |
+
# return engine
|
126 |
+
|
127 |
+
|
128 |
+
|
129 |
+
# def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
|
130 |
+
# try:
|
131 |
+
# video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
|
132 |
+
|
133 |
+
# with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
|
134 |
+
# temp_file.name
|
135 |
+
# unique_filename = f"{uuid.uuid4().hex}.mp4"
|
136 |
+
# output_path = os.path.join(tempfile.gettempdir(), unique_filename)
|
137 |
+
|
138 |
+
# engine.save_video(video, output_path)
|
139 |
+
# return output_path
|
140 |
+
# except Exception as e:
|
141 |
+
# logger.error(f"An error occurred: {str(e)}")
|
142 |
+
# return None
|
143 |
+
|
144 |
+
|
145 |
+
|
146 |
+
# with gr.Blocks() as demo:
|
147 |
+
# gr.Markdown("""
|
148 |
+
# <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
149 |
+
# VideoSys Huggingface Space🤗
|
150 |
+
# </div>
|
151 |
+
# <div style="text-align: center;">
|
152 |
+
# <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys">🌐 Github</a>
|
153 |
+
# </div>
|
154 |
+
|
155 |
+
# <div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
|
156 |
+
# ⚠️ This demo is for academic research and experiential use only.
|
157 |
+
# Users should strictly adhere to local laws and ethics.
|
158 |
+
# </div>
|
159 |
+
# <div style="text-align: center; font-size: 15px; font-weight: bold; color: magenta; margin-bottom: 20px;">
|
160 |
+
# 💡 This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.
|
161 |
+
# </div>
|
162 |
+
# """)
|
163 |
+
# with gr.Row():
|
164 |
+
# with gr.Column():
|
165 |
+
# prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="a bear hunting for prey", lines=5)
|
166 |
+
# with gr.Row():
|
167 |
+
# gr.Markdown(
|
168 |
+
# "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
|
169 |
+
# )
|
170 |
+
# enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
|
171 |
+
|
172 |
+
# with gr.Column():
|
173 |
+
# gr.Markdown(
|
174 |
+
# "**Optional Parameters** (default values are recommended)<br>"
|
175 |
+
# "Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
|
176 |
+
# "50 steps are recommended for most cases. will cause 120 seconds for inference.<br>"
|
177 |
+
# )
|
178 |
+
# with gr.Row():
|
179 |
+
# num_inference_steps = gr.Number(label="Inference Steps", value=50)
|
180 |
+
# guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
|
181 |
+
# pab_gap = gr.Number(label="PAB Gap", value=2, precision=0)
|
182 |
+
# pab_threshold = gr.Textbox(label="PAB Threshold", value="100,850", lines=1)
|
183 |
+
# with gr.Row():
|
184 |
+
# generate_button = gr.Button("🎬 Generate Video")
|
185 |
+
# generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
|
186 |
+
|
187 |
+
# with gr.Column():
|
188 |
+
# with gr.Row():
|
189 |
+
# video_output = gr.Video(label="CogVideoX", width=720, height=480)
|
190 |
+
# with gr.Row():
|
191 |
+
# download_video_button = gr.File(label="📥 Download Video", visible=False)
|
192 |
+
# elapsed_time = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
|
193 |
+
# with gr.Row():
|
194 |
+
# video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
|
195 |
+
# with gr.Row():
|
196 |
+
# download_video_button_vs = gr.File(label="📥 Download Video", visible=False)
|
197 |
+
# elapsed_time_vs = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
|
198 |
+
|
199 |
+
# def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
200 |
+
# # tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
201 |
+
# engine = load_model()
|
202 |
+
# t = time()
|
203 |
+
# video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
|
204 |
+
# elapsed_time = time() - t
|
205 |
+
# video_update = gr.update(visible=True, value=video_path)
|
206 |
+
# elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
|
207 |
+
|
208 |
+
# return video_path, video_update, elapsed_time
|
209 |
+
|
210 |
+
# def generate_vs(prompt, num_inference_steps, guidance_scale, threshold, gap, progress=gr.Progress(track_tqdm=True)):
|
211 |
+
# # tensor = infer(prompt, num_inference_steps, guidance_scale, progress=progress)
|
212 |
+
# threshold = [int(i) for i in threshold.split(",")]
|
213 |
+
# gap = int(gap)
|
214 |
+
# engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_gap=gap)
|
215 |
+
# t = time()
|
216 |
+
# video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
|
217 |
+
# elapsed_time = time() - t
|
218 |
+
# video_update = gr.update(visible=True, value=video_path)
|
219 |
+
# elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
|
220 |
+
|
221 |
+
# return video_path, video_update, elapsed_time
|
222 |
+
|
223 |
+
|
224 |
+
# def enhance_prompt_func(prompt):
|
225 |
+
# return convert_prompt(prompt, retry_times=1)
|
226 |
+
|
227 |
+
# generate_button.click(
|
228 |
+
# generate_vanilla,
|
229 |
+
# inputs=[prompt, num_inference_steps, guidance_scale],
|
230 |
+
# outputs=[video_output, download_video_button, elapsed_time],
|
231 |
+
# )
|
232 |
+
|
233 |
+
# generate_button_vs.click(
|
234 |
+
# generate_vs,
|
235 |
+
# inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold, pab_gap],
|
236 |
+
# outputs=[video_output_vs, download_video_button_vs, elapsed_time_vs],
|
237 |
+
# )
|
238 |
+
|
239 |
+
# enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
|
240 |
+
|
241 |
+
# if __name__ == "__main__":
|
242 |
+
# demo.launch()
|
243 |
+
|
244 |
+
|
245 |
+
|
246 |
+
import gradio as gr
|
247 |
+
from videosys import CogVideoConfig, VideoSysEngine
|
248 |
+
from videosys.models.cogvideo.pipeline import CogVideoPABConfig
|
249 |
+
import os
|
250 |
+
import numpy as np
|
251 |
+
import torch
|
252 |
+
from openai import OpenAI
|
253 |
+
from time import time
|
254 |
+
import tempfile
|
255 |
+
import uuid
|
256 |
+
import logging
|
257 |
+
|
258 |
+
logging.basicConfig(level=logging.INFO)
|
259 |
+
logger = logging.getLogger(__name__)
|
260 |
+
|
261 |
+
dtype = torch.bfloat16
|
262 |
+
sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
|
263 |
+
|
264 |
+
For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
|
265 |
+
There are a few rules to follow:
|
266 |
+
|
267 |
+
You will only ever output a single video description per user request.
|
268 |
+
|
269 |
+
When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
|
270 |
+
Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
|
271 |
+
|
272 |
+
Video descriptions must have the same num of words as examples below. Extra words will be ignored.
|
273 |
+
"""
|
274 |
+
|
275 |
+
def convert_prompt(prompt: str, retry_times: int = 3) -> str:
|
276 |
+
if not os.environ.get("OPENAI_API_KEY"):
|
277 |
+
return prompt
|
278 |
+
client = OpenAI()
|
279 |
+
text = prompt.strip()
|
280 |
+
|
281 |
+
for i in range(retry_times):
|
282 |
+
response = client.chat.completions.create(
|
283 |
+
messages=[
|
284 |
+
{"role": "system", "content": sys_prompt},
|
285 |
+
{
|
286 |
+
"role": "user",
|
287 |
+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"role": "assistant",
|
291 |
+
"content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
|
292 |
+
},
|
293 |
+
{
|
294 |
+
"role": "user",
|
295 |
+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"role": "assistant",
|
299 |
+
"content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"role": "user",
|
303 |
+
"content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
|
304 |
+
},
|
305 |
+
{
|
306 |
+
"role": "assistant",
|
307 |
+
"content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
|
308 |
+
},
|
309 |
+
{
|
310 |
+
"role": "user",
|
311 |
+
"content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
|
312 |
+
},
|
313 |
+
],
|
314 |
+
model="glm-4-0520",
|
315 |
+
temperature=0.01,
|
316 |
+
top_p=0.7,
|
317 |
+
stream=False,
|
318 |
+
max_tokens=250,
|
319 |
+
)
|
320 |
+
if response.choices:
|
321 |
+
return response.choices[0].message.content
|
322 |
+
return prompt
|
323 |
+
|
324 |
+
def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_gap=2):
|
325 |
+
pab_config = CogVideoPABConfig(full_threshold=pab_threshold, full_gap=pab_gap)
|
326 |
+
config = CogVideoConfig(world_size=1, enable_pab=enable_video_sys, pab_config=pab_config)
|
327 |
+
engine = VideoSysEngine(config)
|
328 |
+
return engine
|
329 |
+
|
330 |
+
def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
|
331 |
+
try:
|
332 |
+
video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
|
333 |
+
|
334 |
+
with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
|
335 |
+
temp_file.name
|
336 |
+
unique_filename = f"{uuid.uuid4().hex}.mp4"
|
337 |
+
output_path = os.path.join(tempfile.gettempdir(), unique_filename)
|
338 |
+
|
339 |
+
engine.save_video(video, output_path)
|
340 |
+
return output_path
|
341 |
+
except Exception as e:
|
342 |
+
logger.error(f"An error occurred: {str(e)}")
|
343 |
+
return None
|
344 |
+
|
345 |
+
css = """
|
346 |
+
body {
|
347 |
+
font-family: Arial, sans-serif;
|
348 |
+
line-height: 1.6;
|
349 |
+
color: #333;
|
350 |
+
max-width: 1200px;
|
351 |
+
margin: 0 auto;
|
352 |
+
padding: 20px;
|
353 |
+
}
|
354 |
+
|
355 |
+
.container {
|
356 |
+
display: flex;
|
357 |
+
flex-direction: column;
|
358 |
+
gap: 20px;
|
359 |
+
}
|
360 |
+
|
361 |
+
.row {
|
362 |
+
display: flex;
|
363 |
+
flex-wrap: wrap;
|
364 |
+
gap: 20px;
|
365 |
+
}
|
366 |
+
|
367 |
+
.column {
|
368 |
+
flex: 1;
|
369 |
+
min-width: 0;
|
370 |
+
}
|
371 |
+
|
372 |
+
.textbox, .number-input, button {
|
373 |
+
width: 100%;
|
374 |
+
padding: 10px;
|
375 |
+
margin-bottom: 10px;
|
376 |
+
border: 1px solid #ddd;
|
377 |
+
border-radius: 4px;
|
378 |
+
}
|
379 |
+
|
380 |
+
button {
|
381 |
+
background-color: #4CAF50;
|
382 |
+
color: white;
|
383 |
+
border: none;
|
384 |
+
cursor: pointer;
|
385 |
+
transition: background-color 0.3s;
|
386 |
+
}
|
387 |
+
|
388 |
+
button:hover {
|
389 |
+
background-color: #45a049;
|
390 |
+
}
|
391 |
+
|
392 |
+
.video-output {
|
393 |
+
width: 100%;
|
394 |
+
max-width: 720px;
|
395 |
+
height: auto;
|
396 |
+
margin: 0 auto;
|
397 |
+
}
|
398 |
+
|
399 |
+
@media (max-width: 768px) {
|
400 |
+
.row {
|
401 |
+
flex-direction: column;
|
402 |
+
}
|
403 |
+
|
404 |
+
.column {
|
405 |
+
width: 100%;
|
406 |
+
}
|
407 |
+
|
408 |
+
.video-output {
|
409 |
+
width: 100%;
|
410 |
+
height: auto;
|
411 |
+
}
|
412 |
+
}
|
413 |
+
"""
|
414 |
+
|
415 |
+
with gr.Blocks(css=css) as demo:
|
416 |
+
gr.HTML("""
|
417 |
+
<div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
|
418 |
+
VideoSys Huggingface Space🤗
|
419 |
+
</div>
|
420 |
+
<div style="text-align: center;">
|
421 |
+
<a href="https://github.com/NUS-HPC-AI-Lab/VideoSys">🌐 Github</a>
|
422 |
+
</div>
|
423 |
+
<div style="text-align: center; font-size: 15px; font-weight: bold; color: red; margin-bottom: 20px;">
|
424 |
+
⚠️ This demo is for academic research and experiential use only.
|
425 |
+
Users should strictly adhere to local laws and ethics.
|
426 |
+
</div>
|
427 |
+
<div style="text-align: center; font-size: 15px; font-weight: bold; color: magenta; margin-bottom: 20px;">
|
428 |
+
💡 This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.
|
429 |
+
</div>
|
430 |
+
""")
|
431 |
+
|
432 |
+
with gr.Row():
|
433 |
+
with gr.Column():
|
434 |
+
prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="a bear hunting for prey", lines=5)
|
435 |
+
with gr.Row():
|
436 |
+
gr.Markdown(
|
437 |
+
"✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
|
438 |
+
)
|
439 |
+
enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
|
440 |
+
|
441 |
+
with gr.Column():
|
442 |
+
gr.Markdown(
|
443 |
+
"**Optional Parameters** (default values are recommended)<br>"
|
444 |
+
"Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
|
445 |
+
"50 steps are recommended for most cases. will cause 120 seconds for inference.<br>"
|
446 |
+
)
|
447 |
+
with gr.Row():
|
448 |
+
num_inference_steps = gr.Number(label="Inference Steps", value=50)
|
449 |
+
guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
|
450 |
+
pab_gap = gr.Number(label="PAB Gap", value=2, precision=0)
|
451 |
+
pab_threshold = gr.Textbox(label="PAB Threshold", value="100,850", lines=1)
|
452 |
+
with gr.Row():
|
453 |
+
generate_button = gr.Button("🎬 Generate Video")
|
454 |
+
generate_button_vs = gr.Button("⚡️ Generate Video with VideoSys (Faster)")
|
455 |
+
|
456 |
+
with gr.Column():
|
457 |
+
with gr.Row():
|
458 |
+
video_output = gr.Video(label="CogVideoX", width=720, height=480)
|
459 |
+
with gr.Row():
|
460 |
+
download_video_button = gr.File(label="📥 Download Video", visible=False)
|
461 |
+
elapsed_time = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
|
462 |
+
with gr.Row():
|
463 |
+
video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
|
464 |
+
with gr.Row():
|
465 |
+
download_video_button_vs = gr.File(label="📥 Download Video", visible=False)
|
466 |
+
elapsed_time_vs = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
|
467 |
+
|
468 |
+
def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
|
469 |
+
engine = load_model()
|
470 |
+
t = time()
|
471 |
+
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
|
472 |
+
elapsed_time = time() - t
|
473 |
+
video_update = gr.update(visible=True, value=video_path)
|
474 |
+
elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
|
475 |
+
|
476 |
+
return video_path, video_update, elapsed_time
|
477 |
+
|
478 |
+
def generate_vs(prompt, num_inference_steps, guidance_scale, threshold, gap, progress=gr.Progress(track_tqdm=True)):
|
479 |
+
threshold = [int(i) for i in threshold.split(",")]
|
480 |
+
gap = int(gap)
|
481 |
+
engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_gap=gap)
|
482 |
+
t = time()
|
483 |
+
video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
|
484 |
+
elapsed_time = time() - t
|
485 |
+
video_update = gr.update(visible=True, value=video_path)
|
486 |
+
elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
|
487 |
+
|
488 |
+
return video_path, video_update, elapsed_time
|
489 |
+
|
490 |
+
def enhance_prompt_func(prompt):
|
491 |
+
return convert_prompt(prompt, retry_times=1)
|
492 |
+
|
493 |
+
generate_button.click(
|
494 |
+
generate_vanilla,
|
495 |
+
inputs=[prompt, num_inference_steps, guidance_scale],
|
496 |
+
outputs=[video_output, download_video_button, elapsed_time],
|
497 |
+
)
|
498 |
+
|
499 |
+
generate_button_vs.click(
|
500 |
+
generate_vs,
|
501 |
+
inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold, pab_gap],
|
502 |
+
outputs=[video_output_vs, download_video_button_vs, elapsed_time_vs],
|
503 |
+
)
|
504 |
+
|
505 |
+
enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
|
506 |
+
|
507 |
+
if __name__ == "__main__":
|
508 |
+
demo.launch()
|
docs/dsp.md
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# DSP
|
2 |
+
|
3 |
+
paper: https://arxiv.org/abs/2403.10266
|
4 |
+
|
5 |
+
![dsp_overview](../assets/figures/dsp_overview.png)
|
6 |
+
|
7 |
+
|
8 |
+
DSP (Dynamic Sequence Parallelism) is a novel, elegant and super efficient sequence parallelism for [OpenSora](https://github.com/hpcaitech/Open-Sora), [Latte](https://github.com/Vchitect/Latte) and other multi-dimensional transformer architecture.
|
9 |
+
|
10 |
+
The key idea is to dynamically switch the parallelism dimension according to the current computation stage, leveraging the potential characteristics of multi-dimensional transformers. Compared with splitting head and sequence dimension as previous methods, it can reduce at least 75% of communication cost.
|
11 |
+
|
12 |
+
It achieves **3x** speed for training and **2x** speed for inference in OpenSora compared with sota sequence parallelism ([DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509)). For a 10s (80 frames) of 512x512 video, the inference latency of OpenSora is:
|
13 |
+
|
14 |
+
| Method | 1xH800 | 8xH800 (DS Ulysses) | 8xH800 (DSP) |
|
15 |
+
| ------ | ------ | ------ | ------ |
|
16 |
+
| Latency(s) | 106 | 45 | 22 |
|
17 |
+
|
18 |
+
The following is DSP's end-to-end throughput for training of OpenSora:
|
19 |
+
|
20 |
+
![dsp_overview](../assets/figures/dsp_exp.png)
|
21 |
+
|
22 |
+
|
23 |
+
### Usage
|
24 |
+
|
25 |
+
DSP is currently supported for: OpenSora, OpenSoraPlan and Latte. To enable DSP, you just need to launch with multiple GPUs.
|
docs/pab.md
ADDED
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Pyramid Attention Broadcast(PAB)
|
2 |
+
|
3 |
+
[[paper](https://arxiv.org/abs/2408.12588)][[blog](https://arxiv.org/abs/2403.10266)]
|
4 |
+
|
5 |
+
Pyramid Attention Broadcast(PAB)(#pyramid-attention-broadcastpab)
|
6 |
+
- [Pyramid Attention Broadcast(PAB)](#pyramid-attention-broadcastpab)
|
7 |
+
- [Insights](#insights)
|
8 |
+
- [Pyramid Attention Broadcast (PAB) Mechanism](#pyramid-attention-broadcast-pab-mechanism)
|
9 |
+
- [Experimental Results](#experimental-results)
|
10 |
+
- [Usage](#usage)
|
11 |
+
- [Supported Models](#supported-models)
|
12 |
+
- [Configuration for PAB](#configuration-for-pab)
|
13 |
+
- [Parameters](#parameters)
|
14 |
+
- [Example Configuration](#example-configuration)
|
15 |
+
|
16 |
+
|
17 |
+
We introduce Pyramid Attention Broadcast (PAB), the first approach that achieves real-time DiT-based video generation. By mitigating redundant attention computation, PAB achieves up to 21.6 FPS with 10.6x acceleration, without sacrificing quality across popular DiT-based video generation models including Open-Sora, Open-Sora-Plan, and Latte. Notably, as a training-free approach, PAB can enpower any future DiT-based video generation models with real-time capabilities.
|
18 |
+
|
19 |
+
## Insights
|
20 |
+
|
21 |
+
![method](../assets/figures/pab_motivation.png)
|
22 |
+
|
23 |
+
Our study reveals two key insights of three **attention mechanisms** within video diffusion transformers:
|
24 |
+
- First, attention differences across time steps exhibit a U-shaped pattern, with significant variations occurring during the first and last 15% of steps, while the middle 70% of steps show very stable, minor differences.
|
25 |
+
- Second, within the stable middle segment, the variability differs among attention types:
|
26 |
+
- **Spatial attention** varies the most, involving high-frequency elements like edges and textures;
|
27 |
+
- **Temporal attention** exhibits mid-frequency variations related to movements and dynamics in videos;
|
28 |
+
- **Cross-modal attention** is the most stable, linking text with video content, analogous to low-frequency signals reflecting textual semantics.
|
29 |
+
|
30 |
+
## Pyramid Attention Broadcast (PAB) Mechanism
|
31 |
+
|
32 |
+
![method](../assets/figures/pab_method.png)
|
33 |
+
|
34 |
+
Building on these insights, we propose a **pyramid attention broadcast(PAB)** mechanism to minimize unnecessary computations and optimize the utility of each attention module, as shown in Figure[xx figure] below.
|
35 |
+
|
36 |
+
In the middle segment, we broadcast one step's attention outputs to its subsequent several steps, thereby significantly reducing the computational cost on attention modules.
|
37 |
+
|
38 |
+
For more efficient broadcast and minimum influence to effect, we set varied broadcast ranges for different attentions based on their stability and differences.
|
39 |
+
**The smaller the variation in attention, the broader the potential broadcast range.**
|
40 |
+
|
41 |
+
|
42 |
+
## Experimental Results
|
43 |
+
Here are the results of our experiments, more results are shown in https://oahzxl.github.io/PAB:
|
44 |
+
|
45 |
+
![pab_vis](../assets/figures/pab_vis.png)
|
46 |
+
|
47 |
+
|
48 |
+
## Usage
|
49 |
+
|
50 |
+
### Supported Models
|
51 |
+
|
52 |
+
PAB currently supports Open-Sora, Open-Sora-Plan, and Latte.
|
53 |
+
|
54 |
+
### Configuration for PAB
|
55 |
+
|
56 |
+
To efficiently use the Pyramid Attention Broadcast (PAB) mechanism, configure the following parameters to control the broadcasting for different attention types. This helps reduce computational costs by skipping certain steps based on attention stability.
|
57 |
+
|
58 |
+
#### Parameters
|
59 |
+
|
60 |
+
- **spatial_broadcast**: Enable or disable broadcasting for spatial attention.
|
61 |
+
- Type: `True` or `False`
|
62 |
+
|
63 |
+
- **spatial_threshold**: Set the range of diffusion steps within which spatial attention is applied.
|
64 |
+
- Format: `[min_value, max_value]`
|
65 |
+
|
66 |
+
- **spatial_gap**: Number of blocks in model to skip during broadcasting for spatial attention.
|
67 |
+
- Type: Integer
|
68 |
+
|
69 |
+
- **temporal_broadcast**: Enable or disable broadcasting for temporal attention.
|
70 |
+
- Type: `True` or `False`
|
71 |
+
|
72 |
+
- **temporal_threshold**: Set the range of diffusion steps within which temporal attention is applied.
|
73 |
+
- Format: `[min_value, max_value]`
|
74 |
+
|
75 |
+
- **temporal_gap**: Number of steps to skip during broadcasting for temporal attention.
|
76 |
+
- Type: Integer
|
77 |
+
|
78 |
+
- **cross_broadcast**: Enable or disable broadcasting for cross-modal attention.
|
79 |
+
- Type: `True` or `False`
|
80 |
+
|
81 |
+
- **cross_threshold**: Set the range of diffusion steps within which cross-modal attention is applied.
|
82 |
+
- Format: `[min_value, max_value]`
|
83 |
+
|
84 |
+
- **cross_gap**: Number of steps to skip during broadcasting for cross-modal attention.
|
85 |
+
- Type: Integer
|
86 |
+
|
87 |
+
#### Example Configuration
|
88 |
+
|
89 |
+
```yaml
|
90 |
+
spatial_broadcast: True
|
91 |
+
spatial_threshold: [100, 800]
|
92 |
+
spatial_gap: 2
|
93 |
+
|
94 |
+
temporal_broadcast: True
|
95 |
+
temporal_threshold: [100, 800]
|
96 |
+
temporal_gap: 3
|
97 |
+
|
98 |
+
cross_broadcast: True
|
99 |
+
cross_threshold: [100, 900]
|
100 |
+
cross_gap: 5
|
101 |
+
```
|
102 |
+
|
103 |
+
Explanation:
|
104 |
+
|
105 |
+
- **Spatial Attention**:
|
106 |
+
- Broadcasting enabled (`spatial_broadcast: True`)
|
107 |
+
- Applied within the threshold range of 100 to 800
|
108 |
+
- Skips every 2 steps (`spatial_gap: 2`)
|
109 |
+
- Active within the first 28 steps (`spatial_block: [0, 28]`)
|
110 |
+
|
111 |
+
- **Temporal Attention**:
|
112 |
+
- Broadcasting enabled (`temporal_broadcast: True`)
|
113 |
+
- Applied within the threshold range of 100 to 800
|
114 |
+
- Skips every 3 steps (`temporal_gap: 3`)
|
115 |
+
|
116 |
+
- **Cross-Modal Attention**:
|
117 |
+
- Broadcasting enabled (`cross_broadcast: True`)
|
118 |
+
- Applied within the threshold range of 100 to 900
|
119 |
+
- Skips every 5 steps (`cross_gap: 5`)
|
120 |
+
|
121 |
+
Adjust these settings based on your specific needs to optimize the performance of each attention mechanism.
|
eval/pab/commom_metrics/README.md
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Common metrics
|
2 |
+
|
3 |
+
Include LPIPS, PSNR and SSIM.
|
4 |
+
|
5 |
+
The code is adapted from [common_metrics_on_video_quality
|
6 |
+
](https://github.com/JunyaoHu/common_metrics_on_video_quality).
|
eval/pab/commom_metrics/__init__.py
ADDED
File without changes
|
eval/pab/commom_metrics/calculate_lpips.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import lpips
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
spatial = True # Return a spatial map of perceptual distance.
|
6 |
+
|
7 |
+
# Linearly calibrated models (LPIPS)
|
8 |
+
loss_fn = lpips.LPIPS(net="alex", spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
|
9 |
+
# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
|
10 |
+
|
11 |
+
|
12 |
+
def trans(x):
|
13 |
+
# if greyscale images add channel
|
14 |
+
if x.shape[-3] == 1:
|
15 |
+
x = x.repeat(1, 1, 3, 1, 1)
|
16 |
+
|
17 |
+
# value range [0, 1] -> [-1, 1]
|
18 |
+
x = x * 2 - 1
|
19 |
+
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
def calculate_lpips(videos1, videos2, device):
|
24 |
+
# image should be RGB, IMPORTANT: normalized to [-1,1]
|
25 |
+
|
26 |
+
assert videos1.shape == videos2.shape
|
27 |
+
|
28 |
+
# videos [batch_size, timestamps, channel, h, w]
|
29 |
+
|
30 |
+
# support grayscale input, if grayscale -> channel*3
|
31 |
+
# value range [0, 1] -> [-1, 1]
|
32 |
+
videos1 = trans(videos1)
|
33 |
+
videos2 = trans(videos2)
|
34 |
+
|
35 |
+
lpips_results = []
|
36 |
+
|
37 |
+
for video_num in range(videos1.shape[0]):
|
38 |
+
# get a video
|
39 |
+
# video [timestamps, channel, h, w]
|
40 |
+
video1 = videos1[video_num]
|
41 |
+
video2 = videos2[video_num]
|
42 |
+
|
43 |
+
lpips_results_of_a_video = []
|
44 |
+
for clip_timestamp in range(len(video1)):
|
45 |
+
# get a img
|
46 |
+
# img [timestamps[x], channel, h, w]
|
47 |
+
# img [channel, h, w] tensor
|
48 |
+
|
49 |
+
img1 = video1[clip_timestamp].unsqueeze(0).to(device)
|
50 |
+
img2 = video2[clip_timestamp].unsqueeze(0).to(device)
|
51 |
+
|
52 |
+
loss_fn.to(device)
|
53 |
+
|
54 |
+
# calculate lpips of a video
|
55 |
+
lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
|
56 |
+
lpips_results.append(lpips_results_of_a_video)
|
57 |
+
|
58 |
+
lpips_results = np.array(lpips_results)
|
59 |
+
|
60 |
+
lpips = {}
|
61 |
+
lpips_std = {}
|
62 |
+
|
63 |
+
for clip_timestamp in range(len(video1)):
|
64 |
+
lpips[clip_timestamp] = np.mean(lpips_results[:, clip_timestamp])
|
65 |
+
lpips_std[clip_timestamp] = np.std(lpips_results[:, clip_timestamp])
|
66 |
+
|
67 |
+
result = {
|
68 |
+
"value": lpips,
|
69 |
+
"value_std": lpips_std,
|
70 |
+
"video_setting": video1.shape,
|
71 |
+
"video_setting_name": "time, channel, heigth, width",
|
72 |
+
}
|
73 |
+
|
74 |
+
return result
|
75 |
+
|
76 |
+
|
77 |
+
# test code / using example
|
78 |
+
|
79 |
+
|
80 |
+
def main():
|
81 |
+
NUMBER_OF_VIDEOS = 8
|
82 |
+
VIDEO_LENGTH = 50
|
83 |
+
CHANNEL = 3
|
84 |
+
SIZE = 64
|
85 |
+
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
86 |
+
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
87 |
+
device = torch.device("cuda")
|
88 |
+
# device = torch.device("cpu")
|
89 |
+
|
90 |
+
import json
|
91 |
+
|
92 |
+
result = calculate_lpips(videos1, videos2, device)
|
93 |
+
print(json.dumps(result, indent=4))
|
94 |
+
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
main()
|
eval/pab/commom_metrics/calculate_psnr.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
|
7 |
+
def img_psnr(img1, img2):
|
8 |
+
# [0,1]
|
9 |
+
# compute mse
|
10 |
+
# mse = np.mean((img1-img2)**2)
|
11 |
+
mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
|
12 |
+
# compute psnr
|
13 |
+
if mse < 1e-10:
|
14 |
+
return 100
|
15 |
+
psnr = 20 * math.log10(1 / math.sqrt(mse))
|
16 |
+
return psnr
|
17 |
+
|
18 |
+
|
19 |
+
def trans(x):
|
20 |
+
return x
|
21 |
+
|
22 |
+
|
23 |
+
def calculate_psnr(videos1, videos2):
|
24 |
+
# videos [batch_size, timestamps, channel, h, w]
|
25 |
+
|
26 |
+
assert videos1.shape == videos2.shape
|
27 |
+
|
28 |
+
videos1 = trans(videos1)
|
29 |
+
videos2 = trans(videos2)
|
30 |
+
|
31 |
+
psnr_results = []
|
32 |
+
|
33 |
+
for video_num in range(videos1.shape[0]):
|
34 |
+
# get a video
|
35 |
+
# video [timestamps, channel, h, w]
|
36 |
+
video1 = videos1[video_num]
|
37 |
+
video2 = videos2[video_num]
|
38 |
+
|
39 |
+
psnr_results_of_a_video = []
|
40 |
+
for clip_timestamp in range(len(video1)):
|
41 |
+
# get a img
|
42 |
+
# img [timestamps[x], channel, h, w]
|
43 |
+
# img [channel, h, w] numpy
|
44 |
+
|
45 |
+
img1 = video1[clip_timestamp].numpy()
|
46 |
+
img2 = video2[clip_timestamp].numpy()
|
47 |
+
|
48 |
+
# calculate psnr of a video
|
49 |
+
psnr_results_of_a_video.append(img_psnr(img1, img2))
|
50 |
+
|
51 |
+
psnr_results.append(psnr_results_of_a_video)
|
52 |
+
|
53 |
+
psnr_results = np.array(psnr_results)
|
54 |
+
|
55 |
+
psnr = {}
|
56 |
+
psnr_std = {}
|
57 |
+
|
58 |
+
for clip_timestamp in range(len(video1)):
|
59 |
+
psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp])
|
60 |
+
psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp])
|
61 |
+
|
62 |
+
result = {
|
63 |
+
"value": psnr,
|
64 |
+
"value_std": psnr_std,
|
65 |
+
"video_setting": video1.shape,
|
66 |
+
"video_setting_name": "time, channel, heigth, width",
|
67 |
+
}
|
68 |
+
|
69 |
+
return result
|
70 |
+
|
71 |
+
|
72 |
+
# test code / using example
|
73 |
+
|
74 |
+
|
75 |
+
def main():
|
76 |
+
NUMBER_OF_VIDEOS = 8
|
77 |
+
VIDEO_LENGTH = 50
|
78 |
+
CHANNEL = 3
|
79 |
+
SIZE = 64
|
80 |
+
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
81 |
+
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
82 |
+
|
83 |
+
import json
|
84 |
+
|
85 |
+
result = calculate_psnr(videos1, videos2)
|
86 |
+
print(json.dumps(result, indent=4))
|
87 |
+
|
88 |
+
|
89 |
+
if __name__ == "__main__":
|
90 |
+
main()
|
eval/pab/commom_metrics/calculate_ssim.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import cv2
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def ssim(img1, img2):
|
7 |
+
C1 = 0.01**2
|
8 |
+
C2 = 0.03**2
|
9 |
+
img1 = img1.astype(np.float64)
|
10 |
+
img2 = img2.astype(np.float64)
|
11 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
12 |
+
window = np.outer(kernel, kernel.transpose())
|
13 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
14 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
15 |
+
mu1_sq = mu1**2
|
16 |
+
mu2_sq = mu2**2
|
17 |
+
mu1_mu2 = mu1 * mu2
|
18 |
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
19 |
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
20 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
21 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
|
22 |
+
return ssim_map.mean()
|
23 |
+
|
24 |
+
|
25 |
+
def calculate_ssim_function(img1, img2):
|
26 |
+
# [0,1]
|
27 |
+
# ssim is the only metric extremely sensitive to gray being compared to b/w
|
28 |
+
if not img1.shape == img2.shape:
|
29 |
+
raise ValueError("Input images must have the same dimensions.")
|
30 |
+
if img1.ndim == 2:
|
31 |
+
return ssim(img1, img2)
|
32 |
+
elif img1.ndim == 3:
|
33 |
+
if img1.shape[0] == 3:
|
34 |
+
ssims = []
|
35 |
+
for i in range(3):
|
36 |
+
ssims.append(ssim(img1[i], img2[i]))
|
37 |
+
return np.array(ssims).mean()
|
38 |
+
elif img1.shape[0] == 1:
|
39 |
+
return ssim(np.squeeze(img1), np.squeeze(img2))
|
40 |
+
else:
|
41 |
+
raise ValueError("Wrong input image dimensions.")
|
42 |
+
|
43 |
+
|
44 |
+
def trans(x):
|
45 |
+
return x
|
46 |
+
|
47 |
+
|
48 |
+
def calculate_ssim(videos1, videos2):
|
49 |
+
# videos [batch_size, timestamps, channel, h, w]
|
50 |
+
|
51 |
+
assert videos1.shape == videos2.shape
|
52 |
+
|
53 |
+
videos1 = trans(videos1)
|
54 |
+
videos2 = trans(videos2)
|
55 |
+
|
56 |
+
ssim_results = []
|
57 |
+
|
58 |
+
for video_num in range(videos1.shape[0]):
|
59 |
+
# get a video
|
60 |
+
# video [timestamps, channel, h, w]
|
61 |
+
video1 = videos1[video_num]
|
62 |
+
video2 = videos2[video_num]
|
63 |
+
|
64 |
+
ssim_results_of_a_video = []
|
65 |
+
for clip_timestamp in range(len(video1)):
|
66 |
+
# get a img
|
67 |
+
# img [timestamps[x], channel, h, w]
|
68 |
+
# img [channel, h, w] numpy
|
69 |
+
|
70 |
+
img1 = video1[clip_timestamp].numpy()
|
71 |
+
img2 = video2[clip_timestamp].numpy()
|
72 |
+
|
73 |
+
# calculate ssim of a video
|
74 |
+
ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
|
75 |
+
|
76 |
+
ssim_results.append(ssim_results_of_a_video)
|
77 |
+
|
78 |
+
ssim_results = np.array(ssim_results)
|
79 |
+
|
80 |
+
ssim = {}
|
81 |
+
ssim_std = {}
|
82 |
+
|
83 |
+
for clip_timestamp in range(len(video1)):
|
84 |
+
ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp])
|
85 |
+
ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp])
|
86 |
+
|
87 |
+
result = {
|
88 |
+
"value": ssim,
|
89 |
+
"value_std": ssim_std,
|
90 |
+
"video_setting": video1.shape,
|
91 |
+
"video_setting_name": "time, channel, heigth, width",
|
92 |
+
}
|
93 |
+
|
94 |
+
return result
|
95 |
+
|
96 |
+
|
97 |
+
# test code / using example
|
98 |
+
|
99 |
+
|
100 |
+
def main():
|
101 |
+
NUMBER_OF_VIDEOS = 8
|
102 |
+
VIDEO_LENGTH = 50
|
103 |
+
CHANNEL = 3
|
104 |
+
SIZE = 64
|
105 |
+
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
106 |
+
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
107 |
+
torch.device("cuda")
|
108 |
+
|
109 |
+
import json
|
110 |
+
|
111 |
+
result = calculate_ssim(videos1, videos2)
|
112 |
+
print(json.dumps(result, indent=4))
|
113 |
+
|
114 |
+
|
115 |
+
if __name__ == "__main__":
|
116 |
+
main()
|
eval/pab/commom_metrics/eval.py
ADDED
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import os
|
3 |
+
|
4 |
+
import imageio
|
5 |
+
import torch
|
6 |
+
import torchvision.transforms.functional as F
|
7 |
+
import tqdm
|
8 |
+
from calculate_lpips import calculate_lpips
|
9 |
+
from calculate_psnr import calculate_psnr
|
10 |
+
from calculate_ssim import calculate_ssim
|
11 |
+
|
12 |
+
|
13 |
+
def load_videos(directory, video_ids, file_extension):
|
14 |
+
videos = []
|
15 |
+
for video_id in video_ids:
|
16 |
+
video_path = os.path.join(directory, f"{video_id}.{file_extension}")
|
17 |
+
if os.path.exists(video_path):
|
18 |
+
video = load_video(video_path) # Define load_video based on how videos are stored
|
19 |
+
videos.append(video)
|
20 |
+
else:
|
21 |
+
raise ValueError(f"Video {video_id}.{file_extension} not found in {directory}")
|
22 |
+
return videos
|
23 |
+
|
24 |
+
|
25 |
+
def load_video(video_path):
|
26 |
+
"""
|
27 |
+
Load a video from the given path and convert it to a PyTorch tensor.
|
28 |
+
"""
|
29 |
+
# Read the video using imageio
|
30 |
+
reader = imageio.get_reader(video_path, "ffmpeg")
|
31 |
+
|
32 |
+
# Extract frames and convert to a list of tensors
|
33 |
+
frames = []
|
34 |
+
for frame in reader:
|
35 |
+
# Convert the frame to a tensor and permute the dimensions to match (C, H, W)
|
36 |
+
frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
|
37 |
+
frames.append(frame_tensor)
|
38 |
+
|
39 |
+
# Stack the list of tensors into a single tensor with shape (T, C, H, W)
|
40 |
+
video_tensor = torch.stack(frames)
|
41 |
+
|
42 |
+
return video_tensor
|
43 |
+
|
44 |
+
|
45 |
+
def resize_video(video, target_height, target_width):
|
46 |
+
resized_frames = []
|
47 |
+
for frame in video:
|
48 |
+
resized_frame = F.resize(frame, [target_height, target_width])
|
49 |
+
resized_frames.append(resized_frame)
|
50 |
+
return torch.stack(resized_frames)
|
51 |
+
|
52 |
+
|
53 |
+
def preprocess_eval_video(eval_video, generated_video_shape):
|
54 |
+
T_gen, _, H_gen, W_gen = generated_video_shape
|
55 |
+
T_eval, _, H_eval, W_eval = eval_video.shape
|
56 |
+
|
57 |
+
if T_eval < T_gen:
|
58 |
+
raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
|
59 |
+
|
60 |
+
if H_eval < H_gen or W_eval < W_gen:
|
61 |
+
# Resize the video maintaining the aspect ratio
|
62 |
+
resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
|
63 |
+
resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
|
64 |
+
eval_video = resize_video(eval_video, resize_height, resize_width)
|
65 |
+
# Recalculate the dimensions
|
66 |
+
T_eval, _, H_eval, W_eval = eval_video.shape
|
67 |
+
|
68 |
+
# Center crop
|
69 |
+
start_h = (H_eval - H_gen) // 2
|
70 |
+
start_w = (W_eval - W_gen) // 2
|
71 |
+
cropped_video = eval_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
|
72 |
+
|
73 |
+
return cropped_video
|
74 |
+
|
75 |
+
|
76 |
+
def main(args):
|
77 |
+
device = "cuda"
|
78 |
+
gt_video_dir = args.gt_video_dir
|
79 |
+
generated_video_dir = args.generated_video_dir
|
80 |
+
|
81 |
+
video_ids = []
|
82 |
+
file_extension = "mp4"
|
83 |
+
for f in os.listdir(generated_video_dir):
|
84 |
+
if f.endswith(f".{file_extension}"):
|
85 |
+
video_ids.append(f.replace(f".{file_extension}", ""))
|
86 |
+
if not video_ids:
|
87 |
+
raise ValueError("No videos found in the generated video dataset. Exiting.")
|
88 |
+
|
89 |
+
print(f"Find {len(video_ids)} videos")
|
90 |
+
prompt_interval = 1
|
91 |
+
batch_size = 16
|
92 |
+
calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
|
93 |
+
|
94 |
+
lpips_results = []
|
95 |
+
psnr_results = []
|
96 |
+
ssim_results = []
|
97 |
+
|
98 |
+
total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
|
99 |
+
|
100 |
+
for idx, video_id in enumerate(tqdm.tqdm(range(total_len))):
|
101 |
+
gt_videos_tensor = []
|
102 |
+
generated_videos_tensor = []
|
103 |
+
for i in range(batch_size):
|
104 |
+
video_idx = idx * batch_size + i
|
105 |
+
if video_idx >= len(video_ids):
|
106 |
+
break
|
107 |
+
video_id = video_ids[video_idx]
|
108 |
+
generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.{file_extension}"))
|
109 |
+
generated_videos_tensor.append(generated_video)
|
110 |
+
eval_video = load_video(os.path.join(gt_video_dir, f"{video_id}.{file_extension}"))
|
111 |
+
gt_videos_tensor.append(eval_video)
|
112 |
+
gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
|
113 |
+
generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
|
114 |
+
|
115 |
+
if calculate_lpips_flag:
|
116 |
+
result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
|
117 |
+
result = result["value"].values()
|
118 |
+
result = sum(result) / len(result)
|
119 |
+
lpips_results.append(result)
|
120 |
+
|
121 |
+
if calculate_psnr_flag:
|
122 |
+
result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
|
123 |
+
result = result["value"].values()
|
124 |
+
result = sum(result) / len(result)
|
125 |
+
psnr_results.append(result)
|
126 |
+
|
127 |
+
if calculate_ssim_flag:
|
128 |
+
result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
|
129 |
+
result = result["value"].values()
|
130 |
+
result = sum(result) / len(result)
|
131 |
+
ssim_results.append(result)
|
132 |
+
|
133 |
+
if (idx + 1) % prompt_interval == 0:
|
134 |
+
out_str = ""
|
135 |
+
for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
|
136 |
+
result = sum(results) / len(results)
|
137 |
+
out_str += f"{name}: {result:.4f}, "
|
138 |
+
print(f"Processed {idx + 1} videos. {out_str[:-2]}")
|
139 |
+
|
140 |
+
out_str = ""
|
141 |
+
for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
|
142 |
+
result = sum(results) / len(results)
|
143 |
+
out_str += f"{name}: {result:.4f}, "
|
144 |
+
out_str = out_str[:-2]
|
145 |
+
|
146 |
+
# save
|
147 |
+
with open(f"./{os.path.basename(generated_video_dir)}.txt", "w+") as f:
|
148 |
+
f.write(out_str)
|
149 |
+
|
150 |
+
print(f"Processed all videos. {out_str}")
|
151 |
+
|
152 |
+
|
153 |
+
if __name__ == "__main__":
|
154 |
+
parser = argparse.ArgumentParser()
|
155 |
+
parser.add_argument("--gt_video_dir", type=str)
|
156 |
+
parser.add_argument("--generated_video_dir", type=str)
|
157 |
+
|
158 |
+
args = parser.parse_args()
|
159 |
+
|
160 |
+
main(args)
|
eval/pab/experiments/__init__.py
ADDED
File without changes
|
eval/pab/experiments/attention_ablation.py
ADDED
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import generate_func, read_prompt_list
|
2 |
+
|
3 |
+
import videosys
|
4 |
+
from videosys import OpenSoraConfig, OpenSoraPipeline
|
5 |
+
from videosys.models.open_sora import OpenSoraPABConfig
|
6 |
+
|
7 |
+
|
8 |
+
def attention_ablation_func(pab_kwargs, prompt_list, output_dir):
|
9 |
+
pab_config = OpenSoraPABConfig(**pab_kwargs)
|
10 |
+
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
11 |
+
pipeline = OpenSoraPipeline(config)
|
12 |
+
|
13 |
+
generate_func(pipeline, prompt_list, output_dir)
|
14 |
+
|
15 |
+
|
16 |
+
def main(prompt_list):
|
17 |
+
# spatial
|
18 |
+
gap_list = [2, 3, 4, 5]
|
19 |
+
for gap in gap_list:
|
20 |
+
pab_kwargs = {
|
21 |
+
"spatial_broadcast": True,
|
22 |
+
"spatial_gap": gap,
|
23 |
+
"temporal_broadcast": False,
|
24 |
+
"cross_broadcast": False,
|
25 |
+
"mlp_skip": False,
|
26 |
+
}
|
27 |
+
output_dir = f"./samples/attention_ablation/spatial_g{gap}"
|
28 |
+
attention_ablation_func(pab_kwargs, prompt_list, output_dir)
|
29 |
+
|
30 |
+
# temporal
|
31 |
+
gap_list = [3, 4, 5, 6]
|
32 |
+
for gap in gap_list:
|
33 |
+
pab_kwargs = {
|
34 |
+
"spatial_broadcast": False,
|
35 |
+
"temporal_broadcast": True,
|
36 |
+
"temporal_gap": gap,
|
37 |
+
"cross_broadcast": False,
|
38 |
+
"mlp_skip": False,
|
39 |
+
}
|
40 |
+
output_dir = f"./samples/attention_ablation/temporal_g{gap}"
|
41 |
+
attention_ablation_func(pab_kwargs, prompt_list, output_dir)
|
42 |
+
|
43 |
+
# cross
|
44 |
+
gap_list = [5, 6, 7, 8]
|
45 |
+
for gap in gap_list:
|
46 |
+
pab_kwargs = {
|
47 |
+
"spatial_broadcast": False,
|
48 |
+
"temporal_broadcast": False,
|
49 |
+
"cross_broadcast": True,
|
50 |
+
"cross_gap": gap,
|
51 |
+
"mlp_skip": False,
|
52 |
+
}
|
53 |
+
output_dir = f"./samples/attention_ablation/cross_g{gap}"
|
54 |
+
attention_ablation_func(pab_kwargs, prompt_list, output_dir)
|
55 |
+
|
56 |
+
|
57 |
+
if __name__ == "__main__":
|
58 |
+
videosys.initialize(42)
|
59 |
+
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
60 |
+
main(prompt_list)
|
eval/pab/experiments/components_ablation.py
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import generate_func, read_prompt_list
|
2 |
+
|
3 |
+
import videosys
|
4 |
+
from videosys import OpenSoraConfig, OpenSoraPipeline
|
5 |
+
from videosys.models.open_sora import OpenSoraPABConfig
|
6 |
+
|
7 |
+
|
8 |
+
def wo_spatial(prompt_list):
|
9 |
+
pab_config = OpenSoraPABConfig(spatial_broadcast=False)
|
10 |
+
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
11 |
+
pipeline = OpenSoraPipeline(config)
|
12 |
+
|
13 |
+
generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_spatial")
|
14 |
+
|
15 |
+
|
16 |
+
def wo_temporal(prompt_list):
|
17 |
+
pab_config = OpenSoraPABConfig(temporal_broadcast=False)
|
18 |
+
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
19 |
+
pipeline = OpenSoraPipeline(config)
|
20 |
+
|
21 |
+
generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_temporal")
|
22 |
+
|
23 |
+
|
24 |
+
def wo_cross(prompt_list):
|
25 |
+
pab_config = OpenSoraPABConfig(cross_broadcast=False)
|
26 |
+
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
27 |
+
pipeline = OpenSoraPipeline(config)
|
28 |
+
|
29 |
+
generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_cross")
|
30 |
+
|
31 |
+
|
32 |
+
def wo_mlp(prompt_list):
|
33 |
+
pab_config = OpenSoraPABConfig(mlp_skip=False)
|
34 |
+
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
35 |
+
pipeline = OpenSoraPipeline(config)
|
36 |
+
|
37 |
+
generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_mlp")
|
38 |
+
|
39 |
+
|
40 |
+
if __name__ == "__main__":
|
41 |
+
videosys.initialize(42)
|
42 |
+
prompt_list = read_prompt_list("./vbench/VBench_full_info.json")
|
43 |
+
wo_spatial(prompt_list)
|
44 |
+
wo_temporal(prompt_list)
|
45 |
+
wo_cross(prompt_list)
|
46 |
+
wo_mlp(prompt_list)
|
eval/pab/experiments/latte.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import generate_func, read_prompt_list
|
2 |
+
|
3 |
+
import videosys
|
4 |
+
from videosys import LatteConfig, LattePipeline
|
5 |
+
from videosys.models.latte import LattePABConfig
|
6 |
+
|
7 |
+
|
8 |
+
def eval_base(prompt_list):
|
9 |
+
config = LatteConfig()
|
10 |
+
pipeline = LattePipeline(config)
|
11 |
+
|
12 |
+
generate_func(pipeline, prompt_list, "./samples/latte_base", loop=5)
|
13 |
+
|
14 |
+
|
15 |
+
def eval_pab1(prompt_list):
|
16 |
+
pab_config = LattePABConfig(
|
17 |
+
spatial_gap=2,
|
18 |
+
temporal_gap=3,
|
19 |
+
cross_gap=6,
|
20 |
+
)
|
21 |
+
config = LatteConfig(enable_pab=True, pab_config=pab_config)
|
22 |
+
pipeline = LattePipeline(config)
|
23 |
+
|
24 |
+
generate_func(pipeline, prompt_list, "./samples/latte_pab1", loop=5)
|
25 |
+
|
26 |
+
|
27 |
+
def eval_pab2(prompt_list):
|
28 |
+
pab_config = LattePABConfig(
|
29 |
+
spatial_gap=3,
|
30 |
+
temporal_gap=4,
|
31 |
+
cross_gap=7,
|
32 |
+
)
|
33 |
+
config = LatteConfig(enable_pab=True, pab_config=pab_config)
|
34 |
+
pipeline = LattePipeline(config)
|
35 |
+
|
36 |
+
generate_func(pipeline, prompt_list, "./samples/latte_pab2", loop=5)
|
37 |
+
|
38 |
+
|
39 |
+
def eval_pab3(prompt_list):
|
40 |
+
pab_config = LattePABConfig(
|
41 |
+
spatial_gap=4,
|
42 |
+
temporal_gap=6,
|
43 |
+
cross_gap=9,
|
44 |
+
)
|
45 |
+
config = LatteConfig(enable_pab=True, pab_config=pab_config)
|
46 |
+
pipeline = LattePipeline(config)
|
47 |
+
|
48 |
+
generate_func(pipeline, prompt_list, "./samples/latte_pab3", loop=5)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
videosys.initialize(42)
|
53 |
+
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
54 |
+
eval_base(prompt_list)
|
55 |
+
eval_pab1(prompt_list)
|
56 |
+
eval_pab2(prompt_list)
|
57 |
+
eval_pab3(prompt_list)
|
eval/pab/experiments/opensora.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import generate_func, read_prompt_list
|
2 |
+
|
3 |
+
import videosys
|
4 |
+
from videosys import OpenSoraConfig, OpenSoraPipeline
|
5 |
+
from videosys.models.open_sora import OpenSoraPABConfig
|
6 |
+
|
7 |
+
|
8 |
+
def eval_base(prompt_list):
|
9 |
+
config = OpenSoraConfig()
|
10 |
+
pipeline = OpenSoraPipeline(config)
|
11 |
+
|
12 |
+
generate_func(pipeline, prompt_list, "./samples/opensora_base", loop=5)
|
13 |
+
|
14 |
+
|
15 |
+
def eval_pab1(prompt_list):
|
16 |
+
config = OpenSoraConfig(enable_pab=True)
|
17 |
+
pipeline = OpenSoraPipeline(config)
|
18 |
+
|
19 |
+
generate_func(pipeline, prompt_list, "./samples/opensora_pab1", loop=5)
|
20 |
+
|
21 |
+
|
22 |
+
def eval_pab2(prompt_list):
|
23 |
+
pab_config = OpenSoraPABConfig(spatial_gap=3, temporal_gap=5, cross_gap=7)
|
24 |
+
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
25 |
+
pipeline = OpenSoraPipeline(config)
|
26 |
+
|
27 |
+
generate_func(pipeline, prompt_list, "./samples/opensora_pab2", loop=5)
|
28 |
+
|
29 |
+
|
30 |
+
def eval_pab3(prompt_list):
|
31 |
+
pab_config = OpenSoraPABConfig(spatial_gap=5, temporal_gap=7, cross_gap=9)
|
32 |
+
config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
|
33 |
+
pipeline = OpenSoraPipeline(config)
|
34 |
+
|
35 |
+
generate_func(pipeline, prompt_list, "./samples/opensora_pab3", loop=5)
|
36 |
+
|
37 |
+
|
38 |
+
if __name__ == "__main__":
|
39 |
+
videosys.initialize(42)
|
40 |
+
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
41 |
+
eval_base(prompt_list)
|
42 |
+
eval_pab1(prompt_list)
|
43 |
+
eval_pab2(prompt_list)
|
44 |
+
eval_pab3(prompt_list)
|
eval/pab/experiments/opensora_plan.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from utils import generate_func, read_prompt_list
|
2 |
+
|
3 |
+
import videosys
|
4 |
+
from videosys import OpenSoraPlanConfig, OpenSoraPlanPipeline
|
5 |
+
from videosys.models.open_sora_plan import OpenSoraPlanPABConfig
|
6 |
+
|
7 |
+
|
8 |
+
def eval_base(prompt_list):
|
9 |
+
config = OpenSoraPlanConfig()
|
10 |
+
pipeline = OpenSoraPlanPipeline(config)
|
11 |
+
|
12 |
+
generate_func(pipeline, prompt_list, "./samples/opensoraplan_base", loop=5)
|
13 |
+
|
14 |
+
|
15 |
+
def eval_pab1(prompt_list):
|
16 |
+
pab_config = OpenSoraPlanPABConfig(
|
17 |
+
spatial_gap=2,
|
18 |
+
temporal_gap=4,
|
19 |
+
cross_gap=6,
|
20 |
+
)
|
21 |
+
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
|
22 |
+
pipeline = OpenSoraPlanPipeline(config)
|
23 |
+
|
24 |
+
generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab1", loop=5)
|
25 |
+
|
26 |
+
|
27 |
+
def eval_pab2(prompt_list):
|
28 |
+
pab_config = OpenSoraPlanPABConfig(
|
29 |
+
spatial_gap=3,
|
30 |
+
temporal_gap=5,
|
31 |
+
cross_gap=7,
|
32 |
+
)
|
33 |
+
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
|
34 |
+
pipeline = OpenSoraPlanPipeline(config)
|
35 |
+
|
36 |
+
generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab2", loop=5)
|
37 |
+
|
38 |
+
|
39 |
+
def eval_pab3(prompt_list):
|
40 |
+
pab_config = OpenSoraPlanPABConfig(
|
41 |
+
spatial_gap=5,
|
42 |
+
temporal_gap=7,
|
43 |
+
cross_gap=9,
|
44 |
+
)
|
45 |
+
config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
|
46 |
+
pipeline = OpenSoraPlanPipeline(config)
|
47 |
+
|
48 |
+
generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab3", loop=5)
|
49 |
+
|
50 |
+
|
51 |
+
if __name__ == "__main__":
|
52 |
+
videosys.initialize(42)
|
53 |
+
prompt_list = read_prompt_list("vbench/VBench_full_info.json")
|
54 |
+
eval_base(prompt_list)
|
55 |
+
eval_pab1(prompt_list)
|
56 |
+
eval_pab2(prompt_list)
|
57 |
+
eval_pab3(prompt_list)
|
eval/pab/experiments/utils.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
|
4 |
+
import tqdm
|
5 |
+
|
6 |
+
from videosys.utils.utils import set_seed
|
7 |
+
|
8 |
+
|
9 |
+
def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict = {}):
|
10 |
+
kwargs["verbose"] = False
|
11 |
+
for prompt in tqdm.tqdm(prompt_list):
|
12 |
+
for l in range(loop):
|
13 |
+
set_seed(l)
|
14 |
+
video = pipeline.generate(prompt, **kwargs).video[0]
|
15 |
+
pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4"))
|
16 |
+
|
17 |
+
|
18 |
+
def read_prompt_list(prompt_list_path):
|
19 |
+
with open(prompt_list_path, "r") as f:
|
20 |
+
prompt_list = json.load(f)
|
21 |
+
prompt_list = [prompt["prompt_en"] for prompt in prompt_list]
|
22 |
+
return prompt_list
|
eval/pab/vbench/VBench_full_info.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
eval/pab/vbench/cal_vbench.py
ADDED
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
|
5 |
+
SEMANTIC_WEIGHT = 1
|
6 |
+
QUALITY_WEIGHT = 4
|
7 |
+
|
8 |
+
QUALITY_LIST = [
|
9 |
+
"subject consistency",
|
10 |
+
"background consistency",
|
11 |
+
"temporal flickering",
|
12 |
+
"motion smoothness",
|
13 |
+
"aesthetic quality",
|
14 |
+
"imaging quality",
|
15 |
+
"dynamic degree",
|
16 |
+
]
|
17 |
+
|
18 |
+
SEMANTIC_LIST = [
|
19 |
+
"object class",
|
20 |
+
"multiple objects",
|
21 |
+
"human action",
|
22 |
+
"color",
|
23 |
+
"spatial relationship",
|
24 |
+
"scene",
|
25 |
+
"appearance style",
|
26 |
+
"temporal style",
|
27 |
+
"overall consistency",
|
28 |
+
]
|
29 |
+
|
30 |
+
NORMALIZE_DIC = {
|
31 |
+
"subject consistency": {"Min": 0.1462, "Max": 1.0},
|
32 |
+
"background consistency": {"Min": 0.2615, "Max": 1.0},
|
33 |
+
"temporal flickering": {"Min": 0.6293, "Max": 1.0},
|
34 |
+
"motion smoothness": {"Min": 0.706, "Max": 0.9975},
|
35 |
+
"dynamic degree": {"Min": 0.0, "Max": 1.0},
|
36 |
+
"aesthetic quality": {"Min": 0.0, "Max": 1.0},
|
37 |
+
"imaging quality": {"Min": 0.0, "Max": 1.0},
|
38 |
+
"object class": {"Min": 0.0, "Max": 1.0},
|
39 |
+
"multiple objects": {"Min": 0.0, "Max": 1.0},
|
40 |
+
"human action": {"Min": 0.0, "Max": 1.0},
|
41 |
+
"color": {"Min": 0.0, "Max": 1.0},
|
42 |
+
"spatial relationship": {"Min": 0.0, "Max": 1.0},
|
43 |
+
"scene": {"Min": 0.0, "Max": 0.8222},
|
44 |
+
"appearance style": {"Min": 0.0009, "Max": 0.2855},
|
45 |
+
"temporal style": {"Min": 0.0, "Max": 0.364},
|
46 |
+
"overall consistency": {"Min": 0.0, "Max": 0.364},
|
47 |
+
}
|
48 |
+
|
49 |
+
DIM_WEIGHT = {
|
50 |
+
"subject consistency": 1,
|
51 |
+
"background consistency": 1,
|
52 |
+
"temporal flickering": 1,
|
53 |
+
"motion smoothness": 1,
|
54 |
+
"aesthetic quality": 1,
|
55 |
+
"imaging quality": 1,
|
56 |
+
"dynamic degree": 0.5,
|
57 |
+
"object class": 1,
|
58 |
+
"multiple objects": 1,
|
59 |
+
"human action": 1,
|
60 |
+
"color": 1,
|
61 |
+
"spatial relationship": 1,
|
62 |
+
"scene": 1,
|
63 |
+
"appearance style": 1,
|
64 |
+
"temporal style": 1,
|
65 |
+
"overall consistency": 1,
|
66 |
+
}
|
67 |
+
|
68 |
+
ordered_scaled_res = [
|
69 |
+
"total score",
|
70 |
+
"quality score",
|
71 |
+
"semantic score",
|
72 |
+
"subject consistency",
|
73 |
+
"background consistency",
|
74 |
+
"temporal flickering",
|
75 |
+
"motion smoothness",
|
76 |
+
"dynamic degree",
|
77 |
+
"aesthetic quality",
|
78 |
+
"imaging quality",
|
79 |
+
"object class",
|
80 |
+
"multiple objects",
|
81 |
+
"human action",
|
82 |
+
"color",
|
83 |
+
"spatial relationship",
|
84 |
+
"scene",
|
85 |
+
"appearance style",
|
86 |
+
"temporal style",
|
87 |
+
"overall consistency",
|
88 |
+
]
|
89 |
+
|
90 |
+
|
91 |
+
def parse_args():
|
92 |
+
parser = argparse.ArgumentParser()
|
93 |
+
parser.add_argument("--score_dir", required=True, type=str)
|
94 |
+
args = parser.parse_args()
|
95 |
+
return args
|
96 |
+
|
97 |
+
|
98 |
+
if __name__ == "__main__":
|
99 |
+
args = parse_args()
|
100 |
+
res_postfix = "_eval_results.json"
|
101 |
+
info_postfix = "_full_info.json"
|
102 |
+
files = os.listdir(args.score_dir)
|
103 |
+
res_files = [x for x in files if res_postfix in x]
|
104 |
+
info_files = [x for x in files if info_postfix in x]
|
105 |
+
assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files"
|
106 |
+
|
107 |
+
full_results = {}
|
108 |
+
for res_file in res_files:
|
109 |
+
# first check if results is normal
|
110 |
+
info_file = res_file.split(res_postfix)[0] + info_postfix
|
111 |
+
with open(os.path.join(args.score_dir, info_file), "r", encoding="utf-8") as f:
|
112 |
+
info = json.load(f)
|
113 |
+
assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list"
|
114 |
+
# read results
|
115 |
+
with open(os.path.join(args.score_dir, res_file), "r", encoding="utf-8") as f:
|
116 |
+
data = json.load(f)
|
117 |
+
for key, val in data.items():
|
118 |
+
full_results[key] = format(val[0], ".4f")
|
119 |
+
|
120 |
+
scaled_results = {}
|
121 |
+
dims = set()
|
122 |
+
for key, val in full_results.items():
|
123 |
+
dim = key.replace("_", " ") if "_" in key else key
|
124 |
+
scaled_score = (float(val) - NORMALIZE_DIC[dim]["Min"]) / (
|
125 |
+
NORMALIZE_DIC[dim]["Max"] - NORMALIZE_DIC[dim]["Min"]
|
126 |
+
)
|
127 |
+
scaled_score *= DIM_WEIGHT[dim]
|
128 |
+
scaled_results[dim] = scaled_score
|
129 |
+
dims.add(dim)
|
130 |
+
|
131 |
+
assert len(dims) == len(NORMALIZE_DIC), f"{set(NORMALIZE_DIC.keys())-dims} not calculated yet"
|
132 |
+
|
133 |
+
quality_score = sum([scaled_results[i] for i in QUALITY_LIST]) / sum([DIM_WEIGHT[i] for i in QUALITY_LIST])
|
134 |
+
semantic_score = sum([scaled_results[i] for i in SEMANTIC_LIST]) / sum([DIM_WEIGHT[i] for i in SEMANTIC_LIST])
|
135 |
+
scaled_results["quality score"] = quality_score
|
136 |
+
scaled_results["semantic score"] = semantic_score
|
137 |
+
scaled_results["total score"] = (quality_score * QUALITY_WEIGHT + semantic_score * SEMANTIC_WEIGHT) / (
|
138 |
+
QUALITY_WEIGHT + SEMANTIC_WEIGHT
|
139 |
+
)
|
140 |
+
|
141 |
+
formated_scaled_results = {"items": []}
|
142 |
+
for key in ordered_scaled_res:
|
143 |
+
formated_score = format(scaled_results[key] * 100, ".2f") + "%"
|
144 |
+
formated_scaled_results["items"].append({key: formated_score})
|
145 |
+
|
146 |
+
output_file_path = os.path.join(args.score_dir, "all_results.json")
|
147 |
+
with open(output_file_path, "w") as outfile:
|
148 |
+
json.dump(full_results, outfile, indent=4, sort_keys=True)
|
149 |
+
print(f"results saved to: {output_file_path}")
|
150 |
+
|
151 |
+
scaled_file_path = os.path.join(args.score_dir, "scaled_results.json")
|
152 |
+
with open(scaled_file_path, "w") as outfile:
|
153 |
+
json.dump(formated_scaled_results, outfile, indent=4, sort_keys=True)
|
154 |
+
print(f"results saved to: {scaled_file_path}")
|
eval/pab/vbench/run_vbench.py
ADDED
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from vbench import VBench
|
5 |
+
|
6 |
+
full_info_path = "./vbench/VBench_full_info.json"
|
7 |
+
|
8 |
+
dimensions = [
|
9 |
+
"subject_consistency",
|
10 |
+
"imaging_quality",
|
11 |
+
"background_consistency",
|
12 |
+
"motion_smoothness",
|
13 |
+
"overall_consistency",
|
14 |
+
"human_action",
|
15 |
+
"multiple_objects",
|
16 |
+
"spatial_relationship",
|
17 |
+
"object_class",
|
18 |
+
"color",
|
19 |
+
"aesthetic_quality",
|
20 |
+
"appearance_style",
|
21 |
+
"temporal_flickering",
|
22 |
+
"scene",
|
23 |
+
"temporal_style",
|
24 |
+
"dynamic_degree",
|
25 |
+
]
|
26 |
+
|
27 |
+
|
28 |
+
def parse_args():
|
29 |
+
parser = argparse.ArgumentParser()
|
30 |
+
parser.add_argument("--video_path", required=True, type=str)
|
31 |
+
args = parser.parse_args()
|
32 |
+
return args
|
33 |
+
|
34 |
+
|
35 |
+
if __name__ == "__main__":
|
36 |
+
args = parse_args()
|
37 |
+
save_path = args.video_path.replace("/samples/", "/vbench_out/")
|
38 |
+
|
39 |
+
kwargs = {}
|
40 |
+
kwargs["imaging_quality_preprocessing_mode"] = "longer" # use VBench/evaluate.py default
|
41 |
+
|
42 |
+
for dimension in dimensions:
|
43 |
+
my_VBench = VBench(torch.device("cuda"), full_info_path, save_path)
|
44 |
+
my_VBench.evaluate(
|
45 |
+
videos_path=args.video_path,
|
46 |
+
name=dimension,
|
47 |
+
local=False,
|
48 |
+
read_frame=False,
|
49 |
+
dimension_list=[dimension],
|
50 |
+
mode="vbench_standard",
|
51 |
+
**kwargs,
|
52 |
+
)
|
examples/cogvideo/sample.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from videosys import CogVideoConfig, VideoSysEngine
|
2 |
+
|
3 |
+
|
4 |
+
def run_base():
|
5 |
+
config = CogVideoConfig(world_size=1)
|
6 |
+
engine = VideoSysEngine(config)
|
7 |
+
|
8 |
+
prompt = "Sunset over the sea."
|
9 |
+
video = engine.generate(prompt).video[0]
|
10 |
+
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
11 |
+
|
12 |
+
|
13 |
+
if __name__ == "__main__":
|
14 |
+
run_base()
|
examples/latte/sample.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from videosys import LatteConfig, VideoSysEngine
|
2 |
+
|
3 |
+
|
4 |
+
def run_base():
|
5 |
+
config = LatteConfig(world_size=1)
|
6 |
+
engine = VideoSysEngine(config)
|
7 |
+
|
8 |
+
prompt = "Sunset over the sea."
|
9 |
+
video = engine.generate(prompt).video[0]
|
10 |
+
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
11 |
+
|
12 |
+
|
13 |
+
def run_pab():
|
14 |
+
config = LatteConfig(world_size=1)
|
15 |
+
engine = VideoSysEngine(config)
|
16 |
+
|
17 |
+
prompt = "Sunset over the sea."
|
18 |
+
video = engine.generate(prompt).video[0]
|
19 |
+
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
run_base()
|
24 |
+
# run_pab()
|
examples/open_sora/sample.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from videosys import OpenSoraConfig, VideoSysEngine
|
2 |
+
|
3 |
+
|
4 |
+
def run_base():
|
5 |
+
config = OpenSoraConfig(world_size=1)
|
6 |
+
engine = VideoSysEngine(config)
|
7 |
+
|
8 |
+
prompt = "Sunset over the sea."
|
9 |
+
video = engine.generate(prompt).video[0]
|
10 |
+
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
11 |
+
|
12 |
+
|
13 |
+
def run_pab():
|
14 |
+
config = OpenSoraConfig(world_size=1, enable_pab=True)
|
15 |
+
engine = VideoSysEngine(config)
|
16 |
+
|
17 |
+
prompt = "Sunset over the sea."
|
18 |
+
video = engine.generate(prompt).video[0]
|
19 |
+
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
run_base()
|
24 |
+
run_pab()
|
examples/open_sora_plan/sample.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from videosys import OpenSoraPlanConfig, VideoSysEngine
|
2 |
+
|
3 |
+
|
4 |
+
def run_base():
|
5 |
+
config = OpenSoraPlanConfig(world_size=1)
|
6 |
+
engine = VideoSysEngine(config)
|
7 |
+
|
8 |
+
prompt = "Sunset over the sea."
|
9 |
+
video = engine.generate(prompt).video[0]
|
10 |
+
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
11 |
+
|
12 |
+
|
13 |
+
def run_pab():
|
14 |
+
config = OpenSoraPlanConfig(world_size=1)
|
15 |
+
engine = VideoSysEngine(config)
|
16 |
+
|
17 |
+
prompt = "Sunset over the sea."
|
18 |
+
video = engine.generate(prompt).video[0]
|
19 |
+
engine.save_video(video, f"./outputs/{prompt}.mp4")
|
20 |
+
|
21 |
+
|
22 |
+
if __name__ == "__main__":
|
23 |
+
run_base()
|
24 |
+
# run_pab()
|
requirements.txt
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
gradio
|
2 |
+
click
|
3 |
+
colossalai
|
4 |
+
contexttimer
|
5 |
+
diffusers==0.30.0
|
6 |
+
einops
|
7 |
+
fabric
|
8 |
+
ftfy
|
9 |
+
imageio
|
10 |
+
imageio-ffmpeg
|
11 |
+
matplotlib
|
12 |
+
ninja
|
13 |
+
numpy
|
14 |
+
omegaconf
|
15 |
+
packaging
|
16 |
+
psutil
|
17 |
+
pydantic
|
18 |
+
ray
|
19 |
+
rich
|
20 |
+
safetensors
|
21 |
+
timm
|
22 |
+
torch>=1.13
|
23 |
+
tqdm
|
24 |
+
transformers
|
25 |
+
openai
|
setup.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import List
|
2 |
+
|
3 |
+
from setuptools import find_packages, setup
|
4 |
+
|
5 |
+
|
6 |
+
def fetch_requirements(path) -> List[str]:
|
7 |
+
"""
|
8 |
+
This function reads the requirements file.
|
9 |
+
|
10 |
+
Args:
|
11 |
+
path (str): the path to the requirements file.
|
12 |
+
|
13 |
+
Returns:
|
14 |
+
The lines in the requirements file.
|
15 |
+
"""
|
16 |
+
with open(path, "r") as fd:
|
17 |
+
return [r.strip() for r in fd.readlines()]
|
18 |
+
|
19 |
+
|
20 |
+
def fetch_readme() -> str:
|
21 |
+
"""
|
22 |
+
This function reads the README.md file in the current directory.
|
23 |
+
|
24 |
+
Returns:
|
25 |
+
The lines in the README file.
|
26 |
+
"""
|
27 |
+
with open("README.md", encoding="utf-8") as f:
|
28 |
+
return f.read()
|
29 |
+
|
30 |
+
|
31 |
+
setup(
|
32 |
+
name="videosys",
|
33 |
+
version="2.0.0",
|
34 |
+
packages=find_packages(
|
35 |
+
exclude=(
|
36 |
+
"videos",
|
37 |
+
"tests",
|
38 |
+
"figure",
|
39 |
+
"*.egg-info",
|
40 |
+
)
|
41 |
+
),
|
42 |
+
description="VideoSys",
|
43 |
+
long_description=fetch_readme(),
|
44 |
+
long_description_content_type="text/markdown",
|
45 |
+
license="Apache Software License 2.0",
|
46 |
+
install_requires=fetch_requirements("requirements.txt"),
|
47 |
+
python_requires=">=3.6",
|
48 |
+
classifiers=[
|
49 |
+
"Programming Language :: Python :: 3",
|
50 |
+
"License :: OSI Approved :: Apache Software License",
|
51 |
+
"Environment :: GPU :: NVIDIA CUDA",
|
52 |
+
"Topic :: Scientific/Engineering :: Artificial Intelligence",
|
53 |
+
"Topic :: System :: Distributed Computing",
|
54 |
+
],
|
55 |
+
)
|
tests/__init__.py
ADDED
File without changes
|
videosys/__init__.py
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .core.engine import VideoSysEngine
|
2 |
+
from .core.parallel_mgr import initialize
|
3 |
+
from .models.cogvideo.pipeline import CogVideoConfig, CogVideoPipeline
|
4 |
+
from .models.latte.pipeline import LatteConfig, LattePipeline
|
5 |
+
from .models.open_sora.pipeline import OpenSoraConfig, OpenSoraPipeline
|
6 |
+
from .models.open_sora_plan.pipeline import OpenSoraPlanConfig, OpenSoraPlanPipeline
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
"initialize",
|
10 |
+
"VideoSysEngine",
|
11 |
+
"LattePipeline",
|
12 |
+
"LatteConfig",
|
13 |
+
"OpenSoraPlanPipeline",
|
14 |
+
"OpenSoraPlanConfig",
|
15 |
+
"OpenSoraPipeline",
|
16 |
+
"OpenSoraConfig",
|
17 |
+
"CogVideoConfig",
|
18 |
+
"CogVideoPipeline",
|
19 |
+
]
|
videosys/core/__init__.py
ADDED
File without changes
|
videosys/core/comm.py
ADDED
@@ -0,0 +1,420 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Any, Optional, Tuple
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
import torch.nn.functional as F
|
6 |
+
from einops import rearrange
|
7 |
+
from torch import Tensor
|
8 |
+
from torch.distributed import ProcessGroup
|
9 |
+
|
10 |
+
from videosys.core.parallel_mgr import get_sequence_parallel_size
|
11 |
+
|
12 |
+
# ======================================================
|
13 |
+
# Model
|
14 |
+
# ======================================================
|
15 |
+
|
16 |
+
|
17 |
+
def model_sharding(model: torch.nn.Module):
|
18 |
+
global_rank = dist.get_rank()
|
19 |
+
world_size = dist.get_world_size()
|
20 |
+
for _, param in model.named_parameters():
|
21 |
+
padding_size = (world_size - param.numel() % world_size) % world_size
|
22 |
+
if padding_size > 0:
|
23 |
+
padding_param = torch.nn.functional.pad(param.data.view(-1), [0, padding_size])
|
24 |
+
else:
|
25 |
+
padding_param = param.data.view(-1)
|
26 |
+
splited_params = padding_param.split(padding_param.numel() // world_size)
|
27 |
+
splited_params = splited_params[global_rank]
|
28 |
+
param.data = splited_params
|
29 |
+
|
30 |
+
|
31 |
+
# ======================================================
|
32 |
+
# AllGather & ReduceScatter
|
33 |
+
# ======================================================
|
34 |
+
|
35 |
+
|
36 |
+
class AsyncAllGatherForTwo(torch.autograd.Function):
|
37 |
+
@staticmethod
|
38 |
+
def forward(
|
39 |
+
ctx: Any,
|
40 |
+
inputs: Tensor,
|
41 |
+
weight: Tensor,
|
42 |
+
bias: Tensor,
|
43 |
+
sp_rank: int,
|
44 |
+
sp_size: int,
|
45 |
+
group: Optional[ProcessGroup] = None,
|
46 |
+
) -> Tuple[Tensor, Any]:
|
47 |
+
"""
|
48 |
+
Returns:
|
49 |
+
outputs: Tensor
|
50 |
+
handle: Optional[Work], if overlap is True
|
51 |
+
"""
|
52 |
+
from torch.distributed._functional_collectives import all_gather_tensor
|
53 |
+
|
54 |
+
ctx.group = group
|
55 |
+
ctx.sp_rank = sp_rank
|
56 |
+
ctx.sp_size = sp_size
|
57 |
+
|
58 |
+
# all gather inputs
|
59 |
+
all_inputs = all_gather_tensor(inputs.unsqueeze(0), 0, group)
|
60 |
+
# compute local qkv
|
61 |
+
local_qkv = F.linear(inputs, weight, bias).unsqueeze(0)
|
62 |
+
|
63 |
+
# remote compute
|
64 |
+
remote_inputs = all_inputs[1 - sp_rank].view(list(local_qkv.shape[:-1]) + [-1])
|
65 |
+
# compute remote qkv
|
66 |
+
remote_qkv = F.linear(remote_inputs, weight, bias)
|
67 |
+
|
68 |
+
# concat local and remote qkv
|
69 |
+
if sp_rank == 0:
|
70 |
+
qkv = torch.cat([local_qkv, remote_qkv], dim=0)
|
71 |
+
else:
|
72 |
+
qkv = torch.cat([remote_qkv, local_qkv], dim=0)
|
73 |
+
qkv = rearrange(qkv, "sp b n c -> b (sp n) c")
|
74 |
+
|
75 |
+
ctx.save_for_backward(inputs, weight, remote_inputs)
|
76 |
+
return qkv
|
77 |
+
|
78 |
+
@staticmethod
|
79 |
+
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
80 |
+
from torch.distributed._functional_collectives import reduce_scatter_tensor
|
81 |
+
|
82 |
+
group = ctx.group
|
83 |
+
sp_rank = ctx.sp_rank
|
84 |
+
sp_size = ctx.sp_size
|
85 |
+
inputs, weight, remote_inputs = ctx.saved_tensors
|
86 |
+
|
87 |
+
# split qkv_grad
|
88 |
+
qkv_grad = grad_outputs[0]
|
89 |
+
qkv_grad = rearrange(qkv_grad, "b (sp n) c -> sp b n c", sp=sp_size)
|
90 |
+
qkv_grad = torch.chunk(qkv_grad, 2, dim=0)
|
91 |
+
if sp_rank == 0:
|
92 |
+
local_qkv_grad, remote_qkv_grad = qkv_grad
|
93 |
+
else:
|
94 |
+
remote_qkv_grad, local_qkv_grad = qkv_grad
|
95 |
+
|
96 |
+
# compute remote grad
|
97 |
+
remote_inputs_grad = torch.matmul(remote_qkv_grad, weight).squeeze(0)
|
98 |
+
weight_grad = torch.matmul(remote_qkv_grad.transpose(-1, -2), remote_inputs).squeeze(0).sum(0)
|
99 |
+
bias_grad = remote_qkv_grad.squeeze(0).sum(0).sum(0)
|
100 |
+
|
101 |
+
# launch async reduce scatter
|
102 |
+
remote_inputs_grad_zero = torch.zeros_like(remote_inputs_grad)
|
103 |
+
if sp_rank == 0:
|
104 |
+
remote_inputs_grad = torch.cat([remote_inputs_grad_zero, remote_inputs_grad], dim=0)
|
105 |
+
else:
|
106 |
+
remote_inputs_grad = torch.cat([remote_inputs_grad, remote_inputs_grad_zero], dim=0)
|
107 |
+
remote_inputs_grad = reduce_scatter_tensor(remote_inputs_grad, "sum", 0, group)
|
108 |
+
|
109 |
+
# compute local grad and wait for reduce scatter
|
110 |
+
local_input_grad = torch.matmul(local_qkv_grad, weight).squeeze(0)
|
111 |
+
weight_grad += torch.matmul(local_qkv_grad.transpose(-1, -2), inputs).squeeze(0).sum(0)
|
112 |
+
bias_grad += local_qkv_grad.squeeze(0).sum(0).sum(0)
|
113 |
+
|
114 |
+
# sum remote and local grad
|
115 |
+
inputs_grad = remote_inputs_grad + local_input_grad
|
116 |
+
return inputs_grad, weight_grad, bias_grad, None, None, None
|
117 |
+
|
118 |
+
|
119 |
+
class AllGather(torch.autograd.Function):
|
120 |
+
@staticmethod
|
121 |
+
def forward(
|
122 |
+
ctx: Any,
|
123 |
+
inputs: Tensor,
|
124 |
+
group: Optional[ProcessGroup] = None,
|
125 |
+
overlap: bool = False,
|
126 |
+
) -> Tuple[Tensor, Any]:
|
127 |
+
"""
|
128 |
+
Returns:
|
129 |
+
outputs: Tensor
|
130 |
+
handle: Optional[Work], if overlap is True
|
131 |
+
"""
|
132 |
+
assert ctx is not None or not overlap
|
133 |
+
|
134 |
+
if ctx is not None:
|
135 |
+
ctx.comm_grp = group
|
136 |
+
|
137 |
+
comm_size = dist.get_world_size(group)
|
138 |
+
if comm_size == 1:
|
139 |
+
return inputs.unsqueeze(0), None
|
140 |
+
|
141 |
+
buffer_shape = (comm_size,) + inputs.shape
|
142 |
+
outputs = torch.empty(buffer_shape, dtype=inputs.dtype, device=inputs.device)
|
143 |
+
buffer_list = list(torch.chunk(outputs, comm_size, dim=0))
|
144 |
+
if not overlap:
|
145 |
+
dist.all_gather(buffer_list, inputs, group=group)
|
146 |
+
return outputs, None
|
147 |
+
else:
|
148 |
+
handle = dist.all_gather(buffer_list, inputs, group=group, async_op=True)
|
149 |
+
return outputs, handle
|
150 |
+
|
151 |
+
@staticmethod
|
152 |
+
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
153 |
+
return (
|
154 |
+
ReduceScatter.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
155 |
+
None,
|
156 |
+
None,
|
157 |
+
)
|
158 |
+
|
159 |
+
|
160 |
+
class ReduceScatter(torch.autograd.Function):
|
161 |
+
@staticmethod
|
162 |
+
def forward(
|
163 |
+
ctx: Any,
|
164 |
+
inputs: Tensor,
|
165 |
+
group: ProcessGroup,
|
166 |
+
overlap: bool = False,
|
167 |
+
) -> Tuple[Tensor, Any]:
|
168 |
+
"""
|
169 |
+
Returns:
|
170 |
+
outputs: Tensor
|
171 |
+
handle: Optional[Work], if overlap is True
|
172 |
+
"""
|
173 |
+
assert ctx is not None or not overlap
|
174 |
+
|
175 |
+
if ctx is not None:
|
176 |
+
ctx.comm_grp = group
|
177 |
+
|
178 |
+
comm_size = dist.get_world_size(group)
|
179 |
+
if comm_size == 1:
|
180 |
+
return inputs.squeeze(0), None
|
181 |
+
|
182 |
+
if not inputs.is_contiguous():
|
183 |
+
inputs = inputs.contiguous()
|
184 |
+
|
185 |
+
output_shape = inputs.shape[1:]
|
186 |
+
outputs = torch.empty(output_shape, dtype=inputs.dtype, device=inputs.device)
|
187 |
+
buffer_list = list(torch.chunk(inputs, comm_size, dim=0))
|
188 |
+
if not overlap:
|
189 |
+
dist.reduce_scatter(outputs, buffer_list, group=group)
|
190 |
+
return outputs, None
|
191 |
+
else:
|
192 |
+
handle = dist.reduce_scatter(outputs, buffer_list, group=group, async_op=True)
|
193 |
+
return outputs, handle
|
194 |
+
|
195 |
+
@staticmethod
|
196 |
+
def backward(ctx: Any, *grad_outputs) -> Tuple[Tensor, None, None]:
|
197 |
+
# TODO: support async backward
|
198 |
+
return (
|
199 |
+
AllGather.forward(None, grad_outputs[0], ctx.comm_grp, False)[0],
|
200 |
+
None,
|
201 |
+
None,
|
202 |
+
)
|
203 |
+
|
204 |
+
|
205 |
+
# ======================================================
|
206 |
+
# AlltoAll
|
207 |
+
# ======================================================
|
208 |
+
|
209 |
+
|
210 |
+
def _all_to_all_func(input_, world_size, group, scatter_dim, gather_dim):
|
211 |
+
input_list = [t.contiguous() for t in torch.tensor_split(input_, world_size, scatter_dim)]
|
212 |
+
output_list = [torch.empty_like(input_list[0]) for _ in range(world_size)]
|
213 |
+
dist.all_to_all(output_list, input_list, group=group)
|
214 |
+
return torch.cat(output_list, dim=gather_dim).contiguous()
|
215 |
+
|
216 |
+
|
217 |
+
class _AllToAll(torch.autograd.Function):
|
218 |
+
"""All-to-all communication.
|
219 |
+
|
220 |
+
Args:
|
221 |
+
input_: input matrix
|
222 |
+
process_group: communication group
|
223 |
+
scatter_dim: scatter dimension
|
224 |
+
gather_dim: gather dimension
|
225 |
+
"""
|
226 |
+
|
227 |
+
@staticmethod
|
228 |
+
def forward(ctx, input_, process_group, scatter_dim, gather_dim):
|
229 |
+
ctx.process_group = process_group
|
230 |
+
ctx.scatter_dim = scatter_dim
|
231 |
+
ctx.gather_dim = gather_dim
|
232 |
+
world_size = dist.get_world_size(process_group)
|
233 |
+
|
234 |
+
return _all_to_all_func(input_, world_size, process_group, scatter_dim, gather_dim)
|
235 |
+
|
236 |
+
@staticmethod
|
237 |
+
def backward(ctx, *grad_output):
|
238 |
+
process_group = ctx.process_group
|
239 |
+
scatter_dim = ctx.gather_dim
|
240 |
+
gather_dim = ctx.scatter_dim
|
241 |
+
return_grad = _AllToAll.apply(*grad_output, process_group, scatter_dim, gather_dim)
|
242 |
+
return (return_grad, None, None, None)
|
243 |
+
|
244 |
+
|
245 |
+
def all_to_all_comm(input_, process_group=None, scatter_dim=2, gather_dim=1):
|
246 |
+
return _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
247 |
+
|
248 |
+
|
249 |
+
# ======================================================
|
250 |
+
# Sequence Gather & Split
|
251 |
+
# ======================================================
|
252 |
+
|
253 |
+
|
254 |
+
def _split_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
|
255 |
+
# skip if only one rank involved
|
256 |
+
world_size = dist.get_world_size(pg)
|
257 |
+
rank = dist.get_rank(pg)
|
258 |
+
if world_size == 1:
|
259 |
+
return input_
|
260 |
+
|
261 |
+
if pad > 0:
|
262 |
+
pad_size = list(input_.shape)
|
263 |
+
pad_size[dim] = pad
|
264 |
+
input_ = torch.cat([input_, torch.zeros(pad_size, dtype=input_.dtype, device=input_.device)], dim=dim)
|
265 |
+
|
266 |
+
dim_size = input_.size(dim)
|
267 |
+
assert dim_size % world_size == 0, f"dim_size ({dim_size}) is not divisible by world_size ({world_size})"
|
268 |
+
|
269 |
+
tensor_list = torch.split(input_, dim_size // world_size, dim=dim)
|
270 |
+
output = tensor_list[rank].contiguous()
|
271 |
+
return output
|
272 |
+
|
273 |
+
|
274 |
+
def _gather_sequence_func(input_, pg: dist.ProcessGroup, dim: int, pad: int):
|
275 |
+
# skip if only one rank involved
|
276 |
+
input_ = input_.contiguous()
|
277 |
+
world_size = dist.get_world_size(pg)
|
278 |
+
dist.get_rank(pg)
|
279 |
+
|
280 |
+
if world_size == 1:
|
281 |
+
return input_
|
282 |
+
|
283 |
+
# all gather
|
284 |
+
tensor_list = [torch.empty_like(input_) for _ in range(world_size)]
|
285 |
+
assert input_.device.type == "cuda"
|
286 |
+
torch.distributed.all_gather(tensor_list, input_, group=pg)
|
287 |
+
|
288 |
+
# concat
|
289 |
+
output = torch.cat(tensor_list, dim=dim)
|
290 |
+
|
291 |
+
if pad > 0:
|
292 |
+
output = output.narrow(dim, 0, output.size(dim) - pad)
|
293 |
+
|
294 |
+
return output
|
295 |
+
|
296 |
+
|
297 |
+
class _GatherForwardSplitBackward(torch.autograd.Function):
|
298 |
+
"""
|
299 |
+
Gather the input sequence.
|
300 |
+
|
301 |
+
Args:
|
302 |
+
input_: input matrix.
|
303 |
+
process_group: process group.
|
304 |
+
dim: dimension
|
305 |
+
"""
|
306 |
+
|
307 |
+
@staticmethod
|
308 |
+
def symbolic(graph, input_):
|
309 |
+
return _gather_sequence_func(input_)
|
310 |
+
|
311 |
+
@staticmethod
|
312 |
+
def forward(ctx, input_, process_group, dim, grad_scale, pad):
|
313 |
+
ctx.process_group = process_group
|
314 |
+
ctx.dim = dim
|
315 |
+
ctx.grad_scale = grad_scale
|
316 |
+
ctx.pad = pad
|
317 |
+
return _gather_sequence_func(input_, process_group, dim, pad)
|
318 |
+
|
319 |
+
@staticmethod
|
320 |
+
def backward(ctx, grad_output):
|
321 |
+
if ctx.grad_scale == "up":
|
322 |
+
grad_output = grad_output * dist.get_world_size(ctx.process_group)
|
323 |
+
elif ctx.grad_scale == "down":
|
324 |
+
grad_output = grad_output / dist.get_world_size(ctx.process_group)
|
325 |
+
|
326 |
+
return _split_sequence_func(grad_output, ctx.process_group, ctx.dim, ctx.pad), None, None, None, None
|
327 |
+
|
328 |
+
|
329 |
+
class _SplitForwardGatherBackward(torch.autograd.Function):
|
330 |
+
"""
|
331 |
+
Split sequence.
|
332 |
+
|
333 |
+
Args:
|
334 |
+
input_: input matrix.
|
335 |
+
process_group: parallel mode.
|
336 |
+
dim: dimension
|
337 |
+
"""
|
338 |
+
|
339 |
+
@staticmethod
|
340 |
+
def symbolic(graph, input_):
|
341 |
+
return _split_sequence_func(input_)
|
342 |
+
|
343 |
+
@staticmethod
|
344 |
+
def forward(ctx, input_, process_group, dim, grad_scale, pad):
|
345 |
+
ctx.process_group = process_group
|
346 |
+
ctx.dim = dim
|
347 |
+
ctx.grad_scale = grad_scale
|
348 |
+
ctx.pad = pad
|
349 |
+
return _split_sequence_func(input_, process_group, dim, pad)
|
350 |
+
|
351 |
+
@staticmethod
|
352 |
+
def backward(ctx, grad_output):
|
353 |
+
if ctx.grad_scale == "up":
|
354 |
+
grad_output = grad_output * dist.get_world_size(ctx.process_group)
|
355 |
+
elif ctx.grad_scale == "down":
|
356 |
+
grad_output = grad_output / dist.get_world_size(ctx.process_group)
|
357 |
+
return _gather_sequence_func(grad_output, ctx.process_group, ctx.pad), None, None, None, None
|
358 |
+
|
359 |
+
|
360 |
+
def split_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
|
361 |
+
return _SplitForwardGatherBackward.apply(input_, process_group, dim, grad_scale, pad)
|
362 |
+
|
363 |
+
|
364 |
+
def gather_sequence(input_, process_group, dim, grad_scale=1.0, pad=0):
|
365 |
+
return _GatherForwardSplitBackward.apply(input_, process_group, dim, grad_scale, pad)
|
366 |
+
|
367 |
+
|
368 |
+
# ==============================
|
369 |
+
# Pad
|
370 |
+
# ==============================
|
371 |
+
|
372 |
+
SPTIAL_PAD = 0
|
373 |
+
TEMPORAL_PAD = 0
|
374 |
+
|
375 |
+
|
376 |
+
def set_spatial_pad(dim_size: int):
|
377 |
+
sp_size = get_sequence_parallel_size()
|
378 |
+
pad = (sp_size - (dim_size % sp_size)) % sp_size
|
379 |
+
global SPTIAL_PAD
|
380 |
+
SPTIAL_PAD = pad
|
381 |
+
|
382 |
+
|
383 |
+
def get_spatial_pad() -> int:
|
384 |
+
return SPTIAL_PAD
|
385 |
+
|
386 |
+
|
387 |
+
def set_temporal_pad(dim_size: int):
|
388 |
+
sp_size = get_sequence_parallel_size()
|
389 |
+
pad = (sp_size - (dim_size % sp_size)) % sp_size
|
390 |
+
global TEMPORAL_PAD
|
391 |
+
TEMPORAL_PAD = pad
|
392 |
+
|
393 |
+
|
394 |
+
def get_temporal_pad() -> int:
|
395 |
+
return TEMPORAL_PAD
|
396 |
+
|
397 |
+
|
398 |
+
def all_to_all_with_pad(
|
399 |
+
input_: torch.Tensor,
|
400 |
+
process_group: dist.ProcessGroup,
|
401 |
+
scatter_dim: int = 2,
|
402 |
+
gather_dim: int = 1,
|
403 |
+
scatter_pad: int = 0,
|
404 |
+
gather_pad: int = 0,
|
405 |
+
):
|
406 |
+
if scatter_pad > 0:
|
407 |
+
pad_shape = list(input_.shape)
|
408 |
+
pad_shape[scatter_dim] = scatter_pad
|
409 |
+
pad_tensor = torch.zeros(pad_shape, device=input_.device, dtype=input_.dtype)
|
410 |
+
input_ = torch.cat([input_, pad_tensor], dim=scatter_dim)
|
411 |
+
|
412 |
+
assert (
|
413 |
+
input_.shape[scatter_dim] % dist.get_world_size(process_group) == 0
|
414 |
+
), f"Dimension to scatter ({input_.shape[scatter_dim]}) is not divisible by world size ({dist.get_world_size(process_group)})"
|
415 |
+
input_ = _AllToAll.apply(input_, process_group, scatter_dim, gather_dim)
|
416 |
+
|
417 |
+
if gather_pad > 0:
|
418 |
+
input_ = input_.narrow(gather_dim, 0, input_.size(gather_dim) - gather_pad)
|
419 |
+
|
420 |
+
return input_
|
videosys/core/engine.py
ADDED
@@ -0,0 +1,132 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from functools import partial
|
3 |
+
from typing import Any, Optional
|
4 |
+
|
5 |
+
import imageio
|
6 |
+
import torch
|
7 |
+
|
8 |
+
import videosys
|
9 |
+
|
10 |
+
from .mp_utils import ProcessWorkerWrapper, ResultHandler, WorkerMonitor, get_distributed_init_method, get_open_port
|
11 |
+
|
12 |
+
|
13 |
+
class VideoSysEngine:
|
14 |
+
"""
|
15 |
+
this is partly inspired by vllm
|
16 |
+
"""
|
17 |
+
|
18 |
+
def __init__(self, config):
|
19 |
+
self.config = config
|
20 |
+
self.parallel_worker_tasks = None
|
21 |
+
self._init_worker(config.pipeline_cls)
|
22 |
+
|
23 |
+
def _init_worker(self, pipeline_cls):
|
24 |
+
world_size = self.config.world_size
|
25 |
+
|
26 |
+
if "CUDA_VISIBLE_DEVICES" not in os.environ:
|
27 |
+
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join(str(i) for i in range(world_size))
|
28 |
+
|
29 |
+
# Disable torch async compiling which won't work with daemonic processes
|
30 |
+
os.environ["TORCHINDUCTOR_COMPILE_THREADS"] = "1"
|
31 |
+
|
32 |
+
# Set OMP_NUM_THREADS to 1 if it is not set explicitly, avoids CPU
|
33 |
+
# contention amongst the shards
|
34 |
+
if "OMP_NUM_THREADS" not in os.environ:
|
35 |
+
os.environ["OMP_NUM_THREADS"] = "1"
|
36 |
+
|
37 |
+
# NOTE: The two following lines need adaption for multi-node
|
38 |
+
assert world_size <= torch.cuda.device_count()
|
39 |
+
|
40 |
+
# change addr for multi-node
|
41 |
+
distributed_init_method = get_distributed_init_method("127.0.0.1", get_open_port())
|
42 |
+
|
43 |
+
if world_size == 1:
|
44 |
+
self.workers = []
|
45 |
+
self.worker_monitor = None
|
46 |
+
else:
|
47 |
+
result_handler = ResultHandler()
|
48 |
+
self.workers = [
|
49 |
+
ProcessWorkerWrapper(
|
50 |
+
result_handler,
|
51 |
+
partial(
|
52 |
+
self._create_pipeline,
|
53 |
+
pipeline_cls=pipeline_cls,
|
54 |
+
rank=rank,
|
55 |
+
local_rank=rank,
|
56 |
+
distributed_init_method=distributed_init_method,
|
57 |
+
),
|
58 |
+
)
|
59 |
+
for rank in range(1, world_size)
|
60 |
+
]
|
61 |
+
|
62 |
+
self.worker_monitor = WorkerMonitor(self.workers, result_handler)
|
63 |
+
result_handler.start()
|
64 |
+
self.worker_monitor.start()
|
65 |
+
|
66 |
+
self.driver_worker = self._create_pipeline(
|
67 |
+
pipeline_cls=pipeline_cls, distributed_init_method=distributed_init_method
|
68 |
+
)
|
69 |
+
|
70 |
+
# TODO: add more options here for pipeline, or wrap all options into config
|
71 |
+
def _create_pipeline(self, pipeline_cls, rank=0, local_rank=0, distributed_init_method=None):
|
72 |
+
videosys.initialize(rank=rank, world_size=self.config.world_size, init_method=distributed_init_method, seed=42)
|
73 |
+
|
74 |
+
pipeline = pipeline_cls(self.config)
|
75 |
+
return pipeline
|
76 |
+
|
77 |
+
def _run_workers(
|
78 |
+
self,
|
79 |
+
method: str,
|
80 |
+
*args,
|
81 |
+
async_run_tensor_parallel_workers_only: bool = False,
|
82 |
+
max_concurrent_workers: Optional[int] = None,
|
83 |
+
**kwargs,
|
84 |
+
) -> Any:
|
85 |
+
"""Runs the given method on all workers."""
|
86 |
+
|
87 |
+
# Start the workers first.
|
88 |
+
worker_outputs = [worker.execute_method(method, *args, **kwargs) for worker in self.workers]
|
89 |
+
|
90 |
+
if async_run_tensor_parallel_workers_only:
|
91 |
+
# Just return futures
|
92 |
+
return worker_outputs
|
93 |
+
|
94 |
+
driver_worker_method = getattr(self.driver_worker, method)
|
95 |
+
driver_worker_output = driver_worker_method(*args, **kwargs)
|
96 |
+
|
97 |
+
# Get the results of the workers.
|
98 |
+
return [driver_worker_output] + [output.get() for output in worker_outputs]
|
99 |
+
|
100 |
+
def _driver_execute_model(self, *args, **kwargs):
|
101 |
+
return self.driver_worker.generate(*args, **kwargs)
|
102 |
+
|
103 |
+
def generate(self, *args, **kwargs):
|
104 |
+
return self._run_workers("generate", *args, **kwargs)[0]
|
105 |
+
|
106 |
+
def stop_remote_worker_execution_loop(self) -> None:
|
107 |
+
if self.parallel_worker_tasks is None:
|
108 |
+
return
|
109 |
+
|
110 |
+
parallel_worker_tasks = self.parallel_worker_tasks
|
111 |
+
self.parallel_worker_tasks = None
|
112 |
+
# Ensure that workers exit model loop cleanly
|
113 |
+
# (this will raise otherwise)
|
114 |
+
self._wait_for_tasks_completion(parallel_worker_tasks)
|
115 |
+
|
116 |
+
def _wait_for_tasks_completion(self, parallel_worker_tasks: Any) -> None:
|
117 |
+
"""Wait for futures returned from _run_workers() with
|
118 |
+
async_run_remote_workers_only to complete."""
|
119 |
+
for result in parallel_worker_tasks:
|
120 |
+
result.get()
|
121 |
+
|
122 |
+
def save_video(self, video, output_path):
|
123 |
+
os.makedirs(os.path.dirname(output_path), exist_ok=True)
|
124 |
+
imageio.mimwrite(output_path, video, fps=24)
|
125 |
+
|
126 |
+
def shutdown(self):
|
127 |
+
if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
|
128 |
+
worker_monitor.close()
|
129 |
+
torch.distributed.destroy_process_group()
|
130 |
+
|
131 |
+
def __del__(self):
|
132 |
+
self.shutdown()
|
videosys/core/mp_utils.py
ADDED
@@ -0,0 +1,270 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# adapted from vllm
|
2 |
+
# https://github.com/vllm-project/vllm/blob/main/vllm/executor/multiproc_worker_utils.py
|
3 |
+
|
4 |
+
import asyncio
|
5 |
+
import multiprocessing
|
6 |
+
import os
|
7 |
+
import socket
|
8 |
+
import sys
|
9 |
+
import threading
|
10 |
+
import traceback
|
11 |
+
import uuid
|
12 |
+
from dataclasses import dataclass
|
13 |
+
from multiprocessing import Queue
|
14 |
+
from multiprocessing.connection import wait
|
15 |
+
from typing import Any, Callable, Dict, Generic, List, Optional, TextIO, TypeVar, Union
|
16 |
+
|
17 |
+
from videosys.utils.logging import create_logger
|
18 |
+
|
19 |
+
T = TypeVar("T")
|
20 |
+
_TERMINATE = "TERMINATE" # sentinel
|
21 |
+
# ANSI color codes
|
22 |
+
CYAN = "\033[1;36m"
|
23 |
+
RESET = "\033[0;0m"
|
24 |
+
JOIN_TIMEOUT_S = 2
|
25 |
+
|
26 |
+
mp_method = "spawn" # fork cann't work
|
27 |
+
mp = multiprocessing.get_context(mp_method)
|
28 |
+
|
29 |
+
logger = create_logger()
|
30 |
+
|
31 |
+
|
32 |
+
def get_distributed_init_method(ip: str, port: int) -> str:
|
33 |
+
# Brackets are not permitted in ipv4 addresses,
|
34 |
+
# see https://github.com/python/cpython/issues/103848
|
35 |
+
return f"tcp://[{ip}]:{port}" if ":" in ip else f"tcp://{ip}:{port}"
|
36 |
+
|
37 |
+
|
38 |
+
def get_open_port() -> int:
|
39 |
+
# try ipv4
|
40 |
+
try:
|
41 |
+
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
|
42 |
+
s.bind(("", 0))
|
43 |
+
return s.getsockname()[1]
|
44 |
+
except OSError:
|
45 |
+
# try ipv6
|
46 |
+
with socket.socket(socket.AF_INET6, socket.SOCK_STREAM) as s:
|
47 |
+
s.bind(("", 0))
|
48 |
+
return s.getsockname()[1]
|
49 |
+
|
50 |
+
|
51 |
+
@dataclass
|
52 |
+
class Result(Generic[T]):
|
53 |
+
"""Result of task dispatched to worker"""
|
54 |
+
|
55 |
+
task_id: uuid.UUID
|
56 |
+
value: Optional[T] = None
|
57 |
+
exception: Optional[BaseException] = None
|
58 |
+
|
59 |
+
|
60 |
+
class ResultFuture(threading.Event, Generic[T]):
|
61 |
+
"""Synchronous future for non-async case"""
|
62 |
+
|
63 |
+
def __init__(self):
|
64 |
+
super().__init__()
|
65 |
+
self.result: Optional[Result[T]] = None
|
66 |
+
|
67 |
+
def set_result(self, result: Result[T]):
|
68 |
+
self.result = result
|
69 |
+
self.set()
|
70 |
+
|
71 |
+
def get(self) -> T:
|
72 |
+
self.wait()
|
73 |
+
assert self.result is not None
|
74 |
+
if self.result.exception is not None:
|
75 |
+
raise self.result.exception
|
76 |
+
return self.result.value # type: ignore[return-value]
|
77 |
+
|
78 |
+
|
79 |
+
def _set_future_result(future: Union[ResultFuture, asyncio.Future], result: Result):
|
80 |
+
if isinstance(future, ResultFuture):
|
81 |
+
future.set_result(result)
|
82 |
+
return
|
83 |
+
loop = future.get_loop()
|
84 |
+
if not loop.is_closed():
|
85 |
+
if result.exception is not None:
|
86 |
+
loop.call_soon_threadsafe(future.set_exception, result.exception)
|
87 |
+
else:
|
88 |
+
loop.call_soon_threadsafe(future.set_result, result.value)
|
89 |
+
|
90 |
+
|
91 |
+
class ResultHandler(threading.Thread):
|
92 |
+
"""Handle results from all workers (in background thread)"""
|
93 |
+
|
94 |
+
def __init__(self) -> None:
|
95 |
+
super().__init__(daemon=True)
|
96 |
+
self.result_queue = mp.Queue()
|
97 |
+
self.tasks: Dict[uuid.UUID, Union[ResultFuture, asyncio.Future]] = {}
|
98 |
+
|
99 |
+
def run(self):
|
100 |
+
for result in iter(self.result_queue.get, _TERMINATE):
|
101 |
+
future = self.tasks.pop(result.task_id)
|
102 |
+
_set_future_result(future, result)
|
103 |
+
# Ensure that all waiters will receive an exception
|
104 |
+
for task_id, future in self.tasks.items():
|
105 |
+
_set_future_result(future, Result(task_id=task_id, exception=ChildProcessError("worker died")))
|
106 |
+
|
107 |
+
def close(self):
|
108 |
+
self.result_queue.put(_TERMINATE)
|
109 |
+
|
110 |
+
|
111 |
+
class WorkerMonitor(threading.Thread):
|
112 |
+
"""Monitor worker status (in background thread)"""
|
113 |
+
|
114 |
+
def __init__(self, workers: List["ProcessWorkerWrapper"], result_handler: ResultHandler):
|
115 |
+
super().__init__(daemon=True)
|
116 |
+
self.workers = workers
|
117 |
+
self.result_handler = result_handler
|
118 |
+
self._close = False
|
119 |
+
|
120 |
+
def run(self) -> None:
|
121 |
+
# Blocks until any worker exits
|
122 |
+
dead_sentinels = wait([w.process.sentinel for w in self.workers])
|
123 |
+
if not self._close:
|
124 |
+
self._close = True
|
125 |
+
|
126 |
+
# Kill / cleanup all workers
|
127 |
+
for worker in self.workers:
|
128 |
+
process = worker.process
|
129 |
+
if process.sentinel in dead_sentinels:
|
130 |
+
process.join(JOIN_TIMEOUT_S)
|
131 |
+
if process.exitcode is not None and process.exitcode != 0:
|
132 |
+
logger.error("Worker %s pid %s died, exit code: %s", process.name, process.pid, process.exitcode)
|
133 |
+
# Cleanup any remaining workers
|
134 |
+
logger.info("Killing local worker processes")
|
135 |
+
for worker in self.workers:
|
136 |
+
worker.kill_worker()
|
137 |
+
# Must be done after worker task queues are all closed
|
138 |
+
self.result_handler.close()
|
139 |
+
|
140 |
+
for worker in self.workers:
|
141 |
+
worker.process.join(JOIN_TIMEOUT_S)
|
142 |
+
|
143 |
+
def close(self):
|
144 |
+
if self._close:
|
145 |
+
return
|
146 |
+
self._close = True
|
147 |
+
logger.info("Terminating local worker processes")
|
148 |
+
for worker in self.workers:
|
149 |
+
worker.terminate_worker()
|
150 |
+
# Must be done after worker task queues are all closed
|
151 |
+
self.result_handler.close()
|
152 |
+
|
153 |
+
|
154 |
+
def _add_prefix(file: TextIO, worker_name: str, pid: int) -> None:
|
155 |
+
"""Prepend each output line with process-specific prefix"""
|
156 |
+
|
157 |
+
prefix = f"{CYAN}({worker_name} pid={pid}){RESET} "
|
158 |
+
file_write = file.write
|
159 |
+
|
160 |
+
def write_with_prefix(s: str):
|
161 |
+
if not s:
|
162 |
+
return
|
163 |
+
if file.start_new_line: # type: ignore[attr-defined]
|
164 |
+
file_write(prefix)
|
165 |
+
idx = 0
|
166 |
+
while (next_idx := s.find("\n", idx)) != -1:
|
167 |
+
next_idx += 1
|
168 |
+
file_write(s[idx:next_idx])
|
169 |
+
if next_idx == len(s):
|
170 |
+
file.start_new_line = True # type: ignore[attr-defined]
|
171 |
+
return
|
172 |
+
file_write(prefix)
|
173 |
+
idx = next_idx
|
174 |
+
file_write(s[idx:])
|
175 |
+
file.start_new_line = False # type: ignore[attr-defined]
|
176 |
+
|
177 |
+
file.start_new_line = True # type: ignore[attr-defined]
|
178 |
+
file.write = write_with_prefix # type: ignore[method-assign]
|
179 |
+
|
180 |
+
|
181 |
+
def _run_worker_process(
|
182 |
+
worker_factory: Callable[[], Any],
|
183 |
+
task_queue: Queue,
|
184 |
+
result_queue: Queue,
|
185 |
+
) -> None:
|
186 |
+
"""Worker process event loop"""
|
187 |
+
|
188 |
+
# Add process-specific prefix to stdout and stderr
|
189 |
+
process_name = mp.current_process().name
|
190 |
+
pid = os.getpid()
|
191 |
+
_add_prefix(sys.stdout, process_name, pid)
|
192 |
+
_add_prefix(sys.stderr, process_name, pid)
|
193 |
+
|
194 |
+
# Initialize worker
|
195 |
+
worker = worker_factory()
|
196 |
+
del worker_factory
|
197 |
+
|
198 |
+
# Accept tasks from the engine in task_queue
|
199 |
+
# and return task output in result_queue
|
200 |
+
logger.info("Worker ready; awaiting tasks")
|
201 |
+
try:
|
202 |
+
for items in iter(task_queue.get, _TERMINATE):
|
203 |
+
output = None
|
204 |
+
exception = None
|
205 |
+
task_id, method, args, kwargs = items
|
206 |
+
try:
|
207 |
+
executor = getattr(worker, method)
|
208 |
+
output = executor(*args, **kwargs)
|
209 |
+
except BaseException as e:
|
210 |
+
tb = traceback.format_exc()
|
211 |
+
logger.error("Exception in worker %s while processing method %s: %s, %s", process_name, method, e, tb)
|
212 |
+
exception = e
|
213 |
+
result_queue.put(Result(task_id=task_id, value=output, exception=exception))
|
214 |
+
except KeyboardInterrupt:
|
215 |
+
pass
|
216 |
+
except Exception:
|
217 |
+
logger.exception("Worker failed")
|
218 |
+
|
219 |
+
logger.info("Worker exiting")
|
220 |
+
|
221 |
+
|
222 |
+
class ProcessWorkerWrapper:
|
223 |
+
"""Local process wrapper for handling single-node multi-GPU."""
|
224 |
+
|
225 |
+
def __init__(self, result_handler: ResultHandler, worker_factory: Callable[[], Any]) -> None:
|
226 |
+
self._task_queue = mp.Queue()
|
227 |
+
self.result_queue = result_handler.result_queue
|
228 |
+
self.tasks = result_handler.tasks
|
229 |
+
self.process = mp.Process( # type: ignore[attr-defined]
|
230 |
+
target=_run_worker_process,
|
231 |
+
name="VideoSysWorkerProcess",
|
232 |
+
kwargs=dict(
|
233 |
+
worker_factory=worker_factory,
|
234 |
+
task_queue=self._task_queue,
|
235 |
+
result_queue=self.result_queue,
|
236 |
+
),
|
237 |
+
daemon=True,
|
238 |
+
)
|
239 |
+
|
240 |
+
self.process.start()
|
241 |
+
|
242 |
+
def _enqueue_task(self, future: Union[ResultFuture, asyncio.Future], method: str, args, kwargs):
|
243 |
+
task_id = uuid.uuid4()
|
244 |
+
self.tasks[task_id] = future
|
245 |
+
try:
|
246 |
+
self._task_queue.put((task_id, method, args, kwargs))
|
247 |
+
except BaseException as e:
|
248 |
+
del self.tasks[task_id]
|
249 |
+
raise ChildProcessError("worker died") from e
|
250 |
+
|
251 |
+
def execute_method(self, method: str, *args, **kwargs):
|
252 |
+
future: ResultFuture = ResultFuture()
|
253 |
+
self._enqueue_task(future, method, args, kwargs)
|
254 |
+
return future
|
255 |
+
|
256 |
+
async def execute_method_async(self, method: str, *args, **kwargs):
|
257 |
+
future = asyncio.get_running_loop().create_future()
|
258 |
+
self._enqueue_task(future, method, args, kwargs)
|
259 |
+
return await future
|
260 |
+
|
261 |
+
def terminate_worker(self):
|
262 |
+
try:
|
263 |
+
self._task_queue.put(_TERMINATE)
|
264 |
+
except ValueError:
|
265 |
+
self.process.kill()
|
266 |
+
self._task_queue.close()
|
267 |
+
|
268 |
+
def kill_worker(self):
|
269 |
+
self._task_queue.close()
|
270 |
+
self.process.kill()
|
videosys/core/pab_mgr.py
ADDED
@@ -0,0 +1,364 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
|
3 |
+
import numpy as np
|
4 |
+
import torch
|
5 |
+
|
6 |
+
from videosys.utils.logging import logger
|
7 |
+
|
8 |
+
PAB_MANAGER = None
|
9 |
+
|
10 |
+
|
11 |
+
class PABConfig:
|
12 |
+
def __init__(
|
13 |
+
self,
|
14 |
+
steps: int,
|
15 |
+
cross_broadcast: bool,
|
16 |
+
cross_threshold: list,
|
17 |
+
cross_gap: int,
|
18 |
+
spatial_broadcast: bool,
|
19 |
+
spatial_threshold: list,
|
20 |
+
spatial_gap: int,
|
21 |
+
temporal_broadcast: bool,
|
22 |
+
temporal_threshold: list,
|
23 |
+
temporal_gap: int,
|
24 |
+
diffusion_skip: bool,
|
25 |
+
diffusion_timestep_respacing: list,
|
26 |
+
diffusion_skip_timestep: list,
|
27 |
+
mlp_skip: bool,
|
28 |
+
mlp_spatial_skip_config: dict,
|
29 |
+
mlp_temporal_skip_config: dict,
|
30 |
+
full_broadcast: bool = False,
|
31 |
+
full_threshold: list = None,
|
32 |
+
full_gap: int = 1,
|
33 |
+
):
|
34 |
+
self.steps = steps
|
35 |
+
|
36 |
+
self.cross_broadcast = cross_broadcast
|
37 |
+
self.cross_threshold = cross_threshold
|
38 |
+
self.cross_gap = cross_gap
|
39 |
+
|
40 |
+
self.spatial_broadcast = spatial_broadcast
|
41 |
+
self.spatial_threshold = spatial_threshold
|
42 |
+
self.spatial_gap = spatial_gap
|
43 |
+
|
44 |
+
self.temporal_broadcast = temporal_broadcast
|
45 |
+
self.temporal_threshold = temporal_threshold
|
46 |
+
self.temporal_gap = temporal_gap
|
47 |
+
|
48 |
+
self.diffusion_skip = diffusion_skip
|
49 |
+
self.diffusion_timestep_respacing = diffusion_timestep_respacing
|
50 |
+
self.diffusion_skip_timestep = diffusion_skip_timestep
|
51 |
+
|
52 |
+
self.mlp_skip = mlp_skip
|
53 |
+
self.mlp_spatial_skip_config = mlp_spatial_skip_config
|
54 |
+
self.mlp_temporal_skip_config = mlp_temporal_skip_config
|
55 |
+
|
56 |
+
self.temporal_mlp_outputs = {}
|
57 |
+
self.spatial_mlp_outputs = {}
|
58 |
+
|
59 |
+
self.full_broadcast = full_broadcast
|
60 |
+
self.full_threshold = full_threshold
|
61 |
+
self.full_gap = full_gap
|
62 |
+
|
63 |
+
|
64 |
+
class PABManager:
|
65 |
+
def __init__(self, config: PABConfig):
|
66 |
+
self.config: PABConfig = config
|
67 |
+
|
68 |
+
init_prompt = f"Init PABManager. steps: {config.steps}."
|
69 |
+
init_prompt += f" spatial_broadcast: {config.spatial_broadcast}, spatial_threshold: {config.spatial_threshold}, spatial_gap: {config.spatial_gap}."
|
70 |
+
init_prompt += f" temporal_broadcast: {config.temporal_broadcast}, temporal_threshold: {config.temporal_threshold}, temporal_gap: {config.temporal_gap}."
|
71 |
+
init_prompt += f" cross_broadcast: {config.cross_broadcast}, cross_threshold: {config.cross_threshold}, cross_gap: {config.cross_gap}."
|
72 |
+
init_prompt += f" full_broadcast: {config.full_broadcast}, full_threshold: {config.full_threshold}, full_gap: {config.full_gap}."
|
73 |
+
logger.info(init_prompt)
|
74 |
+
|
75 |
+
def if_broadcast_cross(self, timestep: int, count: int):
|
76 |
+
if (
|
77 |
+
self.config.cross_broadcast
|
78 |
+
and (timestep is not None)
|
79 |
+
and (count % self.config.cross_gap != 0)
|
80 |
+
and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
|
81 |
+
):
|
82 |
+
flag = True
|
83 |
+
else:
|
84 |
+
flag = False
|
85 |
+
count = (count + 1) % self.config.steps
|
86 |
+
return flag, count
|
87 |
+
|
88 |
+
def if_broadcast_temporal(self, timestep: int, count: int):
|
89 |
+
if (
|
90 |
+
self.config.temporal_broadcast
|
91 |
+
and (timestep is not None)
|
92 |
+
and (count % self.config.temporal_gap != 0)
|
93 |
+
and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
|
94 |
+
):
|
95 |
+
flag = True
|
96 |
+
else:
|
97 |
+
flag = False
|
98 |
+
count = (count + 1) % self.config.steps
|
99 |
+
return flag, count
|
100 |
+
|
101 |
+
def if_broadcast_spatial(self, timestep: int, count: int, block_idx: int):
|
102 |
+
if (
|
103 |
+
self.config.spatial_broadcast
|
104 |
+
and (timestep is not None)
|
105 |
+
and (count % self.config.spatial_gap != 0)
|
106 |
+
and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
|
107 |
+
):
|
108 |
+
flag = True
|
109 |
+
else:
|
110 |
+
flag = False
|
111 |
+
count = (count + 1) % self.config.steps
|
112 |
+
return flag, count
|
113 |
+
|
114 |
+
def if_broadcast_full(self, timestep: int, count: int, block_idx: int):
|
115 |
+
if (
|
116 |
+
self.config.full_broadcast
|
117 |
+
and (timestep is not None)
|
118 |
+
and (count % self.config.full_gap != 0)
|
119 |
+
and (self.config.full_threshold[0] < timestep < self.config.full_threshold[1])
|
120 |
+
):
|
121 |
+
flag = True
|
122 |
+
else:
|
123 |
+
flag = False
|
124 |
+
count = (count + 1) % self.config.steps
|
125 |
+
return flag, count
|
126 |
+
|
127 |
+
@staticmethod
|
128 |
+
def _is_t_in_skip_config(all_timesteps, timestep, config):
|
129 |
+
is_t_in_skip_config = False
|
130 |
+
for key in config:
|
131 |
+
if key not in all_timesteps:
|
132 |
+
continue
|
133 |
+
index = all_timesteps.index(key)
|
134 |
+
skip_range = all_timesteps[index : index + 1 + int(config[key]["skip_count"])]
|
135 |
+
if timestep in skip_range:
|
136 |
+
is_t_in_skip_config = True
|
137 |
+
skip_range = [all_timesteps[index], all_timesteps[index + int(config[key]["skip_count"])]]
|
138 |
+
break
|
139 |
+
return is_t_in_skip_config, skip_range
|
140 |
+
|
141 |
+
def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
142 |
+
if not self.config.mlp_skip:
|
143 |
+
return False, None, False, None
|
144 |
+
|
145 |
+
if is_temporal:
|
146 |
+
cur_config = self.config.mlp_temporal_skip_config
|
147 |
+
else:
|
148 |
+
cur_config = self.config.mlp_spatial_skip_config
|
149 |
+
|
150 |
+
is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
|
151 |
+
next_flag = False
|
152 |
+
if (
|
153 |
+
self.config.mlp_skip
|
154 |
+
and (timestep is not None)
|
155 |
+
and (timestep in cur_config)
|
156 |
+
and (block_idx in cur_config[timestep]["block"])
|
157 |
+
):
|
158 |
+
flag = False
|
159 |
+
next_flag = True
|
160 |
+
count = count + 1
|
161 |
+
elif (
|
162 |
+
self.config.mlp_skip
|
163 |
+
and (timestep is not None)
|
164 |
+
and (is_t_in_skip_config)
|
165 |
+
and (block_idx in cur_config[skip_range[0]]["block"])
|
166 |
+
):
|
167 |
+
flag = True
|
168 |
+
count = 0
|
169 |
+
else:
|
170 |
+
flag = False
|
171 |
+
|
172 |
+
return flag, count, next_flag, skip_range
|
173 |
+
|
174 |
+
def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
|
175 |
+
if is_temporal:
|
176 |
+
self.config.temporal_mlp_outputs[(timestep, block_idx)] = ff_output
|
177 |
+
else:
|
178 |
+
self.config.spatial_mlp_outputs[(timestep, block_idx)] = ff_output
|
179 |
+
|
180 |
+
def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
|
181 |
+
skip_start_t = skip_range[0]
|
182 |
+
if is_temporal:
|
183 |
+
skip_output = (
|
184 |
+
self.config.temporal_mlp_outputs.get((skip_start_t, block_idx), None)
|
185 |
+
if self.config.temporal_mlp_outputs is not None
|
186 |
+
else None
|
187 |
+
)
|
188 |
+
else:
|
189 |
+
skip_output = (
|
190 |
+
self.config.spatial_mlp_outputs.get((skip_start_t, block_idx), None)
|
191 |
+
if self.config.spatial_mlp_outputs is not None
|
192 |
+
else None
|
193 |
+
)
|
194 |
+
|
195 |
+
if skip_output is not None:
|
196 |
+
if timestep == skip_range[-1]:
|
197 |
+
# TODO: save memory
|
198 |
+
if is_temporal:
|
199 |
+
del self.config.temporal_mlp_outputs[(skip_start_t, block_idx)]
|
200 |
+
else:
|
201 |
+
del self.config.spatial_mlp_outputs[(skip_start_t, block_idx)]
|
202 |
+
else:
|
203 |
+
raise ValueError(
|
204 |
+
f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
|
205 |
+
)
|
206 |
+
|
207 |
+
return skip_output
|
208 |
+
|
209 |
+
def get_spatial_mlp_outputs(self):
|
210 |
+
return self.config.spatial_mlp_outputs
|
211 |
+
|
212 |
+
def get_temporal_mlp_outputs(self):
|
213 |
+
return self.config.temporal_mlp_outputs
|
214 |
+
|
215 |
+
|
216 |
+
def set_pab_manager(config: PABConfig):
|
217 |
+
global PAB_MANAGER
|
218 |
+
PAB_MANAGER = PABManager(config)
|
219 |
+
|
220 |
+
|
221 |
+
def enable_pab():
|
222 |
+
if PAB_MANAGER is None:
|
223 |
+
return False
|
224 |
+
return (
|
225 |
+
PAB_MANAGER.config.cross_broadcast
|
226 |
+
or PAB_MANAGER.config.spatial_broadcast
|
227 |
+
or PAB_MANAGER.config.temporal_broadcast
|
228 |
+
)
|
229 |
+
|
230 |
+
|
231 |
+
def update_steps(steps: int):
|
232 |
+
if PAB_MANAGER is not None:
|
233 |
+
PAB_MANAGER.config.steps = steps
|
234 |
+
|
235 |
+
|
236 |
+
def if_broadcast_cross(timestep: int, count: int):
|
237 |
+
if not enable_pab():
|
238 |
+
return False, count
|
239 |
+
return PAB_MANAGER.if_broadcast_cross(timestep, count)
|
240 |
+
|
241 |
+
|
242 |
+
def if_broadcast_temporal(timestep: int, count: int):
|
243 |
+
if not enable_pab():
|
244 |
+
return False, count
|
245 |
+
return PAB_MANAGER.if_broadcast_temporal(timestep, count)
|
246 |
+
|
247 |
+
|
248 |
+
def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
|
249 |
+
if not enable_pab():
|
250 |
+
return False, count
|
251 |
+
return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
|
252 |
+
|
253 |
+
def if_broadcast_full(timestep: int, count: int, block_idx: int):
|
254 |
+
if not enable_pab():
|
255 |
+
return False, count
|
256 |
+
return PAB_MANAGER.if_broadcast_full(timestep, count, block_idx)
|
257 |
+
|
258 |
+
|
259 |
+
def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
|
260 |
+
if not enable_pab():
|
261 |
+
return False, count
|
262 |
+
return PAB_MANAGER.if_skip_mlp(timestep, count, block_idx, all_timesteps, is_temporal)
|
263 |
+
|
264 |
+
|
265 |
+
def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False):
|
266 |
+
return PAB_MANAGER.save_skip_output(timestep, block_idx, ff_output, is_temporal)
|
267 |
+
|
268 |
+
|
269 |
+
def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
|
270 |
+
return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
|
271 |
+
|
272 |
+
|
273 |
+
def get_diffusion_skip():
|
274 |
+
return enable_pab() and PAB_MANAGER.config.diffusion_skip
|
275 |
+
|
276 |
+
|
277 |
+
def get_diffusion_timestep_respacing():
|
278 |
+
return PAB_MANAGER.config.diffusion_timestep_respacing
|
279 |
+
|
280 |
+
|
281 |
+
def get_diffusion_skip_timestep():
|
282 |
+
return enable_pab() and PAB_MANAGER.config.diffusion_skip_timestep
|
283 |
+
|
284 |
+
|
285 |
+
def space_timesteps(time_steps, time_bins):
|
286 |
+
num_bins = len(time_bins)
|
287 |
+
bin_size = time_steps // num_bins
|
288 |
+
|
289 |
+
result = []
|
290 |
+
|
291 |
+
for i, bin_count in enumerate(time_bins):
|
292 |
+
start = i * bin_size
|
293 |
+
end = start + bin_size
|
294 |
+
|
295 |
+
bin_steps = np.linspace(start, end, bin_count, endpoint=False, dtype=int).tolist()
|
296 |
+
result.extend(bin_steps)
|
297 |
+
|
298 |
+
result_tensor = torch.tensor(result, dtype=torch.int32)
|
299 |
+
sorted_tensor = torch.sort(result_tensor, descending=True).values
|
300 |
+
|
301 |
+
return sorted_tensor
|
302 |
+
|
303 |
+
|
304 |
+
def skip_diffusion_timestep(timesteps, diffusion_skip_timestep):
|
305 |
+
if isinstance(timesteps, list):
|
306 |
+
# If timesteps is a list, we assume each element is a tensor
|
307 |
+
timesteps_np = [t.cpu().numpy() for t in timesteps]
|
308 |
+
device = timesteps[0].device
|
309 |
+
else:
|
310 |
+
# If timesteps is a tensor
|
311 |
+
timesteps_np = timesteps.cpu().numpy()
|
312 |
+
device = timesteps.device
|
313 |
+
|
314 |
+
num_bins = len(diffusion_skip_timestep)
|
315 |
+
|
316 |
+
if isinstance(timesteps_np, list):
|
317 |
+
bin_size = len(timesteps_np) // num_bins
|
318 |
+
new_timesteps = []
|
319 |
+
|
320 |
+
for i in range(num_bins):
|
321 |
+
bin_start = i * bin_size
|
322 |
+
bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
|
323 |
+
bin_timesteps = timesteps_np[bin_start:bin_end]
|
324 |
+
|
325 |
+
if diffusion_skip_timestep[i] == 0:
|
326 |
+
# If the bin is marked with 0, keep all timesteps
|
327 |
+
new_timesteps.extend(bin_timesteps)
|
328 |
+
elif diffusion_skip_timestep[i] == 1:
|
329 |
+
# If the bin is marked with 1, omit the last timestep in the bin
|
330 |
+
new_timesteps.extend(bin_timesteps[1:])
|
331 |
+
|
332 |
+
new_timesteps_tensor = [torch.tensor(t, device=device) for t in new_timesteps]
|
333 |
+
else:
|
334 |
+
bin_size = len(timesteps_np) // num_bins
|
335 |
+
new_timesteps = []
|
336 |
+
|
337 |
+
for i in range(num_bins):
|
338 |
+
bin_start = i * bin_size
|
339 |
+
bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
|
340 |
+
bin_timesteps = timesteps_np[bin_start:bin_end]
|
341 |
+
|
342 |
+
if diffusion_skip_timestep[i] == 0:
|
343 |
+
# If the bin is marked with 0, keep all timesteps
|
344 |
+
new_timesteps.extend(bin_timesteps)
|
345 |
+
elif diffusion_skip_timestep[i] == 1:
|
346 |
+
# If the bin is marked with 1, omit the last timestep in the bin
|
347 |
+
new_timesteps.extend(bin_timesteps[1:])
|
348 |
+
elif diffusion_skip_timestep[i] != 0:
|
349 |
+
# If the bin is marked with a non-zero value, randomly omit n timesteps
|
350 |
+
if len(bin_timesteps) > diffusion_skip_timestep[i]:
|
351 |
+
indices_to_remove = set(random.sample(range(len(bin_timesteps)), diffusion_skip_timestep[i]))
|
352 |
+
timesteps_to_keep = [
|
353 |
+
timestep for idx, timestep in enumerate(bin_timesteps) if idx not in indices_to_remove
|
354 |
+
]
|
355 |
+
else:
|
356 |
+
timesteps_to_keep = bin_timesteps # 如果bin_timesteps的长度小于等于n,则不删除任何元素
|
357 |
+
new_timesteps.extend(timesteps_to_keep)
|
358 |
+
|
359 |
+
new_timesteps_tensor = torch.tensor(new_timesteps, device=device)
|
360 |
+
|
361 |
+
if isinstance(timesteps, list):
|
362 |
+
return new_timesteps_tensor
|
363 |
+
else:
|
364 |
+
return new_timesteps_tensor
|
videosys/core/parallel_mgr.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Optional
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.distributed as dist
|
5 |
+
from colossalai.cluster.process_group_mesh import ProcessGroupMesh
|
6 |
+
from torch.distributed import ProcessGroup
|
7 |
+
|
8 |
+
from videosys.utils.logging import init_dist_logger, logger
|
9 |
+
from videosys.utils.utils import set_seed
|
10 |
+
|
11 |
+
PARALLEL_MANAGER = None
|
12 |
+
|
13 |
+
|
14 |
+
class ParallelManager(ProcessGroupMesh):
|
15 |
+
def __init__(self, dp_size, cp_size, sp_size):
|
16 |
+
super().__init__(dp_size, cp_size, sp_size)
|
17 |
+
dp_axis, cp_axis, sp_axis = 0, 1, 2
|
18 |
+
|
19 |
+
self.dp_size = dp_size
|
20 |
+
self.dp_group: ProcessGroup = self.get_group_along_axis(dp_axis)
|
21 |
+
self.dp_rank = dist.get_rank(self.dp_group)
|
22 |
+
|
23 |
+
self.cp_size = cp_size
|
24 |
+
self.cp_group: ProcessGroup = self.get_group_along_axis(cp_axis)
|
25 |
+
self.cp_rank = dist.get_rank(self.cp_group)
|
26 |
+
|
27 |
+
self.sp_size = sp_size
|
28 |
+
self.sp_group: ProcessGroup = self.get_group_along_axis(sp_axis)
|
29 |
+
self.sp_rank = dist.get_rank(self.sp_group)
|
30 |
+
self.enable_sp = sp_size > 1
|
31 |
+
|
32 |
+
logger.info(f"Init parallel manager with dp_size: {dp_size}, cp_size: {cp_size}, sp_size: {sp_size}")
|
33 |
+
|
34 |
+
|
35 |
+
def set_parallel_manager(dp_size, cp_size, sp_size):
|
36 |
+
global PARALLEL_MANAGER
|
37 |
+
PARALLEL_MANAGER = ParallelManager(dp_size, cp_size, sp_size)
|
38 |
+
|
39 |
+
|
40 |
+
def get_data_parallel_group():
|
41 |
+
return PARALLEL_MANAGER.dp_group
|
42 |
+
|
43 |
+
|
44 |
+
def get_data_parallel_size():
|
45 |
+
return PARALLEL_MANAGER.dp_size
|
46 |
+
|
47 |
+
|
48 |
+
def get_data_parallel_rank():
|
49 |
+
return PARALLEL_MANAGER.dp_rank
|
50 |
+
|
51 |
+
|
52 |
+
def get_sequence_parallel_group():
|
53 |
+
return PARALLEL_MANAGER.sp_group
|
54 |
+
|
55 |
+
|
56 |
+
def get_sequence_parallel_size():
|
57 |
+
return PARALLEL_MANAGER.sp_size
|
58 |
+
|
59 |
+
|
60 |
+
def get_sequence_parallel_rank():
|
61 |
+
return PARALLEL_MANAGER.sp_rank
|
62 |
+
|
63 |
+
|
64 |
+
def get_cfg_parallel_group():
|
65 |
+
return PARALLEL_MANAGER.cp_group
|
66 |
+
|
67 |
+
|
68 |
+
def get_cfg_parallel_size():
|
69 |
+
return PARALLEL_MANAGER.cp_size
|
70 |
+
|
71 |
+
|
72 |
+
def enable_sequence_parallel():
|
73 |
+
if PARALLEL_MANAGER is None:
|
74 |
+
return False
|
75 |
+
return PARALLEL_MANAGER.enable_sp
|
76 |
+
|
77 |
+
|
78 |
+
def get_parallel_manager():
|
79 |
+
return PARALLEL_MANAGER
|
80 |
+
|
81 |
+
|
82 |
+
def initialize(
|
83 |
+
rank=0,
|
84 |
+
world_size=1,
|
85 |
+
init_method=None,
|
86 |
+
seed: Optional[int] = None,
|
87 |
+
sp_size: Optional[int] = None,
|
88 |
+
enable_cp: bool = True,
|
89 |
+
):
|
90 |
+
if not dist.is_initialized():
|
91 |
+
try:
|
92 |
+
dist.destroy_process_group()
|
93 |
+
except Exception:
|
94 |
+
pass
|
95 |
+
dist.init_process_group(backend="nccl", init_method=init_method, world_size=world_size, rank=rank)
|
96 |
+
torch.cuda.set_device(rank)
|
97 |
+
init_dist_logger()
|
98 |
+
torch.backends.cuda.matmul.allow_tf32 = True
|
99 |
+
torch.backends.cudnn.allow_tf32 = True
|
100 |
+
|
101 |
+
# init sequence parallel
|
102 |
+
if sp_size is None:
|
103 |
+
sp_size = dist.get_world_size()
|
104 |
+
dp_size = 1
|
105 |
+
else:
|
106 |
+
assert dist.get_world_size() % sp_size == 0, f"world_size {dist.get_world_size()} must be divisible by sp_size"
|
107 |
+
dp_size = dist.get_world_size() // sp_size
|
108 |
+
|
109 |
+
# update cfg parallel
|
110 |
+
if enable_cp and sp_size % 2 == 0:
|
111 |
+
sp_size = sp_size // 2
|
112 |
+
cp_size = 2
|
113 |
+
else:
|
114 |
+
cp_size = 1
|
115 |
+
|
116 |
+
set_parallel_manager(dp_size, cp_size, sp_size)
|
117 |
+
|
118 |
+
if seed is not None:
|
119 |
+
set_seed(seed + get_data_parallel_rank())
|
videosys/core/pipeline.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import abstractmethod
|
2 |
+
from dataclasses import dataclass
|
3 |
+
|
4 |
+
import torch
|
5 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
6 |
+
from diffusers.utils import BaseOutput
|
7 |
+
|
8 |
+
|
9 |
+
class VideoSysPipeline(DiffusionPipeline):
|
10 |
+
def __init__(self):
|
11 |
+
super().__init__()
|
12 |
+
|
13 |
+
@staticmethod
|
14 |
+
def set_eval_and_device(device: torch.device, *modules):
|
15 |
+
for module in modules:
|
16 |
+
module.eval()
|
17 |
+
module.to(device)
|
18 |
+
|
19 |
+
@abstractmethod
|
20 |
+
def generate(self, *args, **kwargs):
|
21 |
+
pass
|
22 |
+
|
23 |
+
def __call__(self, *args, **kwargs):
|
24 |
+
"""
|
25 |
+
In diffusers, it is a convention to call the pipeline object.
|
26 |
+
But in VideoSys, we will use the generate method for better prompt.
|
27 |
+
This is a wrapper for the generate method to support the diffusers usage.
|
28 |
+
"""
|
29 |
+
return self.generate(*args, **kwargs)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class VideoSysPipelineOutput(BaseOutput):
|
34 |
+
video: torch.Tensor
|
videosys/core/shardformer/__init__.py
ADDED
File without changes
|
videosys/core/shardformer/t5/__init__.py
ADDED
File without changes
|
videosys/core/shardformer/t5/modeling.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
|
5 |
+
class T5LayerNorm(nn.Module):
|
6 |
+
def __init__(self, hidden_size, eps=1e-6):
|
7 |
+
"""
|
8 |
+
Construct a layernorm module in the T5 style. No bias and no subtraction of mean.
|
9 |
+
"""
|
10 |
+
super().__init__()
|
11 |
+
self.weight = nn.Parameter(torch.ones(hidden_size))
|
12 |
+
self.variance_epsilon = eps
|
13 |
+
|
14 |
+
def forward(self, hidden_states):
|
15 |
+
# T5 uses a layer_norm which only scales and doesn't shift, which is also known as Root Mean
|
16 |
+
# Square Layer Normalization https://arxiv.org/abs/1910.07467 thus varience is calculated
|
17 |
+
# w/o mean and there is no bias. Additionally we want to make sure that the accumulation for
|
18 |
+
# half-precision inputs is done in fp32
|
19 |
+
|
20 |
+
variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True)
|
21 |
+
hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
|
22 |
+
|
23 |
+
# convert into half-precision if necessary
|
24 |
+
if self.weight.dtype in [torch.float16, torch.bfloat16]:
|
25 |
+
hidden_states = hidden_states.to(self.weight.dtype)
|
26 |
+
|
27 |
+
return self.weight * hidden_states
|
28 |
+
|
29 |
+
@staticmethod
|
30 |
+
def from_native_module(module, *args, **kwargs):
|
31 |
+
assert module.__class__.__name__ == "FusedRMSNorm", (
|
32 |
+
"Recovering T5LayerNorm requires the original layer to be apex's Fused RMS Norm."
|
33 |
+
"Apex's fused norm is automatically used by Hugging Face Transformers https://github.com/huggingface/transformers/blob/main/src/transformers/models/t5/modeling_t5.py#L265C5-L265C48"
|
34 |
+
)
|
35 |
+
|
36 |
+
layer_norm = T5LayerNorm(module.normalized_shape, eps=module.eps)
|
37 |
+
layer_norm.weight.data.copy_(module.weight.data)
|
38 |
+
layer_norm = layer_norm.to(module.weight.device)
|
39 |
+
return layer_norm
|
videosys/core/shardformer/t5/policy.py
ADDED
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from colossalai.shardformer.modeling.jit import get_jit_fused_dropout_add_func
|
2 |
+
from colossalai.shardformer.modeling.t5 import get_jit_fused_T5_layer_ff_forward, get_T5_layer_self_attention_forward
|
3 |
+
from colossalai.shardformer.policies.base_policy import Policy, SubModuleReplacementDescription
|
4 |
+
|
5 |
+
|
6 |
+
class T5EncoderPolicy(Policy):
|
7 |
+
def config_sanity_check(self):
|
8 |
+
assert not self.shard_config.enable_tensor_parallelism
|
9 |
+
assert not self.shard_config.enable_flash_attention
|
10 |
+
|
11 |
+
def preprocess(self):
|
12 |
+
return self.model
|
13 |
+
|
14 |
+
def module_policy(self):
|
15 |
+
from transformers.models.t5.modeling_t5 import T5LayerFF, T5LayerSelfAttention, T5Stack
|
16 |
+
|
17 |
+
policy = {}
|
18 |
+
|
19 |
+
# check whether apex is installed
|
20 |
+
try:
|
21 |
+
from apex.normalization import FusedRMSNorm # noqa
|
22 |
+
from videosys.core.shardformer.t5.modeling import T5LayerNorm
|
23 |
+
|
24 |
+
# recover hf from fused rms norm to T5 norm which is faster
|
25 |
+
self.append_or_create_submodule_replacement(
|
26 |
+
description=SubModuleReplacementDescription(
|
27 |
+
suffix="layer_norm",
|
28 |
+
target_module=T5LayerNorm,
|
29 |
+
),
|
30 |
+
policy=policy,
|
31 |
+
target_key=T5LayerFF,
|
32 |
+
)
|
33 |
+
self.append_or_create_submodule_replacement(
|
34 |
+
description=SubModuleReplacementDescription(suffix="layer_norm", target_module=T5LayerNorm),
|
35 |
+
policy=policy,
|
36 |
+
target_key=T5LayerSelfAttention,
|
37 |
+
)
|
38 |
+
self.append_or_create_submodule_replacement(
|
39 |
+
description=SubModuleReplacementDescription(suffix="final_layer_norm", target_module=T5LayerNorm),
|
40 |
+
policy=policy,
|
41 |
+
target_key=T5Stack,
|
42 |
+
)
|
43 |
+
except (ImportError, ModuleNotFoundError):
|
44 |
+
pass
|
45 |
+
|
46 |
+
# use jit operator
|
47 |
+
if self.shard_config.enable_jit_fused:
|
48 |
+
self.append_or_create_method_replacement(
|
49 |
+
description={
|
50 |
+
"forward": get_jit_fused_T5_layer_ff_forward(),
|
51 |
+
"dropout_add": get_jit_fused_dropout_add_func(),
|
52 |
+
},
|
53 |
+
policy=policy,
|
54 |
+
target_key=T5LayerFF,
|
55 |
+
)
|
56 |
+
self.append_or_create_method_replacement(
|
57 |
+
description={
|
58 |
+
"forward": get_T5_layer_self_attention_forward(),
|
59 |
+
"dropout_add": get_jit_fused_dropout_add_func(),
|
60 |
+
},
|
61 |
+
policy=policy,
|
62 |
+
target_key=T5LayerSelfAttention,
|
63 |
+
)
|
64 |
+
|
65 |
+
return policy
|
66 |
+
|
67 |
+
def postprocess(self):
|
68 |
+
return self.model
|
videosys/datasets/dataloader.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
from typing import Iterator, Optional
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
7 |
+
from torch.utils.data.distributed import DistributedSampler
|
8 |
+
|
9 |
+
from videosys.core.parallel_mgr import ParallelManager
|
10 |
+
|
11 |
+
|
12 |
+
class StatefulDistributedSampler(DistributedSampler):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
dataset: Dataset,
|
16 |
+
num_replicas: Optional[int] = None,
|
17 |
+
rank: Optional[int] = None,
|
18 |
+
shuffle: bool = True,
|
19 |
+
seed: int = 0,
|
20 |
+
drop_last: bool = False,
|
21 |
+
) -> None:
|
22 |
+
super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
|
23 |
+
self.start_index: int = 0
|
24 |
+
|
25 |
+
def __iter__(self) -> Iterator:
|
26 |
+
iterator = super().__iter__()
|
27 |
+
indices = list(iterator)
|
28 |
+
indices = indices[self.start_index :]
|
29 |
+
return iter(indices)
|
30 |
+
|
31 |
+
def __len__(self) -> int:
|
32 |
+
return self.num_samples - self.start_index
|
33 |
+
|
34 |
+
def set_start_index(self, start_index: int) -> None:
|
35 |
+
self.start_index = start_index
|
36 |
+
|
37 |
+
|
38 |
+
def prepare_dataloader(
|
39 |
+
dataset,
|
40 |
+
batch_size,
|
41 |
+
shuffle=False,
|
42 |
+
seed=1024,
|
43 |
+
drop_last=False,
|
44 |
+
pin_memory=False,
|
45 |
+
num_workers=0,
|
46 |
+
pg_manager: Optional[ParallelManager] = None,
|
47 |
+
**kwargs,
|
48 |
+
):
|
49 |
+
r"""
|
50 |
+
Prepare a dataloader for distributed training. The dataloader will be wrapped by
|
51 |
+
`torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
|
52 |
+
|
53 |
+
|
54 |
+
Args:
|
55 |
+
dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
|
56 |
+
shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
|
57 |
+
seed (int, optional): Random worker seed for sampling, defaults to 1024.
|
58 |
+
add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
|
59 |
+
drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
|
60 |
+
is not divisible by the batch size. If False and the size of dataset is not divisible by
|
61 |
+
the batch size, then the last batch will be smaller, defaults to False.
|
62 |
+
pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
|
63 |
+
num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
|
64 |
+
kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
|
65 |
+
`DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
:class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
|
69 |
+
"""
|
70 |
+
_kwargs = kwargs.copy()
|
71 |
+
sampler = StatefulDistributedSampler(
|
72 |
+
dataset,
|
73 |
+
num_replicas=pg_manager.size(pg_manager.dp_axis),
|
74 |
+
rank=pg_manager.coordinate(pg_manager.dp_axis),
|
75 |
+
shuffle=shuffle,
|
76 |
+
)
|
77 |
+
|
78 |
+
# Deterministic dataloader
|
79 |
+
def seed_worker(worker_id):
|
80 |
+
worker_seed = seed
|
81 |
+
np.random.seed(worker_seed)
|
82 |
+
torch.manual_seed(worker_seed)
|
83 |
+
random.seed(worker_seed)
|
84 |
+
|
85 |
+
return DataLoader(
|
86 |
+
dataset,
|
87 |
+
batch_size=batch_size,
|
88 |
+
sampler=sampler,
|
89 |
+
worker_init_fn=seed_worker,
|
90 |
+
drop_last=drop_last,
|
91 |
+
pin_memory=pin_memory,
|
92 |
+
num_workers=num_workers,
|
93 |
+
**_kwargs,
|
94 |
+
)
|
videosys/datasets/image_transform.py
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from DiT
|
2 |
+
|
3 |
+
# This source code is licensed under the license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# References:
|
7 |
+
# DiT: https://github.com/facebookresearch/DiT
|
8 |
+
# --------------------------------------------------------
|
9 |
+
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import torchvision.transforms as transforms
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
|
16 |
+
def center_crop_arr(pil_image, image_size):
|
17 |
+
"""
|
18 |
+
Center cropping implementation from ADM.
|
19 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
20 |
+
"""
|
21 |
+
while min(*pil_image.size) >= 2 * image_size:
|
22 |
+
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
|
23 |
+
|
24 |
+
scale = image_size / min(*pil_image.size)
|
25 |
+
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
26 |
+
|
27 |
+
arr = np.array(pil_image)
|
28 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
29 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
30 |
+
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
31 |
+
|
32 |
+
|
33 |
+
def get_transforms_image(image_size=256):
|
34 |
+
transform = transforms.Compose(
|
35 |
+
[
|
36 |
+
transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
|
37 |
+
transforms.RandomHorizontalFlip(),
|
38 |
+
transforms.ToTensor(),
|
39 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
40 |
+
]
|
41 |
+
)
|
42 |
+
return transform
|
videosys/datasets/video_transform.py
ADDED
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Adapted from OpenSora and Latte
|
2 |
+
|
3 |
+
# This source code is licensed under the license found in the
|
4 |
+
# LICENSE file in the root directory of this source tree.
|
5 |
+
# --------------------------------------------------------
|
6 |
+
# References:
|
7 |
+
# OpenSora: https://github.com/hpcaitech/Open-Sora
|
8 |
+
# Latte: https://github.com/Vchitect/Latte
|
9 |
+
# --------------------------------------------------------
|
10 |
+
|
11 |
+
import numbers
|
12 |
+
import random
|
13 |
+
|
14 |
+
import numpy as np
|
15 |
+
import torch
|
16 |
+
from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
def _is_tensor_video_clip(clip):
|
20 |
+
if not torch.is_tensor(clip):
|
21 |
+
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
22 |
+
|
23 |
+
if not clip.ndimension() == 4:
|
24 |
+
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
25 |
+
|
26 |
+
return True
|
27 |
+
|
28 |
+
|
29 |
+
def center_crop_arr(pil_image, image_size):
|
30 |
+
"""
|
31 |
+
Center cropping implementation from ADM.
|
32 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
33 |
+
"""
|
34 |
+
while min(*pil_image.size) >= 2 * image_size:
|
35 |
+
pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
|
36 |
+
|
37 |
+
scale = image_size / min(*pil_image.size)
|
38 |
+
pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
|
39 |
+
|
40 |
+
arr = np.array(pil_image)
|
41 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
42 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
43 |
+
return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
|
44 |
+
|
45 |
+
|
46 |
+
def crop(clip, i, j, h, w):
|
47 |
+
"""
|
48 |
+
Args:
|
49 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
50 |
+
"""
|
51 |
+
if len(clip.size()) != 4:
|
52 |
+
raise ValueError("clip should be a 4D tensor")
|
53 |
+
return clip[..., i : i + h, j : j + w]
|
54 |
+
|
55 |
+
|
56 |
+
def resize(clip, target_size, interpolation_mode):
|
57 |
+
if len(target_size) != 2:
|
58 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
59 |
+
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
|
60 |
+
|
61 |
+
|
62 |
+
def resize_scale(clip, target_size, interpolation_mode):
|
63 |
+
if len(target_size) != 2:
|
64 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
65 |
+
H, W = clip.size(-2), clip.size(-1)
|
66 |
+
scale_ = target_size[0] / min(H, W)
|
67 |
+
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
|
68 |
+
|
69 |
+
|
70 |
+
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
71 |
+
"""
|
72 |
+
Do spatial cropping and resizing to the video clip
|
73 |
+
Args:
|
74 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
75 |
+
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
76 |
+
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
77 |
+
h (int): Height of the cropped region.
|
78 |
+
w (int): Width of the cropped region.
|
79 |
+
size (tuple(int, int)): height and width of resized clip
|
80 |
+
Returns:
|
81 |
+
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
|
82 |
+
"""
|
83 |
+
if not _is_tensor_video_clip(clip):
|
84 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
85 |
+
clip = crop(clip, i, j, h, w)
|
86 |
+
clip = resize(clip, size, interpolation_mode)
|
87 |
+
return clip
|
88 |
+
|
89 |
+
|
90 |
+
def center_crop(clip, crop_size):
|
91 |
+
if not _is_tensor_video_clip(clip):
|
92 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
93 |
+
h, w = clip.size(-2), clip.size(-1)
|
94 |
+
th, tw = crop_size
|
95 |
+
if h < th or w < tw:
|
96 |
+
raise ValueError("height and width must be no smaller than crop_size")
|
97 |
+
|
98 |
+
i = int(round((h - th) / 2.0))
|
99 |
+
j = int(round((w - tw) / 2.0))
|
100 |
+
return crop(clip, i, j, th, tw)
|
101 |
+
|
102 |
+
|
103 |
+
def center_crop_using_short_edge(clip):
|
104 |
+
if not _is_tensor_video_clip(clip):
|
105 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
106 |
+
h, w = clip.size(-2), clip.size(-1)
|
107 |
+
if h < w:
|
108 |
+
th, tw = h, h
|
109 |
+
i = 0
|
110 |
+
j = int(round((w - tw) / 2.0))
|
111 |
+
else:
|
112 |
+
th, tw = w, w
|
113 |
+
i = int(round((h - th) / 2.0))
|
114 |
+
j = 0
|
115 |
+
return crop(clip, i, j, th, tw)
|
116 |
+
|
117 |
+
|
118 |
+
def random_shift_crop(clip):
|
119 |
+
"""
|
120 |
+
Slide along the long edge, with the short edge as crop size
|
121 |
+
"""
|
122 |
+
if not _is_tensor_video_clip(clip):
|
123 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
124 |
+
h, w = clip.size(-2), clip.size(-1)
|
125 |
+
|
126 |
+
if h <= w:
|
127 |
+
short_edge = h
|
128 |
+
else:
|
129 |
+
short_edge = w
|
130 |
+
|
131 |
+
th, tw = short_edge, short_edge
|
132 |
+
|
133 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
134 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
135 |
+
return crop(clip, i, j, th, tw)
|
136 |
+
|
137 |
+
|
138 |
+
def to_tensor(clip):
|
139 |
+
"""
|
140 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
141 |
+
permute the dimensions of clip tensor
|
142 |
+
Args:
|
143 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
144 |
+
Return:
|
145 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
146 |
+
"""
|
147 |
+
_is_tensor_video_clip(clip)
|
148 |
+
if not clip.dtype == torch.uint8:
|
149 |
+
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
150 |
+
# return clip.float().permute(3, 0, 1, 2) / 255.0
|
151 |
+
return clip.float() / 255.0
|
152 |
+
|
153 |
+
|
154 |
+
def normalize(clip, mean, std, inplace=False):
|
155 |
+
"""
|
156 |
+
Args:
|
157 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
158 |
+
mean (tuple): pixel RGB mean. Size is (3)
|
159 |
+
std (tuple): pixel standard deviation. Size is (3)
|
160 |
+
Returns:
|
161 |
+
normalized clip (torch.tensor): Size is (T, C, H, W)
|
162 |
+
"""
|
163 |
+
if not _is_tensor_video_clip(clip):
|
164 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
165 |
+
if not inplace:
|
166 |
+
clip = clip.clone()
|
167 |
+
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
168 |
+
# print(mean)
|
169 |
+
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
170 |
+
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
171 |
+
return clip
|
172 |
+
|
173 |
+
|
174 |
+
def hflip(clip):
|
175 |
+
"""
|
176 |
+
Args:
|
177 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
178 |
+
Returns:
|
179 |
+
flipped clip (torch.tensor): Size is (T, C, H, W)
|
180 |
+
"""
|
181 |
+
if not _is_tensor_video_clip(clip):
|
182 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
183 |
+
return clip.flip(-1)
|
184 |
+
|
185 |
+
|
186 |
+
class RandomCropVideo:
|
187 |
+
def __init__(self, size):
|
188 |
+
if isinstance(size, numbers.Number):
|
189 |
+
self.size = (int(size), int(size))
|
190 |
+
else:
|
191 |
+
self.size = size
|
192 |
+
|
193 |
+
def __call__(self, clip):
|
194 |
+
"""
|
195 |
+
Args:
|
196 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
197 |
+
Returns:
|
198 |
+
torch.tensor: randomly cropped video clip.
|
199 |
+
size is (T, C, OH, OW)
|
200 |
+
"""
|
201 |
+
i, j, h, w = self.get_params(clip)
|
202 |
+
return crop(clip, i, j, h, w)
|
203 |
+
|
204 |
+
def get_params(self, clip):
|
205 |
+
h, w = clip.shape[-2:]
|
206 |
+
th, tw = self.size
|
207 |
+
|
208 |
+
if h < th or w < tw:
|
209 |
+
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
|
210 |
+
|
211 |
+
if w == tw and h == th:
|
212 |
+
return 0, 0, h, w
|
213 |
+
|
214 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
215 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
216 |
+
|
217 |
+
return i, j, th, tw
|
218 |
+
|
219 |
+
def __repr__(self) -> str:
|
220 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
221 |
+
|
222 |
+
|
223 |
+
class CenterCropResizeVideo:
|
224 |
+
"""
|
225 |
+
First use the short side for cropping length,
|
226 |
+
center crop video, then resize to the specified size
|
227 |
+
"""
|
228 |
+
|
229 |
+
def __init__(
|
230 |
+
self,
|
231 |
+
size,
|
232 |
+
interpolation_mode="bilinear",
|
233 |
+
):
|
234 |
+
if isinstance(size, tuple):
|
235 |
+
if len(size) != 2:
|
236 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
237 |
+
self.size = size
|
238 |
+
else:
|
239 |
+
self.size = (size, size)
|
240 |
+
|
241 |
+
self.interpolation_mode = interpolation_mode
|
242 |
+
|
243 |
+
def __call__(self, clip):
|
244 |
+
"""
|
245 |
+
Args:
|
246 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
247 |
+
Returns:
|
248 |
+
torch.tensor: scale resized / center cropped video clip.
|
249 |
+
size is (T, C, crop_size, crop_size)
|
250 |
+
"""
|
251 |
+
clip_center_crop = center_crop_using_short_edge(clip)
|
252 |
+
clip_center_crop_resize = resize(
|
253 |
+
clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode
|
254 |
+
)
|
255 |
+
return clip_center_crop_resize
|
256 |
+
|
257 |
+
def __repr__(self) -> str:
|
258 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
259 |
+
|
260 |
+
|
261 |
+
class UCFCenterCropVideo:
|
262 |
+
"""
|
263 |
+
First scale to the specified size in equal proportion to the short edge,
|
264 |
+
then center cropping
|
265 |
+
"""
|
266 |
+
|
267 |
+
def __init__(
|
268 |
+
self,
|
269 |
+
size,
|
270 |
+
interpolation_mode="bilinear",
|
271 |
+
):
|
272 |
+
if isinstance(size, tuple):
|
273 |
+
if len(size) != 2:
|
274 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
275 |
+
self.size = size
|
276 |
+
else:
|
277 |
+
self.size = (size, size)
|
278 |
+
|
279 |
+
self.interpolation_mode = interpolation_mode
|
280 |
+
|
281 |
+
def __call__(self, clip):
|
282 |
+
"""
|
283 |
+
Args:
|
284 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
285 |
+
Returns:
|
286 |
+
torch.tensor: scale resized / center cropped video clip.
|
287 |
+
size is (T, C, crop_size, crop_size)
|
288 |
+
"""
|
289 |
+
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
290 |
+
clip_center_crop = center_crop(clip_resize, self.size)
|
291 |
+
return clip_center_crop
|
292 |
+
|
293 |
+
def __repr__(self) -> str:
|
294 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
295 |
+
|
296 |
+
|
297 |
+
class KineticsRandomCropResizeVideo:
|
298 |
+
"""
|
299 |
+
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
|
300 |
+
"""
|
301 |
+
|
302 |
+
def __init__(
|
303 |
+
self,
|
304 |
+
size,
|
305 |
+
interpolation_mode="bilinear",
|
306 |
+
):
|
307 |
+
if isinstance(size, tuple):
|
308 |
+
if len(size) != 2:
|
309 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
310 |
+
self.size = size
|
311 |
+
else:
|
312 |
+
self.size = (size, size)
|
313 |
+
|
314 |
+
self.interpolation_mode = interpolation_mode
|
315 |
+
|
316 |
+
def __call__(self, clip):
|
317 |
+
clip_random_crop = random_shift_crop(clip)
|
318 |
+
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
|
319 |
+
return clip_resize
|
320 |
+
|
321 |
+
|
322 |
+
class CenterCropVideo:
|
323 |
+
def __init__(
|
324 |
+
self,
|
325 |
+
size,
|
326 |
+
interpolation_mode="bilinear",
|
327 |
+
):
|
328 |
+
if isinstance(size, tuple):
|
329 |
+
if len(size) != 2:
|
330 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
331 |
+
self.size = size
|
332 |
+
else:
|
333 |
+
self.size = (size, size)
|
334 |
+
|
335 |
+
self.interpolation_mode = interpolation_mode
|
336 |
+
|
337 |
+
def __call__(self, clip):
|
338 |
+
"""
|
339 |
+
Args:
|
340 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
341 |
+
Returns:
|
342 |
+
torch.tensor: center cropped video clip.
|
343 |
+
size is (T, C, crop_size, crop_size)
|
344 |
+
"""
|
345 |
+
clip_center_crop = center_crop(clip, self.size)
|
346 |
+
return clip_center_crop
|
347 |
+
|
348 |
+
def __repr__(self) -> str:
|
349 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
350 |
+
|
351 |
+
|
352 |
+
class NormalizeVideo:
|
353 |
+
"""
|
354 |
+
Normalize the video clip by mean subtraction and division by standard deviation
|
355 |
+
Args:
|
356 |
+
mean (3-tuple): pixel RGB mean
|
357 |
+
std (3-tuple): pixel RGB standard deviation
|
358 |
+
inplace (boolean): whether do in-place normalization
|
359 |
+
"""
|
360 |
+
|
361 |
+
def __init__(self, mean, std, inplace=False):
|
362 |
+
self.mean = mean
|
363 |
+
self.std = std
|
364 |
+
self.inplace = inplace
|
365 |
+
|
366 |
+
def __call__(self, clip):
|
367 |
+
"""
|
368 |
+
Args:
|
369 |
+
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
|
370 |
+
"""
|
371 |
+
return normalize(clip, self.mean, self.std, self.inplace)
|
372 |
+
|
373 |
+
def __repr__(self) -> str:
|
374 |
+
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
|
375 |
+
|
376 |
+
|
377 |
+
class ToTensorVideo:
|
378 |
+
"""
|
379 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
380 |
+
permute the dimensions of clip tensor
|
381 |
+
"""
|
382 |
+
|
383 |
+
def __init__(self):
|
384 |
+
pass
|
385 |
+
|
386 |
+
def __call__(self, clip):
|
387 |
+
"""
|
388 |
+
Args:
|
389 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
390 |
+
Return:
|
391 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
392 |
+
"""
|
393 |
+
return to_tensor(clip)
|
394 |
+
|
395 |
+
def __repr__(self) -> str:
|
396 |
+
return self.__class__.__name__
|
397 |
+
|
398 |
+
|
399 |
+
class RandomHorizontalFlipVideo:
|
400 |
+
"""
|
401 |
+
Flip the video clip along the horizontal direction with a given probability
|
402 |
+
Args:
|
403 |
+
p (float): probability of the clip being flipped. Default value is 0.5
|
404 |
+
"""
|
405 |
+
|
406 |
+
def __init__(self, p=0.5):
|
407 |
+
self.p = p
|
408 |
+
|
409 |
+
def __call__(self, clip):
|
410 |
+
"""
|
411 |
+
Args:
|
412 |
+
clip (torch.tensor): Size is (T, C, H, W)
|
413 |
+
Return:
|
414 |
+
clip (torch.tensor): Size is (T, C, H, W)
|
415 |
+
"""
|
416 |
+
if random.random() < self.p:
|
417 |
+
clip = hflip(clip)
|
418 |
+
return clip
|
419 |
+
|
420 |
+
def __repr__(self) -> str:
|
421 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
422 |
+
|
423 |
+
|
424 |
+
# ------------------------------------------------------------
|
425 |
+
# --------------------- Sampling ---------------------------
|
426 |
+
# ------------------------------------------------------------
|
427 |
+
class TemporalRandomCrop(object):
|
428 |
+
"""Temporally crop the given frame indices at a random location.
|
429 |
+
|
430 |
+
Args:
|
431 |
+
size (int): Desired length of frames will be seen in the model.
|
432 |
+
"""
|
433 |
+
|
434 |
+
def __init__(self, size):
|
435 |
+
self.size = size
|
436 |
+
|
437 |
+
def __call__(self, total_frames):
|
438 |
+
rand_end = max(0, total_frames - self.size - 1)
|
439 |
+
begin_index = random.randint(0, rand_end)
|
440 |
+
end_index = min(begin_index + self.size, total_frames)
|
441 |
+
return begin_index, end_index
|
videosys/diffusion/__init__.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos and Meta DiT
|
2 |
+
# DiT: https://github.com/facebookresearch/DiT/tree/main
|
3 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
4 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
5 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
6 |
+
|
7 |
+
from . import gaussian_diffusion as gd
|
8 |
+
from .respace import SpacedDiffusion, space_timesteps
|
9 |
+
|
10 |
+
|
11 |
+
def create_diffusion(
|
12 |
+
timestep_respacing,
|
13 |
+
noise_schedule="linear",
|
14 |
+
use_kl=False,
|
15 |
+
sigma_small=False,
|
16 |
+
predict_xstart=False,
|
17 |
+
learn_sigma=True,
|
18 |
+
rescale_learned_sigmas=False,
|
19 |
+
diffusion_steps=1000,
|
20 |
+
):
|
21 |
+
betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
|
22 |
+
if use_kl:
|
23 |
+
loss_type = gd.LossType.RESCALED_KL
|
24 |
+
elif rescale_learned_sigmas:
|
25 |
+
loss_type = gd.LossType.RESCALED_MSE
|
26 |
+
else:
|
27 |
+
loss_type = gd.LossType.MSE
|
28 |
+
if timestep_respacing is None or timestep_respacing == "":
|
29 |
+
timestep_respacing = [diffusion_steps]
|
30 |
+
return SpacedDiffusion(
|
31 |
+
use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
|
32 |
+
betas=betas,
|
33 |
+
model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
|
34 |
+
model_var_type=(
|
35 |
+
(gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
|
36 |
+
if not learn_sigma
|
37 |
+
else gd.ModelVarType.LEARNED_RANGE
|
38 |
+
),
|
39 |
+
loss_type=loss_type
|
40 |
+
# rescale_timesteps=rescale_timesteps,
|
41 |
+
)
|
videosys/diffusion/diffusion_utils.py
ADDED
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
|
10 |
+
def normal_kl(mean1, logvar1, mean2, logvar2):
|
11 |
+
"""
|
12 |
+
Compute the KL divergence between two gaussians.
|
13 |
+
Shapes are automatically broadcasted, so batches can be compared to
|
14 |
+
scalars, among other use cases.
|
15 |
+
"""
|
16 |
+
tensor = None
|
17 |
+
for obj in (mean1, logvar1, mean2, logvar2):
|
18 |
+
if isinstance(obj, th.Tensor):
|
19 |
+
tensor = obj
|
20 |
+
break
|
21 |
+
assert tensor is not None, "at least one argument must be a Tensor"
|
22 |
+
|
23 |
+
# Force variances to be Tensors. Broadcasting helps convert scalars to
|
24 |
+
# Tensors, but it does not work for th.exp().
|
25 |
+
logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
|
26 |
+
|
27 |
+
return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
|
28 |
+
|
29 |
+
|
30 |
+
def approx_standard_normal_cdf(x):
|
31 |
+
"""
|
32 |
+
A fast approximation of the cumulative distribution function of the
|
33 |
+
standard normal.
|
34 |
+
"""
|
35 |
+
return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
|
36 |
+
|
37 |
+
|
38 |
+
def continuous_gaussian_log_likelihood(x, *, means, log_scales):
|
39 |
+
"""
|
40 |
+
Compute the log-likelihood of a continuous Gaussian distribution.
|
41 |
+
:param x: the targets
|
42 |
+
:param means: the Gaussian mean Tensor.
|
43 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
44 |
+
:return: a tensor like x of log probabilities (in nats).
|
45 |
+
"""
|
46 |
+
centered_x = x - means
|
47 |
+
inv_stdv = th.exp(-log_scales)
|
48 |
+
normalized_x = centered_x * inv_stdv
|
49 |
+
log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
|
50 |
+
return log_probs
|
51 |
+
|
52 |
+
|
53 |
+
def discretized_gaussian_log_likelihood(x, *, means, log_scales):
|
54 |
+
"""
|
55 |
+
Compute the log-likelihood of a Gaussian distribution discretizing to a
|
56 |
+
given image.
|
57 |
+
:param x: the target images. It is assumed that this was uint8 values,
|
58 |
+
rescaled to the range [-1, 1].
|
59 |
+
:param means: the Gaussian mean Tensor.
|
60 |
+
:param log_scales: the Gaussian log stddev Tensor.
|
61 |
+
:return: a tensor like x of log probabilities (in nats).
|
62 |
+
"""
|
63 |
+
assert x.shape == means.shape == log_scales.shape
|
64 |
+
centered_x = x - means
|
65 |
+
inv_stdv = th.exp(-log_scales)
|
66 |
+
plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
|
67 |
+
cdf_plus = approx_standard_normal_cdf(plus_in)
|
68 |
+
min_in = inv_stdv * (centered_x - 1.0 / 255.0)
|
69 |
+
cdf_min = approx_standard_normal_cdf(min_in)
|
70 |
+
log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
|
71 |
+
log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
|
72 |
+
cdf_delta = cdf_plus - cdf_min
|
73 |
+
log_probs = th.where(
|
74 |
+
x < -0.999,
|
75 |
+
log_cdf_plus,
|
76 |
+
th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
|
77 |
+
)
|
78 |
+
assert log_probs.shape == x.shape
|
79 |
+
return log_probs
|
videosys/diffusion/gaussian_diffusion.py
ADDED
@@ -0,0 +1,829 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
|
7 |
+
import enum
|
8 |
+
import math
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch as th
|
12 |
+
|
13 |
+
from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
|
14 |
+
|
15 |
+
|
16 |
+
def mean_flat(tensor):
|
17 |
+
"""
|
18 |
+
Take the mean over all non-batch dimensions.
|
19 |
+
"""
|
20 |
+
return tensor.mean(dim=list(range(1, len(tensor.shape))))
|
21 |
+
|
22 |
+
|
23 |
+
class ModelMeanType(enum.Enum):
|
24 |
+
"""
|
25 |
+
Which type of output the model predicts.
|
26 |
+
"""
|
27 |
+
|
28 |
+
PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
|
29 |
+
START_X = enum.auto() # the model predicts x_0
|
30 |
+
EPSILON = enum.auto() # the model predicts epsilon
|
31 |
+
|
32 |
+
|
33 |
+
class ModelVarType(enum.Enum):
|
34 |
+
"""
|
35 |
+
What is used as the model's output variance.
|
36 |
+
The LEARNED_RANGE option has been added to allow the model to predict
|
37 |
+
values between FIXED_SMALL and FIXED_LARGE, making its job easier.
|
38 |
+
"""
|
39 |
+
|
40 |
+
LEARNED = enum.auto()
|
41 |
+
FIXED_SMALL = enum.auto()
|
42 |
+
FIXED_LARGE = enum.auto()
|
43 |
+
LEARNED_RANGE = enum.auto()
|
44 |
+
|
45 |
+
|
46 |
+
class LossType(enum.Enum):
|
47 |
+
MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
|
48 |
+
RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances)
|
49 |
+
KL = enum.auto() # use the variational lower-bound
|
50 |
+
RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
|
51 |
+
|
52 |
+
def is_vb(self):
|
53 |
+
return self == LossType.KL or self == LossType.RESCALED_KL
|
54 |
+
|
55 |
+
|
56 |
+
def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
|
57 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
58 |
+
warmup_time = int(num_diffusion_timesteps * warmup_frac)
|
59 |
+
betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
|
60 |
+
return betas
|
61 |
+
|
62 |
+
|
63 |
+
def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
|
64 |
+
"""
|
65 |
+
This is the deprecated API for creating beta schedules.
|
66 |
+
See get_named_beta_schedule() for the new library of schedules.
|
67 |
+
"""
|
68 |
+
if beta_schedule == "quad":
|
69 |
+
betas = (
|
70 |
+
np.linspace(
|
71 |
+
beta_start**0.5,
|
72 |
+
beta_end**0.5,
|
73 |
+
num_diffusion_timesteps,
|
74 |
+
dtype=np.float64,
|
75 |
+
)
|
76 |
+
** 2
|
77 |
+
)
|
78 |
+
elif beta_schedule == "linear":
|
79 |
+
betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
|
80 |
+
elif beta_schedule == "warmup10":
|
81 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
|
82 |
+
elif beta_schedule == "warmup50":
|
83 |
+
betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
|
84 |
+
elif beta_schedule == "const":
|
85 |
+
betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
|
86 |
+
elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
|
87 |
+
betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
|
88 |
+
else:
|
89 |
+
raise NotImplementedError(beta_schedule)
|
90 |
+
assert betas.shape == (num_diffusion_timesteps,)
|
91 |
+
return betas
|
92 |
+
|
93 |
+
|
94 |
+
def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
|
95 |
+
"""
|
96 |
+
Get a pre-defined beta schedule for the given name.
|
97 |
+
The beta schedule library consists of beta schedules which remain similar
|
98 |
+
in the limit of num_diffusion_timesteps.
|
99 |
+
Beta schedules may be added, but should not be removed or changed once
|
100 |
+
they are committed to maintain backwards compatibility.
|
101 |
+
"""
|
102 |
+
if schedule_name == "linear":
|
103 |
+
# Linear schedule from Ho et al, extended to work for any number of
|
104 |
+
# diffusion steps.
|
105 |
+
scale = 1000 / num_diffusion_timesteps
|
106 |
+
return get_beta_schedule(
|
107 |
+
"linear",
|
108 |
+
beta_start=scale * 0.0001,
|
109 |
+
beta_end=scale * 0.02,
|
110 |
+
num_diffusion_timesteps=num_diffusion_timesteps,
|
111 |
+
)
|
112 |
+
elif schedule_name == "squaredcos_cap_v2":
|
113 |
+
return betas_for_alpha_bar(
|
114 |
+
num_diffusion_timesteps,
|
115 |
+
lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
|
119 |
+
|
120 |
+
|
121 |
+
def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
|
122 |
+
"""
|
123 |
+
Create a beta schedule that discretizes the given alpha_t_bar function,
|
124 |
+
which defines the cumulative product of (1-beta) over time from t = [0,1].
|
125 |
+
:param num_diffusion_timesteps: the number of betas to produce.
|
126 |
+
:param alpha_bar: a lambda that takes an argument t from 0 to 1 and
|
127 |
+
produces the cumulative product of (1-beta) up to that
|
128 |
+
part of the diffusion process.
|
129 |
+
:param max_beta: the maximum beta to use; use values lower than 1 to
|
130 |
+
prevent singularities.
|
131 |
+
"""
|
132 |
+
betas = []
|
133 |
+
for i in range(num_diffusion_timesteps):
|
134 |
+
t1 = i / num_diffusion_timesteps
|
135 |
+
t2 = (i + 1) / num_diffusion_timesteps
|
136 |
+
betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
|
137 |
+
return np.array(betas)
|
138 |
+
|
139 |
+
|
140 |
+
class GaussianDiffusion:
|
141 |
+
"""
|
142 |
+
Utilities for training and sampling diffusion models.
|
143 |
+
Original ported from this codebase:
|
144 |
+
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
|
145 |
+
:param betas: a 1-D numpy array of betas for each diffusion timestep,
|
146 |
+
starting at T and going to 1.
|
147 |
+
"""
|
148 |
+
|
149 |
+
def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):
|
150 |
+
self.model_mean_type = model_mean_type
|
151 |
+
self.model_var_type = model_var_type
|
152 |
+
self.loss_type = loss_type
|
153 |
+
|
154 |
+
# Use float64 for accuracy.
|
155 |
+
betas = np.array(betas, dtype=np.float64)
|
156 |
+
self.betas = betas
|
157 |
+
assert len(betas.shape) == 1, "betas must be 1-D"
|
158 |
+
assert (betas > 0).all() and (betas <= 1).all()
|
159 |
+
|
160 |
+
self.num_timesteps = int(betas.shape[0])
|
161 |
+
|
162 |
+
alphas = 1.0 - betas
|
163 |
+
self.alphas_cumprod = np.cumprod(alphas, axis=0)
|
164 |
+
self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
|
165 |
+
self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
|
166 |
+
assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
|
167 |
+
|
168 |
+
# calculations for diffusion q(x_t | x_{t-1}) and others
|
169 |
+
self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
|
170 |
+
self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
|
171 |
+
self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
|
172 |
+
self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
|
173 |
+
self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
|
174 |
+
|
175 |
+
# calculations for posterior q(x_{t-1} | x_t, x_0)
|
176 |
+
self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
177 |
+
# below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
|
178 |
+
self.posterior_log_variance_clipped = (
|
179 |
+
np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
|
180 |
+
if len(self.posterior_variance) > 1
|
181 |
+
else np.array([])
|
182 |
+
)
|
183 |
+
|
184 |
+
self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
|
185 |
+
self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
|
186 |
+
|
187 |
+
def q_mean_variance(self, x_start, t):
|
188 |
+
"""
|
189 |
+
Get the distribution q(x_t | x_0).
|
190 |
+
:param x_start: the [N x C x ...] tensor of noiseless inputs.
|
191 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
192 |
+
:return: A tuple (mean, variance, log_variance), all of x_start's shape.
|
193 |
+
"""
|
194 |
+
mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
195 |
+
variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
|
196 |
+
log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
|
197 |
+
return mean, variance, log_variance
|
198 |
+
|
199 |
+
def q_sample(self, x_start, t, noise=None):
|
200 |
+
"""
|
201 |
+
Diffuse the data for a given number of diffusion steps.
|
202 |
+
In other words, sample from q(x_t | x_0).
|
203 |
+
:param x_start: the initial data batch.
|
204 |
+
:param t: the number of diffusion steps (minus 1). Here, 0 means one step.
|
205 |
+
:param noise: if specified, the split-out normal noise.
|
206 |
+
:return: A noisy version of x_start.
|
207 |
+
"""
|
208 |
+
if noise is None:
|
209 |
+
noise = th.randn_like(x_start)
|
210 |
+
assert noise.shape == x_start.shape
|
211 |
+
return (
|
212 |
+
_extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
|
213 |
+
+ _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
|
214 |
+
)
|
215 |
+
|
216 |
+
def q_posterior_mean_variance(self, x_start, x_t, t):
|
217 |
+
"""
|
218 |
+
Compute the mean and variance of the diffusion posterior:
|
219 |
+
q(x_{t-1} | x_t, x_0)
|
220 |
+
"""
|
221 |
+
assert x_start.shape == x_t.shape
|
222 |
+
posterior_mean = (
|
223 |
+
_extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
|
224 |
+
+ _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
|
225 |
+
)
|
226 |
+
posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
|
227 |
+
posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
|
228 |
+
assert (
|
229 |
+
posterior_mean.shape[0]
|
230 |
+
== posterior_variance.shape[0]
|
231 |
+
== posterior_log_variance_clipped.shape[0]
|
232 |
+
== x_start.shape[0]
|
233 |
+
)
|
234 |
+
return posterior_mean, posterior_variance, posterior_log_variance_clipped
|
235 |
+
|
236 |
+
def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
|
237 |
+
"""
|
238 |
+
Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
|
239 |
+
the initial x, x_0.
|
240 |
+
:param model: the model, which takes a signal and a batch of timesteps
|
241 |
+
as input.
|
242 |
+
:param x: the [N x C x ...] tensor at time t.
|
243 |
+
:param t: a 1-D Tensor of timesteps.
|
244 |
+
:param clip_denoised: if True, clip the denoised signal into [-1, 1].
|
245 |
+
:param denoised_fn: if not None, a function which applies to the
|
246 |
+
x_start prediction before it is used to sample. Applies before
|
247 |
+
clip_denoised.
|
248 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
249 |
+
pass to the model. This can be used for conditioning.
|
250 |
+
:return: a dict with the following keys:
|
251 |
+
- 'mean': the model mean output.
|
252 |
+
- 'variance': the model variance output.
|
253 |
+
- 'log_variance': the log of 'variance'.
|
254 |
+
- 'pred_xstart': the prediction for x_0.
|
255 |
+
"""
|
256 |
+
if model_kwargs is None:
|
257 |
+
model_kwargs = {}
|
258 |
+
|
259 |
+
B, C = x.shape[:2]
|
260 |
+
assert t.shape == (B,)
|
261 |
+
model_output = model(x, t, **model_kwargs)
|
262 |
+
if isinstance(model_output, tuple):
|
263 |
+
model_output, extra = model_output
|
264 |
+
else:
|
265 |
+
extra = None
|
266 |
+
|
267 |
+
if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
|
268 |
+
assert model_output.shape == (B, C * 2, *x.shape[2:])
|
269 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
270 |
+
min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
|
271 |
+
max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
|
272 |
+
# The model_var_values is [-1, 1] for [min_var, max_var].
|
273 |
+
frac = (model_var_values + 1) / 2
|
274 |
+
model_log_variance = frac * max_log + (1 - frac) * min_log
|
275 |
+
model_variance = th.exp(model_log_variance)
|
276 |
+
else:
|
277 |
+
model_variance, model_log_variance = {
|
278 |
+
# for fixedlarge, we set the initial (log-)variance like so
|
279 |
+
# to get a better decoder log likelihood.
|
280 |
+
ModelVarType.FIXED_LARGE: (
|
281 |
+
np.append(self.posterior_variance[1], self.betas[1:]),
|
282 |
+
np.log(np.append(self.posterior_variance[1], self.betas[1:])),
|
283 |
+
),
|
284 |
+
ModelVarType.FIXED_SMALL: (
|
285 |
+
self.posterior_variance,
|
286 |
+
self.posterior_log_variance_clipped,
|
287 |
+
),
|
288 |
+
}[self.model_var_type]
|
289 |
+
model_variance = _extract_into_tensor(model_variance, t, x.shape)
|
290 |
+
model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
|
291 |
+
|
292 |
+
def process_xstart(x):
|
293 |
+
if denoised_fn is not None:
|
294 |
+
x = denoised_fn(x)
|
295 |
+
if clip_denoised:
|
296 |
+
return x.clamp(-1, 1)
|
297 |
+
return x
|
298 |
+
|
299 |
+
if self.model_mean_type == ModelMeanType.START_X:
|
300 |
+
pred_xstart = process_xstart(model_output)
|
301 |
+
else:
|
302 |
+
pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
|
303 |
+
model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
|
304 |
+
|
305 |
+
assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
|
306 |
+
return {
|
307 |
+
"mean": model_mean,
|
308 |
+
"variance": model_variance,
|
309 |
+
"log_variance": model_log_variance,
|
310 |
+
"pred_xstart": pred_xstart,
|
311 |
+
"extra": extra,
|
312 |
+
}
|
313 |
+
|
314 |
+
def _predict_xstart_from_eps(self, x_t, t, eps):
|
315 |
+
assert x_t.shape == eps.shape
|
316 |
+
return (
|
317 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
|
318 |
+
- _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
|
319 |
+
)
|
320 |
+
|
321 |
+
def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
|
322 |
+
return (
|
323 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
|
324 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
|
325 |
+
|
326 |
+
def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
327 |
+
"""
|
328 |
+
Compute the mean for the previous step, given a function cond_fn that
|
329 |
+
computes the gradient of a conditional log probability with respect to
|
330 |
+
x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
|
331 |
+
condition on y.
|
332 |
+
This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
|
333 |
+
"""
|
334 |
+
gradient = cond_fn(x, t, **model_kwargs)
|
335 |
+
new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
|
336 |
+
return new_mean
|
337 |
+
|
338 |
+
def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
|
339 |
+
"""
|
340 |
+
Compute what the p_mean_variance output would have been, should the
|
341 |
+
model's score function be conditioned by cond_fn.
|
342 |
+
See condition_mean() for details on cond_fn.
|
343 |
+
Unlike condition_mean(), this instead uses the conditioning strategy
|
344 |
+
from Song et al (2020).
|
345 |
+
"""
|
346 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
347 |
+
|
348 |
+
eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
|
349 |
+
eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
|
350 |
+
|
351 |
+
out = p_mean_var.copy()
|
352 |
+
out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
|
353 |
+
out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
|
354 |
+
return out
|
355 |
+
|
356 |
+
def p_sample(
|
357 |
+
self,
|
358 |
+
model,
|
359 |
+
x,
|
360 |
+
t,
|
361 |
+
clip_denoised=True,
|
362 |
+
denoised_fn=None,
|
363 |
+
cond_fn=None,
|
364 |
+
model_kwargs=None,
|
365 |
+
):
|
366 |
+
"""
|
367 |
+
Sample x_{t-1} from the model at the given timestep.
|
368 |
+
:param model: the model to sample from.
|
369 |
+
:param x: the current tensor at x_{t-1}.
|
370 |
+
:param t: the value of t, starting at 0 for the first diffusion step.
|
371 |
+
:param clip_denoised: if True, clip the x_start prediction to [-1, 1].
|
372 |
+
:param denoised_fn: if not None, a function which applies to the
|
373 |
+
x_start prediction before it is used to sample.
|
374 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
375 |
+
similarly to the model.
|
376 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
377 |
+
pass to the model. This can be used for conditioning.
|
378 |
+
:return: a dict containing the following keys:
|
379 |
+
- 'sample': a random sample from the model.
|
380 |
+
- 'pred_xstart': a prediction of x_0.
|
381 |
+
"""
|
382 |
+
out = self.p_mean_variance(
|
383 |
+
model,
|
384 |
+
x,
|
385 |
+
t,
|
386 |
+
clip_denoised=clip_denoised,
|
387 |
+
denoised_fn=denoised_fn,
|
388 |
+
model_kwargs=model_kwargs,
|
389 |
+
)
|
390 |
+
noise = th.randn_like(x)
|
391 |
+
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
|
392 |
+
if cond_fn is not None:
|
393 |
+
out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
394 |
+
sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
|
395 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
396 |
+
|
397 |
+
def p_sample_loop(
|
398 |
+
self,
|
399 |
+
model,
|
400 |
+
shape,
|
401 |
+
noise=None,
|
402 |
+
clip_denoised=True,
|
403 |
+
denoised_fn=None,
|
404 |
+
cond_fn=None,
|
405 |
+
model_kwargs=None,
|
406 |
+
device=None,
|
407 |
+
progress=False,
|
408 |
+
):
|
409 |
+
"""
|
410 |
+
Generate samples from the model.
|
411 |
+
:param model: the model module.
|
412 |
+
:param shape: the shape of the samples, (N, C, H, W).
|
413 |
+
:param noise: if specified, the noise from the encoder to sample.
|
414 |
+
Should be of the same shape as `shape`.
|
415 |
+
:param clip_denoised: if True, clip x_start predictions to [-1, 1].
|
416 |
+
:param denoised_fn: if not None, a function which applies to the
|
417 |
+
x_start prediction before it is used to sample.
|
418 |
+
:param cond_fn: if not None, this is a gradient function that acts
|
419 |
+
similarly to the model.
|
420 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
421 |
+
pass to the model. This can be used for conditioning.
|
422 |
+
:param device: if specified, the device to create the samples on.
|
423 |
+
If not specified, use a model parameter's device.
|
424 |
+
:param progress: if True, show a tqdm progress bar.
|
425 |
+
:return: a non-differentiable batch of samples.
|
426 |
+
"""
|
427 |
+
final = None
|
428 |
+
for sample in self.p_sample_loop_progressive(
|
429 |
+
model,
|
430 |
+
shape,
|
431 |
+
noise=noise,
|
432 |
+
clip_denoised=clip_denoised,
|
433 |
+
denoised_fn=denoised_fn,
|
434 |
+
cond_fn=cond_fn,
|
435 |
+
model_kwargs=model_kwargs,
|
436 |
+
device=device,
|
437 |
+
progress=progress,
|
438 |
+
):
|
439 |
+
final = sample
|
440 |
+
return final["sample"]
|
441 |
+
|
442 |
+
def p_sample_loop_progressive(
|
443 |
+
self,
|
444 |
+
model,
|
445 |
+
shape,
|
446 |
+
noise=None,
|
447 |
+
clip_denoised=True,
|
448 |
+
denoised_fn=None,
|
449 |
+
cond_fn=None,
|
450 |
+
model_kwargs=None,
|
451 |
+
device=None,
|
452 |
+
progress=False,
|
453 |
+
):
|
454 |
+
"""
|
455 |
+
Generate samples from the model and yield intermediate samples from
|
456 |
+
each timestep of diffusion.
|
457 |
+
Arguments are the same as p_sample_loop().
|
458 |
+
Returns a generator over dicts, where each dict is the return value of
|
459 |
+
p_sample().
|
460 |
+
"""
|
461 |
+
if device is None:
|
462 |
+
device = next(model.parameters()).device
|
463 |
+
assert isinstance(shape, (tuple, list))
|
464 |
+
if noise is not None:
|
465 |
+
img = noise
|
466 |
+
else:
|
467 |
+
img = th.randn(*shape, device=device)
|
468 |
+
indices = list(range(self.num_timesteps))[::-1]
|
469 |
+
|
470 |
+
if progress:
|
471 |
+
# Lazy import so that we don't depend on tqdm.
|
472 |
+
from tqdm.auto import tqdm
|
473 |
+
|
474 |
+
indices = tqdm(indices)
|
475 |
+
|
476 |
+
for i in indices:
|
477 |
+
t = th.tensor([i] * shape[0], device=device)
|
478 |
+
with th.no_grad():
|
479 |
+
out = self.p_sample(
|
480 |
+
model,
|
481 |
+
img,
|
482 |
+
t,
|
483 |
+
clip_denoised=clip_denoised,
|
484 |
+
denoised_fn=denoised_fn,
|
485 |
+
cond_fn=cond_fn,
|
486 |
+
model_kwargs=model_kwargs,
|
487 |
+
)
|
488 |
+
yield out
|
489 |
+
img = out["sample"]
|
490 |
+
|
491 |
+
def ddim_sample(
|
492 |
+
self,
|
493 |
+
model,
|
494 |
+
x,
|
495 |
+
t,
|
496 |
+
clip_denoised=True,
|
497 |
+
denoised_fn=None,
|
498 |
+
cond_fn=None,
|
499 |
+
model_kwargs=None,
|
500 |
+
eta=0.0,
|
501 |
+
):
|
502 |
+
"""
|
503 |
+
Sample x_{t-1} from the model using DDIM.
|
504 |
+
Same usage as p_sample().
|
505 |
+
"""
|
506 |
+
out = self.p_mean_variance(
|
507 |
+
model,
|
508 |
+
x,
|
509 |
+
t,
|
510 |
+
clip_denoised=clip_denoised,
|
511 |
+
denoised_fn=denoised_fn,
|
512 |
+
model_kwargs=model_kwargs,
|
513 |
+
)
|
514 |
+
if cond_fn is not None:
|
515 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
516 |
+
|
517 |
+
# Usually our model outputs epsilon, but we re-derive it
|
518 |
+
# in case we used x_start or x_prev prediction.
|
519 |
+
eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
|
520 |
+
|
521 |
+
alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
|
522 |
+
alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
|
523 |
+
sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
|
524 |
+
# Equation 12.
|
525 |
+
noise = th.randn_like(x)
|
526 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
|
527 |
+
nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
|
528 |
+
sample = mean_pred + nonzero_mask * sigma * noise
|
529 |
+
return {"sample": sample, "pred_xstart": out["pred_xstart"]}
|
530 |
+
|
531 |
+
def ddim_reverse_sample(
|
532 |
+
self,
|
533 |
+
model,
|
534 |
+
x,
|
535 |
+
t,
|
536 |
+
clip_denoised=True,
|
537 |
+
denoised_fn=None,
|
538 |
+
cond_fn=None,
|
539 |
+
model_kwargs=None,
|
540 |
+
eta=0.0,
|
541 |
+
):
|
542 |
+
"""
|
543 |
+
Sample x_{t+1} from the model using DDIM reverse ODE.
|
544 |
+
"""
|
545 |
+
assert eta == 0.0, "Reverse ODE only for deterministic path"
|
546 |
+
out = self.p_mean_variance(
|
547 |
+
model,
|
548 |
+
x,
|
549 |
+
t,
|
550 |
+
clip_denoised=clip_denoised,
|
551 |
+
denoised_fn=denoised_fn,
|
552 |
+
model_kwargs=model_kwargs,
|
553 |
+
)
|
554 |
+
if cond_fn is not None:
|
555 |
+
out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
|
556 |
+
# Usually our model outputs epsilon, but we re-derive it
|
557 |
+
# in case we used x_start or x_prev prediction.
|
558 |
+
eps = (
|
559 |
+
_extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
|
560 |
+
) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
|
561 |
+
alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
|
562 |
+
|
563 |
+
# Equation 12. reversed
|
564 |
+
mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
|
565 |
+
|
566 |
+
return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
|
567 |
+
|
568 |
+
def ddim_sample_loop(
|
569 |
+
self,
|
570 |
+
model,
|
571 |
+
shape,
|
572 |
+
noise=None,
|
573 |
+
clip_denoised=True,
|
574 |
+
denoised_fn=None,
|
575 |
+
cond_fn=None,
|
576 |
+
model_kwargs=None,
|
577 |
+
device=None,
|
578 |
+
progress=False,
|
579 |
+
eta=0.0,
|
580 |
+
):
|
581 |
+
"""
|
582 |
+
Generate samples from the model using DDIM.
|
583 |
+
Same usage as p_sample_loop().
|
584 |
+
"""
|
585 |
+
final = None
|
586 |
+
for sample in self.ddim_sample_loop_progressive(
|
587 |
+
model,
|
588 |
+
shape,
|
589 |
+
noise=noise,
|
590 |
+
clip_denoised=clip_denoised,
|
591 |
+
denoised_fn=denoised_fn,
|
592 |
+
cond_fn=cond_fn,
|
593 |
+
model_kwargs=model_kwargs,
|
594 |
+
device=device,
|
595 |
+
progress=progress,
|
596 |
+
eta=eta,
|
597 |
+
):
|
598 |
+
final = sample
|
599 |
+
return final["sample"]
|
600 |
+
|
601 |
+
def ddim_sample_loop_progressive(
|
602 |
+
self,
|
603 |
+
model,
|
604 |
+
shape,
|
605 |
+
noise=None,
|
606 |
+
clip_denoised=True,
|
607 |
+
denoised_fn=None,
|
608 |
+
cond_fn=None,
|
609 |
+
model_kwargs=None,
|
610 |
+
device=None,
|
611 |
+
progress=False,
|
612 |
+
eta=0.0,
|
613 |
+
):
|
614 |
+
"""
|
615 |
+
Use DDIM to sample from the model and yield intermediate samples from
|
616 |
+
each timestep of DDIM.
|
617 |
+
Same usage as p_sample_loop_progressive().
|
618 |
+
"""
|
619 |
+
if device is None:
|
620 |
+
device = next(model.parameters()).device
|
621 |
+
assert isinstance(shape, (tuple, list))
|
622 |
+
if noise is not None:
|
623 |
+
img = noise
|
624 |
+
else:
|
625 |
+
img = th.randn(*shape, device=device)
|
626 |
+
indices = list(range(self.num_timesteps))[::-1]
|
627 |
+
|
628 |
+
if progress:
|
629 |
+
# Lazy import so that we don't depend on tqdm.
|
630 |
+
from tqdm.auto import tqdm
|
631 |
+
|
632 |
+
indices = tqdm(indices)
|
633 |
+
|
634 |
+
for i in indices:
|
635 |
+
t = th.tensor([i] * shape[0], device=device)
|
636 |
+
with th.no_grad():
|
637 |
+
out = self.ddim_sample(
|
638 |
+
model,
|
639 |
+
img,
|
640 |
+
t,
|
641 |
+
clip_denoised=clip_denoised,
|
642 |
+
denoised_fn=denoised_fn,
|
643 |
+
cond_fn=cond_fn,
|
644 |
+
model_kwargs=model_kwargs,
|
645 |
+
eta=eta,
|
646 |
+
)
|
647 |
+
yield out
|
648 |
+
img = out["sample"]
|
649 |
+
|
650 |
+
def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
|
651 |
+
"""
|
652 |
+
Get a term for the variational lower-bound.
|
653 |
+
The resulting units are bits (rather than nats, as one might expect).
|
654 |
+
This allows for comparison to other papers.
|
655 |
+
:return: a dict with the following keys:
|
656 |
+
- 'output': a shape [N] tensor of NLLs or KLs.
|
657 |
+
- 'pred_xstart': the x_0 predictions.
|
658 |
+
"""
|
659 |
+
true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
|
660 |
+
out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
|
661 |
+
kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
|
662 |
+
kl = mean_flat(kl) / np.log(2.0)
|
663 |
+
|
664 |
+
decoder_nll = -discretized_gaussian_log_likelihood(
|
665 |
+
x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
|
666 |
+
)
|
667 |
+
assert decoder_nll.shape == x_start.shape
|
668 |
+
decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
|
669 |
+
|
670 |
+
# At the first timestep return the decoder NLL,
|
671 |
+
# otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
|
672 |
+
output = th.where((t == 0), decoder_nll, kl)
|
673 |
+
return {"output": output, "pred_xstart": out["pred_xstart"]}
|
674 |
+
|
675 |
+
def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
|
676 |
+
"""
|
677 |
+
Compute training losses for a single timestep.
|
678 |
+
:param model: the model to evaluate loss on.
|
679 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
680 |
+
:param t: a batch of timestep indices.
|
681 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
682 |
+
pass to the model. This can be used for conditioning.
|
683 |
+
:param noise: if specified, the specific Gaussian noise to try to remove.
|
684 |
+
:return: a dict with the key "loss" containing a tensor of shape [N].
|
685 |
+
Some mean or variance settings may also have other keys.
|
686 |
+
"""
|
687 |
+
if model_kwargs is None:
|
688 |
+
model_kwargs = {}
|
689 |
+
if noise is None:
|
690 |
+
noise = th.randn_like(x_start)
|
691 |
+
x_t = self.q_sample(x_start, t, noise=noise)
|
692 |
+
|
693 |
+
terms = {}
|
694 |
+
|
695 |
+
if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
|
696 |
+
terms["loss"] = self._vb_terms_bpd(
|
697 |
+
model=model,
|
698 |
+
x_start=x_start,
|
699 |
+
x_t=x_t,
|
700 |
+
t=t,
|
701 |
+
clip_denoised=False,
|
702 |
+
model_kwargs=model_kwargs,
|
703 |
+
)["output"]
|
704 |
+
if self.loss_type == LossType.RESCALED_KL:
|
705 |
+
terms["loss"] *= self.num_timesteps
|
706 |
+
elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
|
707 |
+
model_output = model(x_t, t, **model_kwargs)
|
708 |
+
|
709 |
+
if self.model_var_type in [
|
710 |
+
ModelVarType.LEARNED,
|
711 |
+
ModelVarType.LEARNED_RANGE,
|
712 |
+
]:
|
713 |
+
B, C = x_t.shape[:2]
|
714 |
+
assert model_output.shape == (B, C * 2, *x_t.shape[2:])
|
715 |
+
model_output, model_var_values = th.split(model_output, C, dim=1)
|
716 |
+
# Learn the variance using the variational bound, but don't let
|
717 |
+
# it affect our mean prediction.
|
718 |
+
frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
|
719 |
+
terms["vb"] = self._vb_terms_bpd(
|
720 |
+
model=lambda *args, r=frozen_out: r,
|
721 |
+
x_start=x_start,
|
722 |
+
x_t=x_t,
|
723 |
+
t=t,
|
724 |
+
clip_denoised=False,
|
725 |
+
)["output"]
|
726 |
+
if self.loss_type == LossType.RESCALED_MSE:
|
727 |
+
# Divide by 1000 for equivalence with initial implementation.
|
728 |
+
# Without a factor of 1/1000, the VB term hurts the MSE term.
|
729 |
+
terms["vb"] *= self.num_timesteps / 1000.0
|
730 |
+
|
731 |
+
target = {
|
732 |
+
ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
|
733 |
+
ModelMeanType.START_X: x_start,
|
734 |
+
ModelMeanType.EPSILON: noise,
|
735 |
+
}[self.model_mean_type]
|
736 |
+
assert model_output.shape == target.shape == x_start.shape
|
737 |
+
terms["mse"] = mean_flat((target - model_output) ** 2)
|
738 |
+
if "vb" in terms:
|
739 |
+
terms["loss"] = terms["mse"] + terms["vb"]
|
740 |
+
else:
|
741 |
+
terms["loss"] = terms["mse"]
|
742 |
+
else:
|
743 |
+
raise NotImplementedError(self.loss_type)
|
744 |
+
|
745 |
+
return terms
|
746 |
+
|
747 |
+
def _prior_bpd(self, x_start):
|
748 |
+
"""
|
749 |
+
Get the prior KL term for the variational lower-bound, measured in
|
750 |
+
bits-per-dim.
|
751 |
+
This term can't be optimized, as it only depends on the encoder.
|
752 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
753 |
+
:return: a batch of [N] KL values (in bits), one per batch element.
|
754 |
+
"""
|
755 |
+
batch_size = x_start.shape[0]
|
756 |
+
t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
|
757 |
+
qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
|
758 |
+
kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
|
759 |
+
return mean_flat(kl_prior) / np.log(2.0)
|
760 |
+
|
761 |
+
def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
|
762 |
+
"""
|
763 |
+
Compute the entire variational lower-bound, measured in bits-per-dim,
|
764 |
+
as well as other related quantities.
|
765 |
+
:param model: the model to evaluate loss on.
|
766 |
+
:param x_start: the [N x C x ...] tensor of inputs.
|
767 |
+
:param clip_denoised: if True, clip denoised samples.
|
768 |
+
:param model_kwargs: if not None, a dict of extra keyword arguments to
|
769 |
+
pass to the model. This can be used for conditioning.
|
770 |
+
:return: a dict containing the following keys:
|
771 |
+
- total_bpd: the total variational lower-bound, per batch element.
|
772 |
+
- prior_bpd: the prior term in the lower-bound.
|
773 |
+
- vb: an [N x T] tensor of terms in the lower-bound.
|
774 |
+
- xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
|
775 |
+
- mse: an [N x T] tensor of epsilon MSEs for each timestep.
|
776 |
+
"""
|
777 |
+
device = x_start.device
|
778 |
+
batch_size = x_start.shape[0]
|
779 |
+
|
780 |
+
vb = []
|
781 |
+
xstart_mse = []
|
782 |
+
mse = []
|
783 |
+
for t in list(range(self.num_timesteps))[::-1]:
|
784 |
+
t_batch = th.tensor([t] * batch_size, device=device)
|
785 |
+
noise = th.randn_like(x_start)
|
786 |
+
x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
|
787 |
+
# Calculate VLB term at the current timestep
|
788 |
+
with th.no_grad():
|
789 |
+
out = self._vb_terms_bpd(
|
790 |
+
model,
|
791 |
+
x_start=x_start,
|
792 |
+
x_t=x_t,
|
793 |
+
t=t_batch,
|
794 |
+
clip_denoised=clip_denoised,
|
795 |
+
model_kwargs=model_kwargs,
|
796 |
+
)
|
797 |
+
vb.append(out["output"])
|
798 |
+
xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
|
799 |
+
eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
|
800 |
+
mse.append(mean_flat((eps - noise) ** 2))
|
801 |
+
|
802 |
+
vb = th.stack(vb, dim=1)
|
803 |
+
xstart_mse = th.stack(xstart_mse, dim=1)
|
804 |
+
mse = th.stack(mse, dim=1)
|
805 |
+
|
806 |
+
prior_bpd = self._prior_bpd(x_start)
|
807 |
+
total_bpd = vb.sum(dim=1) + prior_bpd
|
808 |
+
return {
|
809 |
+
"total_bpd": total_bpd,
|
810 |
+
"prior_bpd": prior_bpd,
|
811 |
+
"vb": vb,
|
812 |
+
"xstart_mse": xstart_mse,
|
813 |
+
"mse": mse,
|
814 |
+
}
|
815 |
+
|
816 |
+
|
817 |
+
def _extract_into_tensor(arr, timesteps, broadcast_shape):
|
818 |
+
"""
|
819 |
+
Extract values from a 1-D numpy array for a batch of indices.
|
820 |
+
:param arr: the 1-D numpy array.
|
821 |
+
:param timesteps: a tensor of indices into the array to extract.
|
822 |
+
:param broadcast_shape: a larger shape of K dimensions with the batch
|
823 |
+
dimension equal to the length of timesteps.
|
824 |
+
:return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
|
825 |
+
"""
|
826 |
+
res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
|
827 |
+
while len(res.shape) < len(broadcast_shape):
|
828 |
+
res = res[..., None]
|
829 |
+
return res + th.zeros(broadcast_shape, device=timesteps.device)
|
videosys/diffusion/respace.py
ADDED
@@ -0,0 +1,119 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
import numpy as np
|
7 |
+
import torch as th
|
8 |
+
|
9 |
+
from .gaussian_diffusion import GaussianDiffusion
|
10 |
+
|
11 |
+
|
12 |
+
def space_timesteps(num_timesteps, section_counts):
|
13 |
+
"""
|
14 |
+
Create a list of timesteps to use from an original diffusion process,
|
15 |
+
given the number of timesteps we want to take from equally-sized portions
|
16 |
+
of the original process.
|
17 |
+
For example, if there's 300 timesteps and the section counts are [10,15,20]
|
18 |
+
then the first 100 timesteps are strided to be 10 timesteps, the second 100
|
19 |
+
are strided to be 15 timesteps, and the final 100 are strided to be 20.
|
20 |
+
If the stride is a string starting with "ddim", then the fixed striding
|
21 |
+
from the DDIM paper is used, and only one section is allowed.
|
22 |
+
:param num_timesteps: the number of diffusion steps in the original
|
23 |
+
process to divide up.
|
24 |
+
:param section_counts: either a list of numbers, or a string containing
|
25 |
+
comma-separated numbers, indicating the step count
|
26 |
+
per section. As a special case, use "ddimN" where N
|
27 |
+
is a number of steps to use the striding from the
|
28 |
+
DDIM paper.
|
29 |
+
:return: a set of diffusion steps from the original process to use.
|
30 |
+
"""
|
31 |
+
if isinstance(section_counts, str):
|
32 |
+
if section_counts.startswith("ddim"):
|
33 |
+
desired_count = int(section_counts[len("ddim") :])
|
34 |
+
for i in range(1, num_timesteps):
|
35 |
+
if len(range(0, num_timesteps, i)) == desired_count:
|
36 |
+
return set(range(0, num_timesteps, i))
|
37 |
+
raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
|
38 |
+
section_counts = [int(x) for x in section_counts.split(",")]
|
39 |
+
size_per = num_timesteps // len(section_counts)
|
40 |
+
extra = num_timesteps % len(section_counts)
|
41 |
+
start_idx = 0
|
42 |
+
all_steps = []
|
43 |
+
for i, section_count in enumerate(section_counts):
|
44 |
+
size = size_per + (1 if i < extra else 0)
|
45 |
+
if size < section_count:
|
46 |
+
raise ValueError(f"cannot divide section of {size} steps into {section_count}")
|
47 |
+
if section_count <= 1:
|
48 |
+
frac_stride = 1
|
49 |
+
else:
|
50 |
+
frac_stride = (size - 1) / (section_count - 1)
|
51 |
+
cur_idx = 0.0
|
52 |
+
taken_steps = []
|
53 |
+
for _ in range(section_count):
|
54 |
+
taken_steps.append(start_idx + round(cur_idx))
|
55 |
+
cur_idx += frac_stride
|
56 |
+
all_steps += taken_steps
|
57 |
+
start_idx += size
|
58 |
+
return set(all_steps)
|
59 |
+
|
60 |
+
|
61 |
+
class SpacedDiffusion(GaussianDiffusion):
|
62 |
+
"""
|
63 |
+
A diffusion process which can skip steps in a base diffusion process.
|
64 |
+
:param use_timesteps: a collection (sequence or set) of timesteps from the
|
65 |
+
original diffusion process to retain.
|
66 |
+
:param kwargs: the kwargs to create the base diffusion process.
|
67 |
+
"""
|
68 |
+
|
69 |
+
def __init__(self, use_timesteps, **kwargs):
|
70 |
+
self.use_timesteps = set(use_timesteps)
|
71 |
+
self.timestep_map = []
|
72 |
+
self.original_num_steps = len(kwargs["betas"])
|
73 |
+
|
74 |
+
base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
|
75 |
+
last_alpha_cumprod = 1.0
|
76 |
+
new_betas = []
|
77 |
+
for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
|
78 |
+
if i in self.use_timesteps:
|
79 |
+
new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
|
80 |
+
last_alpha_cumprod = alpha_cumprod
|
81 |
+
self.timestep_map.append(i)
|
82 |
+
kwargs["betas"] = np.array(new_betas)
|
83 |
+
super().__init__(**kwargs)
|
84 |
+
|
85 |
+
def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
|
86 |
+
return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
|
87 |
+
|
88 |
+
def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
|
89 |
+
return super().training_losses(self._wrap_model(model), *args, **kwargs)
|
90 |
+
|
91 |
+
def condition_mean(self, cond_fn, *args, **kwargs):
|
92 |
+
return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
|
93 |
+
|
94 |
+
def condition_score(self, cond_fn, *args, **kwargs):
|
95 |
+
return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
|
96 |
+
|
97 |
+
def _wrap_model(self, model):
|
98 |
+
if isinstance(model, _WrappedModel):
|
99 |
+
return model
|
100 |
+
return _WrappedModel(model, self.timestep_map, self.original_num_steps)
|
101 |
+
|
102 |
+
def _scale_timesteps(self, t):
|
103 |
+
# Scaling is done by the wrapped model.
|
104 |
+
return t
|
105 |
+
|
106 |
+
|
107 |
+
class _WrappedModel:
|
108 |
+
def __init__(self, model, timestep_map, original_num_steps):
|
109 |
+
self.model = model
|
110 |
+
self.timestep_map = timestep_map
|
111 |
+
# self.rescale_timesteps = rescale_timesteps
|
112 |
+
self.original_num_steps = original_num_steps
|
113 |
+
|
114 |
+
def __call__(self, x, ts, **kwargs):
|
115 |
+
map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
|
116 |
+
new_ts = map_tensor[ts]
|
117 |
+
# if self.rescale_timesteps:
|
118 |
+
# new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
|
119 |
+
return self.model(x, new_ts, **kwargs)
|
videosys/diffusion/timestep_sampler.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Modified from OpenAI's diffusion repos
|
2 |
+
# GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
|
3 |
+
# ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
|
4 |
+
# IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
|
5 |
+
|
6 |
+
from abc import ABC, abstractmethod
|
7 |
+
|
8 |
+
import numpy as np
|
9 |
+
import torch as th
|
10 |
+
import torch.distributed as dist
|
11 |
+
|
12 |
+
|
13 |
+
def create_named_schedule_sampler(name, diffusion):
|
14 |
+
"""
|
15 |
+
Create a ScheduleSampler from a library of pre-defined samplers.
|
16 |
+
:param name: the name of the sampler.
|
17 |
+
:param diffusion: the diffusion object to sample for.
|
18 |
+
"""
|
19 |
+
if name == "uniform":
|
20 |
+
return UniformSampler(diffusion)
|
21 |
+
elif name == "loss-second-moment":
|
22 |
+
return LossSecondMomentResampler(diffusion)
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(f"unknown schedule sampler: {name}")
|
25 |
+
|
26 |
+
|
27 |
+
class ScheduleSampler(ABC):
|
28 |
+
"""
|
29 |
+
A distribution over timesteps in the diffusion process, intended to reduce
|
30 |
+
variance of the objective.
|
31 |
+
By default, samplers perform unbiased importance sampling, in which the
|
32 |
+
objective's mean is unchanged.
|
33 |
+
However, subclasses may override sample() to change how the resampled
|
34 |
+
terms are reweighted, allowing for actual changes in the objective.
|
35 |
+
"""
|
36 |
+
|
37 |
+
@abstractmethod
|
38 |
+
def weights(self):
|
39 |
+
"""
|
40 |
+
Get a numpy array of weights, one per diffusion step.
|
41 |
+
The weights needn't be normalized, but must be positive.
|
42 |
+
"""
|
43 |
+
|
44 |
+
def sample(self, batch_size, device):
|
45 |
+
"""
|
46 |
+
Importance-sample timesteps for a batch.
|
47 |
+
:param batch_size: the number of timesteps.
|
48 |
+
:param device: the torch device to save to.
|
49 |
+
:return: a tuple (timesteps, weights):
|
50 |
+
- timesteps: a tensor of timestep indices.
|
51 |
+
- weights: a tensor of weights to scale the resulting losses.
|
52 |
+
"""
|
53 |
+
w = self.weights()
|
54 |
+
p = w / np.sum(w)
|
55 |
+
indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
|
56 |
+
indices = th.from_numpy(indices_np).long().to(device)
|
57 |
+
weights_np = 1 / (len(p) * p[indices_np])
|
58 |
+
weights = th.from_numpy(weights_np).float().to(device)
|
59 |
+
return indices, weights
|
60 |
+
|
61 |
+
|
62 |
+
class UniformSampler(ScheduleSampler):
|
63 |
+
def __init__(self, diffusion):
|
64 |
+
self.diffusion = diffusion
|
65 |
+
self._weights = np.ones([diffusion.num_timesteps])
|
66 |
+
|
67 |
+
def weights(self):
|
68 |
+
return self._weights
|
69 |
+
|
70 |
+
|
71 |
+
class LossAwareSampler(ScheduleSampler):
|
72 |
+
def update_with_local_losses(self, local_ts, local_losses):
|
73 |
+
"""
|
74 |
+
Update the reweighting using losses from a model.
|
75 |
+
Call this method from each rank with a batch of timesteps and the
|
76 |
+
corresponding losses for each of those timesteps.
|
77 |
+
This method will perform synchronization to make sure all of the ranks
|
78 |
+
maintain the exact same reweighting.
|
79 |
+
:param local_ts: an integer Tensor of timesteps.
|
80 |
+
:param local_losses: a 1D Tensor of losses.
|
81 |
+
"""
|
82 |
+
batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())]
|
83 |
+
dist.all_gather(
|
84 |
+
batch_sizes,
|
85 |
+
th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
|
86 |
+
)
|
87 |
+
|
88 |
+
# Pad all_gather batches to be the maximum batch size.
|
89 |
+
batch_sizes = [x.item() for x in batch_sizes]
|
90 |
+
max_bs = max(batch_sizes)
|
91 |
+
|
92 |
+
timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
|
93 |
+
loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
|
94 |
+
dist.all_gather(timestep_batches, local_ts)
|
95 |
+
dist.all_gather(loss_batches, local_losses)
|
96 |
+
timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]]
|
97 |
+
losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
|
98 |
+
self.update_with_all_losses(timesteps, losses)
|
99 |
+
|
100 |
+
@abstractmethod
|
101 |
+
def update_with_all_losses(self, ts, losses):
|
102 |
+
"""
|
103 |
+
Update the reweighting using losses from a model.
|
104 |
+
Sub-classes should override this method to update the reweighting
|
105 |
+
using losses from the model.
|
106 |
+
This method directly updates the reweighting without synchronizing
|
107 |
+
between workers. It is called by update_with_local_losses from all
|
108 |
+
ranks with identical arguments. Thus, it should have deterministic
|
109 |
+
behavior to maintain state across workers.
|
110 |
+
:param ts: a list of int timesteps.
|
111 |
+
:param losses: a list of float losses, one per timestep.
|
112 |
+
"""
|
113 |
+
|
114 |
+
|
115 |
+
class LossSecondMomentResampler(LossAwareSampler):
|
116 |
+
def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
|
117 |
+
self.diffusion = diffusion
|
118 |
+
self.history_per_term = history_per_term
|
119 |
+
self.uniform_prob = uniform_prob
|
120 |
+
self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64)
|
121 |
+
self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
|
122 |
+
|
123 |
+
def weights(self):
|
124 |
+
if not self._warmed_up():
|
125 |
+
return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
|
126 |
+
weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
|
127 |
+
weights /= np.sum(weights)
|
128 |
+
weights *= 1 - self.uniform_prob
|
129 |
+
weights += self.uniform_prob / len(weights)
|
130 |
+
return weights
|
131 |
+
|
132 |
+
def update_with_all_losses(self, ts, losses):
|
133 |
+
for t, loss in zip(ts, losses):
|
134 |
+
if self._loss_counts[t] == self.history_per_term:
|
135 |
+
# Shift out the oldest loss term.
|
136 |
+
self._loss_history[t, :-1] = self._loss_history[t, 1:]
|
137 |
+
self._loss_history[t, -1] = loss
|
138 |
+
else:
|
139 |
+
self._loss_history[t, self._loss_counts[t]] = loss
|
140 |
+
self._loss_counts[t] += 1
|
141 |
+
|
142 |
+
def _warmed_up(self):
|
143 |
+
return (self._loss_counts == self.history_per_term).all()
|
videosys/models/__init__.py
ADDED
File without changes
|