oahzxl commited on
Commit
ab7be96
β€’
1 Parent(s): 654833d
This view is limited to 50 files because it contains too many changes. Β  See raw diff
Files changed (50) hide show
  1. .gitignore +167 -0
  2. .isort.cfg +7 -0
  3. .pre-commit-config.yaml +39 -0
  4. app.py +102 -205
  5. docs/dsp.md +0 -25
  6. docs/pab.md +0 -121
  7. eval/pab/commom_metrics/README.md +0 -6
  8. eval/pab/commom_metrics/calculate_lpips.py +0 -97
  9. eval/pab/commom_metrics/calculate_psnr.py +0 -90
  10. eval/pab/commom_metrics/calculate_ssim.py +0 -116
  11. eval/pab/commom_metrics/eval.py +0 -160
  12. eval/pab/experiments/attention_ablation.py +0 -60
  13. eval/pab/experiments/components_ablation.py +0 -46
  14. eval/pab/experiments/latte.py +0 -57
  15. eval/pab/experiments/opensora.py +0 -44
  16. eval/pab/experiments/opensora_plan.py +0 -57
  17. eval/pab/experiments/utils.py +0 -22
  18. eval/pab/vbench/VBench_full_info.json +0 -0
  19. eval/pab/vbench/cal_vbench.py +0 -154
  20. eval/pab/vbench/run_vbench.py +0 -52
  21. examples/cogvideo/sample.py +0 -14
  22. examples/latte/sample.py +0 -24
  23. examples/open_sora/sample.py +0 -24
  24. examples/open_sora_plan/sample.py +0 -24
  25. videosys/__init__.py +9 -13
  26. videosys/core/engine.py +2 -4
  27. videosys/core/pab_mgr.py +43 -175
  28. videosys/datasets/dataloader.py +0 -94
  29. videosys/datasets/image_transform.py +0 -42
  30. videosys/datasets/video_transform.py +0 -441
  31. videosys/diffusion/__init__.py +0 -41
  32. videosys/diffusion/diffusion_utils.py +0 -79
  33. videosys/diffusion/gaussian_diffusion.py +0 -829
  34. videosys/diffusion/respace.py +0 -119
  35. videosys/diffusion/timestep_sampler.py +0 -143
  36. {eval/pab/commom_metrics β†’ videosys/models/autoencoders}/__init__.py +0 -0
  37. videosys/models/{cogvideo/autoencoder_kl.py β†’ autoencoders/autoencoder_kl_cogvideox.py} +328 -94
  38. videosys/models/{open_sora/vae.py β†’ autoencoders/autoencoder_kl_open_sora.py} +2 -9
  39. videosys/models/{open_sora_plan/ae.py β†’ autoencoders/autoencoder_kl_open_sora_plan.py} +797 -14
  40. videosys/models/cogvideo/__init__.py +0 -6
  41. videosys/models/cogvideo/modules.py +0 -317
  42. videosys/models/cogvideo/retrieve_timesteps.py +0 -74
  43. videosys/models/latte/__init__.py +0 -7
  44. {eval/pab/experiments β†’ videosys/models/modules}/__init__.py +0 -0
  45. videosys/models/modules/activations.py +3 -0
  46. videosys/{modules/attn.py β†’ models/modules/attentions.py} +45 -131
  47. videosys/models/modules/downsampling.py +71 -0
  48. videosys/models/{open_sora/modules.py β†’ modules/embeddings.py} +171 -209
  49. videosys/models/modules/normalization.py +102 -0
  50. videosys/models/modules/upsampling.py +67 -0
.gitignore ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ outputs/
2
+ processed/
3
+ profile/
4
+
5
+ # Byte-compiled / optimized / DLL files
6
+ __pycache__/
7
+ *.py[cod]
8
+ *$py.class
9
+
10
+ # C extensions
11
+ *.so
12
+
13
+ # Distribution / packaging
14
+ .Python
15
+ build/
16
+ develop-eggs/
17
+ dist/
18
+ downloads/
19
+ eggs/
20
+ .eggs/
21
+ lib/
22
+ lib64/
23
+ parts/
24
+ sdist/
25
+ var/
26
+ wheels/
27
+ pip-wheel-metadata/
28
+ share/python-wheels/
29
+ *.egg-info/
30
+ .installed.cfg
31
+ *.egg
32
+ MANIFEST
33
+
34
+ # PyInstaller
35
+ # Usually these files are written by a python script from a template
36
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
37
+ *.manifest
38
+ *.spec
39
+
40
+ # Installer logs
41
+ pip-log.txt
42
+ pip-delete-this-directory.txt
43
+
44
+ # Unit test / coverage reports
45
+ htmlcov/
46
+ .tox/
47
+ .nox/
48
+ .coverage
49
+ .coverage.*
50
+ .cache
51
+ nosetests.xml
52
+ coverage.xml
53
+ *.cover
54
+ *.py,cover
55
+ .hypothesis/
56
+ .pytest_cache/
57
+
58
+ # Translations
59
+ *.mo
60
+ *.pot
61
+
62
+ # Django stuff:
63
+ *.log
64
+ local_settings.py
65
+ db.sqlite3
66
+ db.sqlite3-journal
67
+
68
+ # Flask stuff:
69
+ instance/
70
+ .webassets-cache
71
+
72
+ # Scrapy stuff:
73
+ .scrapy
74
+
75
+ # Sphinx documentation
76
+ docs/_build/
77
+ docs/.build/
78
+
79
+ # PyBuilder
80
+ target/
81
+
82
+ # Jupyter Notebook
83
+ .ipynb_checkpoints
84
+
85
+ # IPython
86
+ profile_default/
87
+ ipython_config.py
88
+
89
+ # pyenv
90
+ .python-version
91
+
92
+ # pipenv
93
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
94
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
95
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
96
+ # install all needed dependencies.
97
+ #Pipfile.lock
98
+
99
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
100
+ __pypackages__/
101
+
102
+ # Celery stuff
103
+ celerybeat-schedule
104
+ celerybeat.pid
105
+
106
+ # SageMath parsed files
107
+ *.sage.py
108
+
109
+ # Environments
110
+ .env
111
+ .venv
112
+ env/
113
+ venv/
114
+ ENV/
115
+ env.bak/
116
+ venv.bak/
117
+
118
+ # Spyder project settings
119
+ .spyderproject
120
+ .spyproject
121
+
122
+ # Rope project settings
123
+ .ropeproject
124
+
125
+ # mkdocs documentation
126
+ /site
127
+
128
+ # mypy
129
+ .mypy_cache/
130
+ .dmypy.json
131
+ dmypy.json
132
+
133
+ # Pyre type checker
134
+ .pyre/
135
+
136
+ # IDE
137
+ .idea/
138
+ .vscode/
139
+
140
+ # macos
141
+ *.DS_Store
142
+ #data/
143
+
144
+ docs/.build
145
+
146
+ # pytorch checkpoint
147
+ *.pt
148
+
149
+ # ignore any kernel build files
150
+ .o
151
+ .so
152
+
153
+ # ignore python interface defition file
154
+ .pyi
155
+
156
+ # ignore coverage test file
157
+ coverage.lcov
158
+ coverage.xml
159
+
160
+ # ignore testmon and coverage files
161
+ .coverage
162
+ .testmondata*
163
+
164
+ pretrained
165
+ samples
166
+ cache_dir
167
+ test_outputs
.isort.cfg ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ [settings]
2
+ line_length = 120
3
+ multi_line_output=3
4
+ include_trailing_comma = true
5
+ ignore_comments = true
6
+ profile = black
7
+ honor_noqa = true
.pre-commit-config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+
3
+ - repo: https://github.com/PyCQA/autoflake
4
+ rev: v2.2.1
5
+ hooks:
6
+ - id: autoflake
7
+ name: autoflake (python)
8
+ args: ['--in-place', '--remove-unused-variables', '--remove-all-unused-imports', '--ignore-init-module-imports']
9
+
10
+ - repo: https://github.com/pycqa/isort
11
+ rev: 5.12.0
12
+ hooks:
13
+ - id: isort
14
+ name: sort all imports (python)
15
+
16
+ - repo: https://github.com/psf/black-pre-commit-mirror
17
+ rev: 23.9.1
18
+ hooks:
19
+ - id: black
20
+ name: black formatter
21
+ args: ['--line-length=120', '--target-version=py37', '--target-version=py38', '--target-version=py39','--target-version=py310']
22
+
23
+ - repo: https://github.com/pre-commit/mirrors-clang-format
24
+ rev: v13.0.1
25
+ hooks:
26
+ - id: clang-format
27
+ name: clang formatter
28
+ types_or: [c++, c]
29
+
30
+ - repo: https://github.com/pre-commit/pre-commit-hooks
31
+ rev: v4.3.0
32
+ hooks:
33
+ - id: check-yaml
34
+ - id: check-merge-conflict
35
+ - id: check-case-conflict
36
+ - id: trailing-whitespace
37
+ - id: end-of-file-fixer
38
+ - id: mixed-line-ending
39
+ args: ['--fix=lf']
app.py CHANGED
@@ -2,131 +2,107 @@ import os
2
 
3
  os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs")
4
 
5
- import torch
6
- from openai import OpenAI
7
- from time import time
8
- import tempfile
9
- import uuid
10
  import logging
 
 
 
11
  import gradio as gr
12
- from videosys import CogVideoConfig, VideoSysEngine
13
- from videosys.models.cogvideo.pipeline import CogVideoPABConfig
14
  import psutil
15
- import GPUtil
16
-
17
 
 
18
 
19
  logging.basicConfig(level=logging.INFO)
20
  logger = logging.getLogger(__name__)
21
 
22
- dtype = torch.bfloat16
23
- sys_prompt = """You are part of a team of bots that creates videos. You work with an assistant bot that will draw anything you say in square brackets.
24
-
25
- For example , outputting " a beautiful morning in the woods with the sun peaking through the trees " will trigger your partner bot to output an video of a forest morning , as described. You will be prompted by people looking to create detailed , amazing videos. The way to accomplish this is to take their short prompts and make them extremely detailed and descriptive.
26
- There are a few rules to follow:
27
-
28
- You will only ever output a single video description per user request.
29
-
30
- When modifications are requested , you should not simply make the description longer . You should refactor the entire description to integrate the suggestions.
31
- Other times the user will not want modifications , but instead want a new image . In this case , you should ignore your previous conversation with the user.
32
-
33
- Video descriptions must have the same num of words as examples below. Extra words will be ignored.
34
- """
35
 
36
- def convert_prompt(prompt: str, retry_times: int = 3) -> str:
37
- if not os.environ.get("OPENAI_API_KEY"):
38
- return prompt
39
- client = OpenAI()
40
- text = prompt.strip()
41
-
42
- for i in range(retry_times):
43
- response = client.chat.completions.create(
44
- messages=[
45
- {"role": "system", "content": sys_prompt},
46
- {
47
- "role": "user",
48
- "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "a girl is on the beach"',
49
- },
50
- {
51
- "role": "assistant",
52
- "content": "A radiant woman stands on a deserted beach, arms outstretched, wearing a beige trench coat, white blouse, light blue jeans, and chic boots, against a backdrop of soft sky and sea. Moments later, she is seen mid-twirl, arms exuberant, with the lighting suggesting dawn or dusk. Then, she runs along the beach, her attire complemented by an off-white scarf and black ankle boots, the tranquil sea behind her. Finally, she holds a paper airplane, her pose reflecting joy and freedom, with the ocean's gentle waves and the sky's soft pastel hues enhancing the serene ambiance.",
53
- },
54
- {
55
- "role": "user",
56
- "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : "A man jogging on a football field"',
57
- },
58
- {
59
- "role": "assistant",
60
- "content": "A determined man in athletic attire, including a blue long-sleeve shirt, black shorts, and blue socks, jogs around a snow-covered soccer field, showcasing his solitary exercise in a quiet, overcast setting. His long dreadlocks, focused expression, and the serene winter backdrop highlight his dedication to fitness. As he moves, his attire, consisting of a blue sports sweatshirt, black athletic pants, gloves, and sneakers, grips the snowy ground. He is seen running past a chain-link fence enclosing the playground area, with a basketball hoop and children's slide, suggesting a moment of solitary exercise amidst the empty field.",
61
- },
62
- {
63
- "role": "user",
64
- "content": 'Create an imaginative video descriptive caption or modify an earlier caption for the user input : " A woman is dancing, HD footage, close-up"',
65
- },
66
- {
67
- "role": "assistant",
68
- "content": "A young woman with her hair in an updo and wearing a teal hoodie stands against a light backdrop, initially looking over her shoulder with a contemplative expression. She then confidently makes a subtle dance move, suggesting rhythm and movement. Next, she appears poised and focused, looking directly at the camera. Her expression shifts to one of introspection as she gazes downward slightly. Finally, she dances with confidence, her left hand over her heart, symbolizing a poignant moment, all while dressed in the same teal hoodie against a plain, light-colored background.",
69
- },
70
- {
71
- "role": "user",
72
- "content": f'Create an imaginative video descriptive caption or modify an earlier caption in ENGLISH for the user input: "{text}"',
73
- },
74
- ],
75
- model="glm-4-0520",
76
- temperature=0.01,
77
- top_p=0.7,
78
- stream=False,
79
- max_tokens=250,
80
- )
81
- if response.choices:
82
- return response.choices[0].message.content
83
- return prompt
84
 
85
- def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_gap=2):
86
- pab_config = CogVideoPABConfig(full_threshold=pab_threshold, full_gap=pab_gap)
87
- config = CogVideoConfig(world_size=1, enable_pab=enable_video_sys, pab_config=pab_config)
88
  engine = VideoSysEngine(config)
89
  return engine
90
 
 
91
  def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
92
- try:
93
- video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
94
 
95
- with tempfile.NamedTemporaryFile(delete=False, suffix=".mp4") as temp_file:
96
- temp_file.name
97
- unique_filename = f"{uuid.uuid4().hex}.mp4"
98
- output_path = os.path.join("./temp_outputs", unique_filename)
99
 
100
- engine.save_video(video, output_path)
101
- return output_path
102
- except Exception as e:
103
- logger.error(f"An error occurred: {str(e)}")
104
- return None
105
 
106
 
107
  def get_server_status():
108
  cpu_percent = psutil.cpu_percent()
109
  memory = psutil.virtual_memory()
110
- disk = psutil.disk_usage('/')
111
  gpus = GPUtil.getGPUs()
112
  gpu_info = []
113
  for gpu in gpus:
114
- gpu_info.append({
115
- 'id': gpu.id,
116
- 'name': gpu.name,
117
- 'load': f"{gpu.load*100:.1f}%",
118
- 'memory_used': f"{gpu.memoryUsed}MB",
119
- 'memory_total': f"{gpu.memoryTotal}MB"
120
- })
121
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
122
  return {
123
- 'cpu': f"{cpu_percent}%",
124
- 'memory': f"{memory.percent}%",
125
- 'disk': f"{disk.percent}%",
126
- 'gpu': gpu_info
127
  }
128
 
129
 
 
 
 
 
130
 
131
  css = """
132
  body {
@@ -137,16 +113,17 @@ body {
137
  padding: 20px;
138
  }
139
 
 
140
  .container {
141
  display: flex;
142
  flex-direction: column;
143
- gap: 20px;
144
  }
145
 
146
  .row {
147
  display: flex;
148
  flex-wrap: wrap;
149
- gap: 18px;
150
  }
151
 
152
  .column {
@@ -186,12 +163,6 @@ body {
186
  font-size: 0.9em !important;
187
  line-height: 1.2 !important;
188
  }
189
- .server-status button {
190
- padding: 1px 8px !important;
191
- height: 22px !important;
192
- font-size: 0.9em !important;
193
- margin-top: 2px !important;
194
- }
195
  .server-status .textbox {
196
  gap: 0 !important;
197
  }
@@ -215,150 +186,76 @@ body {
215
  """
216
 
217
  with gr.Blocks(css=css) as demo:
218
- gr.HTML("""
 
219
  <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
220
- VideoSys Huggingface SpaceπŸ€—
221
  </div>
222
  <div style="text-align: center; font-size: 15px;">
223
  🌐 Github: <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys">https://github.com/NUS-HPC-AI-Lab/VideoSys</a><br>
224
-
225
- ⚠️ This demo is for academic research and experiential use only.
226
  Users should strictly adhere to local laws and ethics.<br>
227
-
228
  πŸ’‘ This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.<br><br>
229
  </div>
230
  </div>
231
- """)
 
232
 
233
  with gr.Row():
234
  with gr.Column():
235
- prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=5)
236
- with gr.Row():
237
- gr.Markdown(
238
- "✨Upon pressing the enhanced prompt button, we will use [GLM-4 Model](https://github.com/THUDM/GLM-4) to polish the prompt and overwrite the original one."
239
- )
240
- enhance_button = gr.Button("✨ Enhance Prompt(Optional)")
241
 
242
  with gr.Column():
243
- gr.Markdown(
244
- "**Optional Parameters** (default values are recommended)<br>"
245
- "Turn Inference Steps larger if you want more detailed video, but it will be slower.<br>"
246
- "50 steps are recommended for most cases. will cause 120 seconds for inference.<br>"
247
- )
248
  with gr.Row():
249
  num_inference_steps = gr.Number(label="Inference Steps", value=50)
250
  guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
251
- pab_gap = gr.Number(label="PAB Gap", value=2, precision=0)
252
- pab_threshold = gr.Textbox(label="PAB Threshold", value="100,850", lines=1)
253
  with gr.Row():
254
- generate_button = gr.Button("🎬 Generate Video")
 
 
 
 
 
255
  generate_button_vs = gr.Button("⚑️ Generate Video with VideoSys (Faster)")
 
256
  with gr.Column(elem_classes="server-status"):
257
  gr.Markdown("#### Server Status")
258
-
259
  with gr.Row():
260
  cpu_status = gr.Textbox(label="CPU", scale=1)
261
  memory_status = gr.Textbox(label="Memory", scale=1)
262
-
263
  with gr.Row():
264
  disk_status = gr.Textbox(label="Disk", scale=1)
265
  gpu_status = gr.Textbox(label="GPU Memory", scale=1)
266
-
267
  with gr.Row():
268
- refresh_button = gr.Button("Refresh", size="sm")
269
 
270
  with gr.Column():
271
- with gr.Row():
272
- video_output = gr.Video(label="CogVideoX", width=720, height=480)
273
- with gr.Row():
274
- download_video_button = gr.File(label="πŸ“₯ Download Video", visible=False)
275
- elapsed_time = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
276
  with gr.Row():
277
  video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
278
  with gr.Row():
279
- download_video_button_vs = gr.File(label="πŸ“₯ Download Video", visible=False)
280
- elapsed_time_vs = gr.Textbox(label="Elapsed Time", value="0s", visible=False)
281
- # with gr.Column():
282
- # task_status = gr.Textbox(label="δ»»εŠ‘ηŠΆζ€", visible=False)
283
-
284
-
285
-
286
-
287
- def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
288
- engine = load_model()
289
- t = time()
290
- video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
291
- elapsed_time = time() - t
292
- video_update = gr.update(visible=True, value=video_path)
293
- elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
294
-
295
- return video_path, video_update, elapsed_time
296
-
297
- def generate_vs(prompt, num_inference_steps, guidance_scale, threshold, gap, progress=gr.Progress(track_tqdm=True)):
298
- threshold = [int(i) for i in threshold.split(",")]
299
- gap = int(gap)
300
- engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_gap=gap)
301
- t = time()
302
- video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
303
- elapsed_time = time() - t
304
- video_update = gr.update(visible=True, value=video_path)
305
- elapsed_time = gr.update(visible=True, value=f"{elapsed_time:.2f}s")
306
-
307
- return video_path, video_update, elapsed_time
308
-
309
- def enhance_prompt_func(prompt):
310
- return convert_prompt(prompt, retry_times=1)
311
-
312
- def get_server_status():
313
- cpu_percent = psutil.cpu_percent()
314
- memory = psutil.virtual_memory()
315
- disk = psutil.disk_usage('/')
316
- try:
317
- gpus = GPUtil.getGPUs()
318
- if gpus:
319
- gpu = gpus[0]
320
- gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
321
- else:
322
- gpu_memory = "No GPU found"
323
- except:
324
- gpu_memory = "GPU information unavailable"
325
-
326
- return {
327
- 'cpu': f"{cpu_percent}%",
328
- 'memory': f"{memory.percent}%",
329
- 'disk': f"{disk.percent}%",
330
- 'gpu_memory': gpu_memory
331
- }
332
-
333
-
334
- def update_server_status():
335
- status = get_server_status()
336
- return (
337
- status['cpu'],
338
- status['memory'],
339
- status['disk'],
340
- status['gpu_memory']
341
- )
342
 
343
-
344
  generate_button.click(
345
  generate_vanilla,
346
  inputs=[prompt, num_inference_steps, guidance_scale],
347
- outputs=[video_output, download_video_button, elapsed_time],
348
  )
349
 
350
  generate_button_vs.click(
351
  generate_vs,
352
- inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold, pab_gap],
353
- outputs=[video_output_vs, download_video_button_vs, elapsed_time_vs],
354
  )
355
 
356
- enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
357
-
358
-
359
  refresh_button.click(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status])
360
  demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1)
361
 
362
  if __name__ == "__main__":
363
  demo.queue(max_size=10, default_concurrency_limit=1)
364
- demo.launch()
 
2
 
3
  os.environ["GRADIO_TEMP_DIR"] = os.path.join(os.getcwd(), ".tmp_outputs")
4
 
 
 
 
 
 
5
  import logging
6
+ import uuid
7
+
8
+ import GPUtil
9
  import gradio as gr
 
 
10
  import psutil
11
+ import torch
 
12
 
13
+ from videosys import CogVideoXConfig, CogVideoXPABConfig, VideoSysEngine
14
 
15
  logging.basicConfig(level=logging.INFO)
16
  logger = logging.getLogger(__name__)
17
 
18
+ dtype = torch.float16
 
 
 
 
 
 
 
 
 
 
 
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
 
21
+ def load_model(enable_video_sys=False, pab_threshold=[100, 850], pab_range=2):
22
+ pab_config = CogVideoXPABConfig(spatial_threshold=pab_threshold, spatial_range=pab_range)
23
+ config = CogVideoXConfig(world_size=1, enable_pab=enable_video_sys, pab_config=pab_config)
24
  engine = VideoSysEngine(config)
25
  return engine
26
 
27
+
28
  def generate(engine, prompt, num_inference_steps=50, guidance_scale=6.0):
29
+ video = engine.generate(prompt, num_inference_steps=num_inference_steps, guidance_scale=guidance_scale).video[0]
 
30
 
31
+ unique_filename = f"{uuid.uuid4().hex}.mp4"
32
+ output_path = os.path.join("./.tmp_outputs", unique_filename)
 
 
33
 
34
+ engine.save_video(video, output_path)
35
+ return output_path
 
 
 
36
 
37
 
38
  def get_server_status():
39
  cpu_percent = psutil.cpu_percent()
40
  memory = psutil.virtual_memory()
41
+ disk = psutil.disk_usage("/")
42
  gpus = GPUtil.getGPUs()
43
  gpu_info = []
44
  for gpu in gpus:
45
+ gpu_info.append(
46
+ {
47
+ "id": gpu.id,
48
+ "name": gpu.name,
49
+ "load": f"{gpu.load*100:.1f}%",
50
+ "memory_used": f"{gpu.memoryUsed}MB",
51
+ "memory_total": f"{gpu.memoryTotal}MB",
52
+ }
53
+ )
54
+
55
+ return {"cpu": f"{cpu_percent}%", "memory": f"{memory.percent}%", "disk": f"{disk.percent}%", "gpu": gpu_info}
56
+
57
+
58
+ def generate_vanilla(prompt, num_inference_steps, guidance_scale, progress=gr.Progress(track_tqdm=True)):
59
+ engine = load_model()
60
+ video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
61
+ return video_path
62
+
63
+
64
+ def generate_vs(
65
+ prompt,
66
+ num_inference_steps,
67
+ guidance_scale,
68
+ threshold_start,
69
+ threshold_end,
70
+ gap,
71
+ progress=gr.Progress(track_tqdm=True),
72
+ ):
73
+ threshold = [int(threshold_end), int(threshold_start)]
74
+ gap = int(gap)
75
+ engine = load_model(enable_video_sys=True, pab_threshold=threshold, pab_range=gap)
76
+ video_path = generate(engine, prompt, num_inference_steps, guidance_scale)
77
+ return video_path
78
+
79
+
80
+ def get_server_status():
81
+ cpu_percent = psutil.cpu_percent()
82
+ memory = psutil.virtual_memory()
83
+ disk = psutil.disk_usage("/")
84
+ try:
85
+ gpus = GPUtil.getGPUs()
86
+ if gpus:
87
+ gpu = gpus[0]
88
+ gpu_memory = f"{gpu.memoryUsed}/{gpu.memoryTotal}MB ({gpu.memoryUtil*100:.1f}%)"
89
+ else:
90
+ gpu_memory = "No GPU found"
91
+ except:
92
+ gpu_memory = "GPU information unavailable"
93
+
94
  return {
95
+ "cpu": f"{cpu_percent}%",
96
+ "memory": f"{memory.percent}%",
97
+ "disk": f"{disk.percent}%",
98
+ "gpu_memory": gpu_memory,
99
  }
100
 
101
 
102
+ def update_server_status():
103
+ status = get_server_status()
104
+ return (status["cpu"], status["memory"], status["disk"], status["gpu_memory"])
105
+
106
 
107
  css = """
108
  body {
 
113
  padding: 20px;
114
  }
115
 
116
+
117
  .container {
118
  display: flex;
119
  flex-direction: column;
120
+ gap: 10px;
121
  }
122
 
123
  .row {
124
  display: flex;
125
  flex-wrap: wrap;
126
+ gap: 10px;
127
  }
128
 
129
  .column {
 
163
  font-size: 0.9em !important;
164
  line-height: 1.2 !important;
165
  }
 
 
 
 
 
 
166
  .server-status .textbox {
167
  gap: 0 !important;
168
  }
 
186
  """
187
 
188
  with gr.Blocks(css=css) as demo:
189
+ gr.HTML(
190
+ """
191
  <div style="text-align: center; font-size: 32px; font-weight: bold; margin-bottom: 20px;">
192
+ VideoSys for CogVideoXπŸ€—
193
  </div>
194
  <div style="text-align: center; font-size: 15px;">
195
  🌐 Github: <a href="https://github.com/NUS-HPC-AI-Lab/VideoSys">https://github.com/NUS-HPC-AI-Lab/VideoSys</a><br>
196
+
197
+ ⚠️ This demo is for academic research and experiential use only.
198
  Users should strictly adhere to local laws and ethics.<br>
199
+
200
  πŸ’‘ This demo only demonstrates single-device inference. To experience the full power of VideoSys, please deploy it with multiple devices.<br><br>
201
  </div>
202
  </div>
203
+ """
204
+ )
205
 
206
  with gr.Row():
207
  with gr.Column():
208
+ prompt = gr.Textbox(label="Prompt (Less than 200 Words)", value="Sunset over the sea.", lines=4)
 
 
 
 
 
209
 
210
  with gr.Column():
211
+ gr.Markdown("**Generation Parameters**<br>")
 
 
 
 
212
  with gr.Row():
213
  num_inference_steps = gr.Number(label="Inference Steps", value=50)
214
  guidance_scale = gr.Number(label="Guidance Scale", value=6.0)
 
 
215
  with gr.Row():
216
+ pab_range = gr.Number(
217
+ label="PAB Broadcast Range", value=2, precision=0, info="Broadcast timesteps range."
218
+ )
219
+ pab_threshold_start = gr.Number(label="PAB Start Timestep", value=850, info="Start from step 1000.")
220
+ pab_threshold_end = gr.Number(label="PAB End Timestep", value=100, info="End at step 0.")
221
+ with gr.Row():
222
  generate_button_vs = gr.Button("⚑️ Generate Video with VideoSys (Faster)")
223
+ generate_button = gr.Button("🎬 Generate Video (Original)")
224
  with gr.Column(elem_classes="server-status"):
225
  gr.Markdown("#### Server Status")
226
+
227
  with gr.Row():
228
  cpu_status = gr.Textbox(label="CPU", scale=1)
229
  memory_status = gr.Textbox(label="Memory", scale=1)
230
+
231
  with gr.Row():
232
  disk_status = gr.Textbox(label="Disk", scale=1)
233
  gpu_status = gr.Textbox(label="GPU Memory", scale=1)
234
+
235
  with gr.Row():
236
+ refresh_button = gr.Button("Refresh")
237
 
238
  with gr.Column():
 
 
 
 
 
239
  with gr.Row():
240
  video_output_vs = gr.Video(label="CogVideoX with VideoSys", width=720, height=480)
241
  with gr.Row():
242
+ video_output = gr.Video(label="CogVideoX", width=720, height=480)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
 
244
  generate_button.click(
245
  generate_vanilla,
246
  inputs=[prompt, num_inference_steps, guidance_scale],
247
+ outputs=[video_output],
248
  )
249
 
250
  generate_button_vs.click(
251
  generate_vs,
252
+ inputs=[prompt, num_inference_steps, guidance_scale, pab_threshold_start, pab_threshold_end, pab_range],
253
+ outputs=[video_output_vs],
254
  )
255
 
 
 
 
256
  refresh_button.click(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status])
257
  demo.load(update_server_status, outputs=[cpu_status, memory_status, disk_status, gpu_status], every=1)
258
 
259
  if __name__ == "__main__":
260
  demo.queue(max_size=10, default_concurrency_limit=1)
261
+ demo.launch()
docs/dsp.md DELETED
@@ -1,25 +0,0 @@
1
- # DSP
2
-
3
- paper: https://arxiv.org/abs/2403.10266
4
-
5
- ![dsp_overview](../assets/figures/dsp_overview.png)
6
-
7
-
8
- DSP (Dynamic Sequence Parallelism) is a novel, elegant and super efficient sequence parallelism for [OpenSora](https://github.com/hpcaitech/Open-Sora), [Latte](https://github.com/Vchitect/Latte) and other multi-dimensional transformer architecture.
9
-
10
- The key idea is to dynamically switch the parallelism dimension according to the current computation stage, leveraging the potential characteristics of multi-dimensional transformers. Compared with splitting head and sequence dimension as previous methods, it can reduce at least 75% of communication cost.
11
-
12
- It achieves **3x** speed for training and **2x** speed for inference in OpenSora compared with sota sequence parallelism ([DeepSpeed Ulysses](https://arxiv.org/abs/2309.14509)). For a 10s (80 frames) of 512x512 video, the inference latency of OpenSora is:
13
-
14
- | Method | 1xH800 | 8xH800 (DS Ulysses) | 8xH800 (DSP) |
15
- | ------ | ------ | ------ | ------ |
16
- | Latency(s) | 106 | 45 | 22 |
17
-
18
- The following is DSP's end-to-end throughput for training of OpenSora:
19
-
20
- ![dsp_overview](../assets/figures/dsp_exp.png)
21
-
22
-
23
- ### Usage
24
-
25
- DSP is currently supported for: OpenSora, OpenSoraPlan and Latte. To enable DSP, you just need to launch with multiple GPUs.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
docs/pab.md DELETED
@@ -1,121 +0,0 @@
1
- # Pyramid Attention Broadcast(PAB)
2
-
3
- [[paper](https://arxiv.org/abs/2408.12588)][[blog](https://arxiv.org/abs/2403.10266)]
4
-
5
- Pyramid Attention Broadcast(PAB)(#pyramid-attention-broadcastpab)
6
- - [Pyramid Attention Broadcast(PAB)](#pyramid-attention-broadcastpab)
7
- - [Insights](#insights)
8
- - [Pyramid Attention Broadcast (PAB) Mechanism](#pyramid-attention-broadcast-pab-mechanism)
9
- - [Experimental Results](#experimental-results)
10
- - [Usage](#usage)
11
- - [Supported Models](#supported-models)
12
- - [Configuration for PAB](#configuration-for-pab)
13
- - [Parameters](#parameters)
14
- - [Example Configuration](#example-configuration)
15
-
16
-
17
- We introduce Pyramid Attention Broadcast (PAB), the first approach that achieves real-time DiT-based video generation. By mitigating redundant attention computation, PAB achieves up to 21.6 FPS with 10.6x acceleration, without sacrificing quality across popular DiT-based video generation models including Open-Sora, Open-Sora-Plan, and Latte. Notably, as a training-free approach, PAB can enpower any future DiT-based video generation models with real-time capabilities.
18
-
19
- ## Insights
20
-
21
- ![method](../assets/figures/pab_motivation.png)
22
-
23
- Our study reveals two key insights of three **attention mechanisms** within video diffusion transformers:
24
- - First, attention differences across time steps exhibit a U-shaped pattern, with significant variations occurring during the first and last 15% of steps, while the middle 70% of steps show very stable, minor differences.
25
- - Second, within the stable middle segment, the variability differs among attention types:
26
- - **Spatial attention** varies the most, involving high-frequency elements like edges and textures;
27
- - **Temporal attention** exhibits mid-frequency variations related to movements and dynamics in videos;
28
- - **Cross-modal attention** is the most stable, linking text with video content, analogous to low-frequency signals reflecting textual semantics.
29
-
30
- ## Pyramid Attention Broadcast (PAB) Mechanism
31
-
32
- ![method](../assets/figures/pab_method.png)
33
-
34
- Building on these insights, we propose a **pyramid attention broadcast(PAB)** mechanism to minimize unnecessary computations and optimize the utility of each attention module, as shown in Figure[xx figure] below.
35
-
36
- In the middle segment, we broadcast one step's attention outputs to its subsequent several steps, thereby significantly reducing the computational cost on attention modules.
37
-
38
- For more efficient broadcast and minimum influence to effect, we set varied broadcast ranges for different attentions based on their stability and differences.
39
- **The smaller the variation in attention, the broader the potential broadcast range.**
40
-
41
-
42
- ## Experimental Results
43
- Here are the results of our experiments, more results are shown in https://oahzxl.github.io/PAB:
44
-
45
- ![pab_vis](../assets/figures/pab_vis.png)
46
-
47
-
48
- ## Usage
49
-
50
- ### Supported Models
51
-
52
- PAB currently supports Open-Sora, Open-Sora-Plan, and Latte.
53
-
54
- ### Configuration for PAB
55
-
56
- To efficiently use the Pyramid Attention Broadcast (PAB) mechanism, configure the following parameters to control the broadcasting for different attention types. This helps reduce computational costs by skipping certain steps based on attention stability.
57
-
58
- #### Parameters
59
-
60
- - **spatial_broadcast**: Enable or disable broadcasting for spatial attention.
61
- - Type: `True` or `False`
62
-
63
- - **spatial_threshold**: Set the range of diffusion steps within which spatial attention is applied.
64
- - Format: `[min_value, max_value]`
65
-
66
- - **spatial_gap**: Number of blocks in model to skip during broadcasting for spatial attention.
67
- - Type: Integer
68
-
69
- - **temporal_broadcast**: Enable or disable broadcasting for temporal attention.
70
- - Type: `True` or `False`
71
-
72
- - **temporal_threshold**: Set the range of diffusion steps within which temporal attention is applied.
73
- - Format: `[min_value, max_value]`
74
-
75
- - **temporal_gap**: Number of steps to skip during broadcasting for temporal attention.
76
- - Type: Integer
77
-
78
- - **cross_broadcast**: Enable or disable broadcasting for cross-modal attention.
79
- - Type: `True` or `False`
80
-
81
- - **cross_threshold**: Set the range of diffusion steps within which cross-modal attention is applied.
82
- - Format: `[min_value, max_value]`
83
-
84
- - **cross_gap**: Number of steps to skip during broadcasting for cross-modal attention.
85
- - Type: Integer
86
-
87
- #### Example Configuration
88
-
89
- ```yaml
90
- spatial_broadcast: True
91
- spatial_threshold: [100, 800]
92
- spatial_gap: 2
93
-
94
- temporal_broadcast: True
95
- temporal_threshold: [100, 800]
96
- temporal_gap: 3
97
-
98
- cross_broadcast: True
99
- cross_threshold: [100, 900]
100
- cross_gap: 5
101
- ```
102
-
103
- Explanation:
104
-
105
- - **Spatial Attention**:
106
- - Broadcasting enabled (`spatial_broadcast: True`)
107
- - Applied within the threshold range of 100 to 800
108
- - Skips every 2 steps (`spatial_gap: 2`)
109
- - Active within the first 28 steps (`spatial_block: [0, 28]`)
110
-
111
- - **Temporal Attention**:
112
- - Broadcasting enabled (`temporal_broadcast: True`)
113
- - Applied within the threshold range of 100 to 800
114
- - Skips every 3 steps (`temporal_gap: 3`)
115
-
116
- - **Cross-Modal Attention**:
117
- - Broadcasting enabled (`cross_broadcast: True`)
118
- - Applied within the threshold range of 100 to 900
119
- - Skips every 5 steps (`cross_gap: 5`)
120
-
121
- Adjust these settings based on your specific needs to optimize the performance of each attention mechanism.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/commom_metrics/README.md DELETED
@@ -1,6 +0,0 @@
1
- Common metrics
2
-
3
- Include LPIPS, PSNR and SSIM.
4
-
5
- The code is adapted from [common_metrics_on_video_quality
6
- ](https://github.com/JunyaoHu/common_metrics_on_video_quality).
 
 
 
 
 
 
 
eval/pab/commom_metrics/calculate_lpips.py DELETED
@@ -1,97 +0,0 @@
1
- import lpips
2
- import numpy as np
3
- import torch
4
-
5
- spatial = True # Return a spatial map of perceptual distance.
6
-
7
- # Linearly calibrated models (LPIPS)
8
- loss_fn = lpips.LPIPS(net="alex", spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
9
- # loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
10
-
11
-
12
- def trans(x):
13
- # if greyscale images add channel
14
- if x.shape[-3] == 1:
15
- x = x.repeat(1, 1, 3, 1, 1)
16
-
17
- # value range [0, 1] -> [-1, 1]
18
- x = x * 2 - 1
19
-
20
- return x
21
-
22
-
23
- def calculate_lpips(videos1, videos2, device):
24
- # image should be RGB, IMPORTANT: normalized to [-1,1]
25
-
26
- assert videos1.shape == videos2.shape
27
-
28
- # videos [batch_size, timestamps, channel, h, w]
29
-
30
- # support grayscale input, if grayscale -> channel*3
31
- # value range [0, 1] -> [-1, 1]
32
- videos1 = trans(videos1)
33
- videos2 = trans(videos2)
34
-
35
- lpips_results = []
36
-
37
- for video_num in range(videos1.shape[0]):
38
- # get a video
39
- # video [timestamps, channel, h, w]
40
- video1 = videos1[video_num]
41
- video2 = videos2[video_num]
42
-
43
- lpips_results_of_a_video = []
44
- for clip_timestamp in range(len(video1)):
45
- # get a img
46
- # img [timestamps[x], channel, h, w]
47
- # img [channel, h, w] tensor
48
-
49
- img1 = video1[clip_timestamp].unsqueeze(0).to(device)
50
- img2 = video2[clip_timestamp].unsqueeze(0).to(device)
51
-
52
- loss_fn.to(device)
53
-
54
- # calculate lpips of a video
55
- lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
56
- lpips_results.append(lpips_results_of_a_video)
57
-
58
- lpips_results = np.array(lpips_results)
59
-
60
- lpips = {}
61
- lpips_std = {}
62
-
63
- for clip_timestamp in range(len(video1)):
64
- lpips[clip_timestamp] = np.mean(lpips_results[:, clip_timestamp])
65
- lpips_std[clip_timestamp] = np.std(lpips_results[:, clip_timestamp])
66
-
67
- result = {
68
- "value": lpips,
69
- "value_std": lpips_std,
70
- "video_setting": video1.shape,
71
- "video_setting_name": "time, channel, heigth, width",
72
- }
73
-
74
- return result
75
-
76
-
77
- # test code / using example
78
-
79
-
80
- def main():
81
- NUMBER_OF_VIDEOS = 8
82
- VIDEO_LENGTH = 50
83
- CHANNEL = 3
84
- SIZE = 64
85
- videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
86
- videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
87
- device = torch.device("cuda")
88
- # device = torch.device("cpu")
89
-
90
- import json
91
-
92
- result = calculate_lpips(videos1, videos2, device)
93
- print(json.dumps(result, indent=4))
94
-
95
-
96
- if __name__ == "__main__":
97
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/commom_metrics/calculate_psnr.py DELETED
@@ -1,90 +0,0 @@
1
- import math
2
-
3
- import numpy as np
4
- import torch
5
-
6
-
7
- def img_psnr(img1, img2):
8
- # [0,1]
9
- # compute mse
10
- # mse = np.mean((img1-img2)**2)
11
- mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
12
- # compute psnr
13
- if mse < 1e-10:
14
- return 100
15
- psnr = 20 * math.log10(1 / math.sqrt(mse))
16
- return psnr
17
-
18
-
19
- def trans(x):
20
- return x
21
-
22
-
23
- def calculate_psnr(videos1, videos2):
24
- # videos [batch_size, timestamps, channel, h, w]
25
-
26
- assert videos1.shape == videos2.shape
27
-
28
- videos1 = trans(videos1)
29
- videos2 = trans(videos2)
30
-
31
- psnr_results = []
32
-
33
- for video_num in range(videos1.shape[0]):
34
- # get a video
35
- # video [timestamps, channel, h, w]
36
- video1 = videos1[video_num]
37
- video2 = videos2[video_num]
38
-
39
- psnr_results_of_a_video = []
40
- for clip_timestamp in range(len(video1)):
41
- # get a img
42
- # img [timestamps[x], channel, h, w]
43
- # img [channel, h, w] numpy
44
-
45
- img1 = video1[clip_timestamp].numpy()
46
- img2 = video2[clip_timestamp].numpy()
47
-
48
- # calculate psnr of a video
49
- psnr_results_of_a_video.append(img_psnr(img1, img2))
50
-
51
- psnr_results.append(psnr_results_of_a_video)
52
-
53
- psnr_results = np.array(psnr_results)
54
-
55
- psnr = {}
56
- psnr_std = {}
57
-
58
- for clip_timestamp in range(len(video1)):
59
- psnr[clip_timestamp] = np.mean(psnr_results[:, clip_timestamp])
60
- psnr_std[clip_timestamp] = np.std(psnr_results[:, clip_timestamp])
61
-
62
- result = {
63
- "value": psnr,
64
- "value_std": psnr_std,
65
- "video_setting": video1.shape,
66
- "video_setting_name": "time, channel, heigth, width",
67
- }
68
-
69
- return result
70
-
71
-
72
- # test code / using example
73
-
74
-
75
- def main():
76
- NUMBER_OF_VIDEOS = 8
77
- VIDEO_LENGTH = 50
78
- CHANNEL = 3
79
- SIZE = 64
80
- videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
81
- videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
82
-
83
- import json
84
-
85
- result = calculate_psnr(videos1, videos2)
86
- print(json.dumps(result, indent=4))
87
-
88
-
89
- if __name__ == "__main__":
90
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/commom_metrics/calculate_ssim.py DELETED
@@ -1,116 +0,0 @@
1
- import cv2
2
- import numpy as np
3
- import torch
4
-
5
-
6
- def ssim(img1, img2):
7
- C1 = 0.01**2
8
- C2 = 0.03**2
9
- img1 = img1.astype(np.float64)
10
- img2 = img2.astype(np.float64)
11
- kernel = cv2.getGaussianKernel(11, 1.5)
12
- window = np.outer(kernel, kernel.transpose())
13
- mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
14
- mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
15
- mu1_sq = mu1**2
16
- mu2_sq = mu2**2
17
- mu1_mu2 = mu1 * mu2
18
- sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
19
- sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
20
- sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
21
- ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) * (sigma1_sq + sigma2_sq + C2))
22
- return ssim_map.mean()
23
-
24
-
25
- def calculate_ssim_function(img1, img2):
26
- # [0,1]
27
- # ssim is the only metric extremely sensitive to gray being compared to b/w
28
- if not img1.shape == img2.shape:
29
- raise ValueError("Input images must have the same dimensions.")
30
- if img1.ndim == 2:
31
- return ssim(img1, img2)
32
- elif img1.ndim == 3:
33
- if img1.shape[0] == 3:
34
- ssims = []
35
- for i in range(3):
36
- ssims.append(ssim(img1[i], img2[i]))
37
- return np.array(ssims).mean()
38
- elif img1.shape[0] == 1:
39
- return ssim(np.squeeze(img1), np.squeeze(img2))
40
- else:
41
- raise ValueError("Wrong input image dimensions.")
42
-
43
-
44
- def trans(x):
45
- return x
46
-
47
-
48
- def calculate_ssim(videos1, videos2):
49
- # videos [batch_size, timestamps, channel, h, w]
50
-
51
- assert videos1.shape == videos2.shape
52
-
53
- videos1 = trans(videos1)
54
- videos2 = trans(videos2)
55
-
56
- ssim_results = []
57
-
58
- for video_num in range(videos1.shape[0]):
59
- # get a video
60
- # video [timestamps, channel, h, w]
61
- video1 = videos1[video_num]
62
- video2 = videos2[video_num]
63
-
64
- ssim_results_of_a_video = []
65
- for clip_timestamp in range(len(video1)):
66
- # get a img
67
- # img [timestamps[x], channel, h, w]
68
- # img [channel, h, w] numpy
69
-
70
- img1 = video1[clip_timestamp].numpy()
71
- img2 = video2[clip_timestamp].numpy()
72
-
73
- # calculate ssim of a video
74
- ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
75
-
76
- ssim_results.append(ssim_results_of_a_video)
77
-
78
- ssim_results = np.array(ssim_results)
79
-
80
- ssim = {}
81
- ssim_std = {}
82
-
83
- for clip_timestamp in range(len(video1)):
84
- ssim[clip_timestamp] = np.mean(ssim_results[:, clip_timestamp])
85
- ssim_std[clip_timestamp] = np.std(ssim_results[:, clip_timestamp])
86
-
87
- result = {
88
- "value": ssim,
89
- "value_std": ssim_std,
90
- "video_setting": video1.shape,
91
- "video_setting_name": "time, channel, heigth, width",
92
- }
93
-
94
- return result
95
-
96
-
97
- # test code / using example
98
-
99
-
100
- def main():
101
- NUMBER_OF_VIDEOS = 8
102
- VIDEO_LENGTH = 50
103
- CHANNEL = 3
104
- SIZE = 64
105
- videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
106
- videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
107
- torch.device("cuda")
108
-
109
- import json
110
-
111
- result = calculate_ssim(videos1, videos2)
112
- print(json.dumps(result, indent=4))
113
-
114
-
115
- if __name__ == "__main__":
116
- main()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/commom_metrics/eval.py DELETED
@@ -1,160 +0,0 @@
1
- import argparse
2
- import os
3
-
4
- import imageio
5
- import torch
6
- import torchvision.transforms.functional as F
7
- import tqdm
8
- from calculate_lpips import calculate_lpips
9
- from calculate_psnr import calculate_psnr
10
- from calculate_ssim import calculate_ssim
11
-
12
-
13
- def load_videos(directory, video_ids, file_extension):
14
- videos = []
15
- for video_id in video_ids:
16
- video_path = os.path.join(directory, f"{video_id}.{file_extension}")
17
- if os.path.exists(video_path):
18
- video = load_video(video_path) # Define load_video based on how videos are stored
19
- videos.append(video)
20
- else:
21
- raise ValueError(f"Video {video_id}.{file_extension} not found in {directory}")
22
- return videos
23
-
24
-
25
- def load_video(video_path):
26
- """
27
- Load a video from the given path and convert it to a PyTorch tensor.
28
- """
29
- # Read the video using imageio
30
- reader = imageio.get_reader(video_path, "ffmpeg")
31
-
32
- # Extract frames and convert to a list of tensors
33
- frames = []
34
- for frame in reader:
35
- # Convert the frame to a tensor and permute the dimensions to match (C, H, W)
36
- frame_tensor = torch.tensor(frame).cuda().permute(2, 0, 1)
37
- frames.append(frame_tensor)
38
-
39
- # Stack the list of tensors into a single tensor with shape (T, C, H, W)
40
- video_tensor = torch.stack(frames)
41
-
42
- return video_tensor
43
-
44
-
45
- def resize_video(video, target_height, target_width):
46
- resized_frames = []
47
- for frame in video:
48
- resized_frame = F.resize(frame, [target_height, target_width])
49
- resized_frames.append(resized_frame)
50
- return torch.stack(resized_frames)
51
-
52
-
53
- def preprocess_eval_video(eval_video, generated_video_shape):
54
- T_gen, _, H_gen, W_gen = generated_video_shape
55
- T_eval, _, H_eval, W_eval = eval_video.shape
56
-
57
- if T_eval < T_gen:
58
- raise ValueError(f"Eval video time steps ({T_eval}) are less than generated video time steps ({T_gen}).")
59
-
60
- if H_eval < H_gen or W_eval < W_gen:
61
- # Resize the video maintaining the aspect ratio
62
- resize_height = max(H_gen, int(H_gen * (H_eval / W_eval)))
63
- resize_width = max(W_gen, int(W_gen * (W_eval / H_eval)))
64
- eval_video = resize_video(eval_video, resize_height, resize_width)
65
- # Recalculate the dimensions
66
- T_eval, _, H_eval, W_eval = eval_video.shape
67
-
68
- # Center crop
69
- start_h = (H_eval - H_gen) // 2
70
- start_w = (W_eval - W_gen) // 2
71
- cropped_video = eval_video[:T_gen, :, start_h : start_h + H_gen, start_w : start_w + W_gen]
72
-
73
- return cropped_video
74
-
75
-
76
- def main(args):
77
- device = "cuda"
78
- gt_video_dir = args.gt_video_dir
79
- generated_video_dir = args.generated_video_dir
80
-
81
- video_ids = []
82
- file_extension = "mp4"
83
- for f in os.listdir(generated_video_dir):
84
- if f.endswith(f".{file_extension}"):
85
- video_ids.append(f.replace(f".{file_extension}", ""))
86
- if not video_ids:
87
- raise ValueError("No videos found in the generated video dataset. Exiting.")
88
-
89
- print(f"Find {len(video_ids)} videos")
90
- prompt_interval = 1
91
- batch_size = 16
92
- calculate_lpips_flag, calculate_psnr_flag, calculate_ssim_flag = True, True, True
93
-
94
- lpips_results = []
95
- psnr_results = []
96
- ssim_results = []
97
-
98
- total_len = len(video_ids) // batch_size + (1 if len(video_ids) % batch_size != 0 else 0)
99
-
100
- for idx, video_id in enumerate(tqdm.tqdm(range(total_len))):
101
- gt_videos_tensor = []
102
- generated_videos_tensor = []
103
- for i in range(batch_size):
104
- video_idx = idx * batch_size + i
105
- if video_idx >= len(video_ids):
106
- break
107
- video_id = video_ids[video_idx]
108
- generated_video = load_video(os.path.join(generated_video_dir, f"{video_id}.{file_extension}"))
109
- generated_videos_tensor.append(generated_video)
110
- eval_video = load_video(os.path.join(gt_video_dir, f"{video_id}.{file_extension}"))
111
- gt_videos_tensor.append(eval_video)
112
- gt_videos_tensor = (torch.stack(gt_videos_tensor) / 255.0).cpu()
113
- generated_videos_tensor = (torch.stack(generated_videos_tensor) / 255.0).cpu()
114
-
115
- if calculate_lpips_flag:
116
- result = calculate_lpips(gt_videos_tensor, generated_videos_tensor, device=device)
117
- result = result["value"].values()
118
- result = sum(result) / len(result)
119
- lpips_results.append(result)
120
-
121
- if calculate_psnr_flag:
122
- result = calculate_psnr(gt_videos_tensor, generated_videos_tensor)
123
- result = result["value"].values()
124
- result = sum(result) / len(result)
125
- psnr_results.append(result)
126
-
127
- if calculate_ssim_flag:
128
- result = calculate_ssim(gt_videos_tensor, generated_videos_tensor)
129
- result = result["value"].values()
130
- result = sum(result) / len(result)
131
- ssim_results.append(result)
132
-
133
- if (idx + 1) % prompt_interval == 0:
134
- out_str = ""
135
- for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
136
- result = sum(results) / len(results)
137
- out_str += f"{name}: {result:.4f}, "
138
- print(f"Processed {idx + 1} videos. {out_str[:-2]}")
139
-
140
- out_str = ""
141
- for results, name in zip([lpips_results, psnr_results, ssim_results], ["lpips", "psnr", "ssim"]):
142
- result = sum(results) / len(results)
143
- out_str += f"{name}: {result:.4f}, "
144
- out_str = out_str[:-2]
145
-
146
- # save
147
- with open(f"./{os.path.basename(generated_video_dir)}.txt", "w+") as f:
148
- f.write(out_str)
149
-
150
- print(f"Processed all videos. {out_str}")
151
-
152
-
153
- if __name__ == "__main__":
154
- parser = argparse.ArgumentParser()
155
- parser.add_argument("--gt_video_dir", type=str)
156
- parser.add_argument("--generated_video_dir", type=str)
157
-
158
- args = parser.parse_args()
159
-
160
- main(args)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/experiments/attention_ablation.py DELETED
@@ -1,60 +0,0 @@
1
- from utils import generate_func, read_prompt_list
2
-
3
- import videosys
4
- from videosys import OpenSoraConfig, OpenSoraPipeline
5
- from videosys.models.open_sora import OpenSoraPABConfig
6
-
7
-
8
- def attention_ablation_func(pab_kwargs, prompt_list, output_dir):
9
- pab_config = OpenSoraPABConfig(**pab_kwargs)
10
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
11
- pipeline = OpenSoraPipeline(config)
12
-
13
- generate_func(pipeline, prompt_list, output_dir)
14
-
15
-
16
- def main(prompt_list):
17
- # spatial
18
- gap_list = [2, 3, 4, 5]
19
- for gap in gap_list:
20
- pab_kwargs = {
21
- "spatial_broadcast": True,
22
- "spatial_gap": gap,
23
- "temporal_broadcast": False,
24
- "cross_broadcast": False,
25
- "mlp_skip": False,
26
- }
27
- output_dir = f"./samples/attention_ablation/spatial_g{gap}"
28
- attention_ablation_func(pab_kwargs, prompt_list, output_dir)
29
-
30
- # temporal
31
- gap_list = [3, 4, 5, 6]
32
- for gap in gap_list:
33
- pab_kwargs = {
34
- "spatial_broadcast": False,
35
- "temporal_broadcast": True,
36
- "temporal_gap": gap,
37
- "cross_broadcast": False,
38
- "mlp_skip": False,
39
- }
40
- output_dir = f"./samples/attention_ablation/temporal_g{gap}"
41
- attention_ablation_func(pab_kwargs, prompt_list, output_dir)
42
-
43
- # cross
44
- gap_list = [5, 6, 7, 8]
45
- for gap in gap_list:
46
- pab_kwargs = {
47
- "spatial_broadcast": False,
48
- "temporal_broadcast": False,
49
- "cross_broadcast": True,
50
- "cross_gap": gap,
51
- "mlp_skip": False,
52
- }
53
- output_dir = f"./samples/attention_ablation/cross_g{gap}"
54
- attention_ablation_func(pab_kwargs, prompt_list, output_dir)
55
-
56
-
57
- if __name__ == "__main__":
58
- videosys.initialize(42)
59
- prompt_list = read_prompt_list("vbench/VBench_full_info.json")
60
- main(prompt_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/experiments/components_ablation.py DELETED
@@ -1,46 +0,0 @@
1
- from utils import generate_func, read_prompt_list
2
-
3
- import videosys
4
- from videosys import OpenSoraConfig, OpenSoraPipeline
5
- from videosys.models.open_sora import OpenSoraPABConfig
6
-
7
-
8
- def wo_spatial(prompt_list):
9
- pab_config = OpenSoraPABConfig(spatial_broadcast=False)
10
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
11
- pipeline = OpenSoraPipeline(config)
12
-
13
- generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_spatial")
14
-
15
-
16
- def wo_temporal(prompt_list):
17
- pab_config = OpenSoraPABConfig(temporal_broadcast=False)
18
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
19
- pipeline = OpenSoraPipeline(config)
20
-
21
- generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_temporal")
22
-
23
-
24
- def wo_cross(prompt_list):
25
- pab_config = OpenSoraPABConfig(cross_broadcast=False)
26
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
27
- pipeline = OpenSoraPipeline(config)
28
-
29
- generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_cross")
30
-
31
-
32
- def wo_mlp(prompt_list):
33
- pab_config = OpenSoraPABConfig(mlp_skip=False)
34
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
35
- pipeline = OpenSoraPipeline(config)
36
-
37
- generate_func(pipeline, prompt_list, "./samples/components_ablation/wo_mlp")
38
-
39
-
40
- if __name__ == "__main__":
41
- videosys.initialize(42)
42
- prompt_list = read_prompt_list("./vbench/VBench_full_info.json")
43
- wo_spatial(prompt_list)
44
- wo_temporal(prompt_list)
45
- wo_cross(prompt_list)
46
- wo_mlp(prompt_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/experiments/latte.py DELETED
@@ -1,57 +0,0 @@
1
- from utils import generate_func, read_prompt_list
2
-
3
- import videosys
4
- from videosys import LatteConfig, LattePipeline
5
- from videosys.models.latte import LattePABConfig
6
-
7
-
8
- def eval_base(prompt_list):
9
- config = LatteConfig()
10
- pipeline = LattePipeline(config)
11
-
12
- generate_func(pipeline, prompt_list, "./samples/latte_base", loop=5)
13
-
14
-
15
- def eval_pab1(prompt_list):
16
- pab_config = LattePABConfig(
17
- spatial_gap=2,
18
- temporal_gap=3,
19
- cross_gap=6,
20
- )
21
- config = LatteConfig(enable_pab=True, pab_config=pab_config)
22
- pipeline = LattePipeline(config)
23
-
24
- generate_func(pipeline, prompt_list, "./samples/latte_pab1", loop=5)
25
-
26
-
27
- def eval_pab2(prompt_list):
28
- pab_config = LattePABConfig(
29
- spatial_gap=3,
30
- temporal_gap=4,
31
- cross_gap=7,
32
- )
33
- config = LatteConfig(enable_pab=True, pab_config=pab_config)
34
- pipeline = LattePipeline(config)
35
-
36
- generate_func(pipeline, prompt_list, "./samples/latte_pab2", loop=5)
37
-
38
-
39
- def eval_pab3(prompt_list):
40
- pab_config = LattePABConfig(
41
- spatial_gap=4,
42
- temporal_gap=6,
43
- cross_gap=9,
44
- )
45
- config = LatteConfig(enable_pab=True, pab_config=pab_config)
46
- pipeline = LattePipeline(config)
47
-
48
- generate_func(pipeline, prompt_list, "./samples/latte_pab3", loop=5)
49
-
50
-
51
- if __name__ == "__main__":
52
- videosys.initialize(42)
53
- prompt_list = read_prompt_list("vbench/VBench_full_info.json")
54
- eval_base(prompt_list)
55
- eval_pab1(prompt_list)
56
- eval_pab2(prompt_list)
57
- eval_pab3(prompt_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/experiments/opensora.py DELETED
@@ -1,44 +0,0 @@
1
- from utils import generate_func, read_prompt_list
2
-
3
- import videosys
4
- from videosys import OpenSoraConfig, OpenSoraPipeline
5
- from videosys.models.open_sora import OpenSoraPABConfig
6
-
7
-
8
- def eval_base(prompt_list):
9
- config = OpenSoraConfig()
10
- pipeline = OpenSoraPipeline(config)
11
-
12
- generate_func(pipeline, prompt_list, "./samples/opensora_base", loop=5)
13
-
14
-
15
- def eval_pab1(prompt_list):
16
- config = OpenSoraConfig(enable_pab=True)
17
- pipeline = OpenSoraPipeline(config)
18
-
19
- generate_func(pipeline, prompt_list, "./samples/opensora_pab1", loop=5)
20
-
21
-
22
- def eval_pab2(prompt_list):
23
- pab_config = OpenSoraPABConfig(spatial_gap=3, temporal_gap=5, cross_gap=7)
24
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
25
- pipeline = OpenSoraPipeline(config)
26
-
27
- generate_func(pipeline, prompt_list, "./samples/opensora_pab2", loop=5)
28
-
29
-
30
- def eval_pab3(prompt_list):
31
- pab_config = OpenSoraPABConfig(spatial_gap=5, temporal_gap=7, cross_gap=9)
32
- config = OpenSoraConfig(enable_pab=True, pab_config=pab_config)
33
- pipeline = OpenSoraPipeline(config)
34
-
35
- generate_func(pipeline, prompt_list, "./samples/opensora_pab3", loop=5)
36
-
37
-
38
- if __name__ == "__main__":
39
- videosys.initialize(42)
40
- prompt_list = read_prompt_list("vbench/VBench_full_info.json")
41
- eval_base(prompt_list)
42
- eval_pab1(prompt_list)
43
- eval_pab2(prompt_list)
44
- eval_pab3(prompt_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/experiments/opensora_plan.py DELETED
@@ -1,57 +0,0 @@
1
- from utils import generate_func, read_prompt_list
2
-
3
- import videosys
4
- from videosys import OpenSoraPlanConfig, OpenSoraPlanPipeline
5
- from videosys.models.open_sora_plan import OpenSoraPlanPABConfig
6
-
7
-
8
- def eval_base(prompt_list):
9
- config = OpenSoraPlanConfig()
10
- pipeline = OpenSoraPlanPipeline(config)
11
-
12
- generate_func(pipeline, prompt_list, "./samples/opensoraplan_base", loop=5)
13
-
14
-
15
- def eval_pab1(prompt_list):
16
- pab_config = OpenSoraPlanPABConfig(
17
- spatial_gap=2,
18
- temporal_gap=4,
19
- cross_gap=6,
20
- )
21
- config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
22
- pipeline = OpenSoraPlanPipeline(config)
23
-
24
- generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab1", loop=5)
25
-
26
-
27
- def eval_pab2(prompt_list):
28
- pab_config = OpenSoraPlanPABConfig(
29
- spatial_gap=3,
30
- temporal_gap=5,
31
- cross_gap=7,
32
- )
33
- config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
34
- pipeline = OpenSoraPlanPipeline(config)
35
-
36
- generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab2", loop=5)
37
-
38
-
39
- def eval_pab3(prompt_list):
40
- pab_config = OpenSoraPlanPABConfig(
41
- spatial_gap=5,
42
- temporal_gap=7,
43
- cross_gap=9,
44
- )
45
- config = OpenSoraPlanConfig(enable_pab=True, pab_config=pab_config)
46
- pipeline = OpenSoraPlanPipeline(config)
47
-
48
- generate_func(pipeline, prompt_list, "./samples/opensoraplan_pab3", loop=5)
49
-
50
-
51
- if __name__ == "__main__":
52
- videosys.initialize(42)
53
- prompt_list = read_prompt_list("vbench/VBench_full_info.json")
54
- eval_base(prompt_list)
55
- eval_pab1(prompt_list)
56
- eval_pab2(prompt_list)
57
- eval_pab3(prompt_list)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/experiments/utils.py DELETED
@@ -1,22 +0,0 @@
1
- import json
2
- import os
3
-
4
- import tqdm
5
-
6
- from videosys.utils.utils import set_seed
7
-
8
-
9
- def generate_func(pipeline, prompt_list, output_dir, loop: int = 5, kwargs: dict = {}):
10
- kwargs["verbose"] = False
11
- for prompt in tqdm.tqdm(prompt_list):
12
- for l in range(loop):
13
- set_seed(l)
14
- video = pipeline.generate(prompt, **kwargs).video[0]
15
- pipeline.save_video(video, os.path.join(output_dir, f"{prompt}-{l}.mp4"))
16
-
17
-
18
- def read_prompt_list(prompt_list_path):
19
- with open(prompt_list_path, "r") as f:
20
- prompt_list = json.load(f)
21
- prompt_list = [prompt["prompt_en"] for prompt in prompt_list]
22
- return prompt_list
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/vbench/VBench_full_info.json DELETED
The diff for this file is too large to render. See raw diff
 
eval/pab/vbench/cal_vbench.py DELETED
@@ -1,154 +0,0 @@
1
- import argparse
2
- import json
3
- import os
4
-
5
- SEMANTIC_WEIGHT = 1
6
- QUALITY_WEIGHT = 4
7
-
8
- QUALITY_LIST = [
9
- "subject consistency",
10
- "background consistency",
11
- "temporal flickering",
12
- "motion smoothness",
13
- "aesthetic quality",
14
- "imaging quality",
15
- "dynamic degree",
16
- ]
17
-
18
- SEMANTIC_LIST = [
19
- "object class",
20
- "multiple objects",
21
- "human action",
22
- "color",
23
- "spatial relationship",
24
- "scene",
25
- "appearance style",
26
- "temporal style",
27
- "overall consistency",
28
- ]
29
-
30
- NORMALIZE_DIC = {
31
- "subject consistency": {"Min": 0.1462, "Max": 1.0},
32
- "background consistency": {"Min": 0.2615, "Max": 1.0},
33
- "temporal flickering": {"Min": 0.6293, "Max": 1.0},
34
- "motion smoothness": {"Min": 0.706, "Max": 0.9975},
35
- "dynamic degree": {"Min": 0.0, "Max": 1.0},
36
- "aesthetic quality": {"Min": 0.0, "Max": 1.0},
37
- "imaging quality": {"Min": 0.0, "Max": 1.0},
38
- "object class": {"Min": 0.0, "Max": 1.0},
39
- "multiple objects": {"Min": 0.0, "Max": 1.0},
40
- "human action": {"Min": 0.0, "Max": 1.0},
41
- "color": {"Min": 0.0, "Max": 1.0},
42
- "spatial relationship": {"Min": 0.0, "Max": 1.0},
43
- "scene": {"Min": 0.0, "Max": 0.8222},
44
- "appearance style": {"Min": 0.0009, "Max": 0.2855},
45
- "temporal style": {"Min": 0.0, "Max": 0.364},
46
- "overall consistency": {"Min": 0.0, "Max": 0.364},
47
- }
48
-
49
- DIM_WEIGHT = {
50
- "subject consistency": 1,
51
- "background consistency": 1,
52
- "temporal flickering": 1,
53
- "motion smoothness": 1,
54
- "aesthetic quality": 1,
55
- "imaging quality": 1,
56
- "dynamic degree": 0.5,
57
- "object class": 1,
58
- "multiple objects": 1,
59
- "human action": 1,
60
- "color": 1,
61
- "spatial relationship": 1,
62
- "scene": 1,
63
- "appearance style": 1,
64
- "temporal style": 1,
65
- "overall consistency": 1,
66
- }
67
-
68
- ordered_scaled_res = [
69
- "total score",
70
- "quality score",
71
- "semantic score",
72
- "subject consistency",
73
- "background consistency",
74
- "temporal flickering",
75
- "motion smoothness",
76
- "dynamic degree",
77
- "aesthetic quality",
78
- "imaging quality",
79
- "object class",
80
- "multiple objects",
81
- "human action",
82
- "color",
83
- "spatial relationship",
84
- "scene",
85
- "appearance style",
86
- "temporal style",
87
- "overall consistency",
88
- ]
89
-
90
-
91
- def parse_args():
92
- parser = argparse.ArgumentParser()
93
- parser.add_argument("--score_dir", required=True, type=str)
94
- args = parser.parse_args()
95
- return args
96
-
97
-
98
- if __name__ == "__main__":
99
- args = parse_args()
100
- res_postfix = "_eval_results.json"
101
- info_postfix = "_full_info.json"
102
- files = os.listdir(args.score_dir)
103
- res_files = [x for x in files if res_postfix in x]
104
- info_files = [x for x in files if info_postfix in x]
105
- assert len(res_files) == len(info_files), f"got {len(res_files)} res files, but {len(info_files)} info files"
106
-
107
- full_results = {}
108
- for res_file in res_files:
109
- # first check if results is normal
110
- info_file = res_file.split(res_postfix)[0] + info_postfix
111
- with open(os.path.join(args.score_dir, info_file), "r", encoding="utf-8") as f:
112
- info = json.load(f)
113
- assert len(info[0]["video_list"]) > 0, f"Error: {info_file} has 0 video list"
114
- # read results
115
- with open(os.path.join(args.score_dir, res_file), "r", encoding="utf-8") as f:
116
- data = json.load(f)
117
- for key, val in data.items():
118
- full_results[key] = format(val[0], ".4f")
119
-
120
- scaled_results = {}
121
- dims = set()
122
- for key, val in full_results.items():
123
- dim = key.replace("_", " ") if "_" in key else key
124
- scaled_score = (float(val) - NORMALIZE_DIC[dim]["Min"]) / (
125
- NORMALIZE_DIC[dim]["Max"] - NORMALIZE_DIC[dim]["Min"]
126
- )
127
- scaled_score *= DIM_WEIGHT[dim]
128
- scaled_results[dim] = scaled_score
129
- dims.add(dim)
130
-
131
- assert len(dims) == len(NORMALIZE_DIC), f"{set(NORMALIZE_DIC.keys())-dims} not calculated yet"
132
-
133
- quality_score = sum([scaled_results[i] for i in QUALITY_LIST]) / sum([DIM_WEIGHT[i] for i in QUALITY_LIST])
134
- semantic_score = sum([scaled_results[i] for i in SEMANTIC_LIST]) / sum([DIM_WEIGHT[i] for i in SEMANTIC_LIST])
135
- scaled_results["quality score"] = quality_score
136
- scaled_results["semantic score"] = semantic_score
137
- scaled_results["total score"] = (quality_score * QUALITY_WEIGHT + semantic_score * SEMANTIC_WEIGHT) / (
138
- QUALITY_WEIGHT + SEMANTIC_WEIGHT
139
- )
140
-
141
- formated_scaled_results = {"items": []}
142
- for key in ordered_scaled_res:
143
- formated_score = format(scaled_results[key] * 100, ".2f") + "%"
144
- formated_scaled_results["items"].append({key: formated_score})
145
-
146
- output_file_path = os.path.join(args.score_dir, "all_results.json")
147
- with open(output_file_path, "w") as outfile:
148
- json.dump(full_results, outfile, indent=4, sort_keys=True)
149
- print(f"results saved to: {output_file_path}")
150
-
151
- scaled_file_path = os.path.join(args.score_dir, "scaled_results.json")
152
- with open(scaled_file_path, "w") as outfile:
153
- json.dump(formated_scaled_results, outfile, indent=4, sort_keys=True)
154
- print(f"results saved to: {scaled_file_path}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eval/pab/vbench/run_vbench.py DELETED
@@ -1,52 +0,0 @@
1
- import argparse
2
-
3
- import torch
4
- from vbench import VBench
5
-
6
- full_info_path = "./vbench/VBench_full_info.json"
7
-
8
- dimensions = [
9
- "subject_consistency",
10
- "imaging_quality",
11
- "background_consistency",
12
- "motion_smoothness",
13
- "overall_consistency",
14
- "human_action",
15
- "multiple_objects",
16
- "spatial_relationship",
17
- "object_class",
18
- "color",
19
- "aesthetic_quality",
20
- "appearance_style",
21
- "temporal_flickering",
22
- "scene",
23
- "temporal_style",
24
- "dynamic_degree",
25
- ]
26
-
27
-
28
- def parse_args():
29
- parser = argparse.ArgumentParser()
30
- parser.add_argument("--video_path", required=True, type=str)
31
- args = parser.parse_args()
32
- return args
33
-
34
-
35
- if __name__ == "__main__":
36
- args = parse_args()
37
- save_path = args.video_path.replace("/samples/", "/vbench_out/")
38
-
39
- kwargs = {}
40
- kwargs["imaging_quality_preprocessing_mode"] = "longer" # use VBench/evaluate.py default
41
-
42
- for dimension in dimensions:
43
- my_VBench = VBench(torch.device("cuda"), full_info_path, save_path)
44
- my_VBench.evaluate(
45
- videos_path=args.video_path,
46
- name=dimension,
47
- local=False,
48
- read_frame=False,
49
- dimension_list=[dimension],
50
- mode="vbench_standard",
51
- **kwargs,
52
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/cogvideo/sample.py DELETED
@@ -1,14 +0,0 @@
1
- from videosys import CogVideoConfig, VideoSysEngine
2
-
3
-
4
- def run_base():
5
- config = CogVideoConfig(world_size=1)
6
- engine = VideoSysEngine(config)
7
-
8
- prompt = "Sunset over the sea."
9
- video = engine.generate(prompt).video[0]
10
- engine.save_video(video, f"./outputs/{prompt}.mp4")
11
-
12
-
13
- if __name__ == "__main__":
14
- run_base()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/latte/sample.py DELETED
@@ -1,24 +0,0 @@
1
- from videosys import LatteConfig, VideoSysEngine
2
-
3
-
4
- def run_base():
5
- config = LatteConfig(world_size=1)
6
- engine = VideoSysEngine(config)
7
-
8
- prompt = "Sunset over the sea."
9
- video = engine.generate(prompt).video[0]
10
- engine.save_video(video, f"./outputs/{prompt}.mp4")
11
-
12
-
13
- def run_pab():
14
- config = LatteConfig(world_size=1)
15
- engine = VideoSysEngine(config)
16
-
17
- prompt = "Sunset over the sea."
18
- video = engine.generate(prompt).video[0]
19
- engine.save_video(video, f"./outputs/{prompt}.mp4")
20
-
21
-
22
- if __name__ == "__main__":
23
- run_base()
24
- # run_pab()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/open_sora/sample.py DELETED
@@ -1,24 +0,0 @@
1
- from videosys import OpenSoraConfig, VideoSysEngine
2
-
3
-
4
- def run_base():
5
- config = OpenSoraConfig(world_size=1)
6
- engine = VideoSysEngine(config)
7
-
8
- prompt = "Sunset over the sea."
9
- video = engine.generate(prompt).video[0]
10
- engine.save_video(video, f"./outputs/{prompt}.mp4")
11
-
12
-
13
- def run_pab():
14
- config = OpenSoraConfig(world_size=1, enable_pab=True)
15
- engine = VideoSysEngine(config)
16
-
17
- prompt = "Sunset over the sea."
18
- video = engine.generate(prompt).video[0]
19
- engine.save_video(video, f"./outputs/{prompt}.mp4")
20
-
21
-
22
- if __name__ == "__main__":
23
- run_base()
24
- run_pab()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
examples/open_sora_plan/sample.py DELETED
@@ -1,24 +0,0 @@
1
- from videosys import OpenSoraPlanConfig, VideoSysEngine
2
-
3
-
4
- def run_base():
5
- config = OpenSoraPlanConfig(world_size=1)
6
- engine = VideoSysEngine(config)
7
-
8
- prompt = "Sunset over the sea."
9
- video = engine.generate(prompt).video[0]
10
- engine.save_video(video, f"./outputs/{prompt}.mp4")
11
-
12
-
13
- def run_pab():
14
- config = OpenSoraPlanConfig(world_size=1)
15
- engine = VideoSysEngine(config)
16
-
17
- prompt = "Sunset over the sea."
18
- video = engine.generate(prompt).video[0]
19
- engine.save_video(video, f"./outputs/{prompt}.mp4")
20
-
21
-
22
- if __name__ == "__main__":
23
- run_base()
24
- # run_pab()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/__init__.py CHANGED
@@ -1,19 +1,15 @@
1
  from .core.engine import VideoSysEngine
2
  from .core.parallel_mgr import initialize
3
- from .models.cogvideo.pipeline import CogVideoConfig, CogVideoPipeline
4
- from .models.latte.pipeline import LatteConfig, LattePipeline
5
- from .models.open_sora.pipeline import OpenSoraConfig, OpenSoraPipeline
6
- from .models.open_sora_plan.pipeline import OpenSoraPlanConfig, OpenSoraPlanPipeline
7
 
8
  __all__ = [
9
  "initialize",
10
  "VideoSysEngine",
11
- "LattePipeline",
12
- "LatteConfig",
13
- "OpenSoraPlanPipeline",
14
- "OpenSoraPlanConfig",
15
- "OpenSoraPipeline",
16
- "OpenSoraConfig",
17
- "CogVideoConfig",
18
- "CogVideoPipeline",
19
- ]
 
1
  from .core.engine import VideoSysEngine
2
  from .core.parallel_mgr import initialize
3
+ from .pipelines.cogvideox import CogVideoXConfig, CogVideoXPABConfig, CogVideoXPipeline
4
+ from .pipelines.latte import LatteConfig, LattePABConfig, LattePipeline
5
+ from .pipelines.open_sora import OpenSoraConfig, OpenSoraPABConfig, OpenSoraPipeline
6
+ from .pipelines.open_sora_plan import OpenSoraPlanConfig, OpenSoraPlanPABConfig, OpenSoraPlanPipeline
7
 
8
  __all__ = [
9
  "initialize",
10
  "VideoSysEngine",
11
+ "LattePipeline", "LatteConfig", "LattePABConfig",
12
+ "OpenSoraPlanPipeline", "OpenSoraPlanConfig", "OpenSoraPlanPABConfig",
13
+ "OpenSoraPipeline", "OpenSoraConfig", "OpenSoraPABConfig",
14
+ "CogVideoXConfig", "CogVideoXPipeline", "CogVideoXPABConfig"
15
+ ] # fmt: skip
 
 
 
 
videosys/core/engine.py CHANGED
@@ -2,7 +2,6 @@ import os
2
  from functools import partial
3
  from typing import Any, Optional
4
 
5
- import imageio
6
  import torch
7
 
8
  import videosys
@@ -120,8 +119,7 @@ class VideoSysEngine:
120
  result.get()
121
 
122
  def save_video(self, video, output_path):
123
- os.makedirs(os.path.dirname(output_path), exist_ok=True)
124
- imageio.mimwrite(output_path, video, fps=24)
125
 
126
  def shutdown(self):
127
  if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
@@ -129,4 +127,4 @@ class VideoSysEngine:
129
  torch.distributed.destroy_process_group()
130
 
131
  def __del__(self):
132
- self.shutdown()
 
2
  from functools import partial
3
  from typing import Any, Optional
4
 
 
5
  import torch
6
 
7
  import videosys
 
119
  result.get()
120
 
121
  def save_video(self, video, output_path):
122
+ return self.driver_worker.save_video(video, output_path)
 
123
 
124
  def shutdown(self):
125
  if (worker_monitor := getattr(self, "worker_monitor", None)) is not None:
 
127
  torch.distributed.destroy_process_group()
128
 
129
  def __del__(self):
130
+ self.shutdown()
videosys/core/pab_mgr.py CHANGED
@@ -1,8 +1,3 @@
1
- import random
2
-
3
- import numpy as np
4
- import torch
5
-
6
  from videosys.utils.logging import logger
7
 
8
  PAB_MANAGER = None
@@ -12,71 +7,56 @@ class PABConfig:
12
  def __init__(
13
  self,
14
  steps: int,
15
- cross_broadcast: bool,
16
- cross_threshold: list,
17
- cross_gap: int,
18
- spatial_broadcast: bool,
19
- spatial_threshold: list,
20
- spatial_gap: int,
21
- temporal_broadcast: bool,
22
- temporal_threshold: list,
23
- temporal_gap: int,
24
- diffusion_skip: bool,
25
- diffusion_timestep_respacing: list,
26
- diffusion_skip_timestep: list,
27
- mlp_skip: bool,
28
- mlp_spatial_skip_config: dict,
29
- mlp_temporal_skip_config: dict,
30
- full_broadcast: bool = False,
31
- full_threshold: list = None,
32
- full_gap: int = 1,
33
  ):
34
  self.steps = steps
35
 
36
  self.cross_broadcast = cross_broadcast
37
  self.cross_threshold = cross_threshold
38
- self.cross_gap = cross_gap
39
 
40
  self.spatial_broadcast = spatial_broadcast
41
  self.spatial_threshold = spatial_threshold
42
- self.spatial_gap = spatial_gap
43
 
44
  self.temporal_broadcast = temporal_broadcast
45
  self.temporal_threshold = temporal_threshold
46
- self.temporal_gap = temporal_gap
47
-
48
- self.diffusion_skip = diffusion_skip
49
- self.diffusion_timestep_respacing = diffusion_timestep_respacing
50
- self.diffusion_skip_timestep = diffusion_skip_timestep
51
 
52
- self.mlp_skip = mlp_skip
53
- self.mlp_spatial_skip_config = mlp_spatial_skip_config
54
- self.mlp_temporal_skip_config = mlp_temporal_skip_config
55
-
56
- self.temporal_mlp_outputs = {}
57
- self.spatial_mlp_outputs = {}
58
-
59
- self.full_broadcast = full_broadcast
60
- self.full_threshold = full_threshold
61
- self.full_gap = full_gap
62
 
63
 
64
  class PABManager:
65
  def __init__(self, config: PABConfig):
66
  self.config: PABConfig = config
67
 
68
- init_prompt = f"Init PABManager. steps: {config.steps}."
69
- init_prompt += f" spatial_broadcast: {config.spatial_broadcast}, spatial_threshold: {config.spatial_threshold}, spatial_gap: {config.spatial_gap}."
70
- init_prompt += f" temporal_broadcast: {config.temporal_broadcast}, temporal_threshold: {config.temporal_threshold}, temporal_gap: {config.temporal_gap}."
71
- init_prompt += f" cross_broadcast: {config.cross_broadcast}, cross_threshold: {config.cross_threshold}, cross_gap: {config.cross_gap}."
72
- init_prompt += f" full_broadcast: {config.full_broadcast}, full_threshold: {config.full_threshold}, full_gap: {config.full_gap}."
73
  logger.info(init_prompt)
74
 
75
  def if_broadcast_cross(self, timestep: int, count: int):
76
  if (
77
  self.config.cross_broadcast
78
  and (timestep is not None)
79
- and (count % self.config.cross_gap != 0)
80
  and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
81
  ):
82
  flag = True
@@ -89,7 +69,7 @@ class PABManager:
89
  if (
90
  self.config.temporal_broadcast
91
  and (timestep is not None)
92
- and (count % self.config.temporal_gap != 0)
93
  and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
94
  ):
95
  flag = True
@@ -102,7 +82,7 @@ class PABManager:
102
  if (
103
  self.config.spatial_broadcast
104
  and (timestep is not None)
105
- and (count % self.config.spatial_gap != 0)
106
  and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
107
  ):
108
  flag = True
@@ -111,19 +91,6 @@ class PABManager:
111
  count = (count + 1) % self.config.steps
112
  return flag, count
113
 
114
- def if_broadcast_full(self, timestep: int, count: int, block_idx: int):
115
- if (
116
- self.config.full_broadcast
117
- and (timestep is not None)
118
- and (count % self.config.full_gap != 0)
119
- and (self.config.full_threshold[0] < timestep < self.config.full_threshold[1])
120
- ):
121
- flag = True
122
- else:
123
- flag = False
124
- count = (count + 1) % self.config.steps
125
- return flag, count
126
-
127
  @staticmethod
128
  def _is_t_in_skip_config(all_timesteps, timestep, config):
129
  is_t_in_skip_config = False
@@ -139,18 +106,18 @@ class PABManager:
139
  return is_t_in_skip_config, skip_range
140
 
141
  def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
142
- if not self.config.mlp_skip:
143
  return False, None, False, None
144
 
145
  if is_temporal:
146
- cur_config = self.config.mlp_temporal_skip_config
147
  else:
148
- cur_config = self.config.mlp_spatial_skip_config
149
 
150
  is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
151
  next_flag = False
152
  if (
153
- self.config.mlp_skip
154
  and (timestep is not None)
155
  and (timestep in cur_config)
156
  and (block_idx in cur_config[timestep]["block"])
@@ -159,7 +126,7 @@ class PABManager:
159
  next_flag = True
160
  count = count + 1
161
  elif (
162
- self.config.mlp_skip
163
  and (timestep is not None)
164
  and (is_t_in_skip_config)
165
  and (block_idx in cur_config[skip_range[0]]["block"])
@@ -173,22 +140,22 @@ class PABManager:
173
 
174
  def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
175
  if is_temporal:
176
- self.config.temporal_mlp_outputs[(timestep, block_idx)] = ff_output
177
  else:
178
- self.config.spatial_mlp_outputs[(timestep, block_idx)] = ff_output
179
 
180
  def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
181
  skip_start_t = skip_range[0]
182
  if is_temporal:
183
  skip_output = (
184
- self.config.temporal_mlp_outputs.get((skip_start_t, block_idx), None)
185
- if self.config.temporal_mlp_outputs is not None
186
  else None
187
  )
188
  else:
189
  skip_output = (
190
- self.config.spatial_mlp_outputs.get((skip_start_t, block_idx), None)
191
- if self.config.spatial_mlp_outputs is not None
192
  else None
193
  )
194
 
@@ -196,9 +163,9 @@ class PABManager:
196
  if timestep == skip_range[-1]:
197
  # TODO: save memory
198
  if is_temporal:
199
- del self.config.temporal_mlp_outputs[(skip_start_t, block_idx)]
200
  else:
201
- del self.config.spatial_mlp_outputs[(skip_start_t, block_idx)]
202
  else:
203
  raise ValueError(
204
  f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
@@ -207,10 +174,10 @@ class PABManager:
207
  return skip_output
208
 
209
  def get_spatial_mlp_outputs(self):
210
- return self.config.spatial_mlp_outputs
211
 
212
  def get_temporal_mlp_outputs(self):
213
- return self.config.temporal_mlp_outputs
214
 
215
 
216
  def set_pab_manager(config: PABConfig):
@@ -250,11 +217,6 @@ def if_broadcast_spatial(timestep: int, count: int, block_idx: int):
250
  return False, count
251
  return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
252
 
253
- def if_broadcast_full(timestep: int, count: int, block_idx: int):
254
- if not enable_pab():
255
- return False, count
256
- return PAB_MANAGER.if_broadcast_full(timestep, count, block_idx)
257
-
258
 
259
  def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
260
  if not enable_pab():
@@ -268,97 +230,3 @@ def save_mlp_output(timestep: int, block_idx: int, ff_output, is_temporal=False)
268
 
269
  def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
270
  return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
271
-
272
-
273
- def get_diffusion_skip():
274
- return enable_pab() and PAB_MANAGER.config.diffusion_skip
275
-
276
-
277
- def get_diffusion_timestep_respacing():
278
- return PAB_MANAGER.config.diffusion_timestep_respacing
279
-
280
-
281
- def get_diffusion_skip_timestep():
282
- return enable_pab() and PAB_MANAGER.config.diffusion_skip_timestep
283
-
284
-
285
- def space_timesteps(time_steps, time_bins):
286
- num_bins = len(time_bins)
287
- bin_size = time_steps // num_bins
288
-
289
- result = []
290
-
291
- for i, bin_count in enumerate(time_bins):
292
- start = i * bin_size
293
- end = start + bin_size
294
-
295
- bin_steps = np.linspace(start, end, bin_count, endpoint=False, dtype=int).tolist()
296
- result.extend(bin_steps)
297
-
298
- result_tensor = torch.tensor(result, dtype=torch.int32)
299
- sorted_tensor = torch.sort(result_tensor, descending=True).values
300
-
301
- return sorted_tensor
302
-
303
-
304
- def skip_diffusion_timestep(timesteps, diffusion_skip_timestep):
305
- if isinstance(timesteps, list):
306
- # If timesteps is a list, we assume each element is a tensor
307
- timesteps_np = [t.cpu().numpy() for t in timesteps]
308
- device = timesteps[0].device
309
- else:
310
- # If timesteps is a tensor
311
- timesteps_np = timesteps.cpu().numpy()
312
- device = timesteps.device
313
-
314
- num_bins = len(diffusion_skip_timestep)
315
-
316
- if isinstance(timesteps_np, list):
317
- bin_size = len(timesteps_np) // num_bins
318
- new_timesteps = []
319
-
320
- for i in range(num_bins):
321
- bin_start = i * bin_size
322
- bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
323
- bin_timesteps = timesteps_np[bin_start:bin_end]
324
-
325
- if diffusion_skip_timestep[i] == 0:
326
- # If the bin is marked with 0, keep all timesteps
327
- new_timesteps.extend(bin_timesteps)
328
- elif diffusion_skip_timestep[i] == 1:
329
- # If the bin is marked with 1, omit the last timestep in the bin
330
- new_timesteps.extend(bin_timesteps[1:])
331
-
332
- new_timesteps_tensor = [torch.tensor(t, device=device) for t in new_timesteps]
333
- else:
334
- bin_size = len(timesteps_np) // num_bins
335
- new_timesteps = []
336
-
337
- for i in range(num_bins):
338
- bin_start = i * bin_size
339
- bin_end = (i + 1) * bin_size if i != num_bins - 1 else len(timesteps_np)
340
- bin_timesteps = timesteps_np[bin_start:bin_end]
341
-
342
- if diffusion_skip_timestep[i] == 0:
343
- # If the bin is marked with 0, keep all timesteps
344
- new_timesteps.extend(bin_timesteps)
345
- elif diffusion_skip_timestep[i] == 1:
346
- # If the bin is marked with 1, omit the last timestep in the bin
347
- new_timesteps.extend(bin_timesteps[1:])
348
- elif diffusion_skip_timestep[i] != 0:
349
- # If the bin is marked with a non-zero value, randomly omit n timesteps
350
- if len(bin_timesteps) > diffusion_skip_timestep[i]:
351
- indices_to_remove = set(random.sample(range(len(bin_timesteps)), diffusion_skip_timestep[i]))
352
- timesteps_to_keep = [
353
- timestep for idx, timestep in enumerate(bin_timesteps) if idx not in indices_to_remove
354
- ]
355
- else:
356
- timesteps_to_keep = bin_timesteps # ε¦‚ζžœbin_timestepsηš„ι•ΏεΊ¦ε°δΊŽη­‰δΊŽnοΌŒεˆ™δΈεˆ ι™€δ»»δ½•ε…ƒη΄ 
357
- new_timesteps.extend(timesteps_to_keep)
358
-
359
- new_timesteps_tensor = torch.tensor(new_timesteps, device=device)
360
-
361
- if isinstance(timesteps, list):
362
- return new_timesteps_tensor
363
- else:
364
- return new_timesteps_tensor
 
 
 
 
 
 
1
  from videosys.utils.logging import logger
2
 
3
  PAB_MANAGER = None
 
7
  def __init__(
8
  self,
9
  steps: int,
10
+ cross_broadcast: bool = False,
11
+ cross_threshold: list = None,
12
+ cross_range: int = None,
13
+ spatial_broadcast: bool = False,
14
+ spatial_threshold: list = None,
15
+ spatial_range: int = None,
16
+ temporal_broadcast: bool = False,
17
+ temporal_threshold: list = None,
18
+ temporal_range: int = None,
19
+ mlp_broadcast: bool = False,
20
+ mlp_spatial_broadcast_config: dict = None,
21
+ mlp_temporal_broadcast_config: dict = None,
 
 
 
 
 
 
22
  ):
23
  self.steps = steps
24
 
25
  self.cross_broadcast = cross_broadcast
26
  self.cross_threshold = cross_threshold
27
+ self.cross_range = cross_range
28
 
29
  self.spatial_broadcast = spatial_broadcast
30
  self.spatial_threshold = spatial_threshold
31
+ self.spatial_range = spatial_range
32
 
33
  self.temporal_broadcast = temporal_broadcast
34
  self.temporal_threshold = temporal_threshold
35
+ self.temporal_range = temporal_range
 
 
 
 
36
 
37
+ self.mlp_broadcast = mlp_broadcast
38
+ self.mlp_spatial_broadcast_config = mlp_spatial_broadcast_config
39
+ self.mlp_temporal_broadcast_config = mlp_temporal_broadcast_config
40
+ self.mlp_temporal_outputs = {}
41
+ self.mlp_spatial_outputs = {}
 
 
 
 
 
42
 
43
 
44
  class PABManager:
45
  def __init__(self, config: PABConfig):
46
  self.config: PABConfig = config
47
 
48
+ init_prompt = f"Init Pyramid Attention Broadcast. steps: {config.steps}."
49
+ init_prompt += f" spatial broadcast: {config.spatial_broadcast}, spatial range: {config.spatial_range}, spatial threshold: {config.spatial_threshold}."
50
+ init_prompt += f" temporal broadcast: {config.temporal_broadcast}, temporal range: {config.temporal_range}, temporal_threshold: {config.temporal_threshold}."
51
+ init_prompt += f" cross broadcast: {config.cross_broadcast}, cross range: {config.cross_range}, cross threshold: {config.cross_threshold}."
52
+ init_prompt += f" mlp broadcast: {config.mlp_broadcast}."
53
  logger.info(init_prompt)
54
 
55
  def if_broadcast_cross(self, timestep: int, count: int):
56
  if (
57
  self.config.cross_broadcast
58
  and (timestep is not None)
59
+ and (count % self.config.cross_range != 0)
60
  and (self.config.cross_threshold[0] < timestep < self.config.cross_threshold[1])
61
  ):
62
  flag = True
 
69
  if (
70
  self.config.temporal_broadcast
71
  and (timestep is not None)
72
+ and (count % self.config.temporal_range != 0)
73
  and (self.config.temporal_threshold[0] < timestep < self.config.temporal_threshold[1])
74
  ):
75
  flag = True
 
82
  if (
83
  self.config.spatial_broadcast
84
  and (timestep is not None)
85
+ and (count % self.config.spatial_range != 0)
86
  and (self.config.spatial_threshold[0] < timestep < self.config.spatial_threshold[1])
87
  ):
88
  flag = True
 
91
  count = (count + 1) % self.config.steps
92
  return flag, count
93
 
 
 
 
 
 
 
 
 
 
 
 
 
 
94
  @staticmethod
95
  def _is_t_in_skip_config(all_timesteps, timestep, config):
96
  is_t_in_skip_config = False
 
106
  return is_t_in_skip_config, skip_range
107
 
108
  def if_skip_mlp(self, timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
109
+ if not self.config.mlp_broadcast:
110
  return False, None, False, None
111
 
112
  if is_temporal:
113
+ cur_config = self.config.mlp_temporal_broadcast_config
114
  else:
115
+ cur_config = self.config.mlp_spatial_broadcast_config
116
 
117
  is_t_in_skip_config, skip_range = self._is_t_in_skip_config(all_timesteps, timestep, cur_config)
118
  next_flag = False
119
  if (
120
+ self.config.mlp_broadcast
121
  and (timestep is not None)
122
  and (timestep in cur_config)
123
  and (block_idx in cur_config[timestep]["block"])
 
126
  next_flag = True
127
  count = count + 1
128
  elif (
129
+ self.config.mlp_broadcast
130
  and (timestep is not None)
131
  and (is_t_in_skip_config)
132
  and (block_idx in cur_config[skip_range[0]]["block"])
 
140
 
141
  def save_skip_output(self, timestep, block_idx, ff_output, is_temporal=False):
142
  if is_temporal:
143
+ self.config.mlp_temporal_outputs[(timestep, block_idx)] = ff_output
144
  else:
145
+ self.config.mlp_spatial_outputs[(timestep, block_idx)] = ff_output
146
 
147
  def get_mlp_output(self, skip_range, timestep, block_idx, is_temporal=False):
148
  skip_start_t = skip_range[0]
149
  if is_temporal:
150
  skip_output = (
151
+ self.config.mlp_temporal_outputs.get((skip_start_t, block_idx), None)
152
+ if self.config.mlp_temporal_outputs is not None
153
  else None
154
  )
155
  else:
156
  skip_output = (
157
+ self.config.mlp_spatial_outputs.get((skip_start_t, block_idx), None)
158
+ if self.config.mlp_spatial_outputs is not None
159
  else None
160
  )
161
 
 
163
  if timestep == skip_range[-1]:
164
  # TODO: save memory
165
  if is_temporal:
166
+ del self.config.mlp_temporal_outputs[(skip_start_t, block_idx)]
167
  else:
168
+ del self.config.mlp_spatial_outputs[(skip_start_t, block_idx)]
169
  else:
170
  raise ValueError(
171
  f"No stored MLP output found | t {timestep} |[{skip_range[0]}, {skip_range[-1]}] | block {block_idx}"
 
174
  return skip_output
175
 
176
  def get_spatial_mlp_outputs(self):
177
+ return self.config.mlp_spatial_outputs
178
 
179
  def get_temporal_mlp_outputs(self):
180
+ return self.config.mlp_temporal_outputs
181
 
182
 
183
  def set_pab_manager(config: PABConfig):
 
217
  return False, count
218
  return PAB_MANAGER.if_broadcast_spatial(timestep, count, block_idx)
219
 
 
 
 
 
 
220
 
221
  def if_broadcast_mlp(timestep: int, count: int, block_idx: int, all_timesteps, is_temporal=False):
222
  if not enable_pab():
 
230
 
231
  def get_mlp_output(skip_range, timestep, block_idx: int, is_temporal=False):
232
  return PAB_MANAGER.get_mlp_output(skip_range, timestep, block_idx, is_temporal)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/datasets/dataloader.py DELETED
@@ -1,94 +0,0 @@
1
- import random
2
- from typing import Iterator, Optional
3
-
4
- import numpy as np
5
- import torch
6
- from torch.utils.data import DataLoader, Dataset, DistributedSampler
7
- from torch.utils.data.distributed import DistributedSampler
8
-
9
- from videosys.core.parallel_mgr import ParallelManager
10
-
11
-
12
- class StatefulDistributedSampler(DistributedSampler):
13
- def __init__(
14
- self,
15
- dataset: Dataset,
16
- num_replicas: Optional[int] = None,
17
- rank: Optional[int] = None,
18
- shuffle: bool = True,
19
- seed: int = 0,
20
- drop_last: bool = False,
21
- ) -> None:
22
- super().__init__(dataset, num_replicas, rank, shuffle, seed, drop_last)
23
- self.start_index: int = 0
24
-
25
- def __iter__(self) -> Iterator:
26
- iterator = super().__iter__()
27
- indices = list(iterator)
28
- indices = indices[self.start_index :]
29
- return iter(indices)
30
-
31
- def __len__(self) -> int:
32
- return self.num_samples - self.start_index
33
-
34
- def set_start_index(self, start_index: int) -> None:
35
- self.start_index = start_index
36
-
37
-
38
- def prepare_dataloader(
39
- dataset,
40
- batch_size,
41
- shuffle=False,
42
- seed=1024,
43
- drop_last=False,
44
- pin_memory=False,
45
- num_workers=0,
46
- pg_manager: Optional[ParallelManager] = None,
47
- **kwargs,
48
- ):
49
- r"""
50
- Prepare a dataloader for distributed training. The dataloader will be wrapped by
51
- `torch.utils.data.DataLoader` and `StatefulDistributedSampler`.
52
-
53
-
54
- Args:
55
- dataset (`torch.utils.data.Dataset`): The dataset to be loaded.
56
- shuffle (bool, optional): Whether to shuffle the dataset. Defaults to False.
57
- seed (int, optional): Random worker seed for sampling, defaults to 1024.
58
- add_sampler: Whether to add ``DistributedDataParallelSampler`` to the dataset. Defaults to True.
59
- drop_last (bool, optional): Set to True to drop the last incomplete batch, if the dataset size
60
- is not divisible by the batch size. If False and the size of dataset is not divisible by
61
- the batch size, then the last batch will be smaller, defaults to False.
62
- pin_memory (bool, optional): Whether to pin memory address in CPU memory. Defaults to False.
63
- num_workers (int, optional): Number of worker threads for this dataloader. Defaults to 0.
64
- kwargs (dict): optional parameters for ``torch.utils.data.DataLoader``, more details could be found in
65
- `DataLoader <https://pytorch.org/docs/stable/_modules/torch/utils/data/dataloader.html#DataLoader>`_.
66
-
67
- Returns:
68
- :class:`torch.utils.data.DataLoader`: A DataLoader used for training or testing.
69
- """
70
- _kwargs = kwargs.copy()
71
- sampler = StatefulDistributedSampler(
72
- dataset,
73
- num_replicas=pg_manager.size(pg_manager.dp_axis),
74
- rank=pg_manager.coordinate(pg_manager.dp_axis),
75
- shuffle=shuffle,
76
- )
77
-
78
- # Deterministic dataloader
79
- def seed_worker(worker_id):
80
- worker_seed = seed
81
- np.random.seed(worker_seed)
82
- torch.manual_seed(worker_seed)
83
- random.seed(worker_seed)
84
-
85
- return DataLoader(
86
- dataset,
87
- batch_size=batch_size,
88
- sampler=sampler,
89
- worker_init_fn=seed_worker,
90
- drop_last=drop_last,
91
- pin_memory=pin_memory,
92
- num_workers=num_workers,
93
- **_kwargs,
94
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/datasets/image_transform.py DELETED
@@ -1,42 +0,0 @@
1
- # Adapted from DiT
2
-
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- # --------------------------------------------------------
6
- # References:
7
- # DiT: https://github.com/facebookresearch/DiT
8
- # --------------------------------------------------------
9
-
10
-
11
- import numpy as np
12
- import torchvision.transforms as transforms
13
- from PIL import Image
14
-
15
-
16
- def center_crop_arr(pil_image, image_size):
17
- """
18
- Center cropping implementation from ADM.
19
- https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
20
- """
21
- while min(*pil_image.size) >= 2 * image_size:
22
- pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
23
-
24
- scale = image_size / min(*pil_image.size)
25
- pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
26
-
27
- arr = np.array(pil_image)
28
- crop_y = (arr.shape[0] - image_size) // 2
29
- crop_x = (arr.shape[1] - image_size) // 2
30
- return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
31
-
32
-
33
- def get_transforms_image(image_size=256):
34
- transform = transforms.Compose(
35
- [
36
- transforms.Lambda(lambda pil_image: center_crop_arr(pil_image, image_size)),
37
- transforms.RandomHorizontalFlip(),
38
- transforms.ToTensor(),
39
- transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
40
- ]
41
- )
42
- return transform
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/datasets/video_transform.py DELETED
@@ -1,441 +0,0 @@
1
- # Adapted from OpenSora and Latte
2
-
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- # --------------------------------------------------------
6
- # References:
7
- # OpenSora: https://github.com/hpcaitech/Open-Sora
8
- # Latte: https://github.com/Vchitect/Latte
9
- # --------------------------------------------------------
10
-
11
- import numbers
12
- import random
13
-
14
- import numpy as np
15
- import torch
16
- from PIL import Image
17
-
18
-
19
- def _is_tensor_video_clip(clip):
20
- if not torch.is_tensor(clip):
21
- raise TypeError("clip should be Tensor. Got %s" % type(clip))
22
-
23
- if not clip.ndimension() == 4:
24
- raise ValueError("clip should be 4D. Got %dD" % clip.dim())
25
-
26
- return True
27
-
28
-
29
- def center_crop_arr(pil_image, image_size):
30
- """
31
- Center cropping implementation from ADM.
32
- https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
33
- """
34
- while min(*pil_image.size) >= 2 * image_size:
35
- pil_image = pil_image.resize(tuple(x // 2 for x in pil_image.size), resample=Image.BOX)
36
-
37
- scale = image_size / min(*pil_image.size)
38
- pil_image = pil_image.resize(tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC)
39
-
40
- arr = np.array(pil_image)
41
- crop_y = (arr.shape[0] - image_size) // 2
42
- crop_x = (arr.shape[1] - image_size) // 2
43
- return Image.fromarray(arr[crop_y : crop_y + image_size, crop_x : crop_x + image_size])
44
-
45
-
46
- def crop(clip, i, j, h, w):
47
- """
48
- Args:
49
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
50
- """
51
- if len(clip.size()) != 4:
52
- raise ValueError("clip should be a 4D tensor")
53
- return clip[..., i : i + h, j : j + w]
54
-
55
-
56
- def resize(clip, target_size, interpolation_mode):
57
- if len(target_size) != 2:
58
- raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
59
- return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=False)
60
-
61
-
62
- def resize_scale(clip, target_size, interpolation_mode):
63
- if len(target_size) != 2:
64
- raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
65
- H, W = clip.size(-2), clip.size(-1)
66
- scale_ = target_size[0] / min(H, W)
67
- return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=False)
68
-
69
-
70
- def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
71
- """
72
- Do spatial cropping and resizing to the video clip
73
- Args:
74
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
75
- i (int): i in (i,j) i.e coordinates of the upper left corner.
76
- j (int): j in (i,j) i.e coordinates of the upper left corner.
77
- h (int): Height of the cropped region.
78
- w (int): Width of the cropped region.
79
- size (tuple(int, int)): height and width of resized clip
80
- Returns:
81
- clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
82
- """
83
- if not _is_tensor_video_clip(clip):
84
- raise ValueError("clip should be a 4D torch.tensor")
85
- clip = crop(clip, i, j, h, w)
86
- clip = resize(clip, size, interpolation_mode)
87
- return clip
88
-
89
-
90
- def center_crop(clip, crop_size):
91
- if not _is_tensor_video_clip(clip):
92
- raise ValueError("clip should be a 4D torch.tensor")
93
- h, w = clip.size(-2), clip.size(-1)
94
- th, tw = crop_size
95
- if h < th or w < tw:
96
- raise ValueError("height and width must be no smaller than crop_size")
97
-
98
- i = int(round((h - th) / 2.0))
99
- j = int(round((w - tw) / 2.0))
100
- return crop(clip, i, j, th, tw)
101
-
102
-
103
- def center_crop_using_short_edge(clip):
104
- if not _is_tensor_video_clip(clip):
105
- raise ValueError("clip should be a 4D torch.tensor")
106
- h, w = clip.size(-2), clip.size(-1)
107
- if h < w:
108
- th, tw = h, h
109
- i = 0
110
- j = int(round((w - tw) / 2.0))
111
- else:
112
- th, tw = w, w
113
- i = int(round((h - th) / 2.0))
114
- j = 0
115
- return crop(clip, i, j, th, tw)
116
-
117
-
118
- def random_shift_crop(clip):
119
- """
120
- Slide along the long edge, with the short edge as crop size
121
- """
122
- if not _is_tensor_video_clip(clip):
123
- raise ValueError("clip should be a 4D torch.tensor")
124
- h, w = clip.size(-2), clip.size(-1)
125
-
126
- if h <= w:
127
- short_edge = h
128
- else:
129
- short_edge = w
130
-
131
- th, tw = short_edge, short_edge
132
-
133
- i = torch.randint(0, h - th + 1, size=(1,)).item()
134
- j = torch.randint(0, w - tw + 1, size=(1,)).item()
135
- return crop(clip, i, j, th, tw)
136
-
137
-
138
- def to_tensor(clip):
139
- """
140
- Convert tensor data type from uint8 to float, divide value by 255.0 and
141
- permute the dimensions of clip tensor
142
- Args:
143
- clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
144
- Return:
145
- clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
146
- """
147
- _is_tensor_video_clip(clip)
148
- if not clip.dtype == torch.uint8:
149
- raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
150
- # return clip.float().permute(3, 0, 1, 2) / 255.0
151
- return clip.float() / 255.0
152
-
153
-
154
- def normalize(clip, mean, std, inplace=False):
155
- """
156
- Args:
157
- clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
158
- mean (tuple): pixel RGB mean. Size is (3)
159
- std (tuple): pixel standard deviation. Size is (3)
160
- Returns:
161
- normalized clip (torch.tensor): Size is (T, C, H, W)
162
- """
163
- if not _is_tensor_video_clip(clip):
164
- raise ValueError("clip should be a 4D torch.tensor")
165
- if not inplace:
166
- clip = clip.clone()
167
- mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
168
- # print(mean)
169
- std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
170
- clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
171
- return clip
172
-
173
-
174
- def hflip(clip):
175
- """
176
- Args:
177
- clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
178
- Returns:
179
- flipped clip (torch.tensor): Size is (T, C, H, W)
180
- """
181
- if not _is_tensor_video_clip(clip):
182
- raise ValueError("clip should be a 4D torch.tensor")
183
- return clip.flip(-1)
184
-
185
-
186
- class RandomCropVideo:
187
- def __init__(self, size):
188
- if isinstance(size, numbers.Number):
189
- self.size = (int(size), int(size))
190
- else:
191
- self.size = size
192
-
193
- def __call__(self, clip):
194
- """
195
- Args:
196
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
197
- Returns:
198
- torch.tensor: randomly cropped video clip.
199
- size is (T, C, OH, OW)
200
- """
201
- i, j, h, w = self.get_params(clip)
202
- return crop(clip, i, j, h, w)
203
-
204
- def get_params(self, clip):
205
- h, w = clip.shape[-2:]
206
- th, tw = self.size
207
-
208
- if h < th or w < tw:
209
- raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
210
-
211
- if w == tw and h == th:
212
- return 0, 0, h, w
213
-
214
- i = torch.randint(0, h - th + 1, size=(1,)).item()
215
- j = torch.randint(0, w - tw + 1, size=(1,)).item()
216
-
217
- return i, j, th, tw
218
-
219
- def __repr__(self) -> str:
220
- return f"{self.__class__.__name__}(size={self.size})"
221
-
222
-
223
- class CenterCropResizeVideo:
224
- """
225
- First use the short side for cropping length,
226
- center crop video, then resize to the specified size
227
- """
228
-
229
- def __init__(
230
- self,
231
- size,
232
- interpolation_mode="bilinear",
233
- ):
234
- if isinstance(size, tuple):
235
- if len(size) != 2:
236
- raise ValueError(f"size should be tuple (height, width), instead got {size}")
237
- self.size = size
238
- else:
239
- self.size = (size, size)
240
-
241
- self.interpolation_mode = interpolation_mode
242
-
243
- def __call__(self, clip):
244
- """
245
- Args:
246
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
247
- Returns:
248
- torch.tensor: scale resized / center cropped video clip.
249
- size is (T, C, crop_size, crop_size)
250
- """
251
- clip_center_crop = center_crop_using_short_edge(clip)
252
- clip_center_crop_resize = resize(
253
- clip_center_crop, target_size=self.size, interpolation_mode=self.interpolation_mode
254
- )
255
- return clip_center_crop_resize
256
-
257
- def __repr__(self) -> str:
258
- return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
259
-
260
-
261
- class UCFCenterCropVideo:
262
- """
263
- First scale to the specified size in equal proportion to the short edge,
264
- then center cropping
265
- """
266
-
267
- def __init__(
268
- self,
269
- size,
270
- interpolation_mode="bilinear",
271
- ):
272
- if isinstance(size, tuple):
273
- if len(size) != 2:
274
- raise ValueError(f"size should be tuple (height, width), instead got {size}")
275
- self.size = size
276
- else:
277
- self.size = (size, size)
278
-
279
- self.interpolation_mode = interpolation_mode
280
-
281
- def __call__(self, clip):
282
- """
283
- Args:
284
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
285
- Returns:
286
- torch.tensor: scale resized / center cropped video clip.
287
- size is (T, C, crop_size, crop_size)
288
- """
289
- clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
290
- clip_center_crop = center_crop(clip_resize, self.size)
291
- return clip_center_crop
292
-
293
- def __repr__(self) -> str:
294
- return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
295
-
296
-
297
- class KineticsRandomCropResizeVideo:
298
- """
299
- Slide along the long edge, with the short edge as crop size. And resie to the desired size.
300
- """
301
-
302
- def __init__(
303
- self,
304
- size,
305
- interpolation_mode="bilinear",
306
- ):
307
- if isinstance(size, tuple):
308
- if len(size) != 2:
309
- raise ValueError(f"size should be tuple (height, width), instead got {size}")
310
- self.size = size
311
- else:
312
- self.size = (size, size)
313
-
314
- self.interpolation_mode = interpolation_mode
315
-
316
- def __call__(self, clip):
317
- clip_random_crop = random_shift_crop(clip)
318
- clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
319
- return clip_resize
320
-
321
-
322
- class CenterCropVideo:
323
- def __init__(
324
- self,
325
- size,
326
- interpolation_mode="bilinear",
327
- ):
328
- if isinstance(size, tuple):
329
- if len(size) != 2:
330
- raise ValueError(f"size should be tuple (height, width), instead got {size}")
331
- self.size = size
332
- else:
333
- self.size = (size, size)
334
-
335
- self.interpolation_mode = interpolation_mode
336
-
337
- def __call__(self, clip):
338
- """
339
- Args:
340
- clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
341
- Returns:
342
- torch.tensor: center cropped video clip.
343
- size is (T, C, crop_size, crop_size)
344
- """
345
- clip_center_crop = center_crop(clip, self.size)
346
- return clip_center_crop
347
-
348
- def __repr__(self) -> str:
349
- return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
350
-
351
-
352
- class NormalizeVideo:
353
- """
354
- Normalize the video clip by mean subtraction and division by standard deviation
355
- Args:
356
- mean (3-tuple): pixel RGB mean
357
- std (3-tuple): pixel RGB standard deviation
358
- inplace (boolean): whether do in-place normalization
359
- """
360
-
361
- def __init__(self, mean, std, inplace=False):
362
- self.mean = mean
363
- self.std = std
364
- self.inplace = inplace
365
-
366
- def __call__(self, clip):
367
- """
368
- Args:
369
- clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
370
- """
371
- return normalize(clip, self.mean, self.std, self.inplace)
372
-
373
- def __repr__(self) -> str:
374
- return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
375
-
376
-
377
- class ToTensorVideo:
378
- """
379
- Convert tensor data type from uint8 to float, divide value by 255.0 and
380
- permute the dimensions of clip tensor
381
- """
382
-
383
- def __init__(self):
384
- pass
385
-
386
- def __call__(self, clip):
387
- """
388
- Args:
389
- clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
390
- Return:
391
- clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
392
- """
393
- return to_tensor(clip)
394
-
395
- def __repr__(self) -> str:
396
- return self.__class__.__name__
397
-
398
-
399
- class RandomHorizontalFlipVideo:
400
- """
401
- Flip the video clip along the horizontal direction with a given probability
402
- Args:
403
- p (float): probability of the clip being flipped. Default value is 0.5
404
- """
405
-
406
- def __init__(self, p=0.5):
407
- self.p = p
408
-
409
- def __call__(self, clip):
410
- """
411
- Args:
412
- clip (torch.tensor): Size is (T, C, H, W)
413
- Return:
414
- clip (torch.tensor): Size is (T, C, H, W)
415
- """
416
- if random.random() < self.p:
417
- clip = hflip(clip)
418
- return clip
419
-
420
- def __repr__(self) -> str:
421
- return f"{self.__class__.__name__}(p={self.p})"
422
-
423
-
424
- # ------------------------------------------------------------
425
- # --------------------- Sampling ---------------------------
426
- # ------------------------------------------------------------
427
- class TemporalRandomCrop(object):
428
- """Temporally crop the given frame indices at a random location.
429
-
430
- Args:
431
- size (int): Desired length of frames will be seen in the model.
432
- """
433
-
434
- def __init__(self, size):
435
- self.size = size
436
-
437
- def __call__(self, total_frames):
438
- rand_end = max(0, total_frames - self.size - 1)
439
- begin_index = random.randint(0, rand_end)
440
- end_index = min(begin_index + self.size, total_frames)
441
- return begin_index, end_index
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/diffusion/__init__.py DELETED
@@ -1,41 +0,0 @@
1
- # Modified from OpenAI's diffusion repos and Meta DiT
2
- # DiT: https://github.com/facebookresearch/DiT/tree/main
3
- # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
4
- # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
5
- # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
6
-
7
- from . import gaussian_diffusion as gd
8
- from .respace import SpacedDiffusion, space_timesteps
9
-
10
-
11
- def create_diffusion(
12
- timestep_respacing,
13
- noise_schedule="linear",
14
- use_kl=False,
15
- sigma_small=False,
16
- predict_xstart=False,
17
- learn_sigma=True,
18
- rescale_learned_sigmas=False,
19
- diffusion_steps=1000,
20
- ):
21
- betas = gd.get_named_beta_schedule(noise_schedule, diffusion_steps)
22
- if use_kl:
23
- loss_type = gd.LossType.RESCALED_KL
24
- elif rescale_learned_sigmas:
25
- loss_type = gd.LossType.RESCALED_MSE
26
- else:
27
- loss_type = gd.LossType.MSE
28
- if timestep_respacing is None or timestep_respacing == "":
29
- timestep_respacing = [diffusion_steps]
30
- return SpacedDiffusion(
31
- use_timesteps=space_timesteps(diffusion_steps, timestep_respacing),
32
- betas=betas,
33
- model_mean_type=(gd.ModelMeanType.EPSILON if not predict_xstart else gd.ModelMeanType.START_X),
34
- model_var_type=(
35
- (gd.ModelVarType.FIXED_LARGE if not sigma_small else gd.ModelVarType.FIXED_SMALL)
36
- if not learn_sigma
37
- else gd.ModelVarType.LEARNED_RANGE
38
- ),
39
- loss_type=loss_type
40
- # rescale_timesteps=rescale_timesteps,
41
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/diffusion/diffusion_utils.py DELETED
@@ -1,79 +0,0 @@
1
- # Modified from OpenAI's diffusion repos
2
- # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
- # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
- # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
-
6
- import numpy as np
7
- import torch as th
8
-
9
-
10
- def normal_kl(mean1, logvar1, mean2, logvar2):
11
- """
12
- Compute the KL divergence between two gaussians.
13
- Shapes are automatically broadcasted, so batches can be compared to
14
- scalars, among other use cases.
15
- """
16
- tensor = None
17
- for obj in (mean1, logvar1, mean2, logvar2):
18
- if isinstance(obj, th.Tensor):
19
- tensor = obj
20
- break
21
- assert tensor is not None, "at least one argument must be a Tensor"
22
-
23
- # Force variances to be Tensors. Broadcasting helps convert scalars to
24
- # Tensors, but it does not work for th.exp().
25
- logvar1, logvar2 = [x if isinstance(x, th.Tensor) else th.tensor(x).to(tensor) for x in (logvar1, logvar2)]
26
-
27
- return 0.5 * (-1.0 + logvar2 - logvar1 + th.exp(logvar1 - logvar2) + ((mean1 - mean2) ** 2) * th.exp(-logvar2))
28
-
29
-
30
- def approx_standard_normal_cdf(x):
31
- """
32
- A fast approximation of the cumulative distribution function of the
33
- standard normal.
34
- """
35
- return 0.5 * (1.0 + th.tanh(np.sqrt(2.0 / np.pi) * (x + 0.044715 * th.pow(x, 3))))
36
-
37
-
38
- def continuous_gaussian_log_likelihood(x, *, means, log_scales):
39
- """
40
- Compute the log-likelihood of a continuous Gaussian distribution.
41
- :param x: the targets
42
- :param means: the Gaussian mean Tensor.
43
- :param log_scales: the Gaussian log stddev Tensor.
44
- :return: a tensor like x of log probabilities (in nats).
45
- """
46
- centered_x = x - means
47
- inv_stdv = th.exp(-log_scales)
48
- normalized_x = centered_x * inv_stdv
49
- log_probs = th.distributions.Normal(th.zeros_like(x), th.ones_like(x)).log_prob(normalized_x)
50
- return log_probs
51
-
52
-
53
- def discretized_gaussian_log_likelihood(x, *, means, log_scales):
54
- """
55
- Compute the log-likelihood of a Gaussian distribution discretizing to a
56
- given image.
57
- :param x: the target images. It is assumed that this was uint8 values,
58
- rescaled to the range [-1, 1].
59
- :param means: the Gaussian mean Tensor.
60
- :param log_scales: the Gaussian log stddev Tensor.
61
- :return: a tensor like x of log probabilities (in nats).
62
- """
63
- assert x.shape == means.shape == log_scales.shape
64
- centered_x = x - means
65
- inv_stdv = th.exp(-log_scales)
66
- plus_in = inv_stdv * (centered_x + 1.0 / 255.0)
67
- cdf_plus = approx_standard_normal_cdf(plus_in)
68
- min_in = inv_stdv * (centered_x - 1.0 / 255.0)
69
- cdf_min = approx_standard_normal_cdf(min_in)
70
- log_cdf_plus = th.log(cdf_plus.clamp(min=1e-12))
71
- log_one_minus_cdf_min = th.log((1.0 - cdf_min).clamp(min=1e-12))
72
- cdf_delta = cdf_plus - cdf_min
73
- log_probs = th.where(
74
- x < -0.999,
75
- log_cdf_plus,
76
- th.where(x > 0.999, log_one_minus_cdf_min, th.log(cdf_delta.clamp(min=1e-12))),
77
- )
78
- assert log_probs.shape == x.shape
79
- return log_probs
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/diffusion/gaussian_diffusion.py DELETED
@@ -1,829 +0,0 @@
1
- # Modified from OpenAI's diffusion repos
2
- # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
- # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
- # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
-
6
-
7
- import enum
8
- import math
9
-
10
- import numpy as np
11
- import torch as th
12
-
13
- from .diffusion_utils import discretized_gaussian_log_likelihood, normal_kl
14
-
15
-
16
- def mean_flat(tensor):
17
- """
18
- Take the mean over all non-batch dimensions.
19
- """
20
- return tensor.mean(dim=list(range(1, len(tensor.shape))))
21
-
22
-
23
- class ModelMeanType(enum.Enum):
24
- """
25
- Which type of output the model predicts.
26
- """
27
-
28
- PREVIOUS_X = enum.auto() # the model predicts x_{t-1}
29
- START_X = enum.auto() # the model predicts x_0
30
- EPSILON = enum.auto() # the model predicts epsilon
31
-
32
-
33
- class ModelVarType(enum.Enum):
34
- """
35
- What is used as the model's output variance.
36
- The LEARNED_RANGE option has been added to allow the model to predict
37
- values between FIXED_SMALL and FIXED_LARGE, making its job easier.
38
- """
39
-
40
- LEARNED = enum.auto()
41
- FIXED_SMALL = enum.auto()
42
- FIXED_LARGE = enum.auto()
43
- LEARNED_RANGE = enum.auto()
44
-
45
-
46
- class LossType(enum.Enum):
47
- MSE = enum.auto() # use raw MSE loss (and KL when learning variances)
48
- RESCALED_MSE = enum.auto() # use raw MSE loss (with RESCALED_KL when learning variances)
49
- KL = enum.auto() # use the variational lower-bound
50
- RESCALED_KL = enum.auto() # like KL, but rescale to estimate the full VLB
51
-
52
- def is_vb(self):
53
- return self == LossType.KL or self == LossType.RESCALED_KL
54
-
55
-
56
- def _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, warmup_frac):
57
- betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
58
- warmup_time = int(num_diffusion_timesteps * warmup_frac)
59
- betas[:warmup_time] = np.linspace(beta_start, beta_end, warmup_time, dtype=np.float64)
60
- return betas
61
-
62
-
63
- def get_beta_schedule(beta_schedule, *, beta_start, beta_end, num_diffusion_timesteps):
64
- """
65
- This is the deprecated API for creating beta schedules.
66
- See get_named_beta_schedule() for the new library of schedules.
67
- """
68
- if beta_schedule == "quad":
69
- betas = (
70
- np.linspace(
71
- beta_start**0.5,
72
- beta_end**0.5,
73
- num_diffusion_timesteps,
74
- dtype=np.float64,
75
- )
76
- ** 2
77
- )
78
- elif beta_schedule == "linear":
79
- betas = np.linspace(beta_start, beta_end, num_diffusion_timesteps, dtype=np.float64)
80
- elif beta_schedule == "warmup10":
81
- betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.1)
82
- elif beta_schedule == "warmup50":
83
- betas = _warmup_beta(beta_start, beta_end, num_diffusion_timesteps, 0.5)
84
- elif beta_schedule == "const":
85
- betas = beta_end * np.ones(num_diffusion_timesteps, dtype=np.float64)
86
- elif beta_schedule == "jsd": # 1/T, 1/(T-1), 1/(T-2), ..., 1
87
- betas = 1.0 / np.linspace(num_diffusion_timesteps, 1, num_diffusion_timesteps, dtype=np.float64)
88
- else:
89
- raise NotImplementedError(beta_schedule)
90
- assert betas.shape == (num_diffusion_timesteps,)
91
- return betas
92
-
93
-
94
- def get_named_beta_schedule(schedule_name, num_diffusion_timesteps):
95
- """
96
- Get a pre-defined beta schedule for the given name.
97
- The beta schedule library consists of beta schedules which remain similar
98
- in the limit of num_diffusion_timesteps.
99
- Beta schedules may be added, but should not be removed or changed once
100
- they are committed to maintain backwards compatibility.
101
- """
102
- if schedule_name == "linear":
103
- # Linear schedule from Ho et al, extended to work for any number of
104
- # diffusion steps.
105
- scale = 1000 / num_diffusion_timesteps
106
- return get_beta_schedule(
107
- "linear",
108
- beta_start=scale * 0.0001,
109
- beta_end=scale * 0.02,
110
- num_diffusion_timesteps=num_diffusion_timesteps,
111
- )
112
- elif schedule_name == "squaredcos_cap_v2":
113
- return betas_for_alpha_bar(
114
- num_diffusion_timesteps,
115
- lambda t: math.cos((t + 0.008) / 1.008 * math.pi / 2) ** 2,
116
- )
117
- else:
118
- raise NotImplementedError(f"unknown beta schedule: {schedule_name}")
119
-
120
-
121
- def betas_for_alpha_bar(num_diffusion_timesteps, alpha_bar, max_beta=0.999):
122
- """
123
- Create a beta schedule that discretizes the given alpha_t_bar function,
124
- which defines the cumulative product of (1-beta) over time from t = [0,1].
125
- :param num_diffusion_timesteps: the number of betas to produce.
126
- :param alpha_bar: a lambda that takes an argument t from 0 to 1 and
127
- produces the cumulative product of (1-beta) up to that
128
- part of the diffusion process.
129
- :param max_beta: the maximum beta to use; use values lower than 1 to
130
- prevent singularities.
131
- """
132
- betas = []
133
- for i in range(num_diffusion_timesteps):
134
- t1 = i / num_diffusion_timesteps
135
- t2 = (i + 1) / num_diffusion_timesteps
136
- betas.append(min(1 - alpha_bar(t2) / alpha_bar(t1), max_beta))
137
- return np.array(betas)
138
-
139
-
140
- class GaussianDiffusion:
141
- """
142
- Utilities for training and sampling diffusion models.
143
- Original ported from this codebase:
144
- https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/diffusion_utils_2.py#L42
145
- :param betas: a 1-D numpy array of betas for each diffusion timestep,
146
- starting at T and going to 1.
147
- """
148
-
149
- def __init__(self, *, betas, model_mean_type, model_var_type, loss_type):
150
- self.model_mean_type = model_mean_type
151
- self.model_var_type = model_var_type
152
- self.loss_type = loss_type
153
-
154
- # Use float64 for accuracy.
155
- betas = np.array(betas, dtype=np.float64)
156
- self.betas = betas
157
- assert len(betas.shape) == 1, "betas must be 1-D"
158
- assert (betas > 0).all() and (betas <= 1).all()
159
-
160
- self.num_timesteps = int(betas.shape[0])
161
-
162
- alphas = 1.0 - betas
163
- self.alphas_cumprod = np.cumprod(alphas, axis=0)
164
- self.alphas_cumprod_prev = np.append(1.0, self.alphas_cumprod[:-1])
165
- self.alphas_cumprod_next = np.append(self.alphas_cumprod[1:], 0.0)
166
- assert self.alphas_cumprod_prev.shape == (self.num_timesteps,)
167
-
168
- # calculations for diffusion q(x_t | x_{t-1}) and others
169
- self.sqrt_alphas_cumprod = np.sqrt(self.alphas_cumprod)
170
- self.sqrt_one_minus_alphas_cumprod = np.sqrt(1.0 - self.alphas_cumprod)
171
- self.log_one_minus_alphas_cumprod = np.log(1.0 - self.alphas_cumprod)
172
- self.sqrt_recip_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod)
173
- self.sqrt_recipm1_alphas_cumprod = np.sqrt(1.0 / self.alphas_cumprod - 1)
174
-
175
- # calculations for posterior q(x_{t-1} | x_t, x_0)
176
- self.posterior_variance = betas * (1.0 - self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
177
- # below: log calculation clipped because the posterior variance is 0 at the beginning of the diffusion chain
178
- self.posterior_log_variance_clipped = (
179
- np.log(np.append(self.posterior_variance[1], self.posterior_variance[1:]))
180
- if len(self.posterior_variance) > 1
181
- else np.array([])
182
- )
183
-
184
- self.posterior_mean_coef1 = betas * np.sqrt(self.alphas_cumprod_prev) / (1.0 - self.alphas_cumprod)
185
- self.posterior_mean_coef2 = (1.0 - self.alphas_cumprod_prev) * np.sqrt(alphas) / (1.0 - self.alphas_cumprod)
186
-
187
- def q_mean_variance(self, x_start, t):
188
- """
189
- Get the distribution q(x_t | x_0).
190
- :param x_start: the [N x C x ...] tensor of noiseless inputs.
191
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
192
- :return: A tuple (mean, variance, log_variance), all of x_start's shape.
193
- """
194
- mean = _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
195
- variance = _extract_into_tensor(1.0 - self.alphas_cumprod, t, x_start.shape)
196
- log_variance = _extract_into_tensor(self.log_one_minus_alphas_cumprod, t, x_start.shape)
197
- return mean, variance, log_variance
198
-
199
- def q_sample(self, x_start, t, noise=None):
200
- """
201
- Diffuse the data for a given number of diffusion steps.
202
- In other words, sample from q(x_t | x_0).
203
- :param x_start: the initial data batch.
204
- :param t: the number of diffusion steps (minus 1). Here, 0 means one step.
205
- :param noise: if specified, the split-out normal noise.
206
- :return: A noisy version of x_start.
207
- """
208
- if noise is None:
209
- noise = th.randn_like(x_start)
210
- assert noise.shape == x_start.shape
211
- return (
212
- _extract_into_tensor(self.sqrt_alphas_cumprod, t, x_start.shape) * x_start
213
- + _extract_into_tensor(self.sqrt_one_minus_alphas_cumprod, t, x_start.shape) * noise
214
- )
215
-
216
- def q_posterior_mean_variance(self, x_start, x_t, t):
217
- """
218
- Compute the mean and variance of the diffusion posterior:
219
- q(x_{t-1} | x_t, x_0)
220
- """
221
- assert x_start.shape == x_t.shape
222
- posterior_mean = (
223
- _extract_into_tensor(self.posterior_mean_coef1, t, x_t.shape) * x_start
224
- + _extract_into_tensor(self.posterior_mean_coef2, t, x_t.shape) * x_t
225
- )
226
- posterior_variance = _extract_into_tensor(self.posterior_variance, t, x_t.shape)
227
- posterior_log_variance_clipped = _extract_into_tensor(self.posterior_log_variance_clipped, t, x_t.shape)
228
- assert (
229
- posterior_mean.shape[0]
230
- == posterior_variance.shape[0]
231
- == posterior_log_variance_clipped.shape[0]
232
- == x_start.shape[0]
233
- )
234
- return posterior_mean, posterior_variance, posterior_log_variance_clipped
235
-
236
- def p_mean_variance(self, model, x, t, clip_denoised=True, denoised_fn=None, model_kwargs=None):
237
- """
238
- Apply the model to get p(x_{t-1} | x_t), as well as a prediction of
239
- the initial x, x_0.
240
- :param model: the model, which takes a signal and a batch of timesteps
241
- as input.
242
- :param x: the [N x C x ...] tensor at time t.
243
- :param t: a 1-D Tensor of timesteps.
244
- :param clip_denoised: if True, clip the denoised signal into [-1, 1].
245
- :param denoised_fn: if not None, a function which applies to the
246
- x_start prediction before it is used to sample. Applies before
247
- clip_denoised.
248
- :param model_kwargs: if not None, a dict of extra keyword arguments to
249
- pass to the model. This can be used for conditioning.
250
- :return: a dict with the following keys:
251
- - 'mean': the model mean output.
252
- - 'variance': the model variance output.
253
- - 'log_variance': the log of 'variance'.
254
- - 'pred_xstart': the prediction for x_0.
255
- """
256
- if model_kwargs is None:
257
- model_kwargs = {}
258
-
259
- B, C = x.shape[:2]
260
- assert t.shape == (B,)
261
- model_output = model(x, t, **model_kwargs)
262
- if isinstance(model_output, tuple):
263
- model_output, extra = model_output
264
- else:
265
- extra = None
266
-
267
- if self.model_var_type in [ModelVarType.LEARNED, ModelVarType.LEARNED_RANGE]:
268
- assert model_output.shape == (B, C * 2, *x.shape[2:])
269
- model_output, model_var_values = th.split(model_output, C, dim=1)
270
- min_log = _extract_into_tensor(self.posterior_log_variance_clipped, t, x.shape)
271
- max_log = _extract_into_tensor(np.log(self.betas), t, x.shape)
272
- # The model_var_values is [-1, 1] for [min_var, max_var].
273
- frac = (model_var_values + 1) / 2
274
- model_log_variance = frac * max_log + (1 - frac) * min_log
275
- model_variance = th.exp(model_log_variance)
276
- else:
277
- model_variance, model_log_variance = {
278
- # for fixedlarge, we set the initial (log-)variance like so
279
- # to get a better decoder log likelihood.
280
- ModelVarType.FIXED_LARGE: (
281
- np.append(self.posterior_variance[1], self.betas[1:]),
282
- np.log(np.append(self.posterior_variance[1], self.betas[1:])),
283
- ),
284
- ModelVarType.FIXED_SMALL: (
285
- self.posterior_variance,
286
- self.posterior_log_variance_clipped,
287
- ),
288
- }[self.model_var_type]
289
- model_variance = _extract_into_tensor(model_variance, t, x.shape)
290
- model_log_variance = _extract_into_tensor(model_log_variance, t, x.shape)
291
-
292
- def process_xstart(x):
293
- if denoised_fn is not None:
294
- x = denoised_fn(x)
295
- if clip_denoised:
296
- return x.clamp(-1, 1)
297
- return x
298
-
299
- if self.model_mean_type == ModelMeanType.START_X:
300
- pred_xstart = process_xstart(model_output)
301
- else:
302
- pred_xstart = process_xstart(self._predict_xstart_from_eps(x_t=x, t=t, eps=model_output))
303
- model_mean, _, _ = self.q_posterior_mean_variance(x_start=pred_xstart, x_t=x, t=t)
304
-
305
- assert model_mean.shape == model_log_variance.shape == pred_xstart.shape == x.shape
306
- return {
307
- "mean": model_mean,
308
- "variance": model_variance,
309
- "log_variance": model_log_variance,
310
- "pred_xstart": pred_xstart,
311
- "extra": extra,
312
- }
313
-
314
- def _predict_xstart_from_eps(self, x_t, t, eps):
315
- assert x_t.shape == eps.shape
316
- return (
317
- _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t
318
- - _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape) * eps
319
- )
320
-
321
- def _predict_eps_from_xstart(self, x_t, t, pred_xstart):
322
- return (
323
- _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x_t.shape) * x_t - pred_xstart
324
- ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x_t.shape)
325
-
326
- def condition_mean(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
327
- """
328
- Compute the mean for the previous step, given a function cond_fn that
329
- computes the gradient of a conditional log probability with respect to
330
- x. In particular, cond_fn computes grad(log(p(y|x))), and we want to
331
- condition on y.
332
- This uses the conditioning strategy from Sohl-Dickstein et al. (2015).
333
- """
334
- gradient = cond_fn(x, t, **model_kwargs)
335
- new_mean = p_mean_var["mean"].float() + p_mean_var["variance"] * gradient.float()
336
- return new_mean
337
-
338
- def condition_score(self, cond_fn, p_mean_var, x, t, model_kwargs=None):
339
- """
340
- Compute what the p_mean_variance output would have been, should the
341
- model's score function be conditioned by cond_fn.
342
- See condition_mean() for details on cond_fn.
343
- Unlike condition_mean(), this instead uses the conditioning strategy
344
- from Song et al (2020).
345
- """
346
- alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
347
-
348
- eps = self._predict_eps_from_xstart(x, t, p_mean_var["pred_xstart"])
349
- eps = eps - (1 - alpha_bar).sqrt() * cond_fn(x, t, **model_kwargs)
350
-
351
- out = p_mean_var.copy()
352
- out["pred_xstart"] = self._predict_xstart_from_eps(x, t, eps)
353
- out["mean"], _, _ = self.q_posterior_mean_variance(x_start=out["pred_xstart"], x_t=x, t=t)
354
- return out
355
-
356
- def p_sample(
357
- self,
358
- model,
359
- x,
360
- t,
361
- clip_denoised=True,
362
- denoised_fn=None,
363
- cond_fn=None,
364
- model_kwargs=None,
365
- ):
366
- """
367
- Sample x_{t-1} from the model at the given timestep.
368
- :param model: the model to sample from.
369
- :param x: the current tensor at x_{t-1}.
370
- :param t: the value of t, starting at 0 for the first diffusion step.
371
- :param clip_denoised: if True, clip the x_start prediction to [-1, 1].
372
- :param denoised_fn: if not None, a function which applies to the
373
- x_start prediction before it is used to sample.
374
- :param cond_fn: if not None, this is a gradient function that acts
375
- similarly to the model.
376
- :param model_kwargs: if not None, a dict of extra keyword arguments to
377
- pass to the model. This can be used for conditioning.
378
- :return: a dict containing the following keys:
379
- - 'sample': a random sample from the model.
380
- - 'pred_xstart': a prediction of x_0.
381
- """
382
- out = self.p_mean_variance(
383
- model,
384
- x,
385
- t,
386
- clip_denoised=clip_denoised,
387
- denoised_fn=denoised_fn,
388
- model_kwargs=model_kwargs,
389
- )
390
- noise = th.randn_like(x)
391
- nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
392
- if cond_fn is not None:
393
- out["mean"] = self.condition_mean(cond_fn, out, x, t, model_kwargs=model_kwargs)
394
- sample = out["mean"] + nonzero_mask * th.exp(0.5 * out["log_variance"]) * noise
395
- return {"sample": sample, "pred_xstart": out["pred_xstart"]}
396
-
397
- def p_sample_loop(
398
- self,
399
- model,
400
- shape,
401
- noise=None,
402
- clip_denoised=True,
403
- denoised_fn=None,
404
- cond_fn=None,
405
- model_kwargs=None,
406
- device=None,
407
- progress=False,
408
- ):
409
- """
410
- Generate samples from the model.
411
- :param model: the model module.
412
- :param shape: the shape of the samples, (N, C, H, W).
413
- :param noise: if specified, the noise from the encoder to sample.
414
- Should be of the same shape as `shape`.
415
- :param clip_denoised: if True, clip x_start predictions to [-1, 1].
416
- :param denoised_fn: if not None, a function which applies to the
417
- x_start prediction before it is used to sample.
418
- :param cond_fn: if not None, this is a gradient function that acts
419
- similarly to the model.
420
- :param model_kwargs: if not None, a dict of extra keyword arguments to
421
- pass to the model. This can be used for conditioning.
422
- :param device: if specified, the device to create the samples on.
423
- If not specified, use a model parameter's device.
424
- :param progress: if True, show a tqdm progress bar.
425
- :return: a non-differentiable batch of samples.
426
- """
427
- final = None
428
- for sample in self.p_sample_loop_progressive(
429
- model,
430
- shape,
431
- noise=noise,
432
- clip_denoised=clip_denoised,
433
- denoised_fn=denoised_fn,
434
- cond_fn=cond_fn,
435
- model_kwargs=model_kwargs,
436
- device=device,
437
- progress=progress,
438
- ):
439
- final = sample
440
- return final["sample"]
441
-
442
- def p_sample_loop_progressive(
443
- self,
444
- model,
445
- shape,
446
- noise=None,
447
- clip_denoised=True,
448
- denoised_fn=None,
449
- cond_fn=None,
450
- model_kwargs=None,
451
- device=None,
452
- progress=False,
453
- ):
454
- """
455
- Generate samples from the model and yield intermediate samples from
456
- each timestep of diffusion.
457
- Arguments are the same as p_sample_loop().
458
- Returns a generator over dicts, where each dict is the return value of
459
- p_sample().
460
- """
461
- if device is None:
462
- device = next(model.parameters()).device
463
- assert isinstance(shape, (tuple, list))
464
- if noise is not None:
465
- img = noise
466
- else:
467
- img = th.randn(*shape, device=device)
468
- indices = list(range(self.num_timesteps))[::-1]
469
-
470
- if progress:
471
- # Lazy import so that we don't depend on tqdm.
472
- from tqdm.auto import tqdm
473
-
474
- indices = tqdm(indices)
475
-
476
- for i in indices:
477
- t = th.tensor([i] * shape[0], device=device)
478
- with th.no_grad():
479
- out = self.p_sample(
480
- model,
481
- img,
482
- t,
483
- clip_denoised=clip_denoised,
484
- denoised_fn=denoised_fn,
485
- cond_fn=cond_fn,
486
- model_kwargs=model_kwargs,
487
- )
488
- yield out
489
- img = out["sample"]
490
-
491
- def ddim_sample(
492
- self,
493
- model,
494
- x,
495
- t,
496
- clip_denoised=True,
497
- denoised_fn=None,
498
- cond_fn=None,
499
- model_kwargs=None,
500
- eta=0.0,
501
- ):
502
- """
503
- Sample x_{t-1} from the model using DDIM.
504
- Same usage as p_sample().
505
- """
506
- out = self.p_mean_variance(
507
- model,
508
- x,
509
- t,
510
- clip_denoised=clip_denoised,
511
- denoised_fn=denoised_fn,
512
- model_kwargs=model_kwargs,
513
- )
514
- if cond_fn is not None:
515
- out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
516
-
517
- # Usually our model outputs epsilon, but we re-derive it
518
- # in case we used x_start or x_prev prediction.
519
- eps = self._predict_eps_from_xstart(x, t, out["pred_xstart"])
520
-
521
- alpha_bar = _extract_into_tensor(self.alphas_cumprod, t, x.shape)
522
- alpha_bar_prev = _extract_into_tensor(self.alphas_cumprod_prev, t, x.shape)
523
- sigma = eta * th.sqrt((1 - alpha_bar_prev) / (1 - alpha_bar)) * th.sqrt(1 - alpha_bar / alpha_bar_prev)
524
- # Equation 12.
525
- noise = th.randn_like(x)
526
- mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_prev) + th.sqrt(1 - alpha_bar_prev - sigma**2) * eps
527
- nonzero_mask = (t != 0).float().view(-1, *([1] * (len(x.shape) - 1))) # no noise when t == 0
528
- sample = mean_pred + nonzero_mask * sigma * noise
529
- return {"sample": sample, "pred_xstart": out["pred_xstart"]}
530
-
531
- def ddim_reverse_sample(
532
- self,
533
- model,
534
- x,
535
- t,
536
- clip_denoised=True,
537
- denoised_fn=None,
538
- cond_fn=None,
539
- model_kwargs=None,
540
- eta=0.0,
541
- ):
542
- """
543
- Sample x_{t+1} from the model using DDIM reverse ODE.
544
- """
545
- assert eta == 0.0, "Reverse ODE only for deterministic path"
546
- out = self.p_mean_variance(
547
- model,
548
- x,
549
- t,
550
- clip_denoised=clip_denoised,
551
- denoised_fn=denoised_fn,
552
- model_kwargs=model_kwargs,
553
- )
554
- if cond_fn is not None:
555
- out = self.condition_score(cond_fn, out, x, t, model_kwargs=model_kwargs)
556
- # Usually our model outputs epsilon, but we re-derive it
557
- # in case we used x_start or x_prev prediction.
558
- eps = (
559
- _extract_into_tensor(self.sqrt_recip_alphas_cumprod, t, x.shape) * x - out["pred_xstart"]
560
- ) / _extract_into_tensor(self.sqrt_recipm1_alphas_cumprod, t, x.shape)
561
- alpha_bar_next = _extract_into_tensor(self.alphas_cumprod_next, t, x.shape)
562
-
563
- # Equation 12. reversed
564
- mean_pred = out["pred_xstart"] * th.sqrt(alpha_bar_next) + th.sqrt(1 - alpha_bar_next) * eps
565
-
566
- return {"sample": mean_pred, "pred_xstart": out["pred_xstart"]}
567
-
568
- def ddim_sample_loop(
569
- self,
570
- model,
571
- shape,
572
- noise=None,
573
- clip_denoised=True,
574
- denoised_fn=None,
575
- cond_fn=None,
576
- model_kwargs=None,
577
- device=None,
578
- progress=False,
579
- eta=0.0,
580
- ):
581
- """
582
- Generate samples from the model using DDIM.
583
- Same usage as p_sample_loop().
584
- """
585
- final = None
586
- for sample in self.ddim_sample_loop_progressive(
587
- model,
588
- shape,
589
- noise=noise,
590
- clip_denoised=clip_denoised,
591
- denoised_fn=denoised_fn,
592
- cond_fn=cond_fn,
593
- model_kwargs=model_kwargs,
594
- device=device,
595
- progress=progress,
596
- eta=eta,
597
- ):
598
- final = sample
599
- return final["sample"]
600
-
601
- def ddim_sample_loop_progressive(
602
- self,
603
- model,
604
- shape,
605
- noise=None,
606
- clip_denoised=True,
607
- denoised_fn=None,
608
- cond_fn=None,
609
- model_kwargs=None,
610
- device=None,
611
- progress=False,
612
- eta=0.0,
613
- ):
614
- """
615
- Use DDIM to sample from the model and yield intermediate samples from
616
- each timestep of DDIM.
617
- Same usage as p_sample_loop_progressive().
618
- """
619
- if device is None:
620
- device = next(model.parameters()).device
621
- assert isinstance(shape, (tuple, list))
622
- if noise is not None:
623
- img = noise
624
- else:
625
- img = th.randn(*shape, device=device)
626
- indices = list(range(self.num_timesteps))[::-1]
627
-
628
- if progress:
629
- # Lazy import so that we don't depend on tqdm.
630
- from tqdm.auto import tqdm
631
-
632
- indices = tqdm(indices)
633
-
634
- for i in indices:
635
- t = th.tensor([i] * shape[0], device=device)
636
- with th.no_grad():
637
- out = self.ddim_sample(
638
- model,
639
- img,
640
- t,
641
- clip_denoised=clip_denoised,
642
- denoised_fn=denoised_fn,
643
- cond_fn=cond_fn,
644
- model_kwargs=model_kwargs,
645
- eta=eta,
646
- )
647
- yield out
648
- img = out["sample"]
649
-
650
- def _vb_terms_bpd(self, model, x_start, x_t, t, clip_denoised=True, model_kwargs=None):
651
- """
652
- Get a term for the variational lower-bound.
653
- The resulting units are bits (rather than nats, as one might expect).
654
- This allows for comparison to other papers.
655
- :return: a dict with the following keys:
656
- - 'output': a shape [N] tensor of NLLs or KLs.
657
- - 'pred_xstart': the x_0 predictions.
658
- """
659
- true_mean, _, true_log_variance_clipped = self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)
660
- out = self.p_mean_variance(model, x_t, t, clip_denoised=clip_denoised, model_kwargs=model_kwargs)
661
- kl = normal_kl(true_mean, true_log_variance_clipped, out["mean"], out["log_variance"])
662
- kl = mean_flat(kl) / np.log(2.0)
663
-
664
- decoder_nll = -discretized_gaussian_log_likelihood(
665
- x_start, means=out["mean"], log_scales=0.5 * out["log_variance"]
666
- )
667
- assert decoder_nll.shape == x_start.shape
668
- decoder_nll = mean_flat(decoder_nll) / np.log(2.0)
669
-
670
- # At the first timestep return the decoder NLL,
671
- # otherwise return KL(q(x_{t-1}|x_t,x_0) || p(x_{t-1}|x_t))
672
- output = th.where((t == 0), decoder_nll, kl)
673
- return {"output": output, "pred_xstart": out["pred_xstart"]}
674
-
675
- def training_losses(self, model, x_start, t, model_kwargs=None, noise=None):
676
- """
677
- Compute training losses for a single timestep.
678
- :param model: the model to evaluate loss on.
679
- :param x_start: the [N x C x ...] tensor of inputs.
680
- :param t: a batch of timestep indices.
681
- :param model_kwargs: if not None, a dict of extra keyword arguments to
682
- pass to the model. This can be used for conditioning.
683
- :param noise: if specified, the specific Gaussian noise to try to remove.
684
- :return: a dict with the key "loss" containing a tensor of shape [N].
685
- Some mean or variance settings may also have other keys.
686
- """
687
- if model_kwargs is None:
688
- model_kwargs = {}
689
- if noise is None:
690
- noise = th.randn_like(x_start)
691
- x_t = self.q_sample(x_start, t, noise=noise)
692
-
693
- terms = {}
694
-
695
- if self.loss_type == LossType.KL or self.loss_type == LossType.RESCALED_KL:
696
- terms["loss"] = self._vb_terms_bpd(
697
- model=model,
698
- x_start=x_start,
699
- x_t=x_t,
700
- t=t,
701
- clip_denoised=False,
702
- model_kwargs=model_kwargs,
703
- )["output"]
704
- if self.loss_type == LossType.RESCALED_KL:
705
- terms["loss"] *= self.num_timesteps
706
- elif self.loss_type == LossType.MSE or self.loss_type == LossType.RESCALED_MSE:
707
- model_output = model(x_t, t, **model_kwargs)
708
-
709
- if self.model_var_type in [
710
- ModelVarType.LEARNED,
711
- ModelVarType.LEARNED_RANGE,
712
- ]:
713
- B, C = x_t.shape[:2]
714
- assert model_output.shape == (B, C * 2, *x_t.shape[2:])
715
- model_output, model_var_values = th.split(model_output, C, dim=1)
716
- # Learn the variance using the variational bound, but don't let
717
- # it affect our mean prediction.
718
- frozen_out = th.cat([model_output.detach(), model_var_values], dim=1)
719
- terms["vb"] = self._vb_terms_bpd(
720
- model=lambda *args, r=frozen_out: r,
721
- x_start=x_start,
722
- x_t=x_t,
723
- t=t,
724
- clip_denoised=False,
725
- )["output"]
726
- if self.loss_type == LossType.RESCALED_MSE:
727
- # Divide by 1000 for equivalence with initial implementation.
728
- # Without a factor of 1/1000, the VB term hurts the MSE term.
729
- terms["vb"] *= self.num_timesteps / 1000.0
730
-
731
- target = {
732
- ModelMeanType.PREVIOUS_X: self.q_posterior_mean_variance(x_start=x_start, x_t=x_t, t=t)[0],
733
- ModelMeanType.START_X: x_start,
734
- ModelMeanType.EPSILON: noise,
735
- }[self.model_mean_type]
736
- assert model_output.shape == target.shape == x_start.shape
737
- terms["mse"] = mean_flat((target - model_output) ** 2)
738
- if "vb" in terms:
739
- terms["loss"] = terms["mse"] + terms["vb"]
740
- else:
741
- terms["loss"] = terms["mse"]
742
- else:
743
- raise NotImplementedError(self.loss_type)
744
-
745
- return terms
746
-
747
- def _prior_bpd(self, x_start):
748
- """
749
- Get the prior KL term for the variational lower-bound, measured in
750
- bits-per-dim.
751
- This term can't be optimized, as it only depends on the encoder.
752
- :param x_start: the [N x C x ...] tensor of inputs.
753
- :return: a batch of [N] KL values (in bits), one per batch element.
754
- """
755
- batch_size = x_start.shape[0]
756
- t = th.tensor([self.num_timesteps - 1] * batch_size, device=x_start.device)
757
- qt_mean, _, qt_log_variance = self.q_mean_variance(x_start, t)
758
- kl_prior = normal_kl(mean1=qt_mean, logvar1=qt_log_variance, mean2=0.0, logvar2=0.0)
759
- return mean_flat(kl_prior) / np.log(2.0)
760
-
761
- def calc_bpd_loop(self, model, x_start, clip_denoised=True, model_kwargs=None):
762
- """
763
- Compute the entire variational lower-bound, measured in bits-per-dim,
764
- as well as other related quantities.
765
- :param model: the model to evaluate loss on.
766
- :param x_start: the [N x C x ...] tensor of inputs.
767
- :param clip_denoised: if True, clip denoised samples.
768
- :param model_kwargs: if not None, a dict of extra keyword arguments to
769
- pass to the model. This can be used for conditioning.
770
- :return: a dict containing the following keys:
771
- - total_bpd: the total variational lower-bound, per batch element.
772
- - prior_bpd: the prior term in the lower-bound.
773
- - vb: an [N x T] tensor of terms in the lower-bound.
774
- - xstart_mse: an [N x T] tensor of x_0 MSEs for each timestep.
775
- - mse: an [N x T] tensor of epsilon MSEs for each timestep.
776
- """
777
- device = x_start.device
778
- batch_size = x_start.shape[0]
779
-
780
- vb = []
781
- xstart_mse = []
782
- mse = []
783
- for t in list(range(self.num_timesteps))[::-1]:
784
- t_batch = th.tensor([t] * batch_size, device=device)
785
- noise = th.randn_like(x_start)
786
- x_t = self.q_sample(x_start=x_start, t=t_batch, noise=noise)
787
- # Calculate VLB term at the current timestep
788
- with th.no_grad():
789
- out = self._vb_terms_bpd(
790
- model,
791
- x_start=x_start,
792
- x_t=x_t,
793
- t=t_batch,
794
- clip_denoised=clip_denoised,
795
- model_kwargs=model_kwargs,
796
- )
797
- vb.append(out["output"])
798
- xstart_mse.append(mean_flat((out["pred_xstart"] - x_start) ** 2))
799
- eps = self._predict_eps_from_xstart(x_t, t_batch, out["pred_xstart"])
800
- mse.append(mean_flat((eps - noise) ** 2))
801
-
802
- vb = th.stack(vb, dim=1)
803
- xstart_mse = th.stack(xstart_mse, dim=1)
804
- mse = th.stack(mse, dim=1)
805
-
806
- prior_bpd = self._prior_bpd(x_start)
807
- total_bpd = vb.sum(dim=1) + prior_bpd
808
- return {
809
- "total_bpd": total_bpd,
810
- "prior_bpd": prior_bpd,
811
- "vb": vb,
812
- "xstart_mse": xstart_mse,
813
- "mse": mse,
814
- }
815
-
816
-
817
- def _extract_into_tensor(arr, timesteps, broadcast_shape):
818
- """
819
- Extract values from a 1-D numpy array for a batch of indices.
820
- :param arr: the 1-D numpy array.
821
- :param timesteps: a tensor of indices into the array to extract.
822
- :param broadcast_shape: a larger shape of K dimensions with the batch
823
- dimension equal to the length of timesteps.
824
- :return: a tensor of shape [batch_size, 1, ...] where the shape has K dims.
825
- """
826
- res = th.from_numpy(arr).to(device=timesteps.device)[timesteps].float()
827
- while len(res.shape) < len(broadcast_shape):
828
- res = res[..., None]
829
- return res + th.zeros(broadcast_shape, device=timesteps.device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/diffusion/respace.py DELETED
@@ -1,119 +0,0 @@
1
- # Modified from OpenAI's diffusion repos
2
- # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
- # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
- # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
-
6
- import numpy as np
7
- import torch as th
8
-
9
- from .gaussian_diffusion import GaussianDiffusion
10
-
11
-
12
- def space_timesteps(num_timesteps, section_counts):
13
- """
14
- Create a list of timesteps to use from an original diffusion process,
15
- given the number of timesteps we want to take from equally-sized portions
16
- of the original process.
17
- For example, if there's 300 timesteps and the section counts are [10,15,20]
18
- then the first 100 timesteps are strided to be 10 timesteps, the second 100
19
- are strided to be 15 timesteps, and the final 100 are strided to be 20.
20
- If the stride is a string starting with "ddim", then the fixed striding
21
- from the DDIM paper is used, and only one section is allowed.
22
- :param num_timesteps: the number of diffusion steps in the original
23
- process to divide up.
24
- :param section_counts: either a list of numbers, or a string containing
25
- comma-separated numbers, indicating the step count
26
- per section. As a special case, use "ddimN" where N
27
- is a number of steps to use the striding from the
28
- DDIM paper.
29
- :return: a set of diffusion steps from the original process to use.
30
- """
31
- if isinstance(section_counts, str):
32
- if section_counts.startswith("ddim"):
33
- desired_count = int(section_counts[len("ddim") :])
34
- for i in range(1, num_timesteps):
35
- if len(range(0, num_timesteps, i)) == desired_count:
36
- return set(range(0, num_timesteps, i))
37
- raise ValueError(f"cannot create exactly {num_timesteps} steps with an integer stride")
38
- section_counts = [int(x) for x in section_counts.split(",")]
39
- size_per = num_timesteps // len(section_counts)
40
- extra = num_timesteps % len(section_counts)
41
- start_idx = 0
42
- all_steps = []
43
- for i, section_count in enumerate(section_counts):
44
- size = size_per + (1 if i < extra else 0)
45
- if size < section_count:
46
- raise ValueError(f"cannot divide section of {size} steps into {section_count}")
47
- if section_count <= 1:
48
- frac_stride = 1
49
- else:
50
- frac_stride = (size - 1) / (section_count - 1)
51
- cur_idx = 0.0
52
- taken_steps = []
53
- for _ in range(section_count):
54
- taken_steps.append(start_idx + round(cur_idx))
55
- cur_idx += frac_stride
56
- all_steps += taken_steps
57
- start_idx += size
58
- return set(all_steps)
59
-
60
-
61
- class SpacedDiffusion(GaussianDiffusion):
62
- """
63
- A diffusion process which can skip steps in a base diffusion process.
64
- :param use_timesteps: a collection (sequence or set) of timesteps from the
65
- original diffusion process to retain.
66
- :param kwargs: the kwargs to create the base diffusion process.
67
- """
68
-
69
- def __init__(self, use_timesteps, **kwargs):
70
- self.use_timesteps = set(use_timesteps)
71
- self.timestep_map = []
72
- self.original_num_steps = len(kwargs["betas"])
73
-
74
- base_diffusion = GaussianDiffusion(**kwargs) # pylint: disable=missing-kwoa
75
- last_alpha_cumprod = 1.0
76
- new_betas = []
77
- for i, alpha_cumprod in enumerate(base_diffusion.alphas_cumprod):
78
- if i in self.use_timesteps:
79
- new_betas.append(1 - alpha_cumprod / last_alpha_cumprod)
80
- last_alpha_cumprod = alpha_cumprod
81
- self.timestep_map.append(i)
82
- kwargs["betas"] = np.array(new_betas)
83
- super().__init__(**kwargs)
84
-
85
- def p_mean_variance(self, model, *args, **kwargs): # pylint: disable=signature-differs
86
- return super().p_mean_variance(self._wrap_model(model), *args, **kwargs)
87
-
88
- def training_losses(self, model, *args, **kwargs): # pylint: disable=signature-differs
89
- return super().training_losses(self._wrap_model(model), *args, **kwargs)
90
-
91
- def condition_mean(self, cond_fn, *args, **kwargs):
92
- return super().condition_mean(self._wrap_model(cond_fn), *args, **kwargs)
93
-
94
- def condition_score(self, cond_fn, *args, **kwargs):
95
- return super().condition_score(self._wrap_model(cond_fn), *args, **kwargs)
96
-
97
- def _wrap_model(self, model):
98
- if isinstance(model, _WrappedModel):
99
- return model
100
- return _WrappedModel(model, self.timestep_map, self.original_num_steps)
101
-
102
- def _scale_timesteps(self, t):
103
- # Scaling is done by the wrapped model.
104
- return t
105
-
106
-
107
- class _WrappedModel:
108
- def __init__(self, model, timestep_map, original_num_steps):
109
- self.model = model
110
- self.timestep_map = timestep_map
111
- # self.rescale_timesteps = rescale_timesteps
112
- self.original_num_steps = original_num_steps
113
-
114
- def __call__(self, x, ts, **kwargs):
115
- map_tensor = th.tensor(self.timestep_map, device=ts.device, dtype=ts.dtype)
116
- new_ts = map_tensor[ts]
117
- # if self.rescale_timesteps:
118
- # new_ts = new_ts.float() * (1000.0 / self.original_num_steps)
119
- return self.model(x, new_ts, **kwargs)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/diffusion/timestep_sampler.py DELETED
@@ -1,143 +0,0 @@
1
- # Modified from OpenAI's diffusion repos
2
- # GLIDE: https://github.com/openai/glide-text2im/blob/main/glide_text2im/gaussian_diffusion.py
3
- # ADM: https://github.com/openai/guided-diffusion/blob/main/guided_diffusion
4
- # IDDPM: https://github.com/openai/improved-diffusion/blob/main/improved_diffusion/gaussian_diffusion.py
5
-
6
- from abc import ABC, abstractmethod
7
-
8
- import numpy as np
9
- import torch as th
10
- import torch.distributed as dist
11
-
12
-
13
- def create_named_schedule_sampler(name, diffusion):
14
- """
15
- Create a ScheduleSampler from a library of pre-defined samplers.
16
- :param name: the name of the sampler.
17
- :param diffusion: the diffusion object to sample for.
18
- """
19
- if name == "uniform":
20
- return UniformSampler(diffusion)
21
- elif name == "loss-second-moment":
22
- return LossSecondMomentResampler(diffusion)
23
- else:
24
- raise NotImplementedError(f"unknown schedule sampler: {name}")
25
-
26
-
27
- class ScheduleSampler(ABC):
28
- """
29
- A distribution over timesteps in the diffusion process, intended to reduce
30
- variance of the objective.
31
- By default, samplers perform unbiased importance sampling, in which the
32
- objective's mean is unchanged.
33
- However, subclasses may override sample() to change how the resampled
34
- terms are reweighted, allowing for actual changes in the objective.
35
- """
36
-
37
- @abstractmethod
38
- def weights(self):
39
- """
40
- Get a numpy array of weights, one per diffusion step.
41
- The weights needn't be normalized, but must be positive.
42
- """
43
-
44
- def sample(self, batch_size, device):
45
- """
46
- Importance-sample timesteps for a batch.
47
- :param batch_size: the number of timesteps.
48
- :param device: the torch device to save to.
49
- :return: a tuple (timesteps, weights):
50
- - timesteps: a tensor of timestep indices.
51
- - weights: a tensor of weights to scale the resulting losses.
52
- """
53
- w = self.weights()
54
- p = w / np.sum(w)
55
- indices_np = np.random.choice(len(p), size=(batch_size,), p=p)
56
- indices = th.from_numpy(indices_np).long().to(device)
57
- weights_np = 1 / (len(p) * p[indices_np])
58
- weights = th.from_numpy(weights_np).float().to(device)
59
- return indices, weights
60
-
61
-
62
- class UniformSampler(ScheduleSampler):
63
- def __init__(self, diffusion):
64
- self.diffusion = diffusion
65
- self._weights = np.ones([diffusion.num_timesteps])
66
-
67
- def weights(self):
68
- return self._weights
69
-
70
-
71
- class LossAwareSampler(ScheduleSampler):
72
- def update_with_local_losses(self, local_ts, local_losses):
73
- """
74
- Update the reweighting using losses from a model.
75
- Call this method from each rank with a batch of timesteps and the
76
- corresponding losses for each of those timesteps.
77
- This method will perform synchronization to make sure all of the ranks
78
- maintain the exact same reweighting.
79
- :param local_ts: an integer Tensor of timesteps.
80
- :param local_losses: a 1D Tensor of losses.
81
- """
82
- batch_sizes = [th.tensor([0], dtype=th.int32, device=local_ts.device) for _ in range(dist.get_world_size())]
83
- dist.all_gather(
84
- batch_sizes,
85
- th.tensor([len(local_ts)], dtype=th.int32, device=local_ts.device),
86
- )
87
-
88
- # Pad all_gather batches to be the maximum batch size.
89
- batch_sizes = [x.item() for x in batch_sizes]
90
- max_bs = max(batch_sizes)
91
-
92
- timestep_batches = [th.zeros(max_bs).to(local_ts) for bs in batch_sizes]
93
- loss_batches = [th.zeros(max_bs).to(local_losses) for bs in batch_sizes]
94
- dist.all_gather(timestep_batches, local_ts)
95
- dist.all_gather(loss_batches, local_losses)
96
- timesteps = [x.item() for y, bs in zip(timestep_batches, batch_sizes) for x in y[:bs]]
97
- losses = [x.item() for y, bs in zip(loss_batches, batch_sizes) for x in y[:bs]]
98
- self.update_with_all_losses(timesteps, losses)
99
-
100
- @abstractmethod
101
- def update_with_all_losses(self, ts, losses):
102
- """
103
- Update the reweighting using losses from a model.
104
- Sub-classes should override this method to update the reweighting
105
- using losses from the model.
106
- This method directly updates the reweighting without synchronizing
107
- between workers. It is called by update_with_local_losses from all
108
- ranks with identical arguments. Thus, it should have deterministic
109
- behavior to maintain state across workers.
110
- :param ts: a list of int timesteps.
111
- :param losses: a list of float losses, one per timestep.
112
- """
113
-
114
-
115
- class LossSecondMomentResampler(LossAwareSampler):
116
- def __init__(self, diffusion, history_per_term=10, uniform_prob=0.001):
117
- self.diffusion = diffusion
118
- self.history_per_term = history_per_term
119
- self.uniform_prob = uniform_prob
120
- self._loss_history = np.zeros([diffusion.num_timesteps, history_per_term], dtype=np.float64)
121
- self._loss_counts = np.zeros([diffusion.num_timesteps], dtype=np.int)
122
-
123
- def weights(self):
124
- if not self._warmed_up():
125
- return np.ones([self.diffusion.num_timesteps], dtype=np.float64)
126
- weights = np.sqrt(np.mean(self._loss_history**2, axis=-1))
127
- weights /= np.sum(weights)
128
- weights *= 1 - self.uniform_prob
129
- weights += self.uniform_prob / len(weights)
130
- return weights
131
-
132
- def update_with_all_losses(self, ts, losses):
133
- for t, loss in zip(ts, losses):
134
- if self._loss_counts[t] == self.history_per_term:
135
- # Shift out the oldest loss term.
136
- self._loss_history[t, :-1] = self._loss_history[t, 1:]
137
- self._loss_history[t, -1] = loss
138
- else:
139
- self._loss_history[t, self._loss_counts[t]] = loss
140
- self._loss_counts[t] += 1
141
-
142
- def _warmed_up(self):
143
- return (self._loss_counts == self.history_per_term).all()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{eval/pab/commom_metrics β†’ videosys/models/autoencoders}/__init__.py RENAMED
File without changes
videosys/models/{cogvideo/autoencoder_kl.py β†’ autoencoders/autoencoder_kl_cogvideox.py} RENAMED
@@ -20,16 +20,16 @@ from diffusers.models.activations import get_activation
20
  from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
21
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
22
  from diffusers.models.modeling_utils import ModelMixin
23
- from diffusers.utils import logging
24
  from diffusers.utils.accelerate_utils import apply_forward_hook
25
 
26
- from .modules import CogVideoXDownsample3D, CogVideoXUpsample3D
27
 
28
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
29
 
30
 
31
  class CogVideoXSafeConv3d(nn.Conv3d):
32
- """
33
  A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
34
  """
35
 
@@ -61,12 +61,12 @@ class CogVideoXCausalConv3d(nn.Module):
61
  r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
62
 
63
  Args:
64
- in_channels (int): Number of channels in the input tensor.
65
- out_channels (int): Number of output channels.
66
- kernel_size (Union[int, Tuple[int, int, int]]): Size of the convolutional kernel.
67
- stride (int, optional): Stride of the convolution. Default is 1.
68
- dilation (int, optional): Dilation rate of the convolution. Default is 1.
69
- pad_mode (str, optional): Padding mode. Default is "constant".
70
  """
71
 
72
  def __init__(
@@ -111,19 +111,10 @@ class CogVideoXCausalConv3d(nn.Module):
111
  self.conv_cache = None
112
 
113
  def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
114
- dim = self.temporal_dim
115
  kernel_size = self.time_kernel_size
116
- if kernel_size == 1:
117
- return inputs
118
-
119
- inputs = inputs.transpose(0, dim)
120
-
121
- if self.conv_cache is not None:
122
- inputs = torch.cat([self.conv_cache.transpose(0, dim).to(inputs.device), inputs], dim=0)
123
- else:
124
- inputs = torch.cat([inputs[:1]] * (kernel_size - 1) + [inputs], dim=0)
125
-
126
- inputs = inputs.transpose(0, dim).contiguous()
127
  return inputs
128
 
129
  def _clear_fake_context_parallel_cache(self):
@@ -131,16 +122,17 @@ class CogVideoXCausalConv3d(nn.Module):
131
  self.conv_cache = None
132
 
133
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
134
- input_parallel = self.fake_context_parallel_forward(inputs)
135
 
136
  self._clear_fake_context_parallel_cache()
137
- self.conv_cache = input_parallel[:, :, -self.time_kernel_size + 1 :].contiguous().detach().clone().cpu()
 
 
138
 
139
  padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
140
- input_parallel = F.pad(input_parallel, padding_2d, mode="constant", value=0)
141
 
142
- output_parallel = self.conv(input_parallel)
143
- output = output_parallel
144
  return output
145
 
146
 
@@ -156,6 +148,8 @@ class CogVideoXSpatialNorm3D(nn.Module):
156
  The number of channels for input to group normalization layer, and output of the spatial norm layer.
157
  zq_channels (`int`):
158
  The number of channels for the quantized vector as described in the paper.
 
 
159
  """
160
 
161
  def __init__(
@@ -190,17 +184,26 @@ class CogVideoXResnetBlock3D(nn.Module):
190
  A 3D ResNet block used in the CogVideoX model.
191
 
192
  Args:
193
- in_channels (int): Number of input channels.
194
- out_channels (Optional[int], optional):
195
- Number of output channels. If None, defaults to `in_channels`. Default is None.
196
- dropout (float, optional): Dropout rate. Default is 0.0.
197
- temb_channels (int, optional): Number of time embedding channels. Default is 512.
198
- groups (int, optional): Number of groups for group normalization. Default is 32.
199
- eps (float, optional): Epsilon value for normalization layers. Default is 1e-6.
200
- non_linearity (str, optional): Activation function to use. Default is "swish".
201
- conv_shortcut (bool, optional): If True, use a convolutional shortcut. Default is False.
202
- spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
203
- pad_mode (str, optional): Padding mode. Default is "first".
 
 
 
 
 
 
 
 
 
204
  """
205
 
206
  def __init__(
@@ -302,18 +305,28 @@ class CogVideoXDownBlock3D(nn.Module):
302
  A downsampling block used in the CogVideoX model.
303
 
304
  Args:
305
- in_channels (int): Number of input channels.
306
- out_channels (int): Number of output channels.
307
- temb_channels (int): Number of time embedding channels.
308
- dropout (float, optional): Dropout rate. Default is 0.0.
309
- num_layers (int, optional): Number of layers in the block. Default is 1.
310
- resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
311
- resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
312
- resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
313
- add_downsample (bool, optional): If True, add a downsampling layer at the end of the block. Default is True.
314
- downsample_padding (int, optional): Padding for the downsampling layer. Default is 0.
315
- compress_time (bool, optional): If True, apply temporal compression. Default is False.
316
- pad_mode (str, optional): Padding mode. Default is "first".
 
 
 
 
 
 
 
 
 
 
317
  """
318
 
319
  _supports_gradient_checkpointing = True
@@ -398,15 +411,24 @@ class CogVideoXMidBlock3D(nn.Module):
398
  A middle block used in the CogVideoX model.
399
 
400
  Args:
401
- in_channels (int): Number of input channels.
402
- temb_channels (int): Number of time embedding channels.
403
- dropout (float, optional): Dropout rate. Default is 0.0.
404
- num_layers (int, optional): Number of layers in the block. Default is 1.
405
- resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
406
- resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
407
- resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
408
- spatial_norm_dim (Optional[int], optional): Dimension of the spatial normalization. Default is None.
409
- pad_mode (str, optional): Padding mode. Default is "first".
 
 
 
 
 
 
 
 
 
410
  """
411
 
412
  _supports_gradient_checkpointing = True
@@ -473,19 +495,30 @@ class CogVideoXUpBlock3D(nn.Module):
473
  An upsampling block used in the CogVideoX model.
474
 
475
  Args:
476
- in_channels (int): Number of input channels.
477
- out_channels (int): Number of output channels.
478
- temb_channels (int): Number of time embedding channels.
479
- dropout (float, optional): Dropout rate. Default is 0.0.
480
- num_layers (int, optional): Number of layers in the block. Default is 1.
481
- resnet_eps (float, optional): Epsilon value for the ResNet layers. Default is 1e-6.
482
- resnet_act_fn (str, optional): Activation function for the ResNet layers. Default is "swish".
483
- resnet_groups (int, optional): Number of groups for group normalization in the ResNet layers. Default is 32.
484
- spatial_norm_dim (int, optional): Dimension of the spatial normalization. Default is 16.
485
- add_upsample (bool, optional): If True, add an upsampling layer at the end of the block. Default is True.
486
- upsample_padding (int, optional): Padding for the upsampling layer. Default is 1.
487
- compress_time (bool, optional): If True, apply temporal compression. Default is False.
488
- pad_mode (str, optional): Padding mode. Default is "first".
 
 
 
 
 
 
 
 
 
 
 
489
  """
490
 
491
  def __init__(
@@ -576,14 +609,12 @@ class CogVideoXEncoder3D(nn.Module):
576
  options.
577
  block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
578
  The number of output channels for each block.
 
 
579
  layers_per_block (`int`, *optional*, defaults to 2):
580
  The number of layers per block.
581
  norm_num_groups (`int`, *optional*, defaults to 32):
582
  The number of groups for normalization.
583
- act_fn (`str`, *optional*, defaults to `"silu"`):
584
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
585
- double_z (`bool`, *optional*, defaults to `True`):
586
- Whether to double the number of output channels for the last block.
587
  """
588
 
589
  _supports_gradient_checkpointing = True
@@ -712,14 +743,12 @@ class CogVideoXDecoder3D(nn.Module):
712
  The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
713
  block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
714
  The number of output channels for each block.
 
 
715
  layers_per_block (`int`, *optional*, defaults to 2):
716
  The number of layers per block.
717
  norm_num_groups (`int`, *optional*, defaults to 32):
718
  The number of groups for normalization.
719
- act_fn (`str`, *optional*, defaults to `"silu"`):
720
- The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
721
- norm_type (`str`, *optional*, defaults to `"group"`):
722
- The normalization type to use. Can be either `"group"` or `"spatial"`.
723
  """
724
 
725
  _supports_gradient_checkpointing = True
@@ -860,7 +889,7 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
860
  Tuple of block output channels.
861
  act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
862
  sample_size (`int`, *optional*, defaults to `32`): Sample input size.
863
- scaling_factor (`float`, *optional*, defaults to 0.18215):
864
  The component-wise standard deviation of the trained latent space computed using the first batch of the
865
  training set. This is used to scale the latent space to have unit variance when training the diffusion
866
  model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
@@ -900,7 +929,8 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
900
  norm_eps: float = 1e-6,
901
  norm_num_groups: int = 32,
902
  temporal_compression_ratio: float = 4,
903
- sample_size: int = 256,
 
904
  scaling_factor: float = 1.15258426,
905
  shift_factor: Optional[float] = None,
906
  latents_mean: Optional[Tuple[float]] = None,
@@ -939,25 +969,105 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
939
  self.use_slicing = False
940
  self.use_tiling = False
941
 
942
- self.tile_sample_min_size = self.config.sample_size
943
- sample_size = (
944
- self.config.sample_size[0]
945
- if isinstance(self.config.sample_size, (list, tuple))
946
- else self.config.sample_size
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
947
  )
948
- self.tile_latent_min_size = int(sample_size / (2 ** (len(self.config.block_out_channels) - 1)))
949
- self.tile_overlap_factor = 0.25
 
 
 
 
 
950
 
951
  def _set_gradient_checkpointing(self, module, value=False):
952
  if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
953
  module.gradient_checkpointing = value
954
 
955
- def clear_fake_context_parallel_cache(self):
956
  for name, module in self.named_modules():
957
  if isinstance(module, CogVideoXCausalConv3d):
958
  logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
959
  module._clear_fake_context_parallel_cache()
960
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
961
  @apply_forward_hook
962
  def encode(
963
  self, x: torch.Tensor, return_dict: bool = True
@@ -982,8 +1092,34 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
982
  return (posterior,)
983
  return AutoencoderKLOutput(latent_dist=posterior)
984
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
985
  @apply_forward_hook
986
- def decode(self, z: torch.FloatTensor, return_dict: bool = True) -> Union[DecoderOutput, torch.FloatTensor]:
987
  """
988
  Decode a batch of images.
989
 
@@ -996,13 +1132,111 @@ class AutoencoderKLCogVideoX(ModelMixin, ConfigMixin, FromOriginalModelMixin):
996
  [`~models.vae.DecoderOutput`] or `tuple`:
997
  If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
998
  returned.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
999
 
 
 
 
 
1000
  """
1001
- if self.post_quant_conv is not None:
1002
- z = self.post_quant_conv(z)
1003
- dec = self.decoder(z)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1004
  if not return_dict:
1005
  return (dec,)
 
1006
  return DecoderOutput(sample=dec)
1007
 
1008
  def forward(
 
20
  from diffusers.models.autoencoders.vae import DecoderOutput, DiagonalGaussianDistribution
21
  from diffusers.models.modeling_outputs import AutoencoderKLOutput
22
  from diffusers.models.modeling_utils import ModelMixin
 
23
  from diffusers.utils.accelerate_utils import apply_forward_hook
24
 
25
+ from videosys.utils.logging import logger
26
 
27
+ from ..modules.downsampling import CogVideoXDownsample3D
28
+ from ..modules.upsampling import CogVideoXUpsample3D
29
 
30
 
31
  class CogVideoXSafeConv3d(nn.Conv3d):
32
+ r"""
33
  A 3D convolution layer that splits the input tensor into smaller parts to avoid OOM in CogVideoX Model.
34
  """
35
 
 
61
  r"""A 3D causal convolution layer that pads the input tensor to ensure causality in CogVideoX Model.
62
 
63
  Args:
64
+ in_channels (`int`): Number of channels in the input tensor.
65
+ out_channels (`int`): Number of output channels produced by the convolution.
66
+ kernel_size (`int` or `Tuple[int, int, int]`): Kernel size of the convolutional kernel.
67
+ stride (`int`, defaults to `1`): Stride of the convolution.
68
+ dilation (`int`, defaults to `1`): Dilation rate of the convolution.
69
+ pad_mode (`str`, defaults to `"constant"`): Padding mode.
70
  """
71
 
72
  def __init__(
 
111
  self.conv_cache = None
112
 
113
  def fake_context_parallel_forward(self, inputs: torch.Tensor) -> torch.Tensor:
 
114
  kernel_size = self.time_kernel_size
115
+ if kernel_size > 1:
116
+ cached_inputs = [self.conv_cache] if self.conv_cache is not None else [inputs[:, :, :1]] * (kernel_size - 1)
117
+ inputs = torch.cat(cached_inputs + [inputs], dim=2)
 
 
 
 
 
 
 
 
118
  return inputs
119
 
120
  def _clear_fake_context_parallel_cache(self):
 
122
  self.conv_cache = None
123
 
124
  def forward(self, inputs: torch.Tensor) -> torch.Tensor:
125
+ inputs = self.fake_context_parallel_forward(inputs)
126
 
127
  self._clear_fake_context_parallel_cache()
128
+ # Note: we could move these to the cpu for a lower maximum memory usage but its only a few
129
+ # hundred megabytes and so let's not do it for now
130
+ self.conv_cache = inputs[:, :, -self.time_kernel_size + 1 :].clone()
131
 
132
  padding_2d = (self.width_pad, self.width_pad, self.height_pad, self.height_pad)
133
+ inputs = F.pad(inputs, padding_2d, mode="constant", value=0)
134
 
135
+ output = self.conv(inputs)
 
136
  return output
137
 
138
 
 
148
  The number of channels for input to group normalization layer, and output of the spatial norm layer.
149
  zq_channels (`int`):
150
  The number of channels for the quantized vector as described in the paper.
151
+ groups (`int`):
152
+ Number of groups to separate the channels into for group normalization.
153
  """
154
 
155
  def __init__(
 
184
  A 3D ResNet block used in the CogVideoX model.
185
 
186
  Args:
187
+ in_channels (`int`):
188
+ Number of input channels.
189
+ out_channels (`int`, *optional*):
190
+ Number of output channels. If None, defaults to `in_channels`.
191
+ dropout (`float`, defaults to `0.0`):
192
+ Dropout rate.
193
+ temb_channels (`int`, defaults to `512`):
194
+ Number of time embedding channels.
195
+ groups (`int`, defaults to `32`):
196
+ Number of groups to separate the channels into for group normalization.
197
+ eps (`float`, defaults to `1e-6`):
198
+ Epsilon value for normalization layers.
199
+ non_linearity (`str`, defaults to `"swish"`):
200
+ Activation function to use.
201
+ conv_shortcut (bool, defaults to `False`):
202
+ Whether or not to use a convolution shortcut.
203
+ spatial_norm_dim (`int`, *optional*):
204
+ The dimension to use for spatial norm if it is to be used instead of group norm.
205
+ pad_mode (str, defaults to `"first"`):
206
+ Padding mode.
207
  """
208
 
209
  def __init__(
 
305
  A downsampling block used in the CogVideoX model.
306
 
307
  Args:
308
+ in_channels (`int`):
309
+ Number of input channels.
310
+ out_channels (`int`, *optional*):
311
+ Number of output channels. If None, defaults to `in_channels`.
312
+ temb_channels (`int`, defaults to `512`):
313
+ Number of time embedding channels.
314
+ num_layers (`int`, defaults to `1`):
315
+ Number of resnet layers.
316
+ dropout (`float`, defaults to `0.0`):
317
+ Dropout rate.
318
+ resnet_eps (`float`, defaults to `1e-6`):
319
+ Epsilon value for normalization layers.
320
+ resnet_act_fn (`str`, defaults to `"swish"`):
321
+ Activation function to use.
322
+ resnet_groups (`int`, defaults to `32`):
323
+ Number of groups to separate the channels into for group normalization.
324
+ add_downsample (`bool`, defaults to `True`):
325
+ Whether or not to use a downsampling layer. If not used, output dimension would be same as input dimension.
326
+ compress_time (`bool`, defaults to `False`):
327
+ Whether or not to downsample across temporal dimension.
328
+ pad_mode (str, defaults to `"first"`):
329
+ Padding mode.
330
  """
331
 
332
  _supports_gradient_checkpointing = True
 
411
  A middle block used in the CogVideoX model.
412
 
413
  Args:
414
+ in_channels (`int`):
415
+ Number of input channels.
416
+ temb_channels (`int`, defaults to `512`):
417
+ Number of time embedding channels.
418
+ dropout (`float`, defaults to `0.0`):
419
+ Dropout rate.
420
+ num_layers (`int`, defaults to `1`):
421
+ Number of resnet layers.
422
+ resnet_eps (`float`, defaults to `1e-6`):
423
+ Epsilon value for normalization layers.
424
+ resnet_act_fn (`str`, defaults to `"swish"`):
425
+ Activation function to use.
426
+ resnet_groups (`int`, defaults to `32`):
427
+ Number of groups to separate the channels into for group normalization.
428
+ spatial_norm_dim (`int`, *optional*):
429
+ The dimension to use for spatial norm if it is to be used instead of group norm.
430
+ pad_mode (str, defaults to `"first"`):
431
+ Padding mode.
432
  """
433
 
434
  _supports_gradient_checkpointing = True
 
495
  An upsampling block used in the CogVideoX model.
496
 
497
  Args:
498
+ in_channels (`int`):
499
+ Number of input channels.
500
+ out_channels (`int`, *optional*):
501
+ Number of output channels. If None, defaults to `in_channels`.
502
+ temb_channels (`int`, defaults to `512`):
503
+ Number of time embedding channels.
504
+ dropout (`float`, defaults to `0.0`):
505
+ Dropout rate.
506
+ num_layers (`int`, defaults to `1`):
507
+ Number of resnet layers.
508
+ resnet_eps (`float`, defaults to `1e-6`):
509
+ Epsilon value for normalization layers.
510
+ resnet_act_fn (`str`, defaults to `"swish"`):
511
+ Activation function to use.
512
+ resnet_groups (`int`, defaults to `32`):
513
+ Number of groups to separate the channels into for group normalization.
514
+ spatial_norm_dim (`int`, defaults to `16`):
515
+ The dimension to use for spatial norm if it is to be used instead of group norm.
516
+ add_upsample (`bool`, defaults to `True`):
517
+ Whether or not to use a upsampling layer. If not used, output dimension would be same as input dimension.
518
+ compress_time (`bool`, defaults to `False`):
519
+ Whether or not to downsample across temporal dimension.
520
+ pad_mode (str, defaults to `"first"`):
521
+ Padding mode.
522
  """
523
 
524
  def __init__(
 
609
  options.
610
  block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
611
  The number of output channels for each block.
612
+ act_fn (`str`, *optional*, defaults to `"silu"`):
613
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
614
  layers_per_block (`int`, *optional*, defaults to 2):
615
  The number of layers per block.
616
  norm_num_groups (`int`, *optional*, defaults to 32):
617
  The number of groups for normalization.
 
 
 
 
618
  """
619
 
620
  _supports_gradient_checkpointing = True
 
743
  The types of up blocks to use. See `~diffusers.models.unet_2d_blocks.get_up_block` for available options.
744
  block_out_channels (`Tuple[int, ...]`, *optional*, defaults to `(64,)`):
745
  The number of output channels for each block.
746
+ act_fn (`str`, *optional*, defaults to `"silu"`):
747
+ The activation function to use. See `~diffusers.models.activations.get_activation` for available options.
748
  layers_per_block (`int`, *optional*, defaults to 2):
749
  The number of layers per block.
750
  norm_num_groups (`int`, *optional*, defaults to 32):
751
  The number of groups for normalization.
 
 
 
 
752
  """
753
 
754
  _supports_gradient_checkpointing = True
 
889
  Tuple of block output channels.
890
  act_fn (`str`, *optional*, defaults to `"silu"`): The activation function to use.
891
  sample_size (`int`, *optional*, defaults to `32`): Sample input size.
892
+ scaling_factor (`float`, *optional*, defaults to `1.15258426`):
893
  The component-wise standard deviation of the trained latent space computed using the first batch of the
894
  training set. This is used to scale the latent space to have unit variance when training the diffusion
895
  model. The latents are scaled with the formula `z = z * scaling_factor` before being passed to the
 
929
  norm_eps: float = 1e-6,
930
  norm_num_groups: int = 32,
931
  temporal_compression_ratio: float = 4,
932
+ sample_height: int = 480,
933
+ sample_width: int = 720,
934
  scaling_factor: float = 1.15258426,
935
  shift_factor: Optional[float] = None,
936
  latents_mean: Optional[Tuple[float]] = None,
 
969
  self.use_slicing = False
970
  self.use_tiling = False
971
 
972
+ # Can be increased to decode more latent frames at once, but comes at a reasonable memory cost and it is not
973
+ # recommended because the temporal parts of the VAE, here, are tricky to understand.
974
+ # If you decode X latent frames together, the number of output frames is:
975
+ # (X + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) => X + 6 frames
976
+ #
977
+ # Example with num_latent_frames_batch_size = 2:
978
+ # - 12 latent frames: (0, 1), (2, 3), (4, 5), (6, 7), (8, 9), (10, 11) are processed together
979
+ # => (12 // 2 frame slices) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
980
+ # => 6 * 8 = 48 frames
981
+ # - 13 latent frames: (0, 1, 2) (special case), (3, 4), (5, 6), (7, 8), (9, 10), (11, 12) are processed together
982
+ # => (1 frame slice) * ((3 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale)) +
983
+ # ((13 - 3) // 2) * ((2 num_latent_frames_batch_size) + (2 conv cache) + (2 time upscale_1) + (4 time upscale_2) - (2 causal conv downscale))
984
+ # => 1 * 9 + 5 * 8 = 49 frames
985
+ # It has been implemented this way so as to not have "magic values" in the code base that would be hard to explain. Note that
986
+ # setting it to anything other than 2 would give poor results because the VAE hasn't been trained to be adaptive with different
987
+ # number of temporal frames.
988
+ self.num_latent_frames_batch_size = 2
989
+
990
+ # We make the minimum height and width of sample for tiling half that of the generally supported
991
+ self.tile_sample_min_height = sample_height // 2
992
+ self.tile_sample_min_width = sample_width // 2
993
+ self.tile_latent_min_height = int(
994
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
995
  )
996
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
997
+
998
+ # These are experimental overlap factors that were chosen based on experimentation and seem to work best for
999
+ # 720x480 (WxH) resolution. The above resolution is the strongly recommended generation resolution in CogVideoX
1000
+ # and so the tiling implementation has only been tested on those specific resolutions.
1001
+ self.tile_overlap_factor_height = 1 / 6
1002
+ self.tile_overlap_factor_width = 1 / 5
1003
 
1004
  def _set_gradient_checkpointing(self, module, value=False):
1005
  if isinstance(module, (CogVideoXEncoder3D, CogVideoXDecoder3D)):
1006
  module.gradient_checkpointing = value
1007
 
1008
+ def _clear_fake_context_parallel_cache(self):
1009
  for name, module in self.named_modules():
1010
  if isinstance(module, CogVideoXCausalConv3d):
1011
  logger.debug(f"Clearing fake Context Parallel cache for layer: {name}")
1012
  module._clear_fake_context_parallel_cache()
1013
 
1014
+ def enable_tiling(
1015
+ self,
1016
+ tile_sample_min_height: Optional[int] = None,
1017
+ tile_sample_min_width: Optional[int] = None,
1018
+ tile_overlap_factor_height: Optional[float] = None,
1019
+ tile_overlap_factor_width: Optional[float] = None,
1020
+ ) -> None:
1021
+ r"""
1022
+ Enable tiled VAE decoding. When this option is enabled, the VAE will split the input tensor into tiles to
1023
+ compute decoding and encoding in several steps. This is useful for saving a large amount of memory and to allow
1024
+ processing larger images.
1025
+
1026
+ Args:
1027
+ tile_sample_min_height (`int`, *optional*):
1028
+ The minimum height required for a sample to be separated into tiles across the height dimension.
1029
+ tile_sample_min_width (`int`, *optional*):
1030
+ The minimum width required for a sample to be separated into tiles across the width dimension.
1031
+ tile_overlap_factor_height (`int`, *optional*):
1032
+ The minimum amount of overlap between two consecutive vertical tiles. This is to ensure that there are
1033
+ no tiling artifacts produced across the height dimension. Must be between 0 and 1. Setting a higher
1034
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1035
+ tile_overlap_factor_width (`int`, *optional*):
1036
+ The minimum amount of overlap between two consecutive horizontal tiles. This is to ensure that there
1037
+ are no tiling artifacts produced across the width dimension. Must be between 0 and 1. Setting a higher
1038
+ value might cause more tiles to be processed leading to slow down of the decoding process.
1039
+ """
1040
+ self.use_tiling = True
1041
+ self.tile_sample_min_height = tile_sample_min_height or self.tile_sample_min_height
1042
+ self.tile_sample_min_width = tile_sample_min_width or self.tile_sample_min_width
1043
+ self.tile_latent_min_height = int(
1044
+ self.tile_sample_min_height / (2 ** (len(self.config.block_out_channels) - 1))
1045
+ )
1046
+ self.tile_latent_min_width = int(self.tile_sample_min_width / (2 ** (len(self.config.block_out_channels) - 1)))
1047
+ self.tile_overlap_factor_height = tile_overlap_factor_height or self.tile_overlap_factor_height
1048
+ self.tile_overlap_factor_width = tile_overlap_factor_width or self.tile_overlap_factor_width
1049
+
1050
+ def disable_tiling(self) -> None:
1051
+ r"""
1052
+ Disable tiled VAE decoding. If `enable_tiling` was previously enabled, this method will go back to computing
1053
+ decoding in one step.
1054
+ """
1055
+ self.use_tiling = False
1056
+
1057
+ def enable_slicing(self) -> None:
1058
+ r"""
1059
+ Enable sliced VAE decoding. When this option is enabled, the VAE will split the input tensor in slices to
1060
+ compute decoding in several steps. This is useful to save some memory and allow larger batch sizes.
1061
+ """
1062
+ self.use_slicing = True
1063
+
1064
+ def disable_slicing(self) -> None:
1065
+ r"""
1066
+ Disable sliced VAE decoding. If `enable_slicing` was previously enabled, this method will go back to computing
1067
+ decoding in one step.
1068
+ """
1069
+ self.use_slicing = False
1070
+
1071
  @apply_forward_hook
1072
  def encode(
1073
  self, x: torch.Tensor, return_dict: bool = True
 
1092
  return (posterior,)
1093
  return AutoencoderKLOutput(latent_dist=posterior)
1094
 
1095
+ def _decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1096
+ batch_size, num_channels, num_frames, height, width = z.shape
1097
+
1098
+ if self.use_tiling and (width > self.tile_latent_min_width or height > self.tile_latent_min_height):
1099
+ return self.tiled_decode(z, return_dict=return_dict)
1100
+
1101
+ frame_batch_size = self.num_latent_frames_batch_size
1102
+ dec = []
1103
+ for i in range(num_frames // frame_batch_size):
1104
+ remaining_frames = num_frames % frame_batch_size
1105
+ start_frame = frame_batch_size * i + (0 if i == 0 else remaining_frames)
1106
+ end_frame = frame_batch_size * (i + 1) + remaining_frames
1107
+ z_intermediate = z[:, :, start_frame:end_frame]
1108
+ if self.post_quant_conv is not None:
1109
+ z_intermediate = self.post_quant_conv(z_intermediate)
1110
+ z_intermediate = self.decoder(z_intermediate)
1111
+ dec.append(z_intermediate)
1112
+
1113
+ self._clear_fake_context_parallel_cache()
1114
+ dec = torch.cat(dec, dim=2)
1115
+
1116
+ if not return_dict:
1117
+ return (dec,)
1118
+
1119
+ return DecoderOutput(sample=dec)
1120
+
1121
  @apply_forward_hook
1122
+ def decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1123
  """
1124
  Decode a batch of images.
1125
 
 
1132
  [`~models.vae.DecoderOutput`] or `tuple`:
1133
  If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1134
  returned.
1135
+ """
1136
+ if self.use_slicing and z.shape[0] > 1:
1137
+ decoded_slices = [self._decode(z_slice).sample for z_slice in z.split(1)]
1138
+ decoded = torch.cat(decoded_slices)
1139
+ else:
1140
+ decoded = self._decode(z).sample
1141
+
1142
+ if not return_dict:
1143
+ return (decoded,)
1144
+ return DecoderOutput(sample=decoded)
1145
+
1146
+ def blend_v(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1147
+ blend_extent = min(a.shape[3], b.shape[3], blend_extent)
1148
+ for y in range(blend_extent):
1149
+ b[:, :, :, y, :] = a[:, :, :, -blend_extent + y, :] * (1 - y / blend_extent) + b[:, :, :, y, :] * (
1150
+ y / blend_extent
1151
+ )
1152
+ return b
1153
+
1154
+ def blend_h(self, a: torch.Tensor, b: torch.Tensor, blend_extent: int) -> torch.Tensor:
1155
+ blend_extent = min(a.shape[4], b.shape[4], blend_extent)
1156
+ for x in range(blend_extent):
1157
+ b[:, :, :, :, x] = a[:, :, :, :, -blend_extent + x] * (1 - x / blend_extent) + b[:, :, :, :, x] * (
1158
+ x / blend_extent
1159
+ )
1160
+ return b
1161
+
1162
+ def tiled_decode(self, z: torch.Tensor, return_dict: bool = True) -> Union[DecoderOutput, torch.Tensor]:
1163
+ r"""
1164
+ Decode a batch of images using a tiled decoder.
1165
+
1166
+ Args:
1167
+ z (`torch.Tensor`): Input batch of latent vectors.
1168
+ return_dict (`bool`, *optional*, defaults to `True`):
1169
+ Whether or not to return a [`~models.vae.DecoderOutput`] instead of a plain tuple.
1170
 
1171
+ Returns:
1172
+ [`~models.vae.DecoderOutput`] or `tuple`:
1173
+ If return_dict is True, a [`~models.vae.DecoderOutput`] is returned, otherwise a plain `tuple` is
1174
+ returned.
1175
  """
1176
+ # Rough memory assessment:
1177
+ # - In CogVideoX-2B, there are a total of 24 CausalConv3d layers.
1178
+ # - The biggest intermediate dimensions are: [1, 128, 9, 480, 720].
1179
+ # - Assume fp16 (2 bytes per value).
1180
+ # Memory required: 1 * 128 * 9 * 480 * 720 * 24 * 2 / 1024**3 = 17.8 GB
1181
+ #
1182
+ # Memory assessment when using tiling:
1183
+ # - Assume everything as above but now HxW is 240x360 by tiling in half
1184
+ # Memory required: 1 * 128 * 9 * 240 * 360 * 24 * 2 / 1024**3 = 4.5 GB
1185
+
1186
+ batch_size, num_channels, num_frames, height, width = z.shape
1187
+
1188
+ overlap_height = int(self.tile_latent_min_height * (1 - self.tile_overlap_factor_height))
1189
+ overlap_width = int(self.tile_latent_min_width * (1 - self.tile_overlap_factor_width))
1190
+ blend_extent_height = int(self.tile_sample_min_height * self.tile_overlap_factor_height)
1191
+ blend_extent_width = int(self.tile_sample_min_width * self.tile_overlap_factor_width)
1192
+ row_limit_height = self.tile_sample_min_height - blend_extent_height
1193
+ row_limit_width = self.tile_sample_min_width - blend_extent_width
1194
+ frame_batch_size = self.num_latent_frames_batch_size
1195
+
1196
+ # Split z into overlapping tiles and decode them separately.
1197
+ # The tiles have an overlap to avoid seams between tiles.
1198
+ rows = []
1199
+ for i in range(0, height, overlap_height):
1200
+ row = []
1201
+ for j in range(0, width, overlap_width):
1202
+ time = []
1203
+ for k in range(num_frames // frame_batch_size):
1204
+ remaining_frames = num_frames % frame_batch_size
1205
+ start_frame = frame_batch_size * k + (0 if k == 0 else remaining_frames)
1206
+ end_frame = frame_batch_size * (k + 1) + remaining_frames
1207
+ tile = z[
1208
+ :,
1209
+ :,
1210
+ start_frame:end_frame,
1211
+ i : i + self.tile_latent_min_height,
1212
+ j : j + self.tile_latent_min_width,
1213
+ ]
1214
+ if self.post_quant_conv is not None:
1215
+ tile = self.post_quant_conv(tile)
1216
+ tile = self.decoder(tile)
1217
+ time.append(tile)
1218
+ self._clear_fake_context_parallel_cache()
1219
+ row.append(torch.cat(time, dim=2))
1220
+ rows.append(row)
1221
+
1222
+ result_rows = []
1223
+ for i, row in enumerate(rows):
1224
+ result_row = []
1225
+ for j, tile in enumerate(row):
1226
+ # blend the above tile and the left tile
1227
+ # to the current tile and add the current tile to the result row
1228
+ if i > 0:
1229
+ tile = self.blend_v(rows[i - 1][j], tile, blend_extent_height)
1230
+ if j > 0:
1231
+ tile = self.blend_h(row[j - 1], tile, blend_extent_width)
1232
+ result_row.append(tile[:, :, :, :row_limit_height, :row_limit_width])
1233
+ result_rows.append(torch.cat(result_row, dim=4))
1234
+
1235
+ dec = torch.cat(result_rows, dim=3)
1236
+
1237
  if not return_dict:
1238
  return (dec,)
1239
+
1240
  return DecoderOutput(sample=dec)
1241
 
1242
  def forward(
videosys/models/{open_sora/vae.py β†’ autoencoders/autoencoder_kl_open_sora.py} RENAMED
@@ -18,8 +18,6 @@ from diffusers.models import AutoencoderKL, AutoencoderKLTemporalDecoder
18
  from einops import rearrange
19
  from transformers import PretrainedConfig, PreTrainedModel
20
 
21
- from .utils import load_checkpoint
22
-
23
 
24
  class DiagonalGaussianDistribution(object):
25
  def __init__(
@@ -474,7 +472,7 @@ class VAE_Temporal(nn.Module):
474
  return recon_video, posterior, z
475
 
476
 
477
- def VAE_Temporal_SD(from_pretrained=None, **kwargs):
478
  model = VAE_Temporal(
479
  in_out_channels=4,
480
  latent_embed_dim=4,
@@ -485,8 +483,6 @@ def VAE_Temporal_SD(from_pretrained=None, **kwargs):
485
  temporal_downsample=(False, True, True),
486
  **kwargs,
487
  )
488
- if from_pretrained is not None:
489
- load_checkpoint(model, from_pretrained)
490
  return model
491
 
492
 
@@ -634,7 +630,7 @@ class VideoAutoencoderPipeline(PreTrainedModel):
634
  micro_batch_size=4,
635
  subfolder="vae",
636
  )
637
- self.temporal_vae = VAE_Temporal_SD(from_pretrained=None)
638
  self.cal_loss = config.cal_loss
639
  self.micro_frame_size = config.micro_frame_size
640
  self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0]
@@ -763,7 +759,4 @@ def OpenSoraVAE_V1_2(
763
  else:
764
  config = VideoAutoencoderPipelineConfig(**kwargs)
765
  model = VideoAutoencoderPipeline(config)
766
-
767
- if from_pretrained:
768
- load_checkpoint(model, from_pretrained)
769
  return model
 
18
  from einops import rearrange
19
  from transformers import PretrainedConfig, PreTrainedModel
20
 
 
 
21
 
22
  class DiagonalGaussianDistribution(object):
23
  def __init__(
 
472
  return recon_video, posterior, z
473
 
474
 
475
+ def VAE_Temporal_SD(**kwargs):
476
  model = VAE_Temporal(
477
  in_out_channels=4,
478
  latent_embed_dim=4,
 
483
  temporal_downsample=(False, True, True),
484
  **kwargs,
485
  )
 
 
486
  return model
487
 
488
 
 
630
  micro_batch_size=4,
631
  subfolder="vae",
632
  )
633
+ self.temporal_vae = VAE_Temporal_SD()
634
  self.cal_loss = config.cal_loss
635
  self.micro_frame_size = config.micro_frame_size
636
  self.micro_z_frame_size = self.temporal_vae.get_latent_size([config.micro_frame_size, None, None])[0]
 
759
  else:
760
  config = VideoAutoencoderPipelineConfig(**kwargs)
761
  model = VideoAutoencoderPipeline(config)
 
 
 
762
  return model
videosys/models/{open_sora_plan/ae.py β†’ autoencoders/autoencoder_kl_open_sora_plan.py} RENAMED
@@ -6,20 +6,24 @@
6
  # References:
7
  # Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
8
  # --------------------------------------------------------
9
-
10
  import glob
11
- import importlib
12
  import os
13
  from typing import Optional, Tuple, Union
14
 
15
  import numpy as np
16
  import torch
 
 
 
17
  from diffusers import ConfigMixin, ModelMixin
18
  from diffusers.configuration_utils import ConfigMixin, register_to_config
19
  from diffusers.models.modeling_utils import ModelMixin
 
20
  from einops import rearrange
21
  from torch import nn
22
 
 
 
23
 
24
  def Normalize(in_channels, num_groups=32):
25
  return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
@@ -80,13 +84,7 @@ class DiagonalGaussianDistribution(object):
80
 
81
 
82
  def resolve_str_to_obj(str_val, append=True):
83
- if append:
84
- str_val = "videosys.models.open_sora_plan.modules." + str_val
85
- if "opensora.models.ae.videobase." in str_val:
86
- str_val = str_val.replace("opensora.models.ae.videobase.", "videosys.models.open_sora_plan.")
87
- module_name, class_name = str_val.rsplit(".", 1)
88
- module = importlib.import_module(module_name)
89
- return getattr(module, class_name)
90
 
91
 
92
  class VideoBaseAE_PL(ModelMixin, ConfigMixin):
@@ -130,7 +128,6 @@ class VideoBaseAE_PL(ModelMixin, ConfigMixin):
130
  model.init_from_ckpt(last_ckpt_file)
131
  return model
132
  else:
133
- print(f"Loading model from {pretrained_model_name_or_path}")
134
  return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
135
 
136
 
@@ -431,8 +428,6 @@ class CausalVAEModel(VideoBaseAE_PL):
431
  self.learning_rate = lr
432
  self.lr_g_factor = 1.0
433
 
434
- self.loss = resolve_str_to_obj(loss_type, append=False)(**loss_params)
435
-
436
  self.encoder = Encoder(
437
  z_channels=z_channels,
438
  hidden_size=hidden_size,
@@ -471,8 +466,6 @@ class CausalVAEModel(VideoBaseAE_PL):
471
  quant_conv_cls = resolve_str_to_obj(q_conv)
472
  self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)
473
  self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)
474
- if hasattr(self.loss, "discriminator"):
475
- self.automatic_optimization = False
476
 
477
  def encode(self, x):
478
  if self.use_tiling and (
@@ -855,3 +848,793 @@ def getae_wrapper(ae):
855
  ae = videobase_ae.get(ae, None)
856
  assert ae is not None
857
  return ae
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
  # References:
7
  # Open-Sora-Plan: https://github.com/PKU-YuanGroup/Open-Sora-Plan
8
  # --------------------------------------------------------
 
9
  import glob
 
10
  import os
11
  from typing import Optional, Tuple, Union
12
 
13
  import numpy as np
14
  import torch
15
+ import torch.distributed as dist
16
+ import torch.nn as nn
17
+ import torch.nn.functional as F
18
  from diffusers import ConfigMixin, ModelMixin
19
  from diffusers.configuration_utils import ConfigMixin, register_to_config
20
  from diffusers.models.modeling_utils import ModelMixin
21
+ from diffusers.utils import logging
22
  from einops import rearrange
23
  from torch import nn
24
 
25
+ logging.set_verbosity_error()
26
+
27
 
28
  def Normalize(in_channels, num_groups=32):
29
  return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
 
84
 
85
 
86
  def resolve_str_to_obj(str_val, append=True):
87
+ return globals()[str_val]
 
 
 
 
 
 
88
 
89
 
90
  class VideoBaseAE_PL(ModelMixin, ConfigMixin):
 
128
  model.init_from_ckpt(last_ckpt_file)
129
  return model
130
  else:
 
131
  return super().from_pretrained(pretrained_model_name_or_path, **kwargs)
132
 
133
 
 
428
  self.learning_rate = lr
429
  self.lr_g_factor = 1.0
430
 
 
 
431
  self.encoder = Encoder(
432
  z_channels=z_channels,
433
  hidden_size=hidden_size,
 
466
  quant_conv_cls = resolve_str_to_obj(q_conv)
467
  self.quant_conv = quant_conv_cls(2 * z_channels, 2 * embed_dim, 1)
468
  self.post_quant_conv = quant_conv_cls(embed_dim, z_channels, 1)
 
 
469
 
470
  def encode(self, x):
471
  if self.use_tiling and (
 
848
  ae = videobase_ae.get(ae, None)
849
  assert ae is not None
850
  return ae
851
+
852
+
853
+ def video_to_image(func):
854
+ def wrapper(self, x, *args, **kwargs):
855
+ if x.dim() == 5:
856
+ t = x.shape[2]
857
+ x = rearrange(x, "b c t h w -> (b t) c h w")
858
+ x = func(self, x, *args, **kwargs)
859
+ x = rearrange(x, "(b t) c h w -> b c t h w", t=t)
860
+ return x
861
+
862
+ return wrapper
863
+
864
+
865
+ class Block(nn.Module):
866
+ def __init__(self, *args, **kwargs) -> None:
867
+ super().__init__(*args, **kwargs)
868
+
869
+
870
+ class LinearAttention(Block):
871
+ def __init__(self, dim, heads=4, dim_head=32):
872
+ super().__init__()
873
+ self.heads = heads
874
+ hidden_dim = dim_head * heads
875
+ self.to_qkv = nn.Conv2d(dim, hidden_dim * 3, 1, bias=False)
876
+ self.to_out = nn.Conv2d(hidden_dim, dim, 1)
877
+
878
+ def forward(self, x):
879
+ b, c, h, w = x.shape
880
+ qkv = self.to_qkv(x)
881
+ q, k, v = rearrange(qkv, "b (qkv heads c) h w -> qkv b heads c (h w)", heads=self.heads, qkv=3)
882
+ k = k.softmax(dim=-1)
883
+ context = torch.einsum("bhdn,bhen->bhde", k, v)
884
+ out = torch.einsum("bhde,bhdn->bhen", context, q)
885
+ out = rearrange(out, "b heads c (h w) -> b (heads c) h w", heads=self.heads, h=h, w=w)
886
+ return self.to_out(out)
887
+
888
+
889
+ class LinAttnBlock(LinearAttention):
890
+ """to match AttnBlock usage"""
891
+
892
+ def __init__(self, in_channels):
893
+ super().__init__(dim=in_channels, heads=1, dim_head=in_channels)
894
+
895
+
896
+ class AttnBlock3D(Block):
897
+ """Compatible with old versions, there are issues, use with caution."""
898
+
899
+ def __init__(self, in_channels):
900
+ super().__init__()
901
+ self.in_channels = in_channels
902
+
903
+ self.norm = Normalize(in_channels)
904
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
905
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
906
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
907
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
908
+
909
+ def forward(self, x):
910
+ h_ = x
911
+ h_ = self.norm(h_)
912
+ q = self.q(h_)
913
+ k = self.k(h_)
914
+ v = self.v(h_)
915
+
916
+ # compute attention
917
+ b, c, t, h, w = q.shape
918
+ q = q.reshape(b * t, c, h * w)
919
+ q = q.permute(0, 2, 1) # b,hw,c
920
+ k = k.reshape(b * t, c, h * w) # b,c,hw
921
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
922
+ w_ = w_ * (int(c) ** (-0.5))
923
+ w_ = torch.nn.functional.softmax(w_, dim=2)
924
+
925
+ # attend to values
926
+ v = v.reshape(b * t, c, h * w)
927
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
928
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
929
+ h_ = h_.reshape(b, c, t, h, w)
930
+
931
+ h_ = self.proj_out(h_)
932
+
933
+ return x + h_
934
+
935
+
936
+ class AttnBlock3DFix(nn.Module):
937
+ """
938
+ Thanks to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/172.
939
+ """
940
+
941
+ def __init__(self, in_channels):
942
+ super().__init__()
943
+ self.in_channels = in_channels
944
+
945
+ self.norm = Normalize(in_channels)
946
+ self.q = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
947
+ self.k = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
948
+ self.v = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
949
+ self.proj_out = CausalConv3d(in_channels, in_channels, kernel_size=1, stride=1)
950
+
951
+ def forward(self, x):
952
+ h_ = x
953
+ h_ = self.norm(h_)
954
+ q = self.q(h_)
955
+ k = self.k(h_)
956
+ v = self.v(h_)
957
+
958
+ # compute attention
959
+ # q: (b c t h w) -> (b t c h w) -> (b*t c h*w) -> (b*t h*w c)
960
+ b, c, t, h, w = q.shape
961
+ q = q.permute(0, 2, 1, 3, 4)
962
+ q = q.reshape(b * t, c, h * w)
963
+ q = q.permute(0, 2, 1)
964
+
965
+ # k: (b c t h w) -> (b t c h w) -> (b*t c h*w)
966
+ k = k.permute(0, 2, 1, 3, 4)
967
+ k = k.reshape(b * t, c, h * w)
968
+
969
+ # w: (b*t hw hw)
970
+ w_ = torch.bmm(q, k)
971
+ w_ = w_ * (int(c) ** (-0.5))
972
+ w_ = torch.nn.functional.softmax(w_, dim=2)
973
+
974
+ # attend to values
975
+ # v: (b c t h w) -> (b t c h w) -> (bt c hw)
976
+ # w_: (bt hw hw) -> (bt hw hw)
977
+ v = v.permute(0, 2, 1, 3, 4)
978
+ v = v.reshape(b * t, c, h * w)
979
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
980
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
981
+
982
+ # h_: (b*t c hw) -> (b t c h w) -> (b c t h w)
983
+ h_ = h_.reshape(b, t, c, h, w)
984
+ h_ = h_.permute(0, 2, 1, 3, 4)
985
+
986
+ h_ = self.proj_out(h_)
987
+
988
+ return x + h_
989
+
990
+
991
+ class AttnBlock(Block):
992
+ def __init__(self, in_channels):
993
+ super().__init__()
994
+ self.in_channels = in_channels
995
+
996
+ self.norm = Normalize(in_channels)
997
+ self.q = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
998
+ self.k = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
999
+ self.v = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
1000
+ self.proj_out = torch.nn.Conv2d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
1001
+
1002
+ @video_to_image
1003
+ def forward(self, x):
1004
+ h_ = x
1005
+ h_ = self.norm(h_)
1006
+ q = self.q(h_)
1007
+ k = self.k(h_)
1008
+ v = self.v(h_)
1009
+
1010
+ # compute attention
1011
+ b, c, h, w = q.shape
1012
+ q = q.reshape(b, c, h * w)
1013
+ q = q.permute(0, 2, 1) # b,hw,c
1014
+ k = k.reshape(b, c, h * w) # b,c,hw
1015
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
1016
+ w_ = w_ * (int(c) ** (-0.5))
1017
+ w_ = torch.nn.functional.softmax(w_, dim=2)
1018
+
1019
+ # attend to values
1020
+ v = v.reshape(b, c, h * w)
1021
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
1022
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
1023
+ h_ = h_.reshape(b, c, h, w)
1024
+
1025
+ h_ = self.proj_out(h_)
1026
+
1027
+ return x + h_
1028
+
1029
+
1030
+ class TemporalAttnBlock(Block):
1031
+ def __init__(self, in_channels):
1032
+ super().__init__()
1033
+ self.in_channels = in_channels
1034
+
1035
+ self.norm = Normalize(in_channels)
1036
+ self.q = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
1037
+ self.k = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
1038
+ self.v = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
1039
+ self.proj_out = torch.nn.Conv3d(in_channels, in_channels, kernel_size=1, stride=1, padding=0)
1040
+
1041
+ def forward(self, x):
1042
+ h_ = x
1043
+ h_ = self.norm(h_)
1044
+ q = self.q(h_)
1045
+ k = self.k(h_)
1046
+ v = self.v(h_)
1047
+
1048
+ # compute attention
1049
+ b, c, t, h, w = q.shape
1050
+ q = rearrange(q, "b c t h w -> (b h w) t c")
1051
+ k = rearrange(k, "b c t h w -> (b h w) c t")
1052
+ v = rearrange(v, "b c t h w -> (b h w) c t")
1053
+ w_ = torch.bmm(q, k)
1054
+ w_ = w_ * (int(c) ** (-0.5))
1055
+ w_ = torch.nn.functional.softmax(w_, dim=2)
1056
+
1057
+ # attend to values
1058
+ w_ = w_.permute(0, 2, 1)
1059
+ h_ = torch.bmm(v, w_)
1060
+ h_ = rearrange(h_, "(b h w) c t -> b c t h w", h=h, w=w)
1061
+ h_ = self.proj_out(h_)
1062
+
1063
+ return x + h_
1064
+
1065
+
1066
+ def make_attn(in_channels, attn_type="vanilla"):
1067
+ assert attn_type in ["vanilla", "linear", "none", "vanilla3D"], f"attn_type {attn_type} unknown"
1068
+ print(f"making attention of type '{attn_type}' with {in_channels} in_channels")
1069
+ print(attn_type)
1070
+ if attn_type == "vanilla":
1071
+ return AttnBlock(in_channels)
1072
+ elif attn_type == "vanilla3D":
1073
+ return AttnBlock3D(in_channels)
1074
+ elif attn_type == "none":
1075
+ return nn.Identity(in_channels)
1076
+ else:
1077
+ return LinAttnBlock(in_channels)
1078
+
1079
+
1080
+ class Conv2d(nn.Conv2d):
1081
+ def __init__(
1082
+ self,
1083
+ in_channels: int,
1084
+ out_channels: int,
1085
+ kernel_size: Union[int, Tuple[int]] = 3,
1086
+ stride: Union[int, Tuple[int]] = 1,
1087
+ padding: Union[str, int, Tuple[int]] = 0,
1088
+ dilation: Union[int, Tuple[int]] = 1,
1089
+ groups: int = 1,
1090
+ bias: bool = True,
1091
+ padding_mode: str = "zeros",
1092
+ device=None,
1093
+ dtype=None,
1094
+ ) -> None:
1095
+ super().__init__(
1096
+ in_channels,
1097
+ out_channels,
1098
+ kernel_size,
1099
+ stride,
1100
+ padding,
1101
+ dilation,
1102
+ groups,
1103
+ bias,
1104
+ padding_mode,
1105
+ device,
1106
+ dtype,
1107
+ )
1108
+
1109
+ @video_to_image
1110
+ def forward(self, x):
1111
+ return super().forward(x)
1112
+
1113
+
1114
+ class CausalConv3d(nn.Module):
1115
+ def __init__(
1116
+ self, chan_in, chan_out, kernel_size: Union[int, Tuple[int, int, int]], init_method="random", **kwargs
1117
+ ):
1118
+ super().__init__()
1119
+ self.kernel_size = cast_tuple(kernel_size, 3)
1120
+ self.time_kernel_size = self.kernel_size[0]
1121
+ self.chan_in = chan_in
1122
+ self.chan_out = chan_out
1123
+ stride = kwargs.pop("stride", 1)
1124
+ padding = kwargs.pop("padding", 0)
1125
+ padding = list(cast_tuple(padding, 3))
1126
+ padding[0] = 0
1127
+ stride = cast_tuple(stride, 3)
1128
+ self.conv = nn.Conv3d(chan_in, chan_out, self.kernel_size, stride=stride, padding=padding)
1129
+ self._init_weights(init_method)
1130
+
1131
+ def _init_weights(self, init_method):
1132
+ torch.tensor(self.kernel_size)
1133
+ if init_method == "avg":
1134
+ assert self.kernel_size[1] == 1 and self.kernel_size[2] == 1, "only support temporal up/down sample"
1135
+ assert self.chan_in == self.chan_out, "chan_in must be equal to chan_out"
1136
+ weight = torch.zeros((self.chan_out, self.chan_in, *self.kernel_size))
1137
+
1138
+ eyes = torch.concat(
1139
+ [
1140
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
1141
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
1142
+ torch.eye(self.chan_in).unsqueeze(-1) * 1 / 3,
1143
+ ],
1144
+ dim=-1,
1145
+ )
1146
+ weight[:, :, :, 0, 0] = eyes
1147
+
1148
+ self.conv.weight = nn.Parameter(
1149
+ weight,
1150
+ requires_grad=True,
1151
+ )
1152
+ elif init_method == "zero":
1153
+ self.conv.weight = nn.Parameter(
1154
+ torch.zeros((self.chan_out, self.chan_in, *self.kernel_size)),
1155
+ requires_grad=True,
1156
+ )
1157
+ if self.conv.bias is not None:
1158
+ nn.init.constant_(self.conv.bias, 0)
1159
+
1160
+ def forward(self, x):
1161
+ # 1 + 16 16 as video, 1 as image
1162
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.time_kernel_size - 1, 1, 1)) # b c t h w
1163
+ x = torch.concatenate((first_frame_pad, x), dim=2) # 3 + 16
1164
+ return self.conv(x)
1165
+
1166
+
1167
+ class GroupNorm(Block):
1168
+ def __init__(self, num_channels, num_groups=32, eps=1e-6, *args, **kwargs) -> None:
1169
+ super().__init__(*args, **kwargs)
1170
+ self.norm = torch.nn.GroupNorm(num_groups=num_groups, num_channels=num_channels, eps=1e-6, affine=True)
1171
+
1172
+ def forward(self, x):
1173
+ return self.norm(x)
1174
+
1175
+
1176
+ def Normalize(in_channels, num_groups=32):
1177
+ return torch.nn.GroupNorm(num_groups=num_groups, num_channels=in_channels, eps=1e-6, affine=True)
1178
+
1179
+
1180
+ class ActNorm(nn.Module):
1181
+ def __init__(self, num_features, logdet=False, affine=True, allow_reverse_init=False):
1182
+ assert affine
1183
+ super().__init__()
1184
+ self.logdet = logdet
1185
+ self.loc = nn.Parameter(torch.zeros(1, num_features, 1, 1))
1186
+ self.scale = nn.Parameter(torch.ones(1, num_features, 1, 1))
1187
+ self.allow_reverse_init = allow_reverse_init
1188
+
1189
+ self.register_buffer("initialized", torch.tensor(0, dtype=torch.uint8))
1190
+
1191
+ def initialize(self, input):
1192
+ with torch.no_grad():
1193
+ flatten = input.permute(1, 0, 2, 3).contiguous().view(input.shape[1], -1)
1194
+ mean = flatten.mean(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
1195
+ std = flatten.std(1).unsqueeze(1).unsqueeze(2).unsqueeze(3).permute(1, 0, 2, 3)
1196
+
1197
+ self.loc.data.copy_(-mean)
1198
+ self.scale.data.copy_(1 / (std + 1e-6))
1199
+
1200
+ def forward(self, input, reverse=False):
1201
+ if reverse:
1202
+ return self.reverse(input)
1203
+ if len(input.shape) == 2:
1204
+ input = input[:, :, None, None]
1205
+ squeeze = True
1206
+ else:
1207
+ squeeze = False
1208
+
1209
+ _, _, height, width = input.shape
1210
+
1211
+ if self.training and self.initialized.item() == 0:
1212
+ self.initialize(input)
1213
+ self.initialized.fill_(1)
1214
+
1215
+ h = self.scale * (input + self.loc)
1216
+
1217
+ if squeeze:
1218
+ h = h.squeeze(-1).squeeze(-1)
1219
+
1220
+ if self.logdet:
1221
+ log_abs = torch.log(torch.abs(self.scale))
1222
+ logdet = height * width * torch.sum(log_abs)
1223
+ logdet = logdet * torch.ones(input.shape[0]).to(input)
1224
+ return h, logdet
1225
+
1226
+ return h
1227
+
1228
+ def reverse(self, output):
1229
+ if self.training and self.initialized.item() == 0:
1230
+ if not self.allow_reverse_init:
1231
+ raise RuntimeError(
1232
+ "Initializing ActNorm in reverse direction is "
1233
+ "disabled by default. Use allow_reverse_init=True to enable."
1234
+ )
1235
+ else:
1236
+ self.initialize(output)
1237
+ self.initialized.fill_(1)
1238
+
1239
+ if len(output.shape) == 2:
1240
+ output = output[:, :, None, None]
1241
+ squeeze = True
1242
+ else:
1243
+ squeeze = False
1244
+
1245
+ h = output / self.scale - self.loc
1246
+
1247
+ if squeeze:
1248
+ h = h.squeeze(-1).squeeze(-1)
1249
+ return h
1250
+
1251
+
1252
+ def nonlinearity(x):
1253
+ return x * torch.sigmoid(x)
1254
+
1255
+
1256
+ def cast_tuple(t, length=1):
1257
+ return t if isinstance(t, tuple) else ((t,) * length)
1258
+
1259
+
1260
+ def shift_dim(x, src_dim=-1, dest_dim=-1, make_contiguous=True):
1261
+ n_dims = len(x.shape)
1262
+ if src_dim < 0:
1263
+ src_dim = n_dims + src_dim
1264
+ if dest_dim < 0:
1265
+ dest_dim = n_dims + dest_dim
1266
+ assert 0 <= src_dim < n_dims and 0 <= dest_dim < n_dims
1267
+ dims = list(range(n_dims))
1268
+ del dims[src_dim]
1269
+ permutation = []
1270
+ ctr = 0
1271
+ for i in range(n_dims):
1272
+ if i == dest_dim:
1273
+ permutation.append(src_dim)
1274
+ else:
1275
+ permutation.append(dims[ctr])
1276
+ ctr += 1
1277
+ x = x.permute(permutation)
1278
+ if make_contiguous:
1279
+ x = x.contiguous()
1280
+ return x
1281
+
1282
+
1283
+ class Codebook(nn.Module):
1284
+ def __init__(self, n_codes, embedding_dim):
1285
+ super().__init__()
1286
+ self.register_buffer("embeddings", torch.randn(n_codes, embedding_dim))
1287
+ self.register_buffer("N", torch.zeros(n_codes))
1288
+ self.register_buffer("z_avg", self.embeddings.data.clone())
1289
+
1290
+ self.n_codes = n_codes
1291
+ self.embedding_dim = embedding_dim
1292
+ self._need_init = True
1293
+
1294
+ def _tile(self, x):
1295
+ d, ew = x.shape
1296
+ if d < self.n_codes:
1297
+ n_repeats = (self.n_codes + d - 1) // d
1298
+ std = 0.01 / np.sqrt(ew)
1299
+ x = x.repeat(n_repeats, 1)
1300
+ x = x + torch.randn_like(x) * std
1301
+ return x
1302
+
1303
+ def _init_embeddings(self, z):
1304
+ # z: [b, c, t, h, w]
1305
+ self._need_init = False
1306
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
1307
+ y = self._tile(flat_inputs)
1308
+
1309
+ y.shape[0]
1310
+ _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
1311
+ if dist.is_initialized():
1312
+ dist.broadcast(_k_rand, 0)
1313
+ self.embeddings.data.copy_(_k_rand)
1314
+ self.z_avg.data.copy_(_k_rand)
1315
+ self.N.data.copy_(torch.ones(self.n_codes))
1316
+
1317
+ def forward(self, z):
1318
+ # z: [b, c, t, h, w]
1319
+ if self._need_init and self.training:
1320
+ self._init_embeddings(z)
1321
+ flat_inputs = shift_dim(z, 1, -1).flatten(end_dim=-2)
1322
+ distances = (
1323
+ (flat_inputs**2).sum(dim=1, keepdim=True)
1324
+ - 2 * flat_inputs @ self.embeddings.t()
1325
+ + (self.embeddings.t() ** 2).sum(dim=0, keepdim=True)
1326
+ )
1327
+
1328
+ encoding_indices = torch.argmin(distances, dim=1)
1329
+ encode_onehot = F.one_hot(encoding_indices, self.n_codes).type_as(flat_inputs)
1330
+ encoding_indices = encoding_indices.view(z.shape[0], *z.shape[2:])
1331
+
1332
+ embeddings = F.embedding(encoding_indices, self.embeddings)
1333
+ embeddings = shift_dim(embeddings, -1, 1)
1334
+
1335
+ commitment_loss = 0.25 * F.mse_loss(z, embeddings.detach())
1336
+
1337
+ # EMA codebook update
1338
+ if self.training:
1339
+ n_total = encode_onehot.sum(dim=0)
1340
+ encode_sum = flat_inputs.t() @ encode_onehot
1341
+ if dist.is_initialized():
1342
+ dist.all_reduce(n_total)
1343
+ dist.all_reduce(encode_sum)
1344
+
1345
+ self.N.data.mul_(0.99).add_(n_total, alpha=0.01)
1346
+ self.z_avg.data.mul_(0.99).add_(encode_sum.t(), alpha=0.01)
1347
+
1348
+ n = self.N.sum()
1349
+ weights = (self.N + 1e-7) / (n + self.n_codes * 1e-7) * n
1350
+ encode_normalized = self.z_avg / weights.unsqueeze(1)
1351
+ self.embeddings.data.copy_(encode_normalized)
1352
+
1353
+ y = self._tile(flat_inputs)
1354
+ _k_rand = y[torch.randperm(y.shape[0])][: self.n_codes]
1355
+ if dist.is_initialized():
1356
+ dist.broadcast(_k_rand, 0)
1357
+
1358
+ usage = (self.N.view(self.n_codes, 1) >= 1).float()
1359
+ self.embeddings.data.mul_(usage).add_(_k_rand * (1 - usage))
1360
+
1361
+ embeddings_st = (embeddings - z).detach() + z
1362
+
1363
+ avg_probs = torch.mean(encode_onehot, dim=0)
1364
+ perplexity = torch.exp(-torch.sum(avg_probs * torch.log(avg_probs + 1e-10)))
1365
+
1366
+ return dict(
1367
+ embeddings=embeddings_st,
1368
+ encodings=encoding_indices,
1369
+ commitment_loss=commitment_loss,
1370
+ perplexity=perplexity,
1371
+ )
1372
+
1373
+ def dictionary_lookup(self, encodings):
1374
+ embeddings = F.embedding(encodings, self.embeddings)
1375
+ return embeddings
1376
+
1377
+
1378
+ class ResnetBlock2D(Block):
1379
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
1380
+ super().__init__()
1381
+ self.in_channels = in_channels
1382
+ self.out_channels = in_channels if out_channels is None else out_channels
1383
+ self.use_conv_shortcut = conv_shortcut
1384
+
1385
+ self.norm1 = Normalize(in_channels)
1386
+ self.conv1 = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
1387
+ self.norm2 = Normalize(out_channels)
1388
+ self.dropout = torch.nn.Dropout(dropout)
1389
+ self.conv2 = torch.nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1)
1390
+ if self.in_channels != self.out_channels:
1391
+ if self.use_conv_shortcut:
1392
+ self.conv_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
1393
+ else:
1394
+ self.nin_shortcut = torch.nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1, padding=0)
1395
+
1396
+ @video_to_image
1397
+ def forward(self, x):
1398
+ h = x
1399
+ h = self.norm1(h)
1400
+ h = nonlinearity(h)
1401
+ h = self.conv1(h)
1402
+ h = self.norm2(h)
1403
+ h = nonlinearity(h)
1404
+ h = self.dropout(h)
1405
+ h = self.conv2(h)
1406
+ if self.in_channels != self.out_channels:
1407
+ if self.use_conv_shortcut:
1408
+ x = self.conv_shortcut(x)
1409
+ else:
1410
+ x = self.nin_shortcut(x)
1411
+ x = x + h
1412
+ return x
1413
+
1414
+
1415
+ class ResnetBlock3D(Block):
1416
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False, dropout):
1417
+ super().__init__()
1418
+ self.in_channels = in_channels
1419
+ self.out_channels = in_channels if out_channels is None else out_channels
1420
+ self.use_conv_shortcut = conv_shortcut
1421
+
1422
+ self.norm1 = Normalize(in_channels)
1423
+ self.conv1 = CausalConv3d(in_channels, out_channels, 3, padding=1)
1424
+ self.norm2 = Normalize(out_channels)
1425
+ self.dropout = torch.nn.Dropout(dropout)
1426
+ self.conv2 = CausalConv3d(out_channels, out_channels, 3, padding=1)
1427
+ if self.in_channels != self.out_channels:
1428
+ if self.use_conv_shortcut:
1429
+ self.conv_shortcut = CausalConv3d(in_channels, out_channels, 3, padding=1)
1430
+ else:
1431
+ self.nin_shortcut = CausalConv3d(in_channels, out_channels, 1, padding=0)
1432
+
1433
+ def forward(self, x):
1434
+ h = x
1435
+ h = self.norm1(h)
1436
+ h = nonlinearity(h)
1437
+ h = self.conv1(h)
1438
+ h = self.norm2(h)
1439
+ h = nonlinearity(h)
1440
+ h = self.dropout(h)
1441
+ h = self.conv2(h)
1442
+ if self.in_channels != self.out_channels:
1443
+ if self.use_conv_shortcut:
1444
+ x = self.conv_shortcut(x)
1445
+ else:
1446
+ x = self.nin_shortcut(x)
1447
+ return x + h
1448
+
1449
+
1450
+ class Upsample(Block):
1451
+ def __init__(self, in_channels, out_channels):
1452
+ super().__init__()
1453
+ self.with_conv = True
1454
+ if self.with_conv:
1455
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1)
1456
+
1457
+ @video_to_image
1458
+ def forward(self, x):
1459
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
1460
+ if self.with_conv:
1461
+ x = self.conv(x)
1462
+ return x
1463
+
1464
+
1465
+ class Downsample(Block):
1466
+ def __init__(self, in_channels, out_channels):
1467
+ super().__init__()
1468
+ self.with_conv = True
1469
+ if self.with_conv:
1470
+ # no asymmetric padding in torch conv, must do it ourselves
1471
+ self.conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, padding=0)
1472
+
1473
+ @video_to_image
1474
+ def forward(self, x):
1475
+ if self.with_conv:
1476
+ pad = (0, 1, 0, 1)
1477
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
1478
+ x = self.conv(x)
1479
+ else:
1480
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
1481
+ return x
1482
+
1483
+
1484
+ class SpatialDownsample2x(Block):
1485
+ def __init__(
1486
+ self,
1487
+ chan_in,
1488
+ chan_out,
1489
+ kernel_size: Union[int, Tuple[int]] = (3, 3),
1490
+ stride: Union[int, Tuple[int]] = (2, 2),
1491
+ ):
1492
+ super().__init__()
1493
+ kernel_size = cast_tuple(kernel_size, 2)
1494
+ stride = cast_tuple(stride, 2)
1495
+ self.chan_in = chan_in
1496
+ self.chan_out = chan_out
1497
+ self.kernel_size = kernel_size
1498
+ self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=0)
1499
+
1500
+ def forward(self, x):
1501
+ pad = (0, 1, 0, 1, 0, 0)
1502
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
1503
+ x = self.conv(x)
1504
+ return x
1505
+
1506
+
1507
+ class SpatialUpsample2x(Block):
1508
+ def __init__(
1509
+ self,
1510
+ chan_in,
1511
+ chan_out,
1512
+ kernel_size: Union[int, Tuple[int]] = (3, 3),
1513
+ stride: Union[int, Tuple[int]] = (1, 1),
1514
+ ):
1515
+ super().__init__()
1516
+ self.chan_in = chan_in
1517
+ self.chan_out = chan_out
1518
+ self.kernel_size = kernel_size
1519
+ self.conv = CausalConv3d(self.chan_in, self.chan_out, (1,) + self.kernel_size, stride=(1,) + stride, padding=1)
1520
+
1521
+ def forward(self, x):
1522
+ t = x.shape[2]
1523
+ x = rearrange(x, "b c t h w -> b (c t) h w")
1524
+ x = F.interpolate(x, scale_factor=(2, 2), mode="nearest")
1525
+ x = rearrange(x, "b (c t) h w -> b c t h w", t=t)
1526
+ x = self.conv(x)
1527
+ return x
1528
+
1529
+
1530
+ class TimeDownsample2x(Block):
1531
+ def __init__(self, chan_in, chan_out, kernel_size: int = 3):
1532
+ super().__init__()
1533
+ self.kernel_size = kernel_size
1534
+ self.conv = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
1535
+
1536
+ def forward(self, x):
1537
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size - 1, 1, 1))
1538
+ x = torch.concatenate((first_frame_pad, x), dim=2)
1539
+ return self.conv(x)
1540
+
1541
+
1542
+ class TimeUpsample2x(Block):
1543
+ def __init__(self, chan_in, chan_out):
1544
+ super().__init__()
1545
+
1546
+ def forward(self, x):
1547
+ if x.size(2) > 1:
1548
+ x, x_ = x[:, :, :1], x[:, :, 1:]
1549
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
1550
+ x = torch.concat([x, x_], dim=2)
1551
+ return x
1552
+
1553
+
1554
+ class TimeDownsampleRes2x(nn.Module):
1555
+ def __init__(
1556
+ self,
1557
+ in_channels,
1558
+ out_channels,
1559
+ kernel_size: int = 3,
1560
+ mix_factor: float = 2.0,
1561
+ ):
1562
+ super().__init__()
1563
+ self.kernel_size = cast_tuple(kernel_size, 3)
1564
+ self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
1565
+ self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
1566
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
1567
+
1568
+ def forward(self, x):
1569
+ alpha = torch.sigmoid(self.mix_factor)
1570
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
1571
+ x = torch.concatenate((first_frame_pad, x), dim=2)
1572
+ return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(x)
1573
+
1574
+
1575
+ class TimeUpsampleRes2x(nn.Module):
1576
+ def __init__(
1577
+ self,
1578
+ in_channels,
1579
+ out_channels,
1580
+ kernel_size: int = 3,
1581
+ mix_factor: float = 2.0,
1582
+ ):
1583
+ super().__init__()
1584
+ self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
1585
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
1586
+
1587
+ def forward(self, x):
1588
+ alpha = torch.sigmoid(self.mix_factor)
1589
+ if x.size(2) > 1:
1590
+ x, x_ = x[:, :, :1], x[:, :, 1:]
1591
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
1592
+ x = torch.concat([x, x_], dim=2)
1593
+ return alpha * x + (1 - alpha) * self.conv(x)
1594
+
1595
+
1596
+ class TimeDownsampleResAdv2x(nn.Module):
1597
+ def __init__(
1598
+ self,
1599
+ in_channels,
1600
+ out_channels,
1601
+ kernel_size: int = 3,
1602
+ mix_factor: float = 1.5,
1603
+ ):
1604
+ super().__init__()
1605
+ self.kernel_size = cast_tuple(kernel_size, 3)
1606
+ self.avg_pool = nn.AvgPool3d((kernel_size, 1, 1), stride=(2, 1, 1))
1607
+ self.attn = TemporalAttnBlock(in_channels)
1608
+ self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
1609
+ self.conv = nn.Conv3d(in_channels, out_channels, self.kernel_size, stride=(2, 1, 1), padding=(0, 1, 1))
1610
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
1611
+
1612
+ def forward(self, x):
1613
+ first_frame_pad = x[:, :, :1, :, :].repeat((1, 1, self.kernel_size[0] - 1, 1, 1))
1614
+ x = torch.concatenate((first_frame_pad, x), dim=2)
1615
+ alpha = torch.sigmoid(self.mix_factor)
1616
+ return alpha * self.avg_pool(x) + (1 - alpha) * self.conv(self.attn((self.res(x))))
1617
+
1618
+
1619
+ class TimeUpsampleResAdv2x(nn.Module):
1620
+ def __init__(
1621
+ self,
1622
+ in_channels,
1623
+ out_channels,
1624
+ kernel_size: int = 3,
1625
+ mix_factor: float = 1.5,
1626
+ ):
1627
+ super().__init__()
1628
+ self.res = ResnetBlock3D(in_channels=in_channels, out_channels=in_channels, dropout=0.0)
1629
+ self.attn = TemporalAttnBlock(in_channels)
1630
+ self.norm = Normalize(in_channels=in_channels)
1631
+ self.conv = CausalConv3d(in_channels, out_channels, kernel_size, padding=1)
1632
+ self.mix_factor = torch.nn.Parameter(torch.Tensor([mix_factor]))
1633
+
1634
+ def forward(self, x):
1635
+ if x.size(2) > 1:
1636
+ x, x_ = x[:, :, :1], x[:, :, 1:]
1637
+ x_ = F.interpolate(x_, scale_factor=(2, 1, 1), mode="trilinear")
1638
+ x = torch.concat([x, x_], dim=2)
1639
+ alpha = torch.sigmoid(self.mix_factor)
1640
+ return alpha * x + (1 - alpha) * self.conv(self.attn(self.res(x)))
videosys/models/cogvideo/__init__.py DELETED
@@ -1,6 +0,0 @@
1
- from .pipeline import CogVideoConfig, CogVideoPipeline
2
-
3
- __all__ = [
4
- "CogVideoConfig",
5
- "CogVideoPipeline",
6
- ]
 
 
 
 
 
 
 
videosys/models/cogvideo/modules.py DELETED
@@ -1,317 +0,0 @@
1
- # Adapted from CogVideo
2
-
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- # --------------------------------------------------------
6
- # References:
7
- # CogVideo: https://github.com/THUDM/CogVideo
8
- # diffusers: https://github.com/huggingface/diffusers
9
- # --------------------------------------------------------
10
-
11
- from typing import Optional, Tuple, Union
12
-
13
- import numpy as np
14
- import torch
15
- import torch.nn as nn
16
- import torch.nn.functional as F
17
- from diffusers.models.embeddings import get_1d_sincos_pos_embed_from_grid, get_2d_sincos_pos_embed_from_grid
18
-
19
-
20
- class CogVideoXDownsample3D(nn.Module):
21
- # Todo: Wait for paper relase.
22
- r"""
23
- A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
24
-
25
- Args:
26
- in_channels (`int`):
27
- Number of channels in the input image.
28
- out_channels (`int`):
29
- Number of channels produced by the convolution.
30
- kernel_size (`int`, defaults to `3`):
31
- Size of the convolving kernel.
32
- stride (`int`, defaults to `2`):
33
- Stride of the convolution.
34
- padding (`int`, defaults to `0`):
35
- Padding added to all four sides of the input.
36
- compress_time (`bool`, defaults to `False`):
37
- Whether or not to compress the time dimension.
38
- """
39
-
40
- def __init__(
41
- self,
42
- in_channels: int,
43
- out_channels: int,
44
- kernel_size: int = 3,
45
- stride: int = 2,
46
- padding: int = 0,
47
- compress_time: bool = False,
48
- ):
49
- super().__init__()
50
-
51
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
52
- self.compress_time = compress_time
53
-
54
- def forward(self, x: torch.Tensor) -> torch.Tensor:
55
- if self.compress_time:
56
- batch_size, channels, frames, height, width = x.shape
57
-
58
- # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
59
- x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
60
-
61
- if x.shape[-1] % 2 == 1:
62
- x_first, x_rest = x[..., 0], x[..., 1:]
63
- if x_rest.shape[-1] > 0:
64
- # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
65
- x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
66
-
67
- x = torch.cat([x_first[..., None], x_rest], dim=-1)
68
- # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
69
- x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
70
- else:
71
- # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
72
- x = F.avg_pool1d(x, kernel_size=2, stride=2)
73
- # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
74
- x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
75
-
76
- # Pad the tensor
77
- pad = (0, 1, 0, 1)
78
- x = F.pad(x, pad, mode="constant", value=0)
79
- batch_size, channels, frames, height, width = x.shape
80
- # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
81
- x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
82
- x = self.conv(x)
83
- # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
84
- x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
85
- return x
86
-
87
-
88
- class CogVideoXUpsample3D(nn.Module):
89
- r"""
90
- A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
91
-
92
- Args:
93
- in_channels (`int`):
94
- Number of channels in the input image.
95
- out_channels (`int`):
96
- Number of channels produced by the convolution.
97
- kernel_size (`int`, defaults to `3`):
98
- Size of the convolving kernel.
99
- stride (`int`, defaults to `1`):
100
- Stride of the convolution.
101
- padding (`int`, defaults to `1`):
102
- Padding added to all four sides of the input.
103
- compress_time (`bool`, defaults to `False`):
104
- Whether or not to compress the time dimension.
105
- """
106
-
107
- def __init__(
108
- self,
109
- in_channels: int,
110
- out_channels: int,
111
- kernel_size: int = 3,
112
- stride: int = 1,
113
- padding: int = 1,
114
- compress_time: bool = False,
115
- ) -> None:
116
- super().__init__()
117
-
118
- self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
119
- self.compress_time = compress_time
120
-
121
- def forward(self, inputs: torch.Tensor) -> torch.Tensor:
122
- if self.compress_time:
123
- if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
124
- # split first frame
125
- x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
126
-
127
- x_first = F.interpolate(x_first, scale_factor=2.0)
128
- x_rest = F.interpolate(x_rest, scale_factor=2.0)
129
- x_first = x_first[:, :, None, :, :]
130
- inputs = torch.cat([x_first, x_rest], dim=2)
131
- elif inputs.shape[2] > 1:
132
- inputs = F.interpolate(inputs, scale_factor=2.0)
133
- else:
134
- inputs = inputs.squeeze(2)
135
- inputs = F.interpolate(inputs, scale_factor=2.0)
136
- inputs = inputs[:, :, None, :, :]
137
- else:
138
- # only interpolate 2D
139
- b, c, t, h, w = inputs.shape
140
- inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
141
- inputs = F.interpolate(inputs, scale_factor=2.0)
142
- inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
143
-
144
- b, c, t, h, w = inputs.shape
145
- inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
146
- inputs = self.conv(inputs)
147
- inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
148
-
149
- return inputs
150
-
151
-
152
- def get_3d_sincos_pos_embed(
153
- embed_dim: int,
154
- spatial_size: Union[int, Tuple[int, int]],
155
- temporal_size: int,
156
- spatial_interpolation_scale: float = 1.0,
157
- temporal_interpolation_scale: float = 1.0,
158
- ) -> np.ndarray:
159
- r"""
160
- Args:
161
- embed_dim (`int`):
162
- spatial_size (`int` or `Tuple[int, int]`):
163
- temporal_size (`int`):
164
- spatial_interpolation_scale (`float`, defaults to 1.0):
165
- temporal_interpolation_scale (`float`, defaults to 1.0):
166
- """
167
- if embed_dim % 4 != 0:
168
- raise ValueError("`embed_dim` must be divisible by 4")
169
- if isinstance(spatial_size, int):
170
- spatial_size = (spatial_size, spatial_size)
171
-
172
- embed_dim_spatial = 3 * embed_dim // 4
173
- embed_dim_temporal = embed_dim // 4
174
-
175
- # 1. Spatial
176
- grid_h = np.arange(spatial_size[1], dtype=np.float32) / spatial_interpolation_scale
177
- grid_w = np.arange(spatial_size[0], dtype=np.float32) / spatial_interpolation_scale
178
- grid = np.meshgrid(grid_w, grid_h) # here w goes first
179
- grid = np.stack(grid, axis=0)
180
-
181
- grid = grid.reshape([2, 1, spatial_size[1], spatial_size[0]])
182
- pos_embed_spatial = get_2d_sincos_pos_embed_from_grid(embed_dim_spatial, grid)
183
-
184
- # 2. Temporal
185
- grid_t = np.arange(temporal_size, dtype=np.float32) / temporal_interpolation_scale
186
- pos_embed_temporal = get_1d_sincos_pos_embed_from_grid(embed_dim_temporal, grid_t)
187
-
188
- # 3. Concat
189
- pos_embed_spatial = pos_embed_spatial[np.newaxis, :, :]
190
- pos_embed_spatial = np.repeat(pos_embed_spatial, temporal_size, axis=0) # [T, H*W, D // 4 * 3]
191
-
192
- pos_embed_temporal = pos_embed_temporal[:, np.newaxis, :]
193
- pos_embed_temporal = np.repeat(pos_embed_temporal, spatial_size[0] * spatial_size[1], axis=1) # [T, H*W, D // 4]
194
-
195
- pos_embed = np.concatenate([pos_embed_temporal, pos_embed_spatial], axis=-1) # [T, H*W, D]
196
- return pos_embed
197
-
198
-
199
- class CogVideoXPatchEmbed(nn.Module):
200
- def __init__(
201
- self,
202
- patch_size: int = 2,
203
- in_channels: int = 16,
204
- embed_dim: int = 1920,
205
- text_embed_dim: int = 4096,
206
- bias: bool = True,
207
- ) -> None:
208
- super().__init__()
209
- self.patch_size = patch_size
210
-
211
- self.proj = nn.Conv2d(
212
- in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
213
- )
214
- self.text_proj = nn.Linear(text_embed_dim, embed_dim)
215
-
216
- def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
217
- r"""
218
- Args:
219
- text_embeds (`torch.Tensor`):
220
- Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
221
- image_embeds (`torch.Tensor`):
222
- Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
223
- """
224
- text_embeds = self.text_proj(text_embeds)
225
-
226
- batch, num_frames, channels, height, width = image_embeds.shape
227
- image_embeds = image_embeds.reshape(-1, channels, height, width)
228
- image_embeds = self.proj(image_embeds)
229
- image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
230
- image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
231
- image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
232
-
233
- embeds = torch.cat(
234
- [text_embeds, image_embeds], dim=1
235
- ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
236
- return embeds
237
-
238
-
239
- class CogVideoXLayerNormZero(nn.Module):
240
- def __init__(
241
- self,
242
- conditioning_dim: int,
243
- embedding_dim: int,
244
- elementwise_affine: bool = True,
245
- eps: float = 1e-5,
246
- bias: bool = True,
247
- ) -> None:
248
- super().__init__()
249
-
250
- self.silu = nn.SiLU()
251
- self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
252
- self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
253
-
254
- def forward(
255
- self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
256
- ) -> Tuple[torch.Tensor, torch.Tensor]:
257
- shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
258
- hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
259
- encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
260
- return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
261
-
262
-
263
- class AdaLayerNorm(nn.Module):
264
- r"""
265
- Norm layer modified to incorporate timestep embeddings.
266
-
267
- Parameters:
268
- embedding_dim (`int`): The size of each embedding vector.
269
- num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
270
- output_dim (`int`, *optional*):
271
- norm_elementwise_affine (`bool`, defaults to `False):
272
- norm_eps (`bool`, defaults to `False`):
273
- chunk_dim (`int`, defaults to `0`):
274
- """
275
-
276
- def __init__(
277
- self,
278
- embedding_dim: int,
279
- num_embeddings: Optional[int] = None,
280
- output_dim: Optional[int] = None,
281
- norm_elementwise_affine: bool = False,
282
- norm_eps: float = 1e-5,
283
- chunk_dim: int = 0,
284
- ):
285
- super().__init__()
286
-
287
- self.chunk_dim = chunk_dim
288
- output_dim = output_dim or embedding_dim * 2
289
-
290
- if num_embeddings is not None:
291
- self.emb = nn.Embedding(num_embeddings, embedding_dim)
292
- else:
293
- self.emb = None
294
-
295
- self.silu = nn.SiLU()
296
- self.linear = nn.Linear(embedding_dim, output_dim)
297
- self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
298
-
299
- def forward(
300
- self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
301
- ) -> torch.Tensor:
302
- if self.emb is not None:
303
- temb = self.emb(timestep)
304
-
305
- temb = self.linear(self.silu(temb))
306
-
307
- if self.chunk_dim == 1:
308
- # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
309
- # other if-branch. This branch is specific to CogVideoX for now.
310
- shift, scale = temb.chunk(2, dim=1)
311
- shift = shift[:, None, :]
312
- scale = scale[:, None, :]
313
- else:
314
- scale, shift = temb.chunk(2, dim=0)
315
-
316
- x = self.norm(x) * (1 + scale) + shift
317
- return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/models/cogvideo/retrieve_timesteps.py DELETED
@@ -1,74 +0,0 @@
1
- # Adapted from CogVideo
2
-
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- # --------------------------------------------------------
6
- # References:
7
- # CogVideo: https://github.com/THUDM/CogVideo
8
- # diffusers: https://github.com/huggingface/diffusers
9
- # --------------------------------------------------------
10
-
11
- import inspect
12
- from typing import List, Optional, Union
13
-
14
- import torch
15
-
16
-
17
- # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
18
- def retrieve_timesteps(
19
- scheduler,
20
- num_inference_steps: Optional[int] = None,
21
- device: Optional[Union[str, torch.device]] = None,
22
- timesteps: Optional[List[int]] = None,
23
- sigmas: Optional[List[float]] = None,
24
- **kwargs,
25
- ):
26
- """
27
- Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
28
- custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
29
-
30
- Args:
31
- scheduler (`SchedulerMixin`):
32
- The scheduler to get timesteps from.
33
- num_inference_steps (`int`):
34
- The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
35
- must be `None`.
36
- device (`str` or `torch.device`, *optional*):
37
- The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
38
- timesteps (`List[int]`, *optional*):
39
- Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
40
- `num_inference_steps` and `sigmas` must be `None`.
41
- sigmas (`List[float]`, *optional*):
42
- Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
43
- `num_inference_steps` and `timesteps` must be `None`.
44
-
45
- Returns:
46
- `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
47
- second element is the number of inference steps.
48
- """
49
- if timesteps is not None and sigmas is not None:
50
- raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
51
- if timesteps is not None:
52
- accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
53
- if not accepts_timesteps:
54
- raise ValueError(
55
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
56
- f" timestep schedules. Please check whether you are using the correct scheduler."
57
- )
58
- scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
59
- timesteps = scheduler.timesteps
60
- num_inference_steps = len(timesteps)
61
- elif sigmas is not None:
62
- accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
63
- if not accept_sigmas:
64
- raise ValueError(
65
- f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
66
- f" sigmas schedules. Please check whether you are using the correct scheduler."
67
- )
68
- scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
69
- timesteps = scheduler.timesteps
70
- num_inference_steps = len(timesteps)
71
- else:
72
- scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
73
- timesteps = scheduler.timesteps
74
- return timesteps, num_inference_steps
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/models/latte/__init__.py DELETED
@@ -1,7 +0,0 @@
1
- from .pipeline import LatteConfig, LattePABConfig, LattePipeline
2
-
3
- __all__ = [
4
- "LattePipeline",
5
- "LattePABConfig",
6
- "LatteConfig",
7
- ]
 
 
 
 
 
 
 
 
{eval/pab/experiments β†’ videosys/models/modules}/__init__.py RENAMED
File without changes
videosys/models/modules/activations.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ import torch.nn as nn
2
+
3
+ approx_gelu = lambda: nn.GELU(approximate="tanh")
videosys/{modules/attn.py β†’ models/modules/attentions.py} RENAMED
@@ -1,12 +1,8 @@
1
- from dataclasses import dataclass
2
- from typing import Iterable, List, Optional, Sequence, Tuple
3
-
4
  import torch
5
  import torch.nn as nn
6
- import torch.nn.functional as F
7
  import torch.utils.checkpoint
8
 
9
- from videosys.modules.layers import LlamaRMSNorm
10
 
11
 
12
  class Attention(nn.Module):
@@ -19,8 +15,9 @@ class Attention(nn.Module):
19
  attn_drop: float = 0.0,
20
  proj_drop: float = 0.0,
21
  norm_layer: nn.Module = LlamaRMSNorm,
22
- enable_flashattn: bool = False,
23
  rope=None,
 
24
  ) -> None:
25
  super().__init__()
26
  assert dim % num_heads == 0, "dim should be divisible by num_heads"
@@ -28,11 +25,12 @@ class Attention(nn.Module):
28
  self.num_heads = num_heads
29
  self.head_dim = dim // num_heads
30
  self.scale = self.head_dim**-0.5
31
- self.enable_flashattn = enable_flashattn
32
 
33
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
34
  self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
35
  self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
 
36
  self.attn_drop = nn.Dropout(attn_drop)
37
  self.proj = nn.Linear(dim, dim)
38
  self.proj_drop = nn.Dropout(proj_drop)
@@ -44,18 +42,32 @@ class Attention(nn.Module):
44
 
45
  def forward(self, x: torch.Tensor) -> torch.Tensor:
46
  B, N, C = x.shape
47
-
 
48
  qkv = self.qkv(x)
49
- qkv = qkv.view(B, N, 3, self.num_heads, self.head_dim).permute(2, 0, 1, 3, 4)
 
 
50
  q, k, v = qkv.unbind(0)
51
- if self.rope:
52
- q = self.rotary_emb(q)
53
- k = self.rotary_emb(k)
54
- q, k = self.q_norm(q), self.k_norm(k)
 
 
 
 
 
 
 
55
 
56
- if self.enable_flashattn:
57
  from flash_attn import flash_attn_func
58
 
 
 
 
 
59
  x = flash_attn_func(
60
  q,
61
  k,
@@ -64,13 +76,17 @@ class Attention(nn.Module):
64
  softmax_scale=self.scale,
65
  )
66
  else:
67
- q, k, v = map(lambda t: t.permute(0, 2, 1, 3), (q, k, v))
68
- x = F.scaled_dot_product_attention(
69
- q, k, v, scale=self.scale, dropout_p=self.attn_drop.p if self.training else 0.0
70
- )
 
 
 
 
71
 
72
  x_output_shape = (B, N, C)
73
- if not self.enable_flashattn:
74
  x = x.transpose(1, 2)
75
  x = x.reshape(x_output_shape)
76
  x = self.proj(x)
@@ -79,139 +95,37 @@ class Attention(nn.Module):
79
 
80
 
81
  class MultiHeadCrossAttention(nn.Module):
82
- def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0, enable_flashattn=False):
83
  super(MultiHeadCrossAttention, self).__init__()
84
  assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
85
 
86
  self.d_model = d_model
87
  self.num_heads = num_heads
88
  self.head_dim = d_model // num_heads
89
- self.enable_flashattn = enable_flashattn
90
 
91
  self.q_linear = nn.Linear(d_model, d_model)
92
  self.kv_linear = nn.Linear(d_model, d_model * 2)
93
  self.attn_drop = nn.Dropout(attn_drop)
94
  self.proj = nn.Linear(d_model, d_model)
95
  self.proj_drop = nn.Dropout(proj_drop)
96
- self.last_out = None
97
- self.count = 0
98
 
99
- def forward(self, x, cond, mask=None, timestep=None):
100
  # query/value: img tokens; key: condition; mask: if padding tokens
101
  B, N, C = x.shape
102
 
103
  q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
104
  kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
105
  k, v = kv.unbind(2)
106
- x = self.flash_attn_impl(q, k, v, mask, B, N, C)
107
 
108
- x = self.proj(x)
109
- x = self.proj_drop(x)
110
- return x
111
 
112
- def flash_attn_impl(self, q, k, v, mask, B, N, C):
113
- from flash_attn import flash_attn_varlen_func
114
-
115
- q_seqinfo = _SeqLenInfo.from_seqlens([N] * B)
116
- k_seqinfo = _SeqLenInfo.from_seqlens(mask)
117
-
118
- x = flash_attn_varlen_func(
119
- q.view(-1, self.num_heads, self.head_dim),
120
- k.view(-1, self.num_heads, self.head_dim),
121
- v.view(-1, self.num_heads, self.head_dim),
122
- cu_seqlens_q=q_seqinfo.seqstart.cuda(),
123
- cu_seqlens_k=k_seqinfo.seqstart.cuda(),
124
- max_seqlen_q=q_seqinfo.max_seqlen,
125
- max_seqlen_k=k_seqinfo.max_seqlen,
126
- dropout_p=self.attn_drop.p if self.training else 0.0,
127
- )
128
- x = x.view(B, N, C)
129
- return x
130
-
131
- def torch_impl(self, q, k, v, mask, B, N, C):
132
- q = q.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
133
- k = k.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
134
- v = v.view(B, -1, self.num_heads, self.head_dim).transpose(1, 2)
135
-
136
- attn_mask = torch.zeros(B, N, k.shape[2], dtype=torch.float32, device=q.device)
137
- for i, m in enumerate(mask):
138
- attn_mask[i, :, m:] = -1e8
139
-
140
- scale = 1 / q.shape[-1] ** 0.5
141
- q = q * scale
142
- attn = q @ k.transpose(-2, -1)
143
- attn = attn.to(torch.float32)
144
  if mask is not None:
145
- attn = attn + attn_mask.unsqueeze(1)
146
- attn = attn.softmax(-1)
147
- attn = attn.to(v.dtype)
148
- out = attn @ v
149
 
150
- x = out.transpose(1, 2).contiguous().view(B, N, C)
 
 
151
  return x
152
-
153
-
154
- @dataclass
155
- class _SeqLenInfo:
156
- """
157
- copied from xformers
158
-
159
- (Internal) Represents the division of a dimension into blocks.
160
- For example, to represents a dimension of length 7 divided into
161
- three blocks of lengths 2, 3 and 2, use `from_seqlength([2, 3, 2])`.
162
- The members will be:
163
- max_seqlen: 3
164
- min_seqlen: 2
165
- seqstart_py: [0, 2, 5, 7]
166
- seqstart: torch.IntTensor([0, 2, 5, 7])
167
- """
168
-
169
- seqstart: torch.Tensor
170
- max_seqlen: int
171
- min_seqlen: int
172
- seqstart_py: List[int]
173
-
174
- def to(self, device: torch.device) -> None:
175
- self.seqstart = self.seqstart.to(device, non_blocking=True)
176
-
177
- def intervals(self) -> Iterable[Tuple[int, int]]:
178
- yield from zip(self.seqstart_py, self.seqstart_py[1:])
179
-
180
- @classmethod
181
- def from_seqlens(cls, seqlens: Iterable[int]) -> "_SeqLenInfo":
182
- """
183
- Input tensors are assumed to be in shape [B, M, *]
184
- """
185
- assert not isinstance(seqlens, torch.Tensor)
186
- seqstart_py = [0]
187
- max_seqlen = -1
188
- min_seqlen = -1
189
- for seqlen in seqlens:
190
- min_seqlen = min(min_seqlen, seqlen) if min_seqlen != -1 else seqlen
191
- max_seqlen = max(max_seqlen, seqlen)
192
- seqstart_py.append(seqstart_py[len(seqstart_py) - 1] + seqlen)
193
- seqstart = torch.tensor(seqstart_py, dtype=torch.int32)
194
- return cls(
195
- max_seqlen=max_seqlen,
196
- min_seqlen=min_seqlen,
197
- seqstart=seqstart,
198
- seqstart_py=seqstart_py,
199
- )
200
-
201
- def split(self, x: torch.Tensor, batch_sizes: Optional[Sequence[int]] = None) -> List[torch.Tensor]:
202
- if self.seqstart_py[-1] != x.shape[1] or x.shape[0] != 1:
203
- raise ValueError(
204
- f"Invalid `torch.Tensor` of shape {x.shape}, expected format "
205
- f"(B, M, *) with B=1 and M={self.seqstart_py[-1]}\n"
206
- f" seqstart: {self.seqstart_py}"
207
- )
208
- if batch_sizes is None:
209
- batch_sizes = [1] * (len(self.seqstart_py) - 1)
210
- split_chunks = []
211
- it = 0
212
- for batch_size in batch_sizes:
213
- split_chunks.append(self.seqstart_py[it + batch_size] - self.seqstart_py[it])
214
- it += batch_size
215
- return [
216
- tensor.reshape([bs, -1, *tensor.shape[2:]]) for bs, tensor in zip(batch_sizes, x.split(split_chunks, dim=1))
217
- ]
 
 
 
 
1
  import torch
2
  import torch.nn as nn
 
3
  import torch.utils.checkpoint
4
 
5
+ from videosys.models.modules.normalization import LlamaRMSNorm
6
 
7
 
8
  class Attention(nn.Module):
 
15
  attn_drop: float = 0.0,
16
  proj_drop: float = 0.0,
17
  norm_layer: nn.Module = LlamaRMSNorm,
18
+ enable_flash_attn: bool = False,
19
  rope=None,
20
+ qk_norm_legacy: bool = False,
21
  ) -> None:
22
  super().__init__()
23
  assert dim % num_heads == 0, "dim should be divisible by num_heads"
 
25
  self.num_heads = num_heads
26
  self.head_dim = dim // num_heads
27
  self.scale = self.head_dim**-0.5
28
+ self.enable_flash_attn = enable_flash_attn
29
 
30
  self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
31
  self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
32
  self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
33
+ self.qk_norm_legacy = qk_norm_legacy
34
  self.attn_drop = nn.Dropout(attn_drop)
35
  self.proj = nn.Linear(dim, dim)
36
  self.proj_drop = nn.Dropout(proj_drop)
 
42
 
43
  def forward(self, x: torch.Tensor) -> torch.Tensor:
44
  B, N, C = x.shape
45
+ # flash attn is not memory efficient for small sequences, this is empirical
46
+ enable_flash_attn = self.enable_flash_attn and (N > B)
47
  qkv = self.qkv(x)
48
+ qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
49
+
50
+ qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
51
  q, k, v = qkv.unbind(0)
52
+ if self.qk_norm_legacy:
53
+ # WARNING: this may be a bug
54
+ if self.rope:
55
+ q = self.rotary_emb(q)
56
+ k = self.rotary_emb(k)
57
+ q, k = self.q_norm(q), self.k_norm(k)
58
+ else:
59
+ q, k = self.q_norm(q), self.k_norm(k)
60
+ if self.rope:
61
+ q = self.rotary_emb(q)
62
+ k = self.rotary_emb(k)
63
 
64
+ if enable_flash_attn:
65
  from flash_attn import flash_attn_func
66
 
67
+ # (B, #heads, N, #dim) -> (B, N, #heads, #dim)
68
+ q = q.permute(0, 2, 1, 3)
69
+ k = k.permute(0, 2, 1, 3)
70
+ v = v.permute(0, 2, 1, 3)
71
  x = flash_attn_func(
72
  q,
73
  k,
 
76
  softmax_scale=self.scale,
77
  )
78
  else:
79
+ dtype = q.dtype
80
+ q = q * self.scale
81
+ attn = q @ k.transpose(-2, -1) # translate attn to float32
82
+ attn = attn.to(torch.float32)
83
+ attn = attn.softmax(dim=-1)
84
+ attn = attn.to(dtype) # cast back attn to original dtype
85
+ attn = self.attn_drop(attn)
86
+ x = attn @ v
87
 
88
  x_output_shape = (B, N, C)
89
+ if not enable_flash_attn:
90
  x = x.transpose(1, 2)
91
  x = x.reshape(x_output_shape)
92
  x = self.proj(x)
 
95
 
96
 
97
  class MultiHeadCrossAttention(nn.Module):
98
+ def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
99
  super(MultiHeadCrossAttention, self).__init__()
100
  assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
101
 
102
  self.d_model = d_model
103
  self.num_heads = num_heads
104
  self.head_dim = d_model // num_heads
 
105
 
106
  self.q_linear = nn.Linear(d_model, d_model)
107
  self.kv_linear = nn.Linear(d_model, d_model * 2)
108
  self.attn_drop = nn.Dropout(attn_drop)
109
  self.proj = nn.Linear(d_model, d_model)
110
  self.proj_drop = nn.Dropout(proj_drop)
 
 
111
 
112
+ def forward(self, x, cond, mask=None):
113
  # query/value: img tokens; key: condition; mask: if padding tokens
114
  B, N, C = x.shape
115
 
116
  q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
117
  kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
118
  k, v = kv.unbind(2)
 
119
 
120
+ attn_bias = None
121
+ # TODO: support torch computation
122
+ import xformers.ops
123
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
124
  if mask is not None:
125
+ attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
126
+ x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
 
 
127
 
128
+ x = x.view(B, -1, C)
129
+ x = self.proj(x)
130
+ x = self.proj_drop(x)
131
  return x
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
videosys/models/modules/downsampling.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class CogVideoXDownsample3D(nn.Module):
7
+ # Todo: Wait for paper relase.
8
+ r"""
9
+ A 3D Downsampling layer using in [CogVideoX]() by Tsinghua University & ZhipuAI
10
+
11
+ Args:
12
+ in_channels (`int`):
13
+ Number of channels in the input image.
14
+ out_channels (`int`):
15
+ Number of channels produced by the convolution.
16
+ kernel_size (`int`, defaults to `3`):
17
+ Size of the convolving kernel.
18
+ stride (`int`, defaults to `2`):
19
+ Stride of the convolution.
20
+ padding (`int`, defaults to `0`):
21
+ Padding added to all four sides of the input.
22
+ compress_time (`bool`, defaults to `False`):
23
+ Whether or not to compress the time dimension.
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ in_channels: int,
29
+ out_channels: int,
30
+ kernel_size: int = 3,
31
+ stride: int = 2,
32
+ padding: int = 0,
33
+ compress_time: bool = False,
34
+ ):
35
+ super().__init__()
36
+
37
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
38
+ self.compress_time = compress_time
39
+
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ if self.compress_time:
42
+ batch_size, channels, frames, height, width = x.shape
43
+
44
+ # (batch_size, channels, frames, height, width) -> (batch_size, height, width, channels, frames) -> (batch_size * height * width, channels, frames)
45
+ x = x.permute(0, 3, 4, 1, 2).reshape(batch_size * height * width, channels, frames)
46
+
47
+ if x.shape[-1] % 2 == 1:
48
+ x_first, x_rest = x[..., 0], x[..., 1:]
49
+ if x_rest.shape[-1] > 0:
50
+ # (batch_size * height * width, channels, frames - 1) -> (batch_size * height * width, channels, (frames - 1) // 2)
51
+ x_rest = F.avg_pool1d(x_rest, kernel_size=2, stride=2)
52
+
53
+ x = torch.cat([x_first[..., None], x_rest], dim=-1)
54
+ # (batch_size * height * width, channels, (frames // 2) + 1) -> (batch_size, height, width, channels, (frames // 2) + 1) -> (batch_size, channels, (frames // 2) + 1, height, width)
55
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
56
+ else:
57
+ # (batch_size * height * width, channels, frames) -> (batch_size * height * width, channels, frames // 2)
58
+ x = F.avg_pool1d(x, kernel_size=2, stride=2)
59
+ # (batch_size * height * width, channels, frames // 2) -> (batch_size, height, width, channels, frames // 2) -> (batch_size, channels, frames // 2, height, width)
60
+ x = x.reshape(batch_size, height, width, channels, x.shape[-1]).permute(0, 3, 4, 1, 2)
61
+
62
+ # Pad the tensor
63
+ pad = (0, 1, 0, 1)
64
+ x = F.pad(x, pad, mode="constant", value=0)
65
+ batch_size, channels, frames, height, width = x.shape
66
+ # (batch_size, channels, frames, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size * frames, channels, height, width)
67
+ x = x.permute(0, 2, 1, 3, 4).reshape(batch_size * frames, channels, height, width)
68
+ x = self.conv(x)
69
+ # (batch_size * frames, channels, height, width) -> (batch_size, frames, channels, height, width) -> (batch_size, channels, frames, height, width)
70
+ x = x.reshape(batch_size, frames, x.shape[1], x.shape[2], x.shape[3]).permute(0, 2, 1, 3, 4)
71
+ return x
videosys/models/{open_sora/modules.py β†’ modules/embeddings.py} RENAMED
@@ -1,16 +1,8 @@
1
- # Adapted from OpenSora
2
-
3
- # This source code is licensed under the license found in the
4
- # LICENSE file in the root directory of this source tree.
5
- # --------------------------------------------------------
6
- # References:
7
- # OpenSora: https://github.com/hpcaitech/Open-Sora
8
- # --------------------------------------------------------
9
-
10
  import functools
11
  import math
12
- from typing import Optional
13
 
 
14
  import torch
15
  import torch.nn as nn
16
  import torch.nn.functional as F
@@ -18,40 +10,48 @@ import torch.utils.checkpoint
18
  from einops import rearrange
19
  from timm.models.vision_transformer import Mlp
20
 
21
- approx_gelu = lambda: nn.GELU(approximate="tanh")
22
 
23
-
24
- class LlamaRMSNorm(nn.Module):
25
- def __init__(self, hidden_size, eps=1e-6):
26
- """
27
- LlamaRMSNorm is equivalent to T5LayerNorm
28
- """
 
 
 
29
  super().__init__()
30
- self.weight = nn.Parameter(torch.ones(hidden_size))
31
- self.variance_epsilon = eps
32
-
33
- def forward(self, hidden_states):
34
- input_dtype = hidden_states.dtype
35
- hidden_states = hidden_states.to(torch.float32)
36
- variance = hidden_states.pow(2).mean(-1, keepdim=True)
37
- hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
38
- return self.weight * hidden_states.to(input_dtype)
39
-
40
-
41
- def get_layernorm(hidden_size: torch.Tensor, eps: float, affine: bool):
42
- return nn.LayerNorm(hidden_size, eps, elementwise_affine=affine)
43
-
44
 
45
- def t2i_modulate(x, shift, scale):
46
- return x * (1 + scale) + shift
 
 
 
 
 
 
 
 
 
 
 
 
47
 
 
 
 
 
 
 
48
 
49
- # ===============================================
50
- # General-purpose Layers
51
- # ===============================================
 
52
 
53
 
54
- class PatchEmbed3D(nn.Module):
55
  """Video to Patch Embedding.
56
 
57
  Args:
@@ -104,176 +104,6 @@ class PatchEmbed3D(nn.Module):
104
  return x
105
 
106
 
107
- class Attention(nn.Module):
108
- def __init__(
109
- self,
110
- dim: int,
111
- num_heads: int = 8,
112
- qkv_bias: bool = False,
113
- qk_norm: bool = False,
114
- attn_drop: float = 0.0,
115
- proj_drop: float = 0.0,
116
- norm_layer: nn.Module = LlamaRMSNorm,
117
- enable_flash_attn: bool = False,
118
- rope=None,
119
- qk_norm_legacy: bool = False,
120
- ) -> None:
121
- super().__init__()
122
- assert dim % num_heads == 0, "dim should be divisible by num_heads"
123
- self.dim = dim
124
- self.num_heads = num_heads
125
- self.head_dim = dim // num_heads
126
- self.scale = self.head_dim**-0.5
127
- self.enable_flash_attn = enable_flash_attn
128
-
129
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
130
- self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
131
- self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
132
- self.qk_norm_legacy = qk_norm_legacy
133
- self.attn_drop = nn.Dropout(attn_drop)
134
- self.proj = nn.Linear(dim, dim)
135
- self.proj_drop = nn.Dropout(proj_drop)
136
-
137
- self.rope = False
138
- if rope is not None:
139
- self.rope = True
140
- self.rotary_emb = rope
141
-
142
- def forward(self, x: torch.Tensor) -> torch.Tensor:
143
- B, N, C = x.shape
144
- # flash attn is not memory efficient for small sequences, this is empirical
145
- enable_flash_attn = self.enable_flash_attn and (N > B)
146
- qkv = self.qkv(x)
147
- qkv_shape = (B, N, 3, self.num_heads, self.head_dim)
148
-
149
- qkv = qkv.view(qkv_shape).permute(2, 0, 3, 1, 4)
150
- q, k, v = qkv.unbind(0)
151
- if self.qk_norm_legacy:
152
- # WARNING: this may be a bug
153
- if self.rope:
154
- q = self.rotary_emb(q)
155
- k = self.rotary_emb(k)
156
- q, k = self.q_norm(q), self.k_norm(k)
157
- else:
158
- q, k = self.q_norm(q), self.k_norm(k)
159
- if self.rope:
160
- q = self.rotary_emb(q)
161
- k = self.rotary_emb(k)
162
-
163
- if enable_flash_attn:
164
- from flash_attn import flash_attn_func
165
-
166
- # (B, #heads, N, #dim) -> (B, N, #heads, #dim)
167
- q = q.permute(0, 2, 1, 3)
168
- k = k.permute(0, 2, 1, 3)
169
- v = v.permute(0, 2, 1, 3)
170
- x = flash_attn_func(
171
- q,
172
- k,
173
- v,
174
- dropout_p=self.attn_drop.p if self.training else 0.0,
175
- softmax_scale=self.scale,
176
- )
177
- else:
178
- dtype = q.dtype
179
- q = q * self.scale
180
- attn = q @ k.transpose(-2, -1) # translate attn to float32
181
- attn = attn.to(torch.float32)
182
- attn = attn.softmax(dim=-1)
183
- attn = attn.to(dtype) # cast back attn to original dtype
184
- attn = self.attn_drop(attn)
185
- x = attn @ v
186
-
187
- x_output_shape = (B, N, C)
188
- if not enable_flash_attn:
189
- x = x.transpose(1, 2)
190
- x = x.reshape(x_output_shape)
191
- x = self.proj(x)
192
- x = self.proj_drop(x)
193
- return x
194
-
195
-
196
- class MultiHeadCrossAttention(nn.Module):
197
- def __init__(self, d_model, num_heads, attn_drop=0.0, proj_drop=0.0):
198
- super(MultiHeadCrossAttention, self).__init__()
199
- assert d_model % num_heads == 0, "d_model must be divisible by num_heads"
200
-
201
- self.d_model = d_model
202
- self.num_heads = num_heads
203
- self.head_dim = d_model // num_heads
204
-
205
- self.q_linear = nn.Linear(d_model, d_model)
206
- self.kv_linear = nn.Linear(d_model, d_model * 2)
207
- self.attn_drop = nn.Dropout(attn_drop)
208
- self.proj = nn.Linear(d_model, d_model)
209
- self.proj_drop = nn.Dropout(proj_drop)
210
-
211
- def forward(self, x, cond, mask=None):
212
- # query/value: img tokens; key: condition; mask: if padding tokens
213
- B, N, C = x.shape
214
-
215
- q = self.q_linear(x).view(1, -1, self.num_heads, self.head_dim)
216
- kv = self.kv_linear(cond).view(1, -1, 2, self.num_heads, self.head_dim)
217
- k, v = kv.unbind(2)
218
-
219
- attn_bias = None
220
- # TODO: support torch computation
221
- import xformers.ops
222
-
223
- if mask is not None:
224
- attn_bias = xformers.ops.fmha.BlockDiagonalMask.from_seqlens([N] * B, mask)
225
- x = xformers.ops.memory_efficient_attention(q, k, v, p=self.attn_drop.p, attn_bias=attn_bias)
226
-
227
- x = x.view(B, -1, C)
228
- x = self.proj(x)
229
- x = self.proj_drop(x)
230
- return x
231
-
232
-
233
- class T2IFinalLayer(nn.Module):
234
- """
235
- The final layer of PixArt.
236
- """
237
-
238
- def __init__(self, hidden_size, num_patch, out_channels, d_t=None, d_s=None):
239
- super().__init__()
240
- self.norm_final = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
241
- self.linear = nn.Linear(hidden_size, num_patch * out_channels, bias=True)
242
- self.scale_shift_table = nn.Parameter(torch.randn(2, hidden_size) / hidden_size**0.5)
243
- self.out_channels = out_channels
244
- self.d_t = d_t
245
- self.d_s = d_s
246
-
247
- def t_mask_select(self, x_mask, x, masked_x, T, S):
248
- # x: [B, (T, S), C]
249
- # mased_x: [B, (T, S), C]
250
- # x_mask: [B, T]
251
- x = rearrange(x, "B (T S) C -> B T S C", T=T, S=S)
252
- masked_x = rearrange(masked_x, "B (T S) C -> B T S C", T=T, S=S)
253
- x = torch.where(x_mask[:, :, None, None], x, masked_x)
254
- x = rearrange(x, "B T S C -> B (T S) C")
255
- return x
256
-
257
- def forward(self, x, t, x_mask=None, t0=None, T=None, S=None):
258
- if T is None:
259
- T = self.d_t
260
- if S is None:
261
- S = self.d_s
262
- shift, scale = (self.scale_shift_table[None] + t[:, None]).chunk(2, dim=1)
263
- x = t2i_modulate(self.norm_final(x), shift, scale)
264
- if x_mask is not None:
265
- shift_zero, scale_zero = (self.scale_shift_table[None] + t0[:, None]).chunk(2, dim=1)
266
- x_zero = t2i_modulate(self.norm_final(x), shift_zero, scale_zero)
267
- x = self.t_mask_select(x_mask, x, x_zero, T, S)
268
- x = self.linear(x)
269
- return x
270
-
271
-
272
- # ===============================================
273
- # Embedding Layers for Timesteps and Class Labels
274
- # ===============================================
275
-
276
-
277
  class TimestepEmbedder(nn.Module):
278
  """
279
  Embeds scalar timesteps into vector representations.
@@ -350,7 +180,7 @@ class SizeEmbedder(TimestepEmbedder):
350
  return next(self.parameters()).dtype
351
 
352
 
353
- class CaptionEmbedder(nn.Module):
354
  """
355
  Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
356
  """
@@ -398,7 +228,7 @@ class CaptionEmbedder(nn.Module):
398
  return caption
399
 
400
 
401
- class PositionEmbedding2D(nn.Module):
402
  def __init__(self, dim: int) -> None:
403
  super().__init__()
404
  self.dim = dim
@@ -448,3 +278,135 @@ class PositionEmbedding2D(nn.Module):
448
  base_size: Optional[int] = None,
449
  ) -> torch.Tensor:
450
  return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import functools
2
  import math
3
+ from typing import Optional, Tuple, Union
4
 
5
+ import numpy as np
6
  import torch
7
  import torch.nn as nn
8
  import torch.nn.functional as F
 
10
  from einops import rearrange
11
  from timm.models.vision_transformer import Mlp
12
 
 
13
 
14
+ class CogVideoXPatchEmbed(nn.Module):
15
+ def __init__(
16
+ self,
17
+ patch_size: int = 2,
18
+ in_channels: int = 16,
19
+ embed_dim: int = 1920,
20
+ text_embed_dim: int = 4096,
21
+ bias: bool = True,
22
+ ) -> None:
23
  super().__init__()
24
+ self.patch_size = patch_size
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ self.proj = nn.Conv2d(
27
+ in_channels, embed_dim, kernel_size=(patch_size, patch_size), stride=patch_size, bias=bias
28
+ )
29
+ self.text_proj = nn.Linear(text_embed_dim, embed_dim)
30
+
31
+ def forward(self, text_embeds: torch.Tensor, image_embeds: torch.Tensor):
32
+ r"""
33
+ Args:
34
+ text_embeds (`torch.Tensor`):
35
+ Input text embeddings. Expected shape: (batch_size, seq_length, embedding_dim).
36
+ image_embeds (`torch.Tensor`):
37
+ Input image embeddings. Expected shape: (batch_size, num_frames, channels, height, width).
38
+ """
39
+ text_embeds = self.text_proj(text_embeds)
40
 
41
+ batch, num_frames, channels, height, width = image_embeds.shape
42
+ image_embeds = image_embeds.reshape(-1, channels, height, width)
43
+ image_embeds = self.proj(image_embeds)
44
+ image_embeds = image_embeds.view(batch, num_frames, *image_embeds.shape[1:])
45
+ image_embeds = image_embeds.flatten(3).transpose(2, 3) # [batch, num_frames, height x width, channels]
46
+ image_embeds = image_embeds.flatten(1, 2) # [batch, num_frames x height x width, channels]
47
 
48
+ embeds = torch.cat(
49
+ [text_embeds, image_embeds], dim=1
50
+ ).contiguous() # [batch, seq_length + num_frames x height x width, channels]
51
+ return embeds
52
 
53
 
54
+ class OpenSoraPatchEmbed3D(nn.Module):
55
  """Video to Patch Embedding.
56
 
57
  Args:
 
104
  return x
105
 
106
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
107
  class TimestepEmbedder(nn.Module):
108
  """
109
  Embeds scalar timesteps into vector representations.
 
180
  return next(self.parameters()).dtype
181
 
182
 
183
+ class OpenSoraCaptionEmbedder(nn.Module):
184
  """
185
  Embeds class labels into vector representations. Also handles label dropout for classifier-free guidance.
186
  """
 
228
  return caption
229
 
230
 
231
+ class OpenSoraPositionEmbedding2D(nn.Module):
232
  def __init__(self, dim: int) -> None:
233
  super().__init__()
234
  self.dim = dim
 
278
  base_size: Optional[int] = None,
279
  ) -> torch.Tensor:
280
  return self._get_cached_emb(x.device, x.dtype, h, w, scale, base_size)
281
+
282
+
283
+ def get_3d_rotary_pos_embed(
284
+ embed_dim, crops_coords, grid_size, temporal_size, theta: int = 10000, use_real: bool = True
285
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
286
+ """
287
+ RoPE for video tokens with 3D structure.
288
+
289
+ Args:
290
+ embed_dim: (`int`):
291
+ The embedding dimension size, corresponding to hidden_size_head.
292
+ crops_coords (`Tuple[int]`):
293
+ The top-left and bottom-right coordinates of the crop.
294
+ grid_size (`Tuple[int]`):
295
+ The grid size of the spatial positional embedding (height, width).
296
+ temporal_size (`int`):
297
+ The size of the temporal dimension.
298
+ theta (`float`):
299
+ Scaling factor for frequency computation.
300
+ use_real (`bool`):
301
+ If True, return real part and imaginary part separately. Otherwise, return complex numbers.
302
+
303
+ Returns:
304
+ `torch.Tensor`: positional embedding with shape `(temporal_size * grid_size[0] * grid_size[1], embed_dim/2)`.
305
+ """
306
+ start, stop = crops_coords
307
+ grid_h = np.linspace(start[0], stop[0], grid_size[0], endpoint=False, dtype=np.float32)
308
+ grid_w = np.linspace(start[1], stop[1], grid_size[1], endpoint=False, dtype=np.float32)
309
+ grid_t = np.linspace(0, temporal_size, temporal_size, endpoint=False, dtype=np.float32)
310
+
311
+ # Compute dimensions for each axis
312
+ dim_t = embed_dim // 4
313
+ dim_h = embed_dim // 8 * 3
314
+ dim_w = embed_dim // 8 * 3
315
+
316
+ # Temporal frequencies
317
+ freqs_t = 1.0 / (theta ** (torch.arange(0, dim_t, 2).float() / dim_t))
318
+ grid_t = torch.from_numpy(grid_t).float()
319
+ freqs_t = torch.einsum("n , f -> n f", grid_t, freqs_t)
320
+ freqs_t = freqs_t.repeat_interleave(2, dim=-1)
321
+
322
+ # Spatial frequencies for height and width
323
+ freqs_h = 1.0 / (theta ** (torch.arange(0, dim_h, 2).float() / dim_h))
324
+ freqs_w = 1.0 / (theta ** (torch.arange(0, dim_w, 2).float() / dim_w))
325
+ grid_h = torch.from_numpy(grid_h).float()
326
+ grid_w = torch.from_numpy(grid_w).float()
327
+ freqs_h = torch.einsum("n , f -> n f", grid_h, freqs_h)
328
+ freqs_w = torch.einsum("n , f -> n f", grid_w, freqs_w)
329
+ freqs_h = freqs_h.repeat_interleave(2, dim=-1)
330
+ freqs_w = freqs_w.repeat_interleave(2, dim=-1)
331
+
332
+ # Broadcast and concatenate tensors along specified dimension
333
+ def broadcast(tensors, dim=-1):
334
+ num_tensors = len(tensors)
335
+ shape_lens = {len(t.shape) for t in tensors}
336
+ assert len(shape_lens) == 1, "tensors must all have the same number of dimensions"
337
+ shape_len = list(shape_lens)[0]
338
+ dim = (dim + shape_len) if dim < 0 else dim
339
+ dims = list(zip(*(list(t.shape) for t in tensors)))
340
+ expandable_dims = [(i, val) for i, val in enumerate(dims) if i != dim]
341
+ assert all(
342
+ [*(len(set(t[1])) <= 2 for t in expandable_dims)]
343
+ ), "invalid dimensions for broadcastable concatenation"
344
+ max_dims = [(t[0], max(t[1])) for t in expandable_dims]
345
+ expanded_dims = [(t[0], (t[1],) * num_tensors) for t in max_dims]
346
+ expanded_dims.insert(dim, (dim, dims[dim]))
347
+ expandable_shapes = list(zip(*(t[1] for t in expanded_dims)))
348
+ tensors = [t[0].expand(*t[1]) for t in zip(tensors, expandable_shapes)]
349
+ return torch.cat(tensors, dim=dim)
350
+
351
+ freqs = broadcast((freqs_t[:, None, None, :], freqs_h[None, :, None, :], freqs_w[None, None, :, :]), dim=-1)
352
+
353
+ t, h, w, d = freqs.shape
354
+ freqs = freqs.view(t * h * w, d)
355
+
356
+ # Generate sine and cosine components
357
+ sin = freqs.sin()
358
+ cos = freqs.cos()
359
+
360
+ if use_real:
361
+ return cos, sin
362
+ else:
363
+ freqs_cis = torch.polar(torch.ones_like(freqs), freqs)
364
+ return freqs_cis
365
+
366
+
367
+ def apply_rotary_emb(
368
+ x: torch.Tensor,
369
+ freqs_cis: Union[torch.Tensor, Tuple[torch.Tensor]],
370
+ use_real: bool = True,
371
+ use_real_unbind_dim: int = -1,
372
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
373
+ """
374
+ Apply rotary embeddings to input tensors using the given frequency tensor. This function applies rotary embeddings
375
+ to the given query or key 'x' tensors using the provided frequency tensor 'freqs_cis'. The input tensors are
376
+ reshaped as complex numbers, and the frequency tensor is reshaped for broadcasting compatibility. The resulting
377
+ tensors contain rotary embeddings and are returned as real tensors.
378
+
379
+ Args:
380
+ x (`torch.Tensor`):
381
+ Query or key tensor to apply rotary embeddings. [B, H, S, D] xk (torch.Tensor): Key tensor to apply
382
+ freqs_cis (`Tuple[torch.Tensor]`): Precomputed frequency tensor for complex exponentials. ([S, D], [S, D],)
383
+
384
+ Returns:
385
+ Tuple[torch.Tensor, torch.Tensor]: Tuple of modified query tensor and key tensor with rotary embeddings.
386
+ """
387
+ if use_real:
388
+ cos, sin = freqs_cis # [S, D]
389
+ cos = cos[None, None]
390
+ sin = sin[None, None]
391
+ cos, sin = cos.to(x.device), sin.to(x.device)
392
+
393
+ if use_real_unbind_dim == -1:
394
+ # Use for example in Lumina
395
+ x_real, x_imag = x.reshape(*x.shape[:-1], -1, 2).unbind(-1) # [B, S, H, D//2]
396
+ x_rotated = torch.stack([-x_imag, x_real], dim=-1).flatten(3)
397
+ elif use_real_unbind_dim == -2:
398
+ # Use for example in Stable Audio
399
+ x_real, x_imag = x.reshape(*x.shape[:-1], 2, -1).unbind(-2) # [B, S, H, D//2]
400
+ x_rotated = torch.cat([-x_imag, x_real], dim=-1)
401
+ else:
402
+ raise ValueError(f"`use_real_unbind_dim={use_real_unbind_dim}` but should be -1 or -2.")
403
+
404
+ out = (x.float() * cos + x_rotated.float() * sin).to(x.dtype)
405
+
406
+ return out
407
+ else:
408
+ x_rotated = torch.view_as_complex(x.float().reshape(*x.shape[:-1], -1, 2))
409
+ freqs_cis = freqs_cis.unsqueeze(2)
410
+ x_out = torch.view_as_real(x_rotated * freqs_cis).flatten(3)
411
+
412
+ return x_out.type_as(x)
videosys/models/modules/normalization.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional, Tuple
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+
6
+
7
+ class LlamaRMSNorm(nn.Module):
8
+ def __init__(self, hidden_size, eps=1e-6):
9
+ """
10
+ LlamaRMSNorm is equivalent to T5LayerNorm
11
+ """
12
+ super().__init__()
13
+ self.weight = nn.Parameter(torch.ones(hidden_size))
14
+ self.variance_epsilon = eps
15
+
16
+ def forward(self, hidden_states):
17
+ input_dtype = hidden_states.dtype
18
+ hidden_states = hidden_states.to(torch.float32)
19
+ variance = hidden_states.pow(2).mean(-1, keepdim=True)
20
+ hidden_states = hidden_states * torch.rsqrt(variance + self.variance_epsilon)
21
+ return self.weight * hidden_states.to(input_dtype)
22
+
23
+
24
+ class CogVideoXLayerNormZero(nn.Module):
25
+ def __init__(
26
+ self,
27
+ conditioning_dim: int,
28
+ embedding_dim: int,
29
+ elementwise_affine: bool = True,
30
+ eps: float = 1e-5,
31
+ bias: bool = True,
32
+ ) -> None:
33
+ super().__init__()
34
+
35
+ self.silu = nn.SiLU()
36
+ self.linear = nn.Linear(conditioning_dim, 6 * embedding_dim, bias=bias)
37
+ self.norm = nn.LayerNorm(embedding_dim, eps=eps, elementwise_affine=elementwise_affine)
38
+
39
+ def forward(
40
+ self, hidden_states: torch.Tensor, encoder_hidden_states: torch.Tensor, temb: torch.Tensor
41
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
42
+ shift, scale, gate, enc_shift, enc_scale, enc_gate = self.linear(self.silu(temb)).chunk(6, dim=1)
43
+ hidden_states = self.norm(hidden_states) * (1 + scale)[:, None, :] + shift[:, None, :]
44
+ encoder_hidden_states = self.norm(encoder_hidden_states) * (1 + enc_scale)[:, None, :] + enc_shift[:, None, :]
45
+ return hidden_states, encoder_hidden_states, gate[:, None, :], enc_gate[:, None, :]
46
+
47
+
48
+ class AdaLayerNorm(nn.Module):
49
+ r"""
50
+ Norm layer modified to incorporate timestep embeddings.
51
+
52
+ Parameters:
53
+ embedding_dim (`int`): The size of each embedding vector.
54
+ num_embeddings (`int`, *optional*): The size of the embeddings dictionary.
55
+ output_dim (`int`, *optional*):
56
+ norm_elementwise_affine (`bool`, defaults to `False):
57
+ norm_eps (`bool`, defaults to `False`):
58
+ chunk_dim (`int`, defaults to `0`):
59
+ """
60
+
61
+ def __init__(
62
+ self,
63
+ embedding_dim: int,
64
+ num_embeddings: Optional[int] = None,
65
+ output_dim: Optional[int] = None,
66
+ norm_elementwise_affine: bool = False,
67
+ norm_eps: float = 1e-5,
68
+ chunk_dim: int = 0,
69
+ ):
70
+ super().__init__()
71
+
72
+ self.chunk_dim = chunk_dim
73
+ output_dim = output_dim or embedding_dim * 2
74
+
75
+ if num_embeddings is not None:
76
+ self.emb = nn.Embedding(num_embeddings, embedding_dim)
77
+ else:
78
+ self.emb = None
79
+
80
+ self.silu = nn.SiLU()
81
+ self.linear = nn.Linear(embedding_dim, output_dim)
82
+ self.norm = nn.LayerNorm(output_dim // 2, norm_eps, norm_elementwise_affine)
83
+
84
+ def forward(
85
+ self, x: torch.Tensor, timestep: Optional[torch.Tensor] = None, temb: Optional[torch.Tensor] = None
86
+ ) -> torch.Tensor:
87
+ if self.emb is not None:
88
+ temb = self.emb(timestep)
89
+
90
+ temb = self.linear(self.silu(temb))
91
+
92
+ if self.chunk_dim == 1:
93
+ # This is a bit weird why we have the order of "shift, scale" here and "scale, shift" in the
94
+ # other if-branch. This branch is specific to CogVideoX for now.
95
+ shift, scale = temb.chunk(2, dim=1)
96
+ shift = shift[:, None, :]
97
+ scale = scale[:, None, :]
98
+ else:
99
+ scale, shift = temb.chunk(2, dim=0)
100
+
101
+ x = self.norm(x) * (1 + scale) + shift
102
+ return x
videosys/models/modules/upsampling.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class CogVideoXUpsample3D(nn.Module):
7
+ r"""
8
+ A 3D Upsample layer using in CogVideoX by Tsinghua University & ZhipuAI # Todo: Wait for paper relase.
9
+
10
+ Args:
11
+ in_channels (`int`):
12
+ Number of channels in the input image.
13
+ out_channels (`int`):
14
+ Number of channels produced by the convolution.
15
+ kernel_size (`int`, defaults to `3`):
16
+ Size of the convolving kernel.
17
+ stride (`int`, defaults to `1`):
18
+ Stride of the convolution.
19
+ padding (`int`, defaults to `1`):
20
+ Padding added to all four sides of the input.
21
+ compress_time (`bool`, defaults to `False`):
22
+ Whether or not to compress the time dimension.
23
+ """
24
+
25
+ def __init__(
26
+ self,
27
+ in_channels: int,
28
+ out_channels: int,
29
+ kernel_size: int = 3,
30
+ stride: int = 1,
31
+ padding: int = 1,
32
+ compress_time: bool = False,
33
+ ) -> None:
34
+ super().__init__()
35
+
36
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride, padding=padding)
37
+ self.compress_time = compress_time
38
+
39
+ def forward(self, inputs: torch.Tensor) -> torch.Tensor:
40
+ if self.compress_time:
41
+ if inputs.shape[2] > 1 and inputs.shape[2] % 2 == 1:
42
+ # split first frame
43
+ x_first, x_rest = inputs[:, :, 0], inputs[:, :, 1:]
44
+
45
+ x_first = F.interpolate(x_first, scale_factor=2.0)
46
+ x_rest = F.interpolate(x_rest, scale_factor=2.0)
47
+ x_first = x_first[:, :, None, :, :]
48
+ inputs = torch.cat([x_first, x_rest], dim=2)
49
+ elif inputs.shape[2] > 1:
50
+ inputs = F.interpolate(inputs, scale_factor=2.0)
51
+ else:
52
+ inputs = inputs.squeeze(2)
53
+ inputs = F.interpolate(inputs, scale_factor=2.0)
54
+ inputs = inputs[:, :, None, :, :]
55
+ else:
56
+ # only interpolate 2D
57
+ b, c, t, h, w = inputs.shape
58
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
59
+ inputs = F.interpolate(inputs, scale_factor=2.0)
60
+ inputs = inputs.reshape(b, t, c, *inputs.shape[2:]).permute(0, 2, 1, 3, 4)
61
+
62
+ b, c, t, h, w = inputs.shape
63
+ inputs = inputs.permute(0, 2, 1, 3, 4).reshape(b * t, c, h, w)
64
+ inputs = self.conv(inputs)
65
+ inputs = inputs.reshape(b, t, *inputs.shape[1:]).permute(0, 2, 1, 3, 4)
66
+
67
+ return inputs