Spaces:
Runtime error
Runtime error
LinB203
commited on
Commit
•
bab971b
1
Parent(s):
3fdb84a
update
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +45 -0
- LICENSE +21 -0
- app.py +8 -0
- docker/LICENSE +21 -0
- docker/docker_build.sh +8 -0
- docker/docker_run.sh +45 -0
- docker/dockerfile.base +24 -0
- docker/packages.txt +3 -0
- docker/ports.txt +1 -0
- docker/postinstallscript.sh +3 -0
- docker/requirements.txt +40 -0
- docker/setup_env.sh +11 -0
- docs/Contribution_Guidelines.md +87 -0
- docs/Data.md +39 -0
- docs/EVAL.md +110 -0
- docs/Report-v1.0.0-cn.md +135 -0
- docs/Report-v1.0.0.md +136 -0
- docs/Train_And_Eval_CausalVideoVAE.md +158 -0
- docs/VQVAE.md +57 -0
- examples/get_latents_std.py +38 -0
- examples/prompt_list_0.txt +24 -0
- examples/rec_image.py +57 -0
- examples/rec_imvi_vae.py +164 -0
- examples/rec_video_vae.py +275 -0
- opensora/__init__.py +1 -0
- opensora/dataset/__init__.py +65 -0
- opensora/dataset/feature_datasets.py +213 -0
- opensora/dataset/t2v_datasets.py +203 -0
- opensora/dataset/transform.py +573 -0
- opensora/dataset/ucf101.py +80 -0
- opensora/eval/cal_flolpips.py +83 -0
- opensora/eval/cal_fvd.py +85 -0
- opensora/eval/cal_lpips.py +97 -0
- opensora/eval/cal_psnr.py +84 -0
- opensora/eval/cal_ssim.py +113 -0
- opensora/eval/eval_clip_score.py +225 -0
- opensora/eval/eval_common_metric.py +224 -0
- opensora/eval/flolpips/correlation/correlation.py +397 -0
- opensora/eval/flolpips/flolpips.py +308 -0
- opensora/eval/flolpips/pretrained_networks.py +180 -0
- opensora/eval/flolpips/pwcnet.py +344 -0
- opensora/eval/flolpips/utils.py +95 -0
- opensora/eval/fvd/styleganv/fvd.py +90 -0
- opensora/eval/fvd/videogpt/fvd.py +137 -0
- opensora/eval/fvd/videogpt/pytorch_i3d.py +322 -0
- opensora/eval/script/cal_clip_score.sh +23 -0
- opensora/eval/script/cal_fvd.sh +9 -0
- opensora/eval/script/cal_lpips.sh +8 -0
- opensora/eval/script/cal_psnr.sh +9 -0
- opensora/eval/script/cal_ssim.sh +8 -0
.gitignore
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ucf101_stride4x4x4
|
2 |
+
__pycache__
|
3 |
+
*.mp4
|
4 |
+
.ipynb_checkpoints
|
5 |
+
*.pth
|
6 |
+
UCF-101/
|
7 |
+
results/
|
8 |
+
vae
|
9 |
+
build/
|
10 |
+
opensora.egg-info/
|
11 |
+
wandb/
|
12 |
+
.idea
|
13 |
+
*.ipynb
|
14 |
+
*.jpg
|
15 |
+
*.mp3
|
16 |
+
*.safetensors
|
17 |
+
*.mp4
|
18 |
+
*.png
|
19 |
+
*.gif
|
20 |
+
*.pth
|
21 |
+
*.pt
|
22 |
+
cache_dir/
|
23 |
+
wandb/
|
24 |
+
test*
|
25 |
+
sample_video*
|
26 |
+
sample_image*
|
27 |
+
512*
|
28 |
+
720*
|
29 |
+
1024*
|
30 |
+
debug*
|
31 |
+
private*
|
32 |
+
caption*
|
33 |
+
*deepspeed*
|
34 |
+
revised*
|
35 |
+
129f*
|
36 |
+
all*
|
37 |
+
read*
|
38 |
+
YSH*
|
39 |
+
*pick*
|
40 |
+
*ysh*
|
41 |
+
hw*
|
42 |
+
257f*
|
43 |
+
513f*
|
44 |
+
taming*
|
45 |
+
221hw*
|
LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 PKU-YUAN's Group (袁粒课题组-北大信工) and Rabbitpre AI
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
app.py
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
current_path = os.path.abspath(__file__)
|
4 |
+
parent_path = os.path.dirname(current_path)
|
5 |
+
print(parent_path)
|
6 |
+
sys.path.append(parent_path)
|
7 |
+
print(sys.path)
|
8 |
+
os.system('python opensora/serve/gradio_web_server.py')
|
docker/LICENSE
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
MIT License
|
2 |
+
|
3 |
+
Copyright (c) 2024 SimonLee
|
4 |
+
|
5 |
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6 |
+
of this software and associated documentation files (the "Software"), to deal
|
7 |
+
in the Software without restriction, including without limitation the rights
|
8 |
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9 |
+
copies of the Software, and to permit persons to whom the Software is
|
10 |
+
furnished to do so, subject to the following conditions:
|
11 |
+
|
12 |
+
The above copyright notice and this permission notice shall be included in all
|
13 |
+
copies or substantial portions of the Software.
|
14 |
+
|
15 |
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16 |
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17 |
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18 |
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19 |
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20 |
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21 |
+
SOFTWARE.
|
docker/docker_build.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
WORK_DIR=$(dirname "$(readlink -f "$0")")
|
4 |
+
cd $WORK_DIR
|
5 |
+
|
6 |
+
source setup_env.sh
|
7 |
+
|
8 |
+
docker build -t $TAG --build-arg BASE_TAG=$BASE_TAG --build-arg USER_NAME=$USER_NAME --build-arg USER_PASSWD=$USER_PASSWD . -f dockerfile.base
|
docker/docker_run.sh
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
|
3 |
+
WORK_DIR=$(dirname "$(readlink -f "$0")")
|
4 |
+
source $WORK_DIR/setup_env.sh
|
5 |
+
|
6 |
+
RUNNING_IDS="$(docker ps --filter ancestor=$TAG --format "{{.ID}}")"
|
7 |
+
|
8 |
+
if [ -n "$RUNNING_IDS" ]; then
|
9 |
+
# Initialize an array to hold the container IDs
|
10 |
+
declare -a container_ids=($RUNNING_IDS)
|
11 |
+
|
12 |
+
# Get the first container ID using array indexing
|
13 |
+
ID=${container_ids[0]}
|
14 |
+
|
15 |
+
# Print the first container ID
|
16 |
+
echo ' '
|
17 |
+
echo "The running container ID is: $ID, enter it!"
|
18 |
+
else
|
19 |
+
echo ' '
|
20 |
+
echo "Not found running containers, run it!"
|
21 |
+
|
22 |
+
# Run a new docker container instance
|
23 |
+
ID=$(docker run \
|
24 |
+
--rm \
|
25 |
+
--gpus all \
|
26 |
+
-itd \
|
27 |
+
--ipc=host \
|
28 |
+
--ulimit memlock=-1 \
|
29 |
+
--ulimit stack=67108864 \
|
30 |
+
-e DISPLAY=$DISPLAY \
|
31 |
+
-v /tmp/.X11-unix/:/tmp/.X11-unix/ \
|
32 |
+
-v $PWD:/home/$USER_NAME/workspace \
|
33 |
+
-w /home/$USER_NAME/workspace \
|
34 |
+
$(cat $WORK_DIR/ports.txt) \
|
35 |
+
$TAG)
|
36 |
+
fi
|
37 |
+
|
38 |
+
docker logs $ID
|
39 |
+
|
40 |
+
echo ' '
|
41 |
+
echo ' '
|
42 |
+
echo '========================================='
|
43 |
+
echo ' '
|
44 |
+
|
45 |
+
docker exec -it $ID bash
|
docker/dockerfile.base
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
ARG BASE_TAG
|
2 |
+
FROM ${BASE_TAG}
|
3 |
+
ARG USER_NAME=myuser
|
4 |
+
ARG USER_PASSWD=111111
|
5 |
+
ARG DEBIAN_FRONTEND=noninteractive
|
6 |
+
|
7 |
+
# Pre-install packages, pip install requirements and run post install script.
|
8 |
+
COPY packages.txt .
|
9 |
+
COPY requirements.txt .
|
10 |
+
COPY postinstallscript.sh .
|
11 |
+
RUN apt-get update && apt-get install -y sudo $(cat packages.txt)
|
12 |
+
RUN pip install --no-cache-dir -r requirements.txt
|
13 |
+
RUN bash postinstallscript.sh
|
14 |
+
|
15 |
+
# Create a new user and group using the username argument
|
16 |
+
RUN groupadd -r ${USER_NAME} && useradd -r -m -g${USER_NAME} ${USER_NAME}
|
17 |
+
RUN echo "${USER_NAME}:${USER_PASSWD}" | chpasswd
|
18 |
+
RUN usermod -aG sudo ${USER_NAME}
|
19 |
+
USER ${USER_NAME}
|
20 |
+
ENV USER=${USER_NAME}
|
21 |
+
WORKDIR /home/${USER_NAME}/workspace
|
22 |
+
|
23 |
+
# Set the prompt to highlight the username
|
24 |
+
RUN echo "export PS1='\[\033[01;32m\]\u\[\033[00m\]@\[\033[01;34m\]\h\[\033[00m\]:\[\033[01;36m\]\w\[\033[00m\]\$'" >> /home/${USER_NAME}/.bashrc
|
docker/packages.txt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
wget
|
2 |
+
curl
|
3 |
+
git
|
docker/ports.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
-p 6006:6006
|
docker/postinstallscript.sh
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env bash
|
2 |
+
# this script will run when build docker image.
|
3 |
+
|
docker/requirements.txt
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
setuptools>=61.0
|
2 |
+
torch==2.0.1
|
3 |
+
torchvision==0.15.2
|
4 |
+
transformers==4.32.0
|
5 |
+
albumentations==1.4.0
|
6 |
+
av==11.0.0
|
7 |
+
decord==0.6.0
|
8 |
+
einops==0.3.0
|
9 |
+
fastapi==0.110.0
|
10 |
+
accelerate==0.21.0
|
11 |
+
gdown==5.1.0
|
12 |
+
h5py==3.10.0
|
13 |
+
idna==3.6
|
14 |
+
imageio==2.34.0
|
15 |
+
matplotlib==3.7.5
|
16 |
+
numpy==1.24.4
|
17 |
+
omegaconf==2.1.1
|
18 |
+
opencv-python==4.9.0.80
|
19 |
+
opencv-python-headless==4.9.0.80
|
20 |
+
pandas==2.0.3
|
21 |
+
pillow==10.2.0
|
22 |
+
pydub==0.25.1
|
23 |
+
pytorch-lightning==1.4.2
|
24 |
+
pytorchvideo==0.1.5
|
25 |
+
PyYAML==6.0.1
|
26 |
+
regex==2023.12.25
|
27 |
+
requests==2.31.0
|
28 |
+
scikit-learn==1.3.2
|
29 |
+
scipy==1.10.1
|
30 |
+
six==1.16.0
|
31 |
+
tensorboard==2.14.0
|
32 |
+
test-tube==0.7.5
|
33 |
+
timm==0.9.16
|
34 |
+
torchdiffeq==0.2.3
|
35 |
+
torchmetrics==0.5.0
|
36 |
+
tqdm==4.66.2
|
37 |
+
urllib3==2.2.1
|
38 |
+
uvicorn==0.27.1
|
39 |
+
diffusers==0.24.0
|
40 |
+
scikit-video==1.1.11
|
docker/setup_env.sh
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Docker tag for new build image
|
2 |
+
TAG=open_sora_plan:dev
|
3 |
+
|
4 |
+
# Base docker image tag used by docker build
|
5 |
+
BASE_TAG=nvcr.io/nvidia/pytorch:23.05-py3
|
6 |
+
|
7 |
+
# User name used in docker container
|
8 |
+
USER_NAME=developer
|
9 |
+
|
10 |
+
# User password used in docker container
|
11 |
+
USER_PASSWD=666666
|
docs/Contribution_Guidelines.md
ADDED
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Contributing to the Open-Sora Plan Community
|
2 |
+
|
3 |
+
The Open-Sora Plan open-source community is a collaborative initiative driven by the community, emphasizing a commitment to being free and void of exploitation. Organized spontaneously by community members, we invite you to contribute to the Open-Sora Plan open-source community and help elevate it to new heights!
|
4 |
+
|
5 |
+
## Submitting a Pull Request (PR)
|
6 |
+
|
7 |
+
As a contributor, before submitting your request, kindly follow these guidelines:
|
8 |
+
|
9 |
+
1. Start by checking the [Open-Sora Plan GitHub](https://github.com/PKU-YuanGroup/Open-Sora-Plan/pulls) to see if there are any open or closed pull requests related to your intended submission. Avoid duplicating existing work.
|
10 |
+
|
11 |
+
2. [Fork](https://github.com/PKU-YuanGroup/Open-Sora-Plan/fork) the [open-sora plan](https://github.com/PKU-YuanGroup/Open-Sora-Plan) repository and download your forked repository to your local machine.
|
12 |
+
|
13 |
+
```bash
|
14 |
+
git clone [your-forked-repository-url]
|
15 |
+
```
|
16 |
+
|
17 |
+
3. Add the original Open-Sora Plan repository as a remote to sync with the latest updates:
|
18 |
+
|
19 |
+
```bash
|
20 |
+
git remote add upstream https://github.com/PKU-YuanGroup/Open-Sora-Plan
|
21 |
+
```
|
22 |
+
|
23 |
+
4. Sync the code from the main repository to your local machine, and then push it back to your forked remote repository.
|
24 |
+
|
25 |
+
```
|
26 |
+
# Pull the latest code from the upstream branch
|
27 |
+
git fetch upstream
|
28 |
+
|
29 |
+
# Switch to the main branch
|
30 |
+
git checkout main
|
31 |
+
|
32 |
+
# Merge the updates from the upstream branch into main, synchronizing the local main branch with the upstream
|
33 |
+
git merge upstream/main
|
34 |
+
|
35 |
+
# Additionally, sync the local main branch to the remote branch of your forked repository
|
36 |
+
git push origin main
|
37 |
+
```
|
38 |
+
|
39 |
+
|
40 |
+
> Note: Sync the code from the main repository before each submission.
|
41 |
+
|
42 |
+
5. Create a branch in your forked repository for your changes, ensuring the branch name is meaningful.
|
43 |
+
|
44 |
+
```bash
|
45 |
+
git checkout -b my-docs-branch main
|
46 |
+
```
|
47 |
+
|
48 |
+
6. While making modifications and committing changes, adhere to our [Commit Message Format](#Commit-Message-Format).
|
49 |
+
|
50 |
+
```bash
|
51 |
+
git commit -m "[docs]: xxxx"
|
52 |
+
```
|
53 |
+
|
54 |
+
7. Push your changes to your GitHub repository.
|
55 |
+
|
56 |
+
```bash
|
57 |
+
git push origin my-docs-branch
|
58 |
+
```
|
59 |
+
|
60 |
+
8. Submit a pull request to `Open-Sora-Plan:main` on the GitHub repository page.
|
61 |
+
|
62 |
+
## Commit Message Format
|
63 |
+
|
64 |
+
Commit messages must include both `<type>` and `<summary>` sections.
|
65 |
+
|
66 |
+
```bash
|
67 |
+
[<type>]: <summary>
|
68 |
+
│ │
|
69 |
+
│ └─⫸ Briefly describe your changes, without ending with a period.
|
70 |
+
│
|
71 |
+
└─⫸ Commit Type: |docs|feat|fix|refactor|
|
72 |
+
```
|
73 |
+
|
74 |
+
### Type
|
75 |
+
|
76 |
+
* **docs**: Modify or add documents.
|
77 |
+
* **feat**: Introduce a new feature.
|
78 |
+
* **fix**: Fix a bug.
|
79 |
+
* **refactor**: Restructure code, excluding new features or bug fixes.
|
80 |
+
|
81 |
+
### Summary
|
82 |
+
|
83 |
+
Describe modifications in English, without ending with a period.
|
84 |
+
|
85 |
+
> e.g., git commit -m "[docs]: add a contributing.md file"
|
86 |
+
|
87 |
+
This guideline is borrowed by [minisora](https://github.com/mini-sora/minisora). We sincerely appreciate MiniSora authors for their awesome templates.
|
docs/Data.md
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
**We need more dataset**, please refer to the [Open-Sora-Dataset](https://github.com/PKU-YuanGroup/Open-Sora-Dataset) for details.
|
3 |
+
|
4 |
+
## v1.0.0
|
5 |
+
|
6 |
+
### Text-to-Video
|
7 |
+
|
8 |
+
We open source v1.0.0 all the training data, the annotations and the original video can be found [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0).
|
9 |
+
|
10 |
+
These data consist of segmented video clips, with each clip obtained through center cropping. The resolution of each clip is 512×512. There are 64 frames in each clip, and their corresponding captions can be found in the annotation files.
|
11 |
+
|
12 |
+
We present additional details in [report](https://github.com/PKU-YuanGroup/Open-Sora-Plan/blob/main/docs/Report-v1.0.0.md#data-construction) and [Open-Sora-Dataset](https://github.com/PKU-YuanGroup/Open-Sora-Dataset).
|
13 |
+
|
14 |
+
### Class-condition
|
15 |
+
|
16 |
+
In order to download UCF-101 dataset, you can download the necessary files in [here](https://www.crcv.ucf.edu/data/UCF101.php). The code assumes a `ucf101` directory with the following structure
|
17 |
+
```
|
18 |
+
UCF-101/
|
19 |
+
ApplyEyeMakeup/
|
20 |
+
v1.avi
|
21 |
+
...
|
22 |
+
...
|
23 |
+
YoYo/
|
24 |
+
v1.avi
|
25 |
+
...
|
26 |
+
```
|
27 |
+
|
28 |
+
### Un-condition
|
29 |
+
|
30 |
+
We use [sky_timelapse](https://drive.google.com/open?id=1xWLiU-MBGN7MrsFHQm4_yXmfHBsMbJQo), which is an un-condition datasets.
|
31 |
+
|
32 |
+
```
|
33 |
+
sky_timelapse
|
34 |
+
├── readme
|
35 |
+
├── sky_test
|
36 |
+
├── sky_train
|
37 |
+
├── test_videofolder.py
|
38 |
+
└── video_folder.py
|
39 |
+
```
|
docs/EVAL.md
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Evaluate the generated videos quality
|
2 |
+
|
3 |
+
You can easily calculate the following video quality metrics, which supports the batch-wise process.
|
4 |
+
- **CLIP-SCORE**: It uses the pretrained CLIP model to measure the cosine similarity between two modalities.
|
5 |
+
- **FVD**: Frechét Video Distance
|
6 |
+
- **SSIM**: structural similarity index measure
|
7 |
+
- **LPIPS**: learned perceptual image patch similarity
|
8 |
+
- **PSNR**: peak-signal-to-noise ratio
|
9 |
+
|
10 |
+
# Requirement
|
11 |
+
## Environment
|
12 |
+
- install Pytorch (torch>=1.7.1)
|
13 |
+
- install CLIP
|
14 |
+
```
|
15 |
+
pip install git+https://github.com/openai/CLIP.git
|
16 |
+
```
|
17 |
+
- install clip-cose from PyPi
|
18 |
+
```
|
19 |
+
pip install clip-score
|
20 |
+
```
|
21 |
+
- Other package
|
22 |
+
```
|
23 |
+
pip install lpips
|
24 |
+
pip install scipy (scipy==1.7.3/1.9.3, if you use 1.11.3, **you will calculate a WRONG FVD VALUE!!!**)
|
25 |
+
pip install numpy
|
26 |
+
pip install pillow
|
27 |
+
pip install torchvision>=0.8.2
|
28 |
+
pip install ftfy
|
29 |
+
pip install regex
|
30 |
+
pip install tqdm
|
31 |
+
```
|
32 |
+
## Pretrain model
|
33 |
+
- FVD
|
34 |
+
Before you cacluate FVD, you should first download the FVD pre-trained model. You can manually download any of the following and put it into FVD folder.
|
35 |
+
- `i3d_torchscript.pt` from [here](https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt)
|
36 |
+
- `i3d_pretrained_400.pt` from [here](https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI)
|
37 |
+
|
38 |
+
|
39 |
+
## Other Notices
|
40 |
+
1. Make sure the pixel value of videos should be in [0, 1].
|
41 |
+
2. We average SSIM when images have 3 channels, ssim is the only metric extremely sensitive to gray being compared to b/w.
|
42 |
+
3. Because the i3d model downsamples in the time dimension, `frames_num` should > 10 when calculating FVD, so FVD calculation begins from 10-th frame, like upper example.
|
43 |
+
4. For grayscale videos, we multiply to 3 channels
|
44 |
+
5. data input specifications for clip_score
|
45 |
+
> - Image Files:All images should be stored in a single directory. The image files can be in either .png or .jpg format.
|
46 |
+
>
|
47 |
+
> - Text Files: All text data should be contained in plain text files in a separate directory. These text files should have the extension .txt.
|
48 |
+
>
|
49 |
+
> Note: The number of files in the image directory should be exactly equal to the number of files in the text directory. Additionally, the files in the image directory and text directory should be paired by file name. For instance, if there is a cat.png in the image directory, there should be a corresponding cat.txt in the text directory.
|
50 |
+
>
|
51 |
+
> Directory Structure Example:
|
52 |
+
> ```
|
53 |
+
> ├── path/to/image
|
54 |
+
> │ ├── cat.png
|
55 |
+
> │ ├── dog.png
|
56 |
+
> │ └── bird.jpg
|
57 |
+
> └── path/to/text
|
58 |
+
> ├── cat.txt
|
59 |
+
> ├── dog.txt
|
60 |
+
> └── bird.txt
|
61 |
+
> ```
|
62 |
+
|
63 |
+
6. data input specifications for fvd, psnr, ssim, lpips
|
64 |
+
|
65 |
+
> Directory Structure Example:
|
66 |
+
> ```
|
67 |
+
> ├── path/to/generated_image
|
68 |
+
> │ ├── cat.mp4
|
69 |
+
> │ ├── dog.mp4
|
70 |
+
> │ └── bird.mp4
|
71 |
+
> └── path/to/real_image
|
72 |
+
> ├── cat.mp4
|
73 |
+
> ├── dog.mp4
|
74 |
+
> └── bird.mp4
|
75 |
+
> ```
|
76 |
+
|
77 |
+
|
78 |
+
|
79 |
+
# Usage
|
80 |
+
|
81 |
+
```
|
82 |
+
# you change the file path and need to set the frame_num, resolution etc...
|
83 |
+
|
84 |
+
# clip_score cross modality
|
85 |
+
cd opensora/eval
|
86 |
+
bash script/cal_clip_score.sh
|
87 |
+
|
88 |
+
|
89 |
+
|
90 |
+
# fvd
|
91 |
+
cd opensora/eval
|
92 |
+
bash script/cal_fvd.sh
|
93 |
+
|
94 |
+
# psnr
|
95 |
+
cd opensora/eval
|
96 |
+
bash eval/script/cal_psnr.sh
|
97 |
+
|
98 |
+
|
99 |
+
# ssim
|
100 |
+
cd opensora/eval
|
101 |
+
bash eval/script/cal_ssim.sh
|
102 |
+
|
103 |
+
|
104 |
+
# lpips
|
105 |
+
cd opensora/eval
|
106 |
+
bash eval/script/cal_lpips.sh
|
107 |
+
```
|
108 |
+
|
109 |
+
# Acknowledgement
|
110 |
+
The evaluation codebase refers to [clip-score](https://github.com/Taited/clip-score) and [common_metrics](https://github.com/JunyaoHu/common_metrics_on_video_quality).
|
docs/Report-v1.0.0-cn.md
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# 技术报告 v1.0.0
|
2 |
+
|
3 |
+
在2024年3月,我们推出了Open-Sora-Plan,一个旨在复现OpenAI [Sora](https://openai.com/sora)的开源计划。它作为一个基础的开源框架,能够训练视频生成模型包括无条件视频生成,类别引导视频生成,文生视频。
|
4 |
+
|
5 |
+
**今天,我们兴奋地展示Open-Sora-Plan v1.0.0,极大地改进视频生成质量、文本控制能力。**
|
6 |
+
|
7 |
+
相比于之前的视频生成模型,Open-Sora-Plan v1.0.0 有以下的改进:
|
8 |
+
|
9 |
+
1. **CausalVideoVAE高效的训练与推理**。 我们用4×8×8的对视频进行时间和空间的压缩。
|
10 |
+
2. **图片视频联合训练提升视觉质量**。 CasualVideoVAE 将首帧看作图片,天然支持同时编码图片和视频。这允许扩散模型提取更多时空细节来改善质量。
|
11 |
+
|
12 |
+
|
13 |
+
### Open-Source Release
|
14 |
+
我们开源了Open-Sora-Plan去促进视频生成社区的进一步发展。公开代码、数据、模型。
|
15 |
+
- 在线演示:Hugging Face [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0), [![Replicate demo and cloud API](https://replicate.com/camenduru/open-sora-plan-512x512/badge)](https://replicate.com/camenduru/open-sora-plan-512x512) 和 [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb), 感谢[@camenduru](https://github.com/camenduru)大力支持我们的工作!🤝
|
16 |
+
- 代码:所有训练脚本和采样代码。
|
17 |
+
- 模型:包括扩散模型和CausalVideoVAE [这里](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0)。
|
18 |
+
- 数据:所有原视频和对应描述 [这里](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)。
|
19 |
+
|
20 |
+
## 效果
|
21 |
+
|
22 |
+
Open-Sora-Plan v1.0.0支持图片视频联合训练。我们在此展示视频和图片的重建以及生成:
|
23 |
+
|
24 |
+
720×1280**视频重建**。 因为github的限制,原视频放在: [1](https://streamable.com/gqojal), [2](https://streamable.com/6nu3j8).
|
25 |
+
|
26 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/c100bb02-2420-48a3-9d7b-4608a41f14aa
|
27 |
+
|
28 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/8aa8f587-d9f1-4e8b-8a82-d3bf9ba91d68
|
29 |
+
|
30 |
+
1536×1024**图片重建**
|
31 |
+
|
32 |
+
<img src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1684c3ec-245d-4a60-865c-b8946d788eb9" width="45%"/> <img src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/46ef714e-3e5b-492c-aec4-3793cb2260b5" width="45%"/>
|
33 |
+
|
34 |
+
65×1024×1024**文生视频**
|
35 |
+
|
36 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/2641a8aa-66ac-4cda-8279-86b2e6a6e011
|
37 |
+
|
38 |
+
65×512×512**文生视频**
|
39 |
+
|
40 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/37e3107e-56b3-4b09-8920-fa1d8d144b9e
|
41 |
+
|
42 |
+
|
43 |
+
512×512**文生视频**
|
44 |
+
|
45 |
+
![download](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/491d72bc-e762-48ff-bdcc-cc69350f56d6)
|
46 |
+
|
47 |
+
## 详细技术报告
|
48 |
+
|
49 |
+
### CausalVideoVAE
|
50 |
+
|
51 |
+
#### 模型结构
|
52 |
+
|
53 |
+
![image](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/e3c8b35d-a217-4d96-b2e9-5c248a2859c8)
|
54 |
+
|
55 |
+
因果VAE架构继承了[Stable-Diffusion Image VAE](https://github.com/CompVis/stable-diffusion/tree/main)。 为了保证图片VAE的预训练权重可以无缝运用到视频VAE中,模型结构采取如下设计:
|
56 |
+
|
57 |
+
1. **CausalConv3D**: 将Conv2D 转变成CausalConv3D可以实现图片和视频的联合训练. CausalConv3D 对第一帧进行特殊处理,因为它无法访问后续帧。对于更多细节,请参考https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/145
|
58 |
+
|
59 |
+
2. **初始化**:将Conv2D扩展到Conv3D常用的[方法](https://github.com/hassony2/inflated_convnets_pytorch/blob/master/src/inflate.py#L5)有两种:平均初始化和中心初始化。 但我们采用了特定的初始化方法(尾部初始化)。 这种初始化方法确保模型无需任何训练就能够直接重建图像,甚至视频。
|
60 |
+
|
61 |
+
#### 训练细节
|
62 |
+
|
63 |
+
<img width="833" alt="image" src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/9ffb6dc4-23f6-4274-a066-bbebc7522a14">
|
64 |
+
|
65 |
+
我们展示了 17×256×256 下两种不同初始化方法的损失曲线。黄色曲线代表使用尾部初始化的损失,而蓝色曲线对应中心初始化的损失。 如图所示,尾部初始化在损失曲线上表现出更好的性能。 此外,我们发现中心初始化会导致错误累积,导致长时间内崩溃。
|
66 |
+
|
67 |
+
#### 推理技巧
|
68 |
+
尽管训练Diffusion中VAE始终是冻住的,我们仍然无法负担CasualVideoVAE的花销。在我们的实验中, 80G的显存只能够在半精度下推理一个256×512×512或32×1024×1024的视频 ,这限制了我们扩展到更长更高清的视频。因此我们采用tile convolution,能够以几乎恒定的内存推理任意时长或任意分辨率的视频。
|
69 |
+
|
70 |
+
### 数据构建
|
71 |
+
我们定义高质量的视频数据集包括两个核心法则:(1) 没有与内容无关的水��。(2) 高质量的文本注释。
|
72 |
+
|
73 |
+
**对于法则1**,我们从开源网站(CC0协议)爬取了大约40k videos:1234个来自[mixkit](https://mixkit.co/),7408个来自[pexels](https://www.pexels.com/),31616个来自[pixabay](https://pixabay.com/)。我们根据[Panda70M](https://github.com/snap-research/Panda-70M/blob/main/splitting/README.md)提供的场景变换剪切script将这些视频切成大约434k video clips。事实上,根据我们的剪切结果,从这些网上上爬取的99%的视频都是单一的场景。另外,我们发现爬取的数据中超过60%为风景相关视频。更多细节可以在[这](https://github.com/PKU-YuanGroup/Open-Sora-Dataset)找到。
|
74 |
+
|
75 |
+
**对于法则2**,很难有大量的高质量的文本注释能够从网上直接爬取。因此我们用成熟的图片标注模型来获取高质量的稠密描述。我们对2个多模态大模型进行消融实验:[ShareGPT4V-Captioner-7B](https://github.com/InternLM/InternLM-XComposer/blob/main/projects/ShareGPT4V/README.md) 和 [LLaVA-1.6-34B](https://github.com/haotian-liu/LLaVA)。前者是专门用来制作文本注释的模型,而后者是一个通用的多模态大模型。经过我们的消融实验,他们在caption的表现差不多。然而他们的推理速度在A800上差距很大:40s/it of batch size of 12 for ShareGPT4V-Captioner-7B,15s/it of batch size of 1 for ShareGPT4V-Captioner-7B。我们开源所有的[文本注释和原视频](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0)。
|
76 |
+
|
77 |
+
| 模型名字 | 平均长度 | 最大值 | 标准差 |
|
78 |
+
|---|---|---|---|
|
79 |
+
| ShareGPT4V-Captioner-7B | 170.0827524529121 | 467 | 53.689967539537776 |
|
80 |
+
| LLaVA-1.6-34B | 141.75851073472666 | 472 | 48.52492072346965 |
|
81 |
+
|
82 |
+
### 训练扩散模型
|
83 |
+
与之前的工作类似,我们采用多阶段的级联的训练方法,总共消耗了2048个A800 GPU 小时。我们发现联合图片训练能够显著加速模型的收敛并且增强视觉观感,这与[Latte](https://github.com/Vchitect/Latte)一致。以下是我们的训练花销。
|
84 |
+
|
85 |
+
| 名字 | Stage 1 | Stage 2 | Stage 3 | Stage 4 |
|
86 |
+
|---|---|---|---|---|
|
87 |
+
| 训练视频尺寸 | 17×256×256 | 65×256×256 | 65×512×512 | 65×1024×1024 |
|
88 |
+
| 计算资源 (#A800 GPU x #小时) | 32 × 40 | 32 × 18 | 32 × 6 | 训练中 |
|
89 |
+
| 权重 | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512) | 训练中 |
|
90 |
+
| 日志 | [wandb](https://api.wandb.ai/links/linbin/p6n3evym) | [wandb](https://api.wandb.ai/links/linbin/t2g53sew) | [wandb](https://api.wandb.ai/links/linbin/uomr0xzb) | 训练中 |
|
91 |
+
| 训练数据 | ~40k videos | ~40k videos | ~40k videos | ~40k videos |
|
92 |
+
|
93 |
+
## 下版本预览
|
94 |
+
### CausalVideoVAE
|
95 |
+
目前我们发布的CausalVideoVAE v1.0.0版本存在2个主要的缺陷:**运动模糊**以及**网格效应**。我们对CasualVideoVAE做了一系列的改进使它推理成本更低且性能更强大,我们暂时叫它为预览版本,将在下个版本发布。
|
96 |
+
|
97 |
+
**1分钟720×1280视频重建**。 受限于GitHub,我们将原视频放在这:[原视频](https://streamable.com/u4onbb),[重建视频](https://streamable.com/qt8ncc)。
|
98 |
+
|
99 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/cdcfa9a3-4de0-42d4-94c0-0669710e407b
|
100 |
+
|
101 |
+
我们从kinetic 400的验证集中随机选取100个样本进行评估,结果表如下所示:
|
102 |
+
|
103 |
+
| | SSIM↑ | LPIPS↓ | PSNR↑ | FLOLPIPS↓ |
|
104 |
+
|---|---|---|---|---|
|
105 |
+
| v1.0.0 | 0.829 | 0.106 | 27.171 | 0.119 |
|
106 |
+
| Preview | 0.877 | 0.064 | 29.695 | 0.070 |
|
107 |
+
|
108 |
+
#### 运动模糊
|
109 |
+
|
110 |
+
| **v1.0.0** | **预览版本** |
|
111 |
+
| --- | --- |
|
112 |
+
| ![6862cae0-b1b6-48d1-bd11-84348cf42b42](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/f815636f-fb38-4891-918b-50b1f9aa086d) | ![9189da06-ef2c-42e6-ad34-bd702a6f538e](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1e413f50-a785-485a-9851-a1449f952f1c) |
|
113 |
+
|
114 |
+
#### 网格效应
|
115 |
+
|
116 |
+
| **v1.0.0** | **预览版本** |
|
117 |
+
| --- | --- |
|
118 |
+
| ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/7fec5bed-3c83-4ee9-baef-4a3dacafc658) | ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/4f41b432-a3ef-484e-a492-8afd8a691bf7) |
|
119 |
+
|
120 |
+
### 数据构建
|
121 |
+
|
122 |
+
**数据源**:正如上文提到,我们的数据集中超过60%为风景视频。这意味着我们的开域视频生成能力有限。然而当前的大规模开源数据集大多从YouTube爬取,尽管视频的数量多,但我们担忧视频本身的质量是否达标。因此,我们将继续收集高质量的数据集,同时也欢迎开源社区的推荐。
|
123 |
+
|
124 |
+
**Caption生成流程**:当我们训练时长增加时,我们不得不考虑更有效的视频caption生成方法,而不是多模态图片大模型。我们正在开发一个新的视频注释生成管线,它能够很好的支持长视频,敬请期待。
|
125 |
+
|
126 |
+
### 训练扩散模型
|
127 |
+
尽管目前v1.0.0展现了可喜的结果,但我们仍然离Sora有一段距离。在接下来的工作中,我们主要围绕这三个方面:
|
128 |
+
|
129 |
+
1. **动态分辨率与时长的训练**: 我们的目标是开发出能够以不同分辨率和持续时间训练模型的技术,使训练过程更加灵活、适应性更强。
|
130 |
+
|
131 |
+
2. **更长的视频生成**: 我们将探索扩展模型生成能力的方法,使其能够制作更长的视频,超越目前的限制。
|
132 |
+
|
133 |
+
3. **更多条件控制**: 我们力求增强模型的条件控制能力,为用户提供更多的选项和对生成视频的控制能力。
|
134 |
+
|
135 |
+
另外,通过仔细观察生成的视频,我们发现存在一些不符合常理的斑点或异常的流动,这是由于CasualVideoVAE的性能不足导致的 如上面提到。在未来的实验中,我们将使用更强的VAE,重新训练一个扩散模型。
|
docs/Report-v1.0.0.md
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Report v1.0.0
|
2 |
+
|
3 |
+
In March 2024, we launched a plan called Open-Sora-Plan, which aims to reproduce the OpenAI [Sora](https://openai.com/sora) through an open-source framework. As a foundational open-source framework, it enables training of video generation models, including Unconditioned Video Generation, Class Video Generation, and Text-to-Video Generation.
|
4 |
+
|
5 |
+
**Today, we are thrilled to present Open-Sora-Plan v1.0.0, which significantly enhances video generation quality and text control capabilities.**
|
6 |
+
|
7 |
+
Compared with previous video generation model, Open-Sora-Plan v1.0.0 has several improvements:
|
8 |
+
|
9 |
+
1. **Efficient training and inference with CausalVideoVAE**. We apply a spatial-temporal compression to the videos by 4×8×8.
|
10 |
+
2. **Joint image-video training for better quality**. Our CausalVideoVAE considers the first frame as an image, allowing for the simultaneous encoding of both images and videos in a natural manner. This allows the diffusion model to grasp more spatial-visual details to improve visual quality.
|
11 |
+
|
12 |
+
### Open-Source Release
|
13 |
+
We open-source the Open-Sora-Plan to facilitate future development of Video Generation in the community. Code, data, model are made publicly available.
|
14 |
+
- Demo: Hugging Face demo [![hf_space](https://img.shields.io/badge/🤗-Open%20In%20Spaces-blue.svg)](https://huggingface.co/spaces/LanguageBind/Open-Sora-Plan-v1.0.0). 🤝 Enjoying the [![Replicate demo and cloud API](https://replicate.com/camenduru/open-sora-plan-512x512/badge)](https://replicate.com/camenduru/open-sora-plan-512x512) and [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/camenduru/Open-Sora-Plan-jupyter/blob/main/Open_Sora_Plan_jupyter.ipynb), created by [@camenduru](https://github.com/camenduru), who generously supports our research!
|
15 |
+
- Code: All training scripts and sample scripts.
|
16 |
+
- Model: Both Diffusion Model and CausalVideoVAE [here](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0).
|
17 |
+
- Data: Both raw videos and captions [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0).
|
18 |
+
|
19 |
+
## Gallery
|
20 |
+
|
21 |
+
Open-Sora-Plan v1.0.0 supports joint training of images and videos. Here, we present the capabilities of Video/Image Reconstruction and Generation:
|
22 |
+
|
23 |
+
### CausalVideoVAE Reconstruction
|
24 |
+
|
25 |
+
**Video Reconstruction** with 720×1280. Since github can't upload large video, we put it here: [1](https://streamable.com/gqojal), [2](https://streamable.com/6nu3j8).
|
26 |
+
|
27 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/c100bb02-2420-48a3-9d7b-4608a41f14aa
|
28 |
+
|
29 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/8aa8f587-d9f1-4e8b-8a82-d3bf9ba91d68
|
30 |
+
|
31 |
+
**Image Reconstruction** in 1536×1024.
|
32 |
+
|
33 |
+
<img src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1684c3ec-245d-4a60-865c-b8946d788eb9" width="45%"/> <img src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/46ef714e-3e5b-492c-aec4-3793cb2260b5" width="45%"/>
|
34 |
+
|
35 |
+
**Text-to-Video Generation** with 65×1024×1024
|
36 |
+
|
37 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/2641a8aa-66ac-4cda-8279-86b2e6a6e011
|
38 |
+
|
39 |
+
**Text-to-Video Generation** with 65×512×512
|
40 |
+
|
41 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/37e3107e-56b3-4b09-8920-fa1d8d144b9e
|
42 |
+
|
43 |
+
|
44 |
+
**Text-to-Image Generation** with 512×512
|
45 |
+
|
46 |
+
![download](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/62638829/491d72bc-e762-48ff-bdcc-cc69350f56d6)
|
47 |
+
|
48 |
+
## Detailed Technical Report
|
49 |
+
|
50 |
+
### CausalVideoVAE
|
51 |
+
|
52 |
+
#### Model Structure
|
53 |
+
|
54 |
+
![image](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/e3c8b35d-a217-4d96-b2e9-5c248a2859c8)
|
55 |
+
|
56 |
+
The CausalVideoVAE architecture inherits from the [Stable-Diffusion Image VAE](https://github.com/CompVis/stable-diffusion/tree/main). To ensure that the pretrained weights of the Image VAE can be seamlessly applied to the Video VAE, the model structure has been designed as follows:
|
57 |
+
|
58 |
+
1. **CausalConv3D**: Converting Conv2D to CausalConv3D enables joint training of image and video data. CausalConv3D applies a special treatment to the first frame, as it does not have access to subsequent frames. For more specific details, please refer to https://github.com/PKU-YuanGroup/Open-Sora-Plan/pull/145
|
59 |
+
|
60 |
+
2. **Initialization**: There are two common [methods](https://github.com/hassony2/inflated_convnets_pytorch/blob/master/src/inflate.py#L5) to expand Conv2D to Conv3D: average initialization and center initialization. But we employ a specific initialization method (tail initialization). This initialization method ensures that without any training, the model is capable of directly reconstructing images, and even videos.
|
61 |
+
|
62 |
+
#### Training Details
|
63 |
+
|
64 |
+
<img width="833" alt="image" src="https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/9ffb6dc4-23f6-4274-a066-bbebc7522a14">
|
65 |
+
|
66 |
+
We present the loss curves for two distinct initialization methods under 17×256×256. The yellow curve represents the loss using tail init, while the blue curve corresponds to the loss from center initialization. As shown in the graph, tail initialization demonstrates better performance on the loss curve. Additionally, we found that center initialization leads to error accumulation, causing the collapse over extended durations.
|
67 |
+
|
68 |
+
#### Inference Tricks
|
69 |
+
Despite the VAE in Diffusion training being frozen, we still find it challenging to afford the cost of the CausalVideoVAE. In our case, with 80GB of GPU memory, we can only infer a video of either 256×512×512 or 32×1024×1024 resolution using half-precision, which limits our ability to scale up to longer and higher-resolution videos. Therefore, we adopt tile convolution, which allows us to infer videos of arbitrary duration or resolution with nearly constant memory usage.
|
70 |
+
|
71 |
+
### Data Construction
|
72 |
+
We define a high-quality video dataset based on two core principles: (1) No content-unrelated watermarks. (2) High-quality and dense captions.
|
73 |
+
|
74 |
+
**For principles 1**, we crawled approximately 40,000 videos from open-source websites under the CC0 license. Specifically, we obtained 1,234 videos from [mixkit](https://mixkit.co/), 7,408 videos from [pexels](https://www.pexels.com/), and 31,616 videos from [pixabay](https://pixabay.com/). These videos adhere to the principle of having no content-unrelated watermarks. According to the scene transformation and clipping script provided by [Panda70M](https://github.com/snap-research/Panda-70M/blob/main/splitting/README.md), we have divided these videos into approximately 434,000 video clips. In fact, based on our clipping results, 99% of the videos obtained from these online sources are found to contain single scenes. Additionally, we have observed that over 60% of the crawled data comprises landscape videos. More details can be found [here](https://github.com/PKU-YuanGroup/Open-Sora-Dataset).
|
75 |
+
|
76 |
+
**For principles 2**, it is challenging to directly crawl a large quantity of high-quality dense captions from the internet. Therefore, we utilize a mature Image-captioner model to obtain high-quality dense captions. We conducted ablation experiments on two multimodal large models: [ShareGPT4V-Captioner-7B](https://github.com/InternLM/InternLM-XComposer/blob/main/projects/ShareGPT4V/README.md) and [LLaVA-1.6-34B](https://github.com/haotian-liu/LLaVA). The former is specifically designed for caption generation, while the latter is a general-purpose multimodal large model. After conducting our ablation experiments, we found that they are comparable in performance. However, there is a significant difference in their inference speed on the A800 GPU: 40s/it of batch size of 12 for ShareGPT4V-Captioner-7B, 15s/it of batch size of 1 for LLaVA-1.6-34B. We open-source all annotations [here](https://huggingface.co/datasets/LanguageBind/Open-Sora-Plan-v1.0.0). We show some statistics here, and we set the maximum length of the model to 300, which covers almost 99% of the samples.
|
77 |
+
|
78 |
+
| Name | Avg length | Max | Std |
|
79 |
+
|---|---|---|---|
|
80 |
+
| ShareGPT4V-Captioner-7B | 170.0827524529121 | 467 | 53.689967539537776 |
|
81 |
+
| LLaVA-1.6-34B | 141.75851073472666 | 472 | 48.52492072346965 |
|
82 |
+
|
83 |
+
### Training Diffusion Model
|
84 |
+
Similar to previous work, we employ a multi-stage cascaded training approach, which consumes a total of 2,528 A800 GPU hours. We found that joint training with images significantly accelerates model convergence and enhances visual perception, aligning with the findings of [Latte](https://github.com/Vchitect/Latte). Below is our training card:
|
85 |
+
|
86 |
+
| Name | Stage 1 | Stage 2 | Stage 3 | Stage 4 |
|
87 |
+
|---|---|---|---|---|
|
88 |
+
| Training Video Size | 17×256×256 | 65×256×256 | 65×512×512 | 65×1024×1024 |
|
89 |
+
| Compute (#A800 GPU x #Hours) | 32 × 40 | 32 × 22 | 32 × 17 | Under training |
|
90 |
+
| Checkpoint | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/17x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x256x256) | [HF](https://huggingface.co/LanguageBind/Open-Sora-Plan-v1.0.0/tree/main/65x512x512) | Under training |
|
91 |
+
| Log | [wandb](https://api.wandb.ai/links/linbin/p6n3evym) | [wandb](https://api.wandb.ai/links/linbin/t2g53sew) | [wandb](https://api.wandb.ai/links/linbin/uomr0xzb) | Under training |
|
92 |
+
| Training Data | ~40k videos | ~40k videos | ~40k videos | ~40k videos |
|
93 |
+
|
94 |
+
## Next Release Preview
|
95 |
+
### CausalVideoVAE
|
96 |
+
Currently, the released version of CausalVideoVAE (v1.0.0) has two main drawbacks: **motion blurring** and **gridding effect**. We have made a series of improvements to CausalVideoVAE to reduce its inference cost and enhance its performance. We are currently referring to this enhanced version as the "preview version," which will be released in the next update. Preview reconstruction is as follows:
|
97 |
+
|
98 |
+
**1 min Video Reconstruction with 720×1280**. Since github can't put too big video, we put it here: [origin video](https://streamable.com/u4onbb), [reconstruction video](https://streamable.com/qt8ncc).
|
99 |
+
|
100 |
+
https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/cdcfa9a3-4de0-42d4-94c0-0669710e407b
|
101 |
+
|
102 |
+
We randomly selected 100 samples from the validation set of Kinetics-400 for evaluation, and the results are presented in the following table:
|
103 |
+
|
104 |
+
| | SSIM↑ | LPIPS↓ | PSNR↑ | FLOLPIPS↓ |
|
105 |
+
|---|---|---|---|---|
|
106 |
+
| v1.0.0 | 0.829 | 0.106 | 27.171 | 0.119 |
|
107 |
+
| Preview | 0.877 | 0.064 | 29.695 | 0.070 |
|
108 |
+
|
109 |
+
#### Motion Blurring
|
110 |
+
|
111 |
+
| **v1.0.0** | **Preview** |
|
112 |
+
| --- | --- |
|
113 |
+
| ![6862cae0-b1b6-48d1-bd11-84348cf42b42](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/f815636f-fb38-4891-918b-50b1f9aa086d) | ![9189da06-ef2c-42e6-ad34-bd702a6f538e](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/1e413f50-a785-485a-9851-a1449f952f1c) |
|
114 |
+
|
115 |
+
#### Gridding effect
|
116 |
+
|
117 |
+
| **v1.0.0** | **Preview** |
|
118 |
+
| --- | --- |
|
119 |
+
| ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/7fec5bed-3c83-4ee9-baef-4a3dacafc658) | ![img](https://github.com/PKU-YuanGroup/Open-Sora-Plan/assets/88202804/4f41b432-a3ef-484e-a492-8afd8a691bf7) |
|
120 |
+
|
121 |
+
### Data Construction
|
122 |
+
|
123 |
+
**Data source**. As mentioned earlier, over 60% of our dataset consists of landscape videos. This implies that our ability to generate videos in other domains is limited. However, most of the current large-scale open-source datasets are primarily obtained through web scraping from platforms like YouTube. While these datasets provide a vast quantity of videos, we have concerns about the quality of the videos themselves. Therefore, we will continue to collect high-quality datasets and also welcome recommendations from the open-source community. We are launching an Open-Sora-Dataset project, check out the details at [Open-Sora-Dataset](https://github.com/PKU-YuanGroup/Open-Sora-Dataset)
|
124 |
+
|
125 |
+
**Caption Generation Pipeline**. As the video duration increases, we need to consider more efficient methods for video caption generation instead of relying solely on large multimodal image models. We are currently developing a new video caption generation pipeline that provides robust support for long videos. We are excited to share more details with you in the near future. Stay tuned!
|
126 |
+
|
127 |
+
### Training Diffusion Model
|
128 |
+
Although v1.0.0 has shown promising results, we acknowledge that we still have a ways to go to reach the level of Sora. In our upcoming work, we will primarily focus on three aspects:
|
129 |
+
|
130 |
+
1. **Training support for dynamic resolution and duration**: We aim to develop techniques that enable training models with varying resolutions and durations, allowing for more flexible and adaptable training processes.
|
131 |
+
|
132 |
+
2. **Support for longer video generation**: We will explore methods to extend the generation capabilities of our models, enabling them to produce longer videos beyond the current limitations.
|
133 |
+
|
134 |
+
3. **Enhanced conditional control**: We seek to enhance the conditional control capabilities of our models, providing users with more options and control over the generated videos.
|
135 |
+
|
136 |
+
Furthermore, through careful observation of the generated videos, we have noticed the presence of some non-physiological speckles or abnormal flow. This can be attributed to the limited performance of CausalVideoVAE, as mentioned earlier. In future experiments, we plan to retrain a diffusion model using a more powerful version of CausalVideoVAE to address these issues.
|
docs/Train_And_Eval_CausalVideoVAE.md
ADDED
@@ -0,0 +1,158 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Training
|
2 |
+
|
3 |
+
To execute in the terminal: `bash scripts/causalvae/train.sh`
|
4 |
+
|
5 |
+
> When using GAN loss for training, two backward propagations are required. However, when [custom optimizers](https://lightning.ai/docs/pytorch/stable/model/manual_optimization.html#use-multiple-optimizers-like-gans) are implemented in PyTorch Lightning, it can lead to the training step count being doubled, meaning each training loop effectively results in two steps. This issue can make it counterintuitive when setting the training step count and the starting step count for the GAN loss.
|
6 |
+
|
7 |
+
## Code Structure
|
8 |
+
|
9 |
+
CausalVideoVAE is located in the directory `opensora/models/ae/videobase`. The directory structure is as follows:
|
10 |
+
|
11 |
+
```
|
12 |
+
.
|
13 |
+
├── causal_vae
|
14 |
+
├── causal_vqvae
|
15 |
+
├── configuration_videobase.py
|
16 |
+
├── dataset_videobase.py
|
17 |
+
├── __init__.py
|
18 |
+
├── losses
|
19 |
+
├── modeling_videobase.py
|
20 |
+
├── modules
|
21 |
+
├── __pycache__
|
22 |
+
├── trainer_videobase.py
|
23 |
+
├── utils
|
24 |
+
└── vqvae
|
25 |
+
```
|
26 |
+
|
27 |
+
The `casual_vae` directory defines the overall structure of the CausalVideoVAE model, and the `modules` directory contains some of the required modules for the model, including **CausalConv3D**, **ResnetBlock3D**, **Attention**, etc. The `losses` directory includes **GAN loss**, **Perception loss**, and other content.
|
28 |
+
|
29 |
+
## Configuration
|
30 |
+
|
31 |
+
Model training requires two key files: one is the `config.json` file, which configures the model structure, loss function, learning rate, etc. The other is the `train.sh` file, which configures the dataset, training steps, precision, etc.
|
32 |
+
|
33 |
+
### Model Configuration File
|
34 |
+
|
35 |
+
Taking the release version model configuration file `release.json` as an example:
|
36 |
+
|
37 |
+
```json
|
38 |
+
{
|
39 |
+
"_class_name": "CausalVAEModel",
|
40 |
+
"_diffusers_version": "0.27.2",
|
41 |
+
"attn_resolutions": [],
|
42 |
+
"decoder_attention": "AttnBlock3D",
|
43 |
+
"decoder_conv_in": "CausalConv3d",
|
44 |
+
"decoder_conv_out": "CausalConv3d",
|
45 |
+
"decoder_mid_resnet": "ResnetBlock3D",
|
46 |
+
"decoder_resnet_blocks": [
|
47 |
+
"ResnetBlock3D",
|
48 |
+
"ResnetBlock3D",
|
49 |
+
"ResnetBlock3D",
|
50 |
+
"ResnetBlock3D"
|
51 |
+
],
|
52 |
+
"decoder_spatial_upsample": [
|
53 |
+
"",
|
54 |
+
"SpatialUpsample2x",
|
55 |
+
"SpatialUpsample2x",
|
56 |
+
"SpatialUpsample2x"
|
57 |
+
],
|
58 |
+
"decoder_temporal_upsample": [
|
59 |
+
"",
|
60 |
+
"",
|
61 |
+
"TimeUpsample2x",
|
62 |
+
"TimeUpsample2x"
|
63 |
+
],
|
64 |
+
"double_z": true,
|
65 |
+
"dropout": 0.0,
|
66 |
+
"embed_dim": 4,
|
67 |
+
"encoder_attention": "AttnBlock3D",
|
68 |
+
"encoder_conv_in": "CausalConv3d",
|
69 |
+
"encoder_conv_out": "CausalConv3d",
|
70 |
+
"encoder_mid_resnet": "ResnetBlock3D",
|
71 |
+
"encoder_resnet_blocks": [
|
72 |
+
"ResnetBlock3D",
|
73 |
+
"ResnetBlock3D",
|
74 |
+
"ResnetBlock3D",
|
75 |
+
"ResnetBlock3D"
|
76 |
+
],
|
77 |
+
"encoder_spatial_downsample": [
|
78 |
+
"SpatialDownsample2x",
|
79 |
+
"SpatialDownsample2x",
|
80 |
+
"SpatialDownsample2x",
|
81 |
+
""
|
82 |
+
],
|
83 |
+
"encoder_temporal_downsample": [
|
84 |
+
"TimeDownsample2x",
|
85 |
+
"TimeDownsample2x",
|
86 |
+
"",
|
87 |
+
""
|
88 |
+
],
|
89 |
+
"hidden_size": 128,
|
90 |
+
"hidden_size_mult": [
|
91 |
+
1,
|
92 |
+
2,
|
93 |
+
4,
|
94 |
+
4
|
95 |
+
],
|
96 |
+
"loss_params": {
|
97 |
+
"disc_start": 2001,
|
98 |
+
"disc_weight": 0.5,
|
99 |
+
"kl_weight": 1e-06,
|
100 |
+
"logvar_init": 0.0
|
101 |
+
},
|
102 |
+
"loss_type": "opensora.models.ae.videobase.losses.LPIPSWithDiscriminator",
|
103 |
+
"lr": 1e-05,
|
104 |
+
"num_res_blocks": 2,
|
105 |
+
"q_conv": "CausalConv3d",
|
106 |
+
"resolution": 256,
|
107 |
+
"z_channels": 4
|
108 |
+
}
|
109 |
+
```
|
110 |
+
|
111 |
+
It configures the modules used in different layers of the encoder and decoder, as well as the loss. By changing the model configuration file, it is easy to train different model structures.
|
112 |
+
|
113 |
+
### Training Script
|
114 |
+
|
115 |
+
The following is a description of the parameters for the `train_causalvae.py`:
|
116 |
+
|
117 |
+
| Parameter | Default Value | Description |
|
118 |
+
|-----------------------------|-----------------|--------------------------------------------------------|
|
119 |
+
| `--exp_name` | "causalvae" | The name of the experiment, used for the folder where results are saved. |
|
120 |
+
| `--batch_size` | 1 | The number of samples per training iteration. |
|
121 |
+
| `--precision` | "bf16" | The numerical precision type used for training. |
|
122 |
+
| `--max_steps` | 100000 | The maximum number of steps for the training process. |
|
123 |
+
| `--save_steps` | 2000 | The interval at which to save the model during training. |
|
124 |
+
| `--output_dir` | "results/causalvae" | The directory where training results are saved. |
|
125 |
+
| `--video_path` | "/remote-home1/dataset/data_split_tt" | The path where the video data is stored. |
|
126 |
+
| `--video_num_frames` | 17 | The number of frames per video. |
|
127 |
+
| `--sample_rate` | 1 | The sampling rate, indicating the number of video frames per second. |
|
128 |
+
| `--dynamic_sample` | False | Whether to use dynamic sampling. |
|
129 |
+
| `--model_config` | "scripts/causalvae/288.yaml" | The path to the model configuration file. |
|
130 |
+
| `--n_nodes` | 1 | The number of nodes used for training. |
|
131 |
+
| `--devices` | 8 | The number of devices used for training. |
|
132 |
+
| `--resolution` | 256 | The resolution of the videos. |
|
133 |
+
| `--num_workers` | 8 | The number of subprocesses used for data handling. |
|
134 |
+
| `--resume_from_checkpoint` | None | Resume training from a specified checkpoint. |
|
135 |
+
| `--load_from_checkpoint` | None | Load the model from a specified checkpoint. |
|
136 |
+
|
137 |
+
Please ensure that the values provided for these parameters are appropriate for your training setup.
|
138 |
+
|
139 |
+
# Evaluation
|
140 |
+
|
141 |
+
|
142 |
+
1. Video Generation:
|
143 |
+
The script `scripts/causalvae/gen_video.sh` in the repository is utilized for generating videos. For the parameters, please refer to the script itself.
|
144 |
+
|
145 |
+
2. Video Evaluation:
|
146 |
+
After video generation, You can evaluate the generated videos using the `scripts/causalvae/eval.sh` script. This evaluation script supports common metrics, including lpips, flolpips, ssim, psnr, and more.
|
147 |
+
|
148 |
+
> Please note that you must generate the videos before executing the eval script. Additionally, it is essential to ensure that the video parameters used when generating the videos are consistent with those used during the evaluation.
|
149 |
+
|
150 |
+
# How to Import a Trained Model
|
151 |
+
|
152 |
+
Our model class inherits from the configuration and model management classes of huggingface, supporting the download and loading of models from huggingface. It can also import models trained with pytorch lightning.
|
153 |
+
|
154 |
+
```
|
155 |
+
model = CausalVAEModel.from_pretrained(args.ckpt)
|
156 |
+
model = model.to(device)
|
157 |
+
```
|
158 |
+
|
docs/VQVAE.md
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# VQVAE Documentation
|
2 |
+
|
3 |
+
# Introduction
|
4 |
+
|
5 |
+
Vector Quantized Variational AutoEncoders (VQ-VAE) is a type of autoencoder that uses a discrete latent representation. It is particularly useful for tasks that require discrete latent variables, such as text-to-speech and video generation.
|
6 |
+
|
7 |
+
# Usage
|
8 |
+
|
9 |
+
## Initialization
|
10 |
+
|
11 |
+
To initialize a VQVAE model, you can use the `VideoGPTVQVAE` class. This class is a part of the `opensora.models.ae` module.
|
12 |
+
|
13 |
+
```python
|
14 |
+
from opensora.models.ae import VideoGPTVQVAE
|
15 |
+
|
16 |
+
vqvae = VideoGPTVQVAE()
|
17 |
+
```
|
18 |
+
|
19 |
+
### Training
|
20 |
+
|
21 |
+
To train the VQVAE model, you can use the `train_videogpt.sh` script. This script will train the model using the parameters specified in the script.
|
22 |
+
|
23 |
+
```bash
|
24 |
+
bash scripts/videogpt/train_videogpt.sh
|
25 |
+
```
|
26 |
+
|
27 |
+
### Loading Pretrained Models
|
28 |
+
|
29 |
+
You can load a pretrained model using the `download_and_load_model` method. This method will download the checkpoint file and load the model.
|
30 |
+
|
31 |
+
```python
|
32 |
+
vqvae = VideoGPTVQVAE.download_and_load_model("bair_stride4x2x2")
|
33 |
+
```
|
34 |
+
|
35 |
+
Alternatively, you can load a model from a checkpoint using the `load_from_checkpoint` method.
|
36 |
+
|
37 |
+
```python
|
38 |
+
vqvae = VQVAEModel.load_from_checkpoint("results/VQVAE/checkpoint-1000")
|
39 |
+
```
|
40 |
+
|
41 |
+
### Encoding and Decoding
|
42 |
+
|
43 |
+
You can encode a video using the `encode` method. This method will return the encodings and embeddings of the video.
|
44 |
+
|
45 |
+
```python
|
46 |
+
encodings, embeddings = vqvae.encode(x_vae, include_embeddings=True)
|
47 |
+
```
|
48 |
+
|
49 |
+
You can reconstruct a video from its encodings using the decode method.
|
50 |
+
|
51 |
+
```python
|
52 |
+
video_recon = vqvae.decode(encodings)
|
53 |
+
```
|
54 |
+
|
55 |
+
## Testing
|
56 |
+
|
57 |
+
You can test the VQVAE model by reconstructing a video. The `examples/rec_video.py` script provides an example of how to do this.
|
examples/get_latents_std.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import DataLoader, Subset
|
3 |
+
import sys
|
4 |
+
sys.path.append(".")
|
5 |
+
from opensora.models.ae.videobase import CausalVAEModel, CausalVAEDataset
|
6 |
+
|
7 |
+
num_workers = 4
|
8 |
+
batch_size = 12
|
9 |
+
|
10 |
+
torch.manual_seed(0)
|
11 |
+
torch.set_grad_enabled(False)
|
12 |
+
|
13 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
14 |
+
|
15 |
+
pretrained_model_name_or_path = 'results/causalvae/checkpoint-26000'
|
16 |
+
data_path = '/remote-home1/dataset/UCF-101'
|
17 |
+
video_num_frames = 17
|
18 |
+
resolution = 128
|
19 |
+
sample_rate = 10
|
20 |
+
|
21 |
+
vae = CausalVAEModel.load_from_checkpoint(pretrained_model_name_or_path)
|
22 |
+
vae.to(device)
|
23 |
+
|
24 |
+
dataset = CausalVAEDataset(data_path, sequence_length=video_num_frames, resolution=resolution, sample_rate=sample_rate)
|
25 |
+
subset_indices = list(range(1000))
|
26 |
+
subset_dataset = Subset(dataset, subset_indices)
|
27 |
+
loader = DataLoader(subset_dataset, batch_size=8, pin_memory=True)
|
28 |
+
|
29 |
+
all_latents = []
|
30 |
+
for video_data in loader:
|
31 |
+
video_data = video_data['video'].to(device)
|
32 |
+
latents = vae.encode(video_data).sample()
|
33 |
+
all_latents.append(video_data.cpu())
|
34 |
+
|
35 |
+
all_latents_tensor = torch.cat(all_latents)
|
36 |
+
std = all_latents_tensor.std().item()
|
37 |
+
normalizer = 1 / std
|
38 |
+
print(f'{normalizer = }')
|
examples/prompt_list_0.txt
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
A quiet beach at dawn, the waves gently lapping at the shore and the sky painted in pastel hues.
|
2 |
+
A quiet beach at dawn, the waves softly lapping at the shore, pink and orange hues painting the sky, offering a moment of solitude and reflection.
|
3 |
+
A stylish woman walks down a Tokyo street filled with warm glowing neon and animated city signage. She wears a black leather jacket, a long red dress, and black boots, and carries a black purse. She wears sunglasses and red lipstick. She walks confidently and casually. The street is damp and reflective, creating a mirror effect of the colorful lights. Many pedestrians walk about.
|
4 |
+
Several giant wooly mammoths approach treading through a snowy meadow, their long wooly fur lightly blows in the wind as they walk, snow covered trees and dramatic snow capped mountains in the distance, mid afternoon light with wispy clouds and a sun high in the distance creates a warm glow, the low camera view is stunning capturing the large furry mammal with beautiful photography, depth of field.
|
5 |
+
A movie trailer featuring the adventures of the 30 year old space man wearing a red wool knitted motorcycle helmet, blue sky, salt desert, cinematic style, shot on 35mm film, vivid colors.
|
6 |
+
Drone view of waves crashing against the rugged cliffs along Big Sur’s garay point beach. The crashing blue waters create white-tipped waves, while the golden light of the setting sun illuminates the rocky shore. A small island with a lighthouse sits in the distance, and green shrubbery covers the cliff’s edge. The steep drop from the road down to the beach is a dramatic feat, with the cliff’s edges jutting out over the sea. This is a view that captures the raw beauty of the coast and the rugged landscape of the Pacific Coast Highway.
|
7 |
+
Animated scene features a close-up of a short fluffy monster kneeling beside a melting red candle. The art style is 3D and realistic, with a focus on lighting and texture. The mood of the painting is one of wonder and curiosity, as the monster gazes at the flame with wide eyes and open mouth. Its pose and expression convey a sense of innocence and playfulness, as if it is exploring the world around it for the first time. The use of warm colors and dramatic lighting further enhances the cozy atmosphere of the image.
|
8 |
+
A gorgeously rendered papercraft world of a coral reef, rife with colorful fish and sea creatures.
|
9 |
+
This close-up shot of a Victoria crowned pigeon showcases its striking blue plumage and red chest. Its crest is made of delicate, lacy feathers, while its eye is a striking red color. The bird’s head is tilted slightly to the side, giving the impression of it looking regal and majestic. The background is blurred, drawing attention to the bird’s striking appearance.
|
10 |
+
Photorealistic closeup video of two pirate ships battling each other as they sail inside a cup of coffee.
|
11 |
+
The majestic beauty of a waterfall cascading down a cliff into a serene lake.
|
12 |
+
Sunset over the sea.
|
13 |
+
a cat wearing sunglasses and working as a lifeguard at pool.
|
14 |
+
Slow pan upward of blazing oak fire in an indoor fireplace.
|
15 |
+
Yellow and black tropical fish dart through the sea.
|
16 |
+
a serene winter scene in a forest. The forest is blanketed in a thick layer of snow, which has settled on the branches of the trees, creating a canopy of white. The trees, a mix of evergreens and deciduous, stand tall and silent, their forms partially obscured by the snow. The ground is a uniform white, with no visible tracks or signs of human activity. The sun is low in the sky, casting a warm glow that contrasts with the cool tones of the snow. The light filters through the trees, creating a soft, diffused illumination that highlights the texture of the snow and the contours of the trees. The overall style of the scene is naturalistic, with a focus on the tranquility and beauty of the winter landscape.
|
17 |
+
a dynamic interaction between the ocean and a large rock. The rock, with its rough texture and jagged edges, is partially submerged in the water, suggesting it is a natural feature of the coastline. The water around the rock is in motion, with white foam and waves crashing against the rock, indicating the force of the ocean's movement. The background is a vast expanse of the ocean, with small ripples and waves, suggesting a moderate sea state. The overall style of the scene is a realistic depiction of a natural landscape, with a focus on the interplay between the rock and the water.
|
18 |
+
A serene waterfall cascading down moss-covered rocks, its soothing sound creating a harmonious symphony with nature.
|
19 |
+
A soaring drone footage captures the majestic beauty of a coastal cliff, its red and yellow stratified rock faces rich in color and against the vibrant turquoise of the sea. Seabirds can be seen taking flight around the cliff's precipices. As the drone slowly moves from different angles, the changing sunlight casts shifting shadows that highlight the rugged textures of the cliff and the surrounding calm sea. The water gently laps at the rock base and the greenery that clings to the top of the cliff, and the scene gives a sense of peaceful isolation at the fringes of the ocean. The video captures the essence of pristine natural beauty untouched by human structures.
|
20 |
+
The video captures the majestic beauty of a waterfall cascading down a cliff into a serene lake. The waterfall, with its powerful flow, is the central focus of the video. The surrounding landscape is lush and green, with trees and foliage adding to the natural beauty of the scene. The camera angle provides a bird's eye view of the waterfall, allowing viewers to appreciate the full height and grandeur of the waterfall. The video is a stunning representation of nature's power and beauty.
|
21 |
+
A vibrant scene of a snowy mountain landscape. The sky is filled with a multitude of colorful hot air balloons, each floating at different heights, creating a dynamic and lively atmosphere. The balloons are scattered across the sky, some closer to the viewer, others further away, adding depth to the scene. Below, the mountainous terrain is blanketed in a thick layer of snow, with a few patches of bare earth visible here and there. The snow-covered mountains provide a stark contrast to the colorful balloons, enhancing the visual appeal of the scene.
|
22 |
+
A serene underwater scene featuring a sea turtle swimming through a coral reef. The turtle, with its greenish-brown shell, is the main focus of the video, swimming gracefully towards the right side of the frame. The coral reef, teeming with life, is visible in the background, providing a vibrant and colorful backdrop to the turtle's journey. Several small fish, darting around the turtle, add a sense of movement and dynamism to the scene.
|
23 |
+
A snowy forest landscape with a dirt road running through it. The road is flanked by trees covered in snow, and the ground is also covered in snow. The sun is shining, creating a bright and serene atmosphere. The road appears to be empty, and there are no people or animals visible in the video. The style of the video is a natural landscape shot, with a focus on the beauty of the snowy forest and the peacefulness of the road.
|
24 |
+
The dynamic movement of tall, wispy grasses swaying in the wind. The sky above is filled with clouds, creating a dramatic backdrop. The sunlight pierces through the clouds, casting a warm glow on the scene. The grasses are a mix of green and brown, indicating a change in seasons. The overall style of the video is naturalistic, capturing the beauty of the landscape in a realistic manner. The focus is on the grasses and their movement, with the sky serving as a secondary element. The video does not contain any human or animal elements.
|
examples/rec_image.py
ADDED
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import sys
|
2 |
+
sys.path.append(".")
|
3 |
+
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from torchvision.transforms import ToTensor, Compose, Resize, Normalize
|
6 |
+
from torch.nn import functional as F
|
7 |
+
from opensora.models.ae.videobase import CausalVAEModel
|
8 |
+
import argparse
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
def preprocess(video_data: torch.Tensor, short_size: int = 128) -> torch.Tensor:
|
12 |
+
transform = Compose(
|
13 |
+
[
|
14 |
+
ToTensor(),
|
15 |
+
Normalize((0.5), (0.5)),
|
16 |
+
Resize(size=short_size),
|
17 |
+
]
|
18 |
+
)
|
19 |
+
outputs = transform(video_data)
|
20 |
+
outputs = outputs.unsqueeze(0).unsqueeze(2)
|
21 |
+
return outputs
|
22 |
+
|
23 |
+
def main(args: argparse.Namespace):
|
24 |
+
image_path = args.image_path
|
25 |
+
resolution = args.resolution
|
26 |
+
device = args.device
|
27 |
+
|
28 |
+
vqvae = CausalVAEModel.load_from_checkpoint(args.ckpt)
|
29 |
+
vqvae.eval()
|
30 |
+
vqvae = vqvae.to(device)
|
31 |
+
|
32 |
+
with torch.no_grad():
|
33 |
+
x_vae = preprocess(Image.open(image_path), resolution)
|
34 |
+
x_vae = x_vae.to(device)
|
35 |
+
latents = vqvae.encode(x_vae)
|
36 |
+
recon = vqvae.decode(latents.sample())
|
37 |
+
x = recon[0, :, 0, :, :]
|
38 |
+
x = x.squeeze()
|
39 |
+
x = x.detach().cpu().numpy()
|
40 |
+
x = np.clip(x, -1, 1)
|
41 |
+
x = (x + 1) / 2
|
42 |
+
x = (255*x).astype(np.uint8)
|
43 |
+
x = x.transpose(1,2,0)
|
44 |
+
image = Image.fromarray(x)
|
45 |
+
image.save(args.rec_path)
|
46 |
+
|
47 |
+
|
48 |
+
if __name__ == '__main__':
|
49 |
+
parser = argparse.ArgumentParser()
|
50 |
+
parser.add_argument('--image-path', type=str, default='')
|
51 |
+
parser.add_argument('--rec-path', type=str, default='')
|
52 |
+
parser.add_argument('--ckpt', type=str, default='')
|
53 |
+
parser.add_argument('--resolution', type=int, default=336)
|
54 |
+
parser.add_argument('--device', type=str, default='cuda')
|
55 |
+
|
56 |
+
args = parser.parse_args()
|
57 |
+
main(args)
|
examples/rec_imvi_vae.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import random
|
3 |
+
import argparse
|
4 |
+
from typing import Optional
|
5 |
+
|
6 |
+
import cv2
|
7 |
+
import numpy as np
|
8 |
+
import numpy.typing as npt
|
9 |
+
import torch
|
10 |
+
from PIL import Image
|
11 |
+
from decord import VideoReader, cpu
|
12 |
+
from torch.nn import functional as F
|
13 |
+
from pytorchvideo.transforms import ShortSideScale
|
14 |
+
from torchvision.transforms import Lambda, Compose
|
15 |
+
|
16 |
+
import sys
|
17 |
+
sys.path.append(".")
|
18 |
+
|
19 |
+
from opensora.models.ae import getae_wrapper
|
20 |
+
from opensora.dataset.transform import CenterCropVideo, resize
|
21 |
+
from opensora.models.ae.videobase import CausalVAEModel
|
22 |
+
|
23 |
+
def array_to_video(image_array: npt.NDArray, fps: float = 30.0, output_file: str = 'output_video.mp4') -> None:
|
24 |
+
height, width, channels = image_array[0].shape
|
25 |
+
fourcc = cv2.VideoWriter_fourcc(*'mp4v')
|
26 |
+
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
|
27 |
+
|
28 |
+
for image in image_array:
|
29 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
30 |
+
video_writer.write(image_rgb)
|
31 |
+
|
32 |
+
video_writer.release()
|
33 |
+
|
34 |
+
|
35 |
+
def custom_to_video(x: torch.Tensor, fps: float = 2.0, output_file: str = 'output_video.mp4') -> None:
|
36 |
+
x = x.detach().cpu()
|
37 |
+
x = torch.clamp(x, -1, 1)
|
38 |
+
x = (x + 1) / 2
|
39 |
+
x = x.permute(0, 2, 3, 1).numpy()
|
40 |
+
x = (255 * x).astype(np.uint8)
|
41 |
+
array_to_video(x, fps=fps, output_file=output_file)
|
42 |
+
return
|
43 |
+
|
44 |
+
|
45 |
+
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
|
46 |
+
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
47 |
+
total_frames = len(decord_vr)
|
48 |
+
sample_frames_len = sample_rate * num_frames
|
49 |
+
|
50 |
+
if total_frames > sample_frames_len:
|
51 |
+
s = random.randint(0, total_frames - sample_frames_len - 1)
|
52 |
+
s = 0
|
53 |
+
e = s + sample_frames_len
|
54 |
+
num_frames = num_frames
|
55 |
+
else:
|
56 |
+
s = 0
|
57 |
+
e = total_frames
|
58 |
+
num_frames = int(total_frames / sample_frames_len * num_frames)
|
59 |
+
print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
|
60 |
+
total_frames)
|
61 |
+
|
62 |
+
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
|
63 |
+
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
|
64 |
+
video_data = torch.from_numpy(video_data)
|
65 |
+
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
66 |
+
return video_data
|
67 |
+
|
68 |
+
|
69 |
+
class ResizeVideo:
|
70 |
+
def __init__(
|
71 |
+
self,
|
72 |
+
size,
|
73 |
+
interpolation_mode="bilinear",
|
74 |
+
):
|
75 |
+
self.size = size
|
76 |
+
|
77 |
+
self.interpolation_mode = interpolation_mode
|
78 |
+
|
79 |
+
def __call__(self, clip):
|
80 |
+
_, _, h, w = clip.shape
|
81 |
+
if w < h:
|
82 |
+
new_h = int(math.floor((float(h) / w) * self.size))
|
83 |
+
new_w = self.size
|
84 |
+
else:
|
85 |
+
new_h = self.size
|
86 |
+
new_w = int(math.floor((float(w) / h) * self.size))
|
87 |
+
return torch.nn.functional.interpolate(
|
88 |
+
clip, size=(new_h, new_w), mode=self.interpolation_mode, align_corners=False, antialias=True
|
89 |
+
)
|
90 |
+
|
91 |
+
def __repr__(self) -> str:
|
92 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
93 |
+
|
94 |
+
|
95 |
+
def preprocess(video_data: torch.Tensor, short_size: int = 128, crop_size: Optional[int] = None) -> torch.Tensor:
|
96 |
+
transform = Compose(
|
97 |
+
[
|
98 |
+
Lambda(lambda x: ((x / 255.0) * 2 - 1)),
|
99 |
+
ResizeVideo(size=short_size),
|
100 |
+
CenterCropVideo(crop_size) if crop_size is not None else Lambda(lambda x: x),
|
101 |
+
]
|
102 |
+
)
|
103 |
+
|
104 |
+
video_outputs = transform(video_data)
|
105 |
+
video_outputs = torch.unsqueeze(video_outputs, 0)
|
106 |
+
|
107 |
+
return video_outputs
|
108 |
+
|
109 |
+
|
110 |
+
def main(args: argparse.Namespace):
|
111 |
+
device = args.device
|
112 |
+
kwarg = {}
|
113 |
+
# vae = getae_wrapper(args.ae)(args.model_path, subfolder="vae", cache_dir='cache_dir', **kwarg).to(device)
|
114 |
+
vae = getae_wrapper(args.ae)(args.ae_path, **kwarg).to(device)
|
115 |
+
if args.enable_tiling:
|
116 |
+
vae.vae.enable_tiling()
|
117 |
+
vae.vae.tile_overlap_factor = args.tile_overlap_factor
|
118 |
+
vae.eval()
|
119 |
+
vae = vae.to(device)
|
120 |
+
vae = vae.half()
|
121 |
+
|
122 |
+
with torch.no_grad():
|
123 |
+
x_vae = preprocess(read_video(args.video_path, args.num_frames, args.sample_rate), args.resolution,
|
124 |
+
args.crop_size)
|
125 |
+
x_vae = x_vae.to(device, dtype=torch.float16) # b c t h w
|
126 |
+
# from tqdm import tqdm
|
127 |
+
# for i in tqdm(range(10000000)):
|
128 |
+
latents = vae.encode(x_vae)
|
129 |
+
latents = latents.to(torch.float16)
|
130 |
+
video_recon = vae.decode(latents) # b t c h w
|
131 |
+
|
132 |
+
if video_recon.shape[2] == 1:
|
133 |
+
x = video_recon[0, 0, :, :, :]
|
134 |
+
x = x.squeeze()
|
135 |
+
x = x.detach().cpu().numpy()
|
136 |
+
x = np.clip(x, -1, 1)
|
137 |
+
x = (x + 1) / 2
|
138 |
+
x = (255 * x).astype(np.uint8)
|
139 |
+
x = x.transpose(1, 2, 0)
|
140 |
+
image = Image.fromarray(x)
|
141 |
+
image.save(args.rec_path.replace('mp4', 'jpg'))
|
142 |
+
else:
|
143 |
+
custom_to_video(video_recon[0], fps=args.fps, output_file=args.rec_path)
|
144 |
+
|
145 |
+
|
146 |
+
if __name__ == '__main__':
|
147 |
+
parser = argparse.ArgumentParser()
|
148 |
+
parser.add_argument('--video_path', type=str, default='')
|
149 |
+
parser.add_argument('--rec_path', type=str, default='')
|
150 |
+
parser.add_argument('--ae', type=str, default='')
|
151 |
+
parser.add_argument('--ae_path', type=str, default='')
|
152 |
+
parser.add_argument('--model_path', type=str, default='results/pretrained')
|
153 |
+
parser.add_argument('--fps', type=int, default=30)
|
154 |
+
parser.add_argument('--resolution', type=int, default=336)
|
155 |
+
parser.add_argument('--crop_size', type=int, default=None)
|
156 |
+
parser.add_argument('--num_frames', type=int, default=100)
|
157 |
+
parser.add_argument('--sample_rate', type=int, default=1)
|
158 |
+
parser.add_argument('--device', type=str, default="cuda")
|
159 |
+
parser.add_argument('--tile_overlap_factor', type=float, default=0.25)
|
160 |
+
parser.add_argument('--enable_tiling', action='store_true')
|
161 |
+
parser.add_argument('--enable_time_chunk', action='store_true')
|
162 |
+
|
163 |
+
args = parser.parse_args()
|
164 |
+
main(args)
|
examples/rec_video_vae.py
ADDED
@@ -0,0 +1,275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import random
|
2 |
+
import argparse
|
3 |
+
import cv2
|
4 |
+
from tqdm import tqdm
|
5 |
+
import numpy as np
|
6 |
+
import numpy.typing as npt
|
7 |
+
import torch
|
8 |
+
from decord import VideoReader, cpu
|
9 |
+
from torch.nn import functional as F
|
10 |
+
from pytorchvideo.transforms import ShortSideScale
|
11 |
+
from torchvision.transforms import Lambda, Compose
|
12 |
+
from torchvision.transforms._transforms_video import CenterCropVideo
|
13 |
+
import sys
|
14 |
+
from torch.utils.data import Dataset, DataLoader, Subset
|
15 |
+
import os
|
16 |
+
|
17 |
+
sys.path.append(".")
|
18 |
+
from opensora.models.ae.videobase import CausalVAEModel
|
19 |
+
import torch.nn as nn
|
20 |
+
|
21 |
+
|
22 |
+
def array_to_video(
|
23 |
+
image_array: npt.NDArray, fps: float = 30.0, output_file: str = "output_video.mp4"
|
24 |
+
) -> None:
|
25 |
+
height, width, channels = image_array[0].shape
|
26 |
+
fourcc = cv2.VideoWriter_fourcc(*"mp4v")
|
27 |
+
video_writer = cv2.VideoWriter(output_file, fourcc, float(fps), (width, height))
|
28 |
+
|
29 |
+
for image in image_array:
|
30 |
+
image_rgb = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
31 |
+
video_writer.write(image_rgb)
|
32 |
+
|
33 |
+
video_writer.release()
|
34 |
+
|
35 |
+
|
36 |
+
def custom_to_video(
|
37 |
+
x: torch.Tensor, fps: float = 2.0, output_file: str = "output_video.mp4"
|
38 |
+
) -> None:
|
39 |
+
x = x.detach().cpu()
|
40 |
+
x = torch.clamp(x, -1, 1)
|
41 |
+
x = (x + 1) / 2
|
42 |
+
x = x.permute(1, 2, 3, 0).float().numpy()
|
43 |
+
x = (255 * x).astype(np.uint8)
|
44 |
+
array_to_video(x, fps=fps, output_file=output_file)
|
45 |
+
return
|
46 |
+
|
47 |
+
|
48 |
+
def read_video(video_path: str, num_frames: int, sample_rate: int) -> torch.Tensor:
|
49 |
+
decord_vr = VideoReader(video_path, ctx=cpu(0), num_threads=8)
|
50 |
+
total_frames = len(decord_vr)
|
51 |
+
sample_frames_len = sample_rate * num_frames
|
52 |
+
|
53 |
+
if total_frames > sample_frames_len:
|
54 |
+
s = 0
|
55 |
+
e = s + sample_frames_len
|
56 |
+
num_frames = num_frames
|
57 |
+
else:
|
58 |
+
s = 0
|
59 |
+
e = total_frames
|
60 |
+
num_frames = int(total_frames / sample_frames_len * num_frames)
|
61 |
+
print(
|
62 |
+
f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}",
|
63 |
+
video_path,
|
64 |
+
total_frames,
|
65 |
+
)
|
66 |
+
|
67 |
+
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
|
68 |
+
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
|
69 |
+
video_data = torch.from_numpy(video_data)
|
70 |
+
video_data = video_data.permute(3, 0, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
71 |
+
return video_data
|
72 |
+
|
73 |
+
|
74 |
+
class RealVideoDataset(Dataset):
|
75 |
+
def __init__(
|
76 |
+
self,
|
77 |
+
real_video_dir,
|
78 |
+
num_frames,
|
79 |
+
sample_rate=1,
|
80 |
+
crop_size=None,
|
81 |
+
resolution=128,
|
82 |
+
) -> None:
|
83 |
+
super().__init__()
|
84 |
+
self.real_video_files = self._combine_without_prefix(real_video_dir)
|
85 |
+
self.num_frames = num_frames
|
86 |
+
self.sample_rate = sample_rate
|
87 |
+
self.crop_size = crop_size
|
88 |
+
self.short_size = resolution
|
89 |
+
|
90 |
+
def __len__(self):
|
91 |
+
return len(self.real_video_files)
|
92 |
+
|
93 |
+
def __getitem__(self, index):
|
94 |
+
if index >= len(self):
|
95 |
+
raise IndexError
|
96 |
+
real_video_file = self.real_video_files[index]
|
97 |
+
real_video_tensor = self._load_video(real_video_file)
|
98 |
+
video_name = os.path.basename(real_video_file)
|
99 |
+
return {'video': real_video_tensor, 'file_name': video_name }
|
100 |
+
|
101 |
+
def _load_video(self, video_path):
|
102 |
+
num_frames = self.num_frames
|
103 |
+
sample_rate = self.sample_rate
|
104 |
+
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
105 |
+
total_frames = len(decord_vr)
|
106 |
+
sample_frames_len = sample_rate * num_frames
|
107 |
+
|
108 |
+
if total_frames > sample_frames_len:
|
109 |
+
s = 0
|
110 |
+
e = s + sample_frames_len
|
111 |
+
num_frames = num_frames
|
112 |
+
else:
|
113 |
+
s = 0
|
114 |
+
e = total_frames
|
115 |
+
num_frames = int(total_frames / sample_frames_len * num_frames)
|
116 |
+
print(
|
117 |
+
f"sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}",
|
118 |
+
video_path,
|
119 |
+
total_frames,
|
120 |
+
)
|
121 |
+
|
122 |
+
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
|
123 |
+
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
|
124 |
+
video_data = torch.from_numpy(video_data)
|
125 |
+
video_data = video_data.permute(3, 0, 1, 2)
|
126 |
+
return _preprocess(
|
127 |
+
video_data, short_size=self.short_size, crop_size=self.crop_size
|
128 |
+
)
|
129 |
+
|
130 |
+
def _combine_without_prefix(self, folder_path, prefix="."):
|
131 |
+
folder = []
|
132 |
+
for name in os.listdir(folder_path):
|
133 |
+
if name[0] == prefix:
|
134 |
+
continue
|
135 |
+
folder.append(os.path.join(folder_path, name))
|
136 |
+
folder.sort()
|
137 |
+
return folder
|
138 |
+
|
139 |
+
def resize(x, resolution):
|
140 |
+
height, width = x.shape[-2:]
|
141 |
+
aspect_ratio = width / height
|
142 |
+
if width <= height:
|
143 |
+
new_width = resolution
|
144 |
+
new_height = int(resolution / aspect_ratio)
|
145 |
+
else:
|
146 |
+
new_height = resolution
|
147 |
+
new_width = int(resolution * aspect_ratio)
|
148 |
+
resized_x = F.interpolate(x, size=(new_height, new_width), mode='bilinear', align_corners=True, antialias=True)
|
149 |
+
return resized_x
|
150 |
+
|
151 |
+
def _preprocess(video_data, short_size=128, crop_size=None):
|
152 |
+
transform = Compose(
|
153 |
+
[
|
154 |
+
Lambda(lambda x: ((x / 255.0) * 2 - 1)),
|
155 |
+
Lambda(lambda x: resize(x, short_size)),
|
156 |
+
(
|
157 |
+
CenterCropVideo(crop_size=crop_size)
|
158 |
+
if crop_size is not None
|
159 |
+
else Lambda(lambda x: x)
|
160 |
+
),
|
161 |
+
]
|
162 |
+
)
|
163 |
+
video_outputs = transform(video_data)
|
164 |
+
video_outputs = _format_video_shape(video_outputs)
|
165 |
+
return video_outputs
|
166 |
+
|
167 |
+
|
168 |
+
def _format_video_shape(video, time_compress=4, spatial_compress=8):
|
169 |
+
time = video.shape[1]
|
170 |
+
height = video.shape[2]
|
171 |
+
width = video.shape[3]
|
172 |
+
new_time = (
|
173 |
+
(time - (time - 1) % time_compress)
|
174 |
+
if (time - 1) % time_compress != 0
|
175 |
+
else time
|
176 |
+
)
|
177 |
+
new_height = (
|
178 |
+
(height - (height) % spatial_compress)
|
179 |
+
if height % spatial_compress != 0
|
180 |
+
else height
|
181 |
+
)
|
182 |
+
new_width = (
|
183 |
+
(width - (width) % spatial_compress) if width % spatial_compress != 0 else width
|
184 |
+
)
|
185 |
+
return video[:, :new_time, :new_height, :new_width]
|
186 |
+
|
187 |
+
|
188 |
+
@torch.no_grad()
|
189 |
+
def main(args: argparse.Namespace):
|
190 |
+
real_video_dir = args.real_video_dir
|
191 |
+
generated_video_dir = args.generated_video_dir
|
192 |
+
ckpt = args.ckpt
|
193 |
+
sample_rate = args.sample_rate
|
194 |
+
resolution = args.resolution
|
195 |
+
crop_size = args.crop_size
|
196 |
+
num_frames = args.num_frames
|
197 |
+
sample_rate = args.sample_rate
|
198 |
+
device = args.device
|
199 |
+
sample_fps = args.sample_fps
|
200 |
+
batch_size = args.batch_size
|
201 |
+
num_workers = args.num_workers
|
202 |
+
subset_size = args.subset_size
|
203 |
+
|
204 |
+
if not os.path.exists(args.generated_video_dir):
|
205 |
+
os.makedirs(args.generated_video_dir, exist_ok=True)
|
206 |
+
|
207 |
+
data_type = torch.bfloat16
|
208 |
+
|
209 |
+
# ---- Load Model ----
|
210 |
+
device = args.device
|
211 |
+
vqvae = CausalVAEModel.from_pretrained(args.ckpt)
|
212 |
+
vqvae = vqvae.to(device).to(data_type)
|
213 |
+
if args.enable_tiling:
|
214 |
+
vqvae.enable_tiling()
|
215 |
+
vqvae.tile_overlap_factor = args.tile_overlap_factor
|
216 |
+
# ---- Load Model ----
|
217 |
+
|
218 |
+
# ---- Prepare Dataset ----
|
219 |
+
dataset = RealVideoDataset(
|
220 |
+
real_video_dir=real_video_dir,
|
221 |
+
num_frames=num_frames,
|
222 |
+
sample_rate=sample_rate,
|
223 |
+
crop_size=crop_size,
|
224 |
+
resolution=resolution,
|
225 |
+
)
|
226 |
+
|
227 |
+
if subset_size:
|
228 |
+
indices = range(subset_size)
|
229 |
+
dataset = Subset(dataset, indices=indices)
|
230 |
+
|
231 |
+
dataloader = DataLoader(
|
232 |
+
dataset, batch_size=batch_size, pin_memory=True, num_workers=num_workers
|
233 |
+
)
|
234 |
+
# ---- Prepare Dataset
|
235 |
+
|
236 |
+
# ---- Inference ----
|
237 |
+
for batch in tqdm(dataloader):
|
238 |
+
x, file_names = batch['video'], batch['file_name']
|
239 |
+
x = x.to(device=device, dtype=data_type) # b c t h w
|
240 |
+
latents = vqvae.encode(x).sample().to(data_type)
|
241 |
+
video_recon = vqvae.decode(latents)
|
242 |
+
for idx, video in enumerate(video_recon):
|
243 |
+
output_path = os.path.join(generated_video_dir, file_names[idx])
|
244 |
+
if args.output_origin:
|
245 |
+
os.makedirs(os.path.join(generated_video_dir, "origin/"), exist_ok=True)
|
246 |
+
origin_output_path = os.path.join(generated_video_dir, "origin/", file_names[idx])
|
247 |
+
custom_to_video(
|
248 |
+
x[idx], fps=sample_fps / sample_rate, output_file=origin_output_path
|
249 |
+
)
|
250 |
+
custom_to_video(
|
251 |
+
video, fps=sample_fps / sample_rate, output_file=output_path
|
252 |
+
)
|
253 |
+
# ---- Inference ----
|
254 |
+
|
255 |
+
if __name__ == "__main__":
|
256 |
+
parser = argparse.ArgumentParser()
|
257 |
+
parser.add_argument("--real_video_dir", type=str, default="")
|
258 |
+
parser.add_argument("--generated_video_dir", type=str, default="")
|
259 |
+
parser.add_argument("--ckpt", type=str, default="")
|
260 |
+
parser.add_argument("--sample_fps", type=int, default=30)
|
261 |
+
parser.add_argument("--resolution", type=int, default=336)
|
262 |
+
parser.add_argument("--crop_size", type=int, default=None)
|
263 |
+
parser.add_argument("--num_frames", type=int, default=17)
|
264 |
+
parser.add_argument("--sample_rate", type=int, default=1)
|
265 |
+
parser.add_argument("--batch_size", type=int, default=1)
|
266 |
+
parser.add_argument("--num_workers", type=int, default=8)
|
267 |
+
parser.add_argument("--subset_size", type=int, default=None)
|
268 |
+
parser.add_argument("--tile_overlap_factor", type=float, default=0.25)
|
269 |
+
parser.add_argument('--enable_tiling', action='store_true')
|
270 |
+
parser.add_argument('--output_origin', action='store_true')
|
271 |
+
parser.add_argument("--device", type=str, default="cuda")
|
272 |
+
|
273 |
+
args = parser.parse_args()
|
274 |
+
main(args)
|
275 |
+
|
opensora/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#
|
opensora/dataset/__init__.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from torchvision.transforms import Compose
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
|
4 |
+
from .feature_datasets import T2V_Feature_dataset, T2V_T5_Feature_dataset
|
5 |
+
from torchvision import transforms
|
6 |
+
from torchvision.transforms import Lambda
|
7 |
+
|
8 |
+
from .t2v_datasets import T2V_dataset
|
9 |
+
from .transform import ToTensorVideo, TemporalRandomCrop, RandomHorizontalFlipVideo, CenterCropResizeVideo, LongSideResizeVideo, SpatialStrideCropVideo
|
10 |
+
|
11 |
+
|
12 |
+
ae_norm = {
|
13 |
+
'CausalVAEModel_4x8x8': Lambda(lambda x: 2. * x - 1.),
|
14 |
+
'CausalVQVAEModel_4x4x4': Lambda(lambda x: x - 0.5),
|
15 |
+
'CausalVQVAEModel_4x8x8': Lambda(lambda x: x - 0.5),
|
16 |
+
'VQVAEModel_4x4x4': Lambda(lambda x: x - 0.5),
|
17 |
+
'VQVAEModel_4x8x8': Lambda(lambda x: x - 0.5),
|
18 |
+
"bair_stride4x2x2": Lambda(lambda x: x - 0.5),
|
19 |
+
"ucf101_stride4x4x4": Lambda(lambda x: x - 0.5),
|
20 |
+
"kinetics_stride4x4x4": Lambda(lambda x: x - 0.5),
|
21 |
+
"kinetics_stride2x4x4": Lambda(lambda x: x - 0.5),
|
22 |
+
'stabilityai/sd-vae-ft-mse': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
23 |
+
'stabilityai/sd-vae-ft-ema': transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
24 |
+
'vqgan_imagenet_f16_1024': Lambda(lambda x: 2. * x - 1.),
|
25 |
+
'vqgan_imagenet_f16_16384': Lambda(lambda x: 2. * x - 1.),
|
26 |
+
'vqgan_gumbel_f8': Lambda(lambda x: 2. * x - 1.),
|
27 |
+
|
28 |
+
}
|
29 |
+
ae_denorm = {
|
30 |
+
'CausalVAEModel_4x8x8': lambda x: (x + 1.) / 2.,
|
31 |
+
'CausalVQVAEModel_4x4x4': lambda x: x + 0.5,
|
32 |
+
'CausalVQVAEModel_4x8x8': lambda x: x + 0.5,
|
33 |
+
'VQVAEModel_4x4x4': lambda x: x + 0.5,
|
34 |
+
'VQVAEModel_4x8x8': lambda x: x + 0.5,
|
35 |
+
"bair_stride4x2x2": lambda x: x + 0.5,
|
36 |
+
"ucf101_stride4x4x4": lambda x: x + 0.5,
|
37 |
+
"kinetics_stride4x4x4": lambda x: x + 0.5,
|
38 |
+
"kinetics_stride2x4x4": lambda x: x + 0.5,
|
39 |
+
'stabilityai/sd-vae-ft-mse': lambda x: 0.5 * x + 0.5,
|
40 |
+
'stabilityai/sd-vae-ft-ema': lambda x: 0.5 * x + 0.5,
|
41 |
+
'vqgan_imagenet_f16_1024': lambda x: (x + 1.) / 2.,
|
42 |
+
'vqgan_imagenet_f16_16384': lambda x: (x + 1.) / 2.,
|
43 |
+
'vqgan_gumbel_f8': lambda x: (x + 1.) / 2.,
|
44 |
+
}
|
45 |
+
|
46 |
+
def getdataset(args):
|
47 |
+
temporal_sample = TemporalRandomCrop(args.num_frames * args.sample_rate) # 16 x
|
48 |
+
norm_fun = ae_norm[args.ae]
|
49 |
+
if args.dataset == 't2v':
|
50 |
+
if args.multi_scale:
|
51 |
+
resize = [
|
52 |
+
LongSideResizeVideo(args.max_image_size, skip_low_resolution=True),
|
53 |
+
SpatialStrideCropVideo(args.stride)
|
54 |
+
]
|
55 |
+
else:
|
56 |
+
resize = [CenterCropResizeVideo(args.max_image_size), ]
|
57 |
+
transform = transforms.Compose([
|
58 |
+
ToTensorVideo(),
|
59 |
+
*resize,
|
60 |
+
# RandomHorizontalFlipVideo(p=0.5), # in case their caption have position decription
|
61 |
+
norm_fun
|
62 |
+
])
|
63 |
+
tokenizer = AutoTokenizer.from_pretrained(args.text_encoder_name, cache_dir=args.cache_dir)
|
64 |
+
return T2V_dataset(args, transform=transform, temporal_sample=temporal_sample, tokenizer=tokenizer)
|
65 |
+
raise NotImplementedError(args.dataset)
|
opensora/dataset/feature_datasets.py
ADDED
@@ -0,0 +1,213 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os
|
3 |
+
import torch
|
4 |
+
import random
|
5 |
+
import torch.utils.data as data
|
6 |
+
|
7 |
+
import numpy as np
|
8 |
+
from glob import glob
|
9 |
+
from PIL import Image
|
10 |
+
from torch.utils.data import Dataset
|
11 |
+
from tqdm import tqdm
|
12 |
+
|
13 |
+
from opensora.dataset.transform import center_crop, RandomCropVideo
|
14 |
+
from opensora.utils.dataset_utils import DecordInit
|
15 |
+
|
16 |
+
|
17 |
+
class T2V_Feature_dataset(Dataset):
|
18 |
+
def __init__(self, args, temporal_sample):
|
19 |
+
|
20 |
+
self.video_folder = args.video_folder
|
21 |
+
self.num_frames = args.video_length
|
22 |
+
self.temporal_sample = temporal_sample
|
23 |
+
|
24 |
+
print('Building dataset...')
|
25 |
+
if os.path.exists('samples_430k.json'):
|
26 |
+
with open('samples_430k.json', 'r') as f:
|
27 |
+
self.samples = json.load(f)
|
28 |
+
else:
|
29 |
+
self.samples = self._make_dataset()
|
30 |
+
with open('samples_430k.json', 'w') as f:
|
31 |
+
json.dump(self.samples, f, indent=2)
|
32 |
+
|
33 |
+
self.use_image_num = args.use_image_num
|
34 |
+
self.use_img_from_vid = args.use_img_from_vid
|
35 |
+
if self.use_image_num != 0 and not self.use_img_from_vid:
|
36 |
+
self.img_cap_list = self.get_img_cap_list()
|
37 |
+
|
38 |
+
def _make_dataset(self):
|
39 |
+
all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True))
|
40 |
+
# all_mp4 = all_mp4[:1000]
|
41 |
+
samples = []
|
42 |
+
for i in tqdm(all_mp4):
|
43 |
+
video_id = os.path.basename(i).split('.')[0]
|
44 |
+
ae = os.path.split(i)[0].replace('data_split_tt', 'lb_causalvideovae444_feature')
|
45 |
+
ae = os.path.join(ae, f'{video_id}_causalvideovae444.npy')
|
46 |
+
if not os.path.exists(ae):
|
47 |
+
continue
|
48 |
+
|
49 |
+
t5 = os.path.split(i)[0].replace('data_split_tt', 'lb_t5_feature')
|
50 |
+
cond_list = []
|
51 |
+
cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy')
|
52 |
+
mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy')
|
53 |
+
if os.path.exists(cond_llava) and os.path.exists(mask_llava):
|
54 |
+
llava = dict(cond=cond_llava, mask=mask_llava)
|
55 |
+
cond_list.append(llava)
|
56 |
+
cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy')
|
57 |
+
mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy')
|
58 |
+
if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v):
|
59 |
+
sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v)
|
60 |
+
cond_list.append(sharegpt4v)
|
61 |
+
if len(cond_list) > 0:
|
62 |
+
sample = dict(ae=ae, t5=cond_list)
|
63 |
+
samples.append(sample)
|
64 |
+
return samples
|
65 |
+
|
66 |
+
def __len__(self):
|
67 |
+
return len(self.samples)
|
68 |
+
|
69 |
+
def __getitem__(self, idx):
|
70 |
+
# try:
|
71 |
+
sample = self.samples[idx]
|
72 |
+
ae, t5 = sample['ae'], sample['t5']
|
73 |
+
t5 = random.choice(t5)
|
74 |
+
video_origin = np.load(ae)[0] # C T H W
|
75 |
+
_, total_frames, _, _ = video_origin.shape
|
76 |
+
# Sampling video frames
|
77 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
78 |
+
assert end_frame_ind - start_frame_ind >= self.num_frames
|
79 |
+
select_video_idx = np.linspace(start_frame_ind, end_frame_ind - 1, num=self.num_frames, dtype=int) # start, stop, num=50
|
80 |
+
# print('select_video_idx', total_frames, select_video_idx)
|
81 |
+
video = video_origin[:, select_video_idx] # C num_frames H W
|
82 |
+
video = torch.from_numpy(video)
|
83 |
+
|
84 |
+
cond = torch.from_numpy(np.load(t5['cond']))[0] # L
|
85 |
+
cond_mask = torch.from_numpy(np.load(t5['mask']))[0] # L D
|
86 |
+
|
87 |
+
if self.use_image_num != 0 and self.use_img_from_vid:
|
88 |
+
select_image_idx = np.random.randint(0, total_frames, self.use_image_num)
|
89 |
+
# print('select_image_idx', total_frames, self.use_image_num, select_image_idx)
|
90 |
+
images = video_origin[:, select_image_idx] # c, num_img, h, w
|
91 |
+
images = torch.from_numpy(images)
|
92 |
+
video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
|
93 |
+
cond = torch.stack([cond] * (1+self.use_image_num)) # 1+self.use_image_num, l
|
94 |
+
cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l
|
95 |
+
elif self.use_image_num != 0 and not self.use_img_from_vid:
|
96 |
+
images, captions = self.img_cap_list[idx]
|
97 |
+
raise NotImplementedError
|
98 |
+
else:
|
99 |
+
pass
|
100 |
+
|
101 |
+
return video, cond, cond_mask
|
102 |
+
# except Exception as e:
|
103 |
+
# print(f'Error with {e}, {sample}')
|
104 |
+
# return self.__getitem__(random.randint(0, self.__len__() - 1))
|
105 |
+
|
106 |
+
def get_img_cap_list(self):
|
107 |
+
raise NotImplementedError
|
108 |
+
|
109 |
+
|
110 |
+
|
111 |
+
|
112 |
+
class T2V_T5_Feature_dataset(Dataset):
|
113 |
+
def __init__(self, args, transform, temporal_sample):
|
114 |
+
|
115 |
+
self.video_folder = args.video_folder
|
116 |
+
self.num_frames = args.num_frames
|
117 |
+
self.transform = transform
|
118 |
+
self.temporal_sample = temporal_sample
|
119 |
+
self.v_decoder = DecordInit()
|
120 |
+
|
121 |
+
print('Building dataset...')
|
122 |
+
if os.path.exists('samples_430k.json'):
|
123 |
+
with open('samples_430k.json', 'r') as f:
|
124 |
+
self.samples = json.load(f)
|
125 |
+
self.samples = [dict(ae=i['ae'].replace('lb_causalvideovae444_feature', 'data_split_1024').replace('_causalvideovae444.npy', '.mp4'), t5=i['t5']) for i in self.samples]
|
126 |
+
else:
|
127 |
+
self.samples = self._make_dataset()
|
128 |
+
with open('samples_430k.json', 'w') as f:
|
129 |
+
json.dump(self.samples, f, indent=2)
|
130 |
+
|
131 |
+
self.use_image_num = args.use_image_num
|
132 |
+
self.use_img_from_vid = args.use_img_from_vid
|
133 |
+
if self.use_image_num != 0 and not self.use_img_from_vid:
|
134 |
+
self.img_cap_list = self.get_img_cap_list()
|
135 |
+
|
136 |
+
def _make_dataset(self):
|
137 |
+
all_mp4 = list(glob(os.path.join(self.video_folder, '**', '*.mp4'), recursive=True))
|
138 |
+
# all_mp4 = all_mp4[:1000]
|
139 |
+
samples = []
|
140 |
+
for i in tqdm(all_mp4):
|
141 |
+
video_id = os.path.basename(i).split('.')[0]
|
142 |
+
# ae = os.path.split(i)[0].replace('data_split', 'lb_causalvideovae444_feature')
|
143 |
+
# ae = os.path.join(ae, f'{video_id}_causalvideovae444.npy')
|
144 |
+
ae = i
|
145 |
+
if not os.path.exists(ae):
|
146 |
+
continue
|
147 |
+
|
148 |
+
t5 = os.path.split(i)[0].replace('data_split_1024', 'lb_t5_feature')
|
149 |
+
cond_list = []
|
150 |
+
cond_llava = os.path.join(t5, f'{video_id}_t5_llava_fea.npy')
|
151 |
+
mask_llava = os.path.join(t5, f'{video_id}_t5_llava_mask.npy')
|
152 |
+
if os.path.exists(cond_llava) and os.path.exists(mask_llava):
|
153 |
+
llava = dict(cond=cond_llava, mask=mask_llava)
|
154 |
+
cond_list.append(llava)
|
155 |
+
cond_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_fea.npy')
|
156 |
+
mask_sharegpt4v = os.path.join(t5, f'{video_id}_t5_sharegpt4v_mask.npy')
|
157 |
+
if os.path.exists(cond_sharegpt4v) and os.path.exists(mask_sharegpt4v):
|
158 |
+
sharegpt4v = dict(cond=cond_sharegpt4v, mask=mask_sharegpt4v)
|
159 |
+
cond_list.append(sharegpt4v)
|
160 |
+
if len(cond_list) > 0:
|
161 |
+
sample = dict(ae=ae, t5=cond_list)
|
162 |
+
samples.append(sample)
|
163 |
+
return samples
|
164 |
+
|
165 |
+
def __len__(self):
|
166 |
+
return len(self.samples)
|
167 |
+
|
168 |
+
def __getitem__(self, idx):
|
169 |
+
try:
|
170 |
+
sample = self.samples[idx]
|
171 |
+
ae, t5 = sample['ae'], sample['t5']
|
172 |
+
t5 = random.choice(t5)
|
173 |
+
|
174 |
+
video = self.decord_read(ae)
|
175 |
+
video = self.transform(video) # T C H W -> T C H W
|
176 |
+
video = video.transpose(0, 1) # T C H W -> C T H W
|
177 |
+
total_frames = video.shape[1]
|
178 |
+
cond = torch.from_numpy(np.load(t5['cond']))[0] # L
|
179 |
+
cond_mask = torch.from_numpy(np.load(t5['mask']))[0] # L D
|
180 |
+
|
181 |
+
if self.use_image_num != 0 and self.use_img_from_vid:
|
182 |
+
select_image_idx = np.random.randint(0, total_frames, self.use_image_num)
|
183 |
+
# print('select_image_idx', total_frames, self.use_image_num, select_image_idx)
|
184 |
+
images = video.numpy()[:, select_image_idx] # c, num_img, h, w
|
185 |
+
images = torch.from_numpy(images)
|
186 |
+
video = torch.cat([video, images], dim=1) # c, num_frame+num_img, h, w
|
187 |
+
cond = torch.stack([cond] * (1+self.use_image_num)) # 1+self.use_image_num, l
|
188 |
+
cond_mask = torch.stack([cond_mask] * (1+self.use_image_num)) # 1+self.use_image_num, l
|
189 |
+
elif self.use_image_num != 0 and not self.use_img_from_vid:
|
190 |
+
images, captions = self.img_cap_list[idx]
|
191 |
+
raise NotImplementedError
|
192 |
+
else:
|
193 |
+
pass
|
194 |
+
|
195 |
+
return video, cond, cond_mask
|
196 |
+
except Exception as e:
|
197 |
+
print(f'Error with {e}, {sample}')
|
198 |
+
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
199 |
+
|
200 |
+
def decord_read(self, path):
|
201 |
+
decord_vr = self.v_decoder(path)
|
202 |
+
total_frames = len(decord_vr)
|
203 |
+
# Sampling video frames
|
204 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
205 |
+
# assert end_frame_ind - start_frame_ind >= self.num_frames
|
206 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
207 |
+
video_data = decord_vr.get_batch(frame_indice).asnumpy()
|
208 |
+
video_data = torch.from_numpy(video_data)
|
209 |
+
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
|
210 |
+
return video_data
|
211 |
+
|
212 |
+
def get_img_cap_list(self):
|
213 |
+
raise NotImplementedError
|
opensora/dataset/t2v_datasets.py
ADDED
@@ -0,0 +1,203 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import os, io, csv, math, random
|
3 |
+
import numpy as np
|
4 |
+
import torchvision
|
5 |
+
from einops import rearrange
|
6 |
+
from decord import VideoReader
|
7 |
+
from os.path import join as opj
|
8 |
+
import gc
|
9 |
+
import torch
|
10 |
+
import torchvision.transforms as transforms
|
11 |
+
from torch.utils.data.dataset import Dataset
|
12 |
+
from tqdm import tqdm
|
13 |
+
from PIL import Image
|
14 |
+
|
15 |
+
from opensora.utils.dataset_utils import DecordInit
|
16 |
+
from opensora.utils.utils import text_preprocessing
|
17 |
+
|
18 |
+
|
19 |
+
def random_video_noise(t, c, h, w):
|
20 |
+
vid = torch.rand(t, c, h, w) * 255.0
|
21 |
+
vid = vid.to(torch.uint8)
|
22 |
+
return vid
|
23 |
+
|
24 |
+
class T2V_dataset(Dataset):
|
25 |
+
def __init__(self, args, transform, temporal_sample, tokenizer):
|
26 |
+
self.image_data = args.image_data
|
27 |
+
self.video_data = args.video_data
|
28 |
+
self.num_frames = args.num_frames
|
29 |
+
self.transform = transform
|
30 |
+
self.temporal_sample = temporal_sample
|
31 |
+
self.tokenizer = tokenizer
|
32 |
+
self.model_max_length = args.model_max_length
|
33 |
+
self.v_decoder = DecordInit()
|
34 |
+
|
35 |
+
self.vid_cap_list = self.get_vid_cap_list()
|
36 |
+
|
37 |
+
self.use_image_num = args.use_image_num
|
38 |
+
self.use_img_from_vid = args.use_img_from_vid
|
39 |
+
if self.use_image_num != 0 and not self.use_img_from_vid:
|
40 |
+
self.img_cap_list = self.get_img_cap_list()
|
41 |
+
|
42 |
+
|
43 |
+
def __len__(self):
|
44 |
+
return len(self.vid_cap_list)
|
45 |
+
|
46 |
+
def __getitem__(self, idx):
|
47 |
+
try:
|
48 |
+
# import ipdb;ipdb.set_trace()
|
49 |
+
video_data = self.get_video(idx)
|
50 |
+
image_data = {}
|
51 |
+
if self.use_image_num != 0 and self.use_img_from_vid:
|
52 |
+
image_data = self.get_image_from_video(video_data)
|
53 |
+
elif self.use_image_num != 0 and not self.use_img_from_vid:
|
54 |
+
image_data = self.get_image(idx)
|
55 |
+
else:
|
56 |
+
raise NotImplementedError
|
57 |
+
gc.collect()
|
58 |
+
return dict(video_data=video_data, image_data=image_data)
|
59 |
+
except Exception as e:
|
60 |
+
# print(f'Error with {e}, {self.vid_cap_list[idx]}')
|
61 |
+
if os.path.exists(self.vid_cap_list[idx]['path']) and '_resize_1080p' in self.vid_cap_list[idx]['path']:
|
62 |
+
os.remove(self.vid_cap_list[idx]['path'])
|
63 |
+
print('remove:', self.vid_cap_list[idx]['path'])
|
64 |
+
return self.__getitem__(random.randint(0, self.__len__() - 1))
|
65 |
+
|
66 |
+
def get_video(self, idx):
|
67 |
+
# video = random.choice([random_video_noise(65, 3, 720, 360) * 255, random_video_noise(65, 3, 1024, 1024), random_video_noise(65, 3, 360, 720)])
|
68 |
+
# # print('random shape', video.shape)
|
69 |
+
# input_ids = torch.ones(1, 120).to(torch.long).squeeze(0)
|
70 |
+
# cond_mask = torch.cat([torch.ones(1, 60).to(torch.long), torch.ones(1, 60).to(torch.long)], dim=1).squeeze(0)
|
71 |
+
|
72 |
+
video_path = self.vid_cap_list[idx]['path']
|
73 |
+
frame_idx = self.vid_cap_list[idx]['frame_idx']
|
74 |
+
#print('before decord')
|
75 |
+
video = self.decord_read(video_path, frame_idx)
|
76 |
+
# video = self.tv_read(video_path, frame_idx)
|
77 |
+
#print('after decord')
|
78 |
+
video = self.transform(video) # T C H W -> T C H W
|
79 |
+
# del raw_video
|
80 |
+
# gc.collect()
|
81 |
+
# video = torch.rand(65, 3, 512, 512)
|
82 |
+
#print('after transform')
|
83 |
+
video = video.transpose(0, 1) # T C H W -> C T H W
|
84 |
+
text = self.vid_cap_list[idx]['cap']
|
85 |
+
|
86 |
+
text = text_preprocessing(text)
|
87 |
+
text_tokens_and_mask = self.tokenizer(
|
88 |
+
text,
|
89 |
+
max_length=self.model_max_length,
|
90 |
+
padding='max_length',
|
91 |
+
truncation=True,
|
92 |
+
return_attention_mask=True,
|
93 |
+
add_special_tokens=True,
|
94 |
+
return_tensors='pt'
|
95 |
+
)
|
96 |
+
input_ids = text_tokens_and_mask['input_ids']
|
97 |
+
cond_mask = text_tokens_and_mask['attention_mask']
|
98 |
+
return dict(video=video, input_ids=input_ids, cond_mask=cond_mask)
|
99 |
+
|
100 |
+
def get_image_from_video(self, video_data):
|
101 |
+
select_image_idx = np.linspace(0, self.num_frames-1, self.use_image_num, dtype=int)
|
102 |
+
assert self.num_frames >= self.use_image_num
|
103 |
+
image = [video_data['video'][:, i:i+1] for i in select_image_idx] # num_img [c, 1, h, w]
|
104 |
+
input_ids = video_data['input_ids'].repeat(self.use_image_num, 1) # self.use_image_num, l
|
105 |
+
cond_mask = video_data['cond_mask'].repeat(self.use_image_num, 1) # self.use_image_num, l
|
106 |
+
return dict(image=image, input_ids=input_ids, cond_mask=cond_mask)
|
107 |
+
|
108 |
+
def get_image(self, idx):
|
109 |
+
idx = idx % len(self.img_cap_list) # out of range
|
110 |
+
image_data = self.img_cap_list[idx] # [{'path': path, 'cap': cap}, ...]
|
111 |
+
|
112 |
+
image = [Image.open(i['path']).convert('RGB') for i in image_data] # num_img [h, w, c]
|
113 |
+
image = [torch.from_numpy(np.array(i)) for i in image] # num_img [h, w, c]
|
114 |
+
image = [rearrange(i, 'h w c -> c h w').unsqueeze(0) for i in image] # num_img [1 c h w]
|
115 |
+
image = [self.transform(i) for i in image] # num_img [1 C H W] -> num_img [1 C H W]
|
116 |
+
image = [i.transpose(0, 1) for i in image] # num_img [1 C H W] -> num_img [C 1 H W]
|
117 |
+
|
118 |
+
caps = [i['cap'] for i in image_data]
|
119 |
+
text = [text_preprocessing(cap) for cap in caps]
|
120 |
+
input_ids, cond_mask = [], []
|
121 |
+
for t in text:
|
122 |
+
text_tokens_and_mask = self.tokenizer(
|
123 |
+
t,
|
124 |
+
max_length=self.model_max_length,
|
125 |
+
padding='max_length',
|
126 |
+
truncation=True,
|
127 |
+
return_attention_mask=True,
|
128 |
+
add_special_tokens=True,
|
129 |
+
return_tensors='pt'
|
130 |
+
)
|
131 |
+
input_ids.append(text_tokens_and_mask['input_ids'])
|
132 |
+
cond_mask.append(text_tokens_and_mask['attention_mask'])
|
133 |
+
input_ids = torch.cat(input_ids) # self.use_image_num, l
|
134 |
+
cond_mask = torch.cat(cond_mask) # self.use_image_num, l
|
135 |
+
return dict(image=image, input_ids=input_ids, cond_mask=cond_mask)
|
136 |
+
|
137 |
+
def tv_read(self, path, frame_idx=None):
|
138 |
+
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
|
139 |
+
total_frames = len(vframes)
|
140 |
+
if frame_idx is None:
|
141 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
142 |
+
else:
|
143 |
+
start_frame_ind, end_frame_ind = frame_idx.split(':')
|
144 |
+
start_frame_ind, end_frame_ind = int(start_frame_ind), int(end_frame_ind)
|
145 |
+
# assert end_frame_ind - start_frame_ind >= self.num_frames
|
146 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
147 |
+
# frame_indice = np.linspace(0, 63, self.num_frames, dtype=int)
|
148 |
+
|
149 |
+
video = vframes[frame_indice] # (T, C, H, W)
|
150 |
+
|
151 |
+
return video
|
152 |
+
|
153 |
+
def decord_read(self, path, frame_idx=None):
|
154 |
+
decord_vr = self.v_decoder(path)
|
155 |
+
total_frames = len(decord_vr)
|
156 |
+
# Sampling video frames
|
157 |
+
if frame_idx is None:
|
158 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
159 |
+
else:
|
160 |
+
start_frame_ind, end_frame_ind = frame_idx.split(':')
|
161 |
+
start_frame_ind, end_frame_ind = int(start_frame_ind), int(end_frame_ind)
|
162 |
+
start_frame_ind, end_frame_ind = int(start_frame_ind), int(start_frame_ind) + self.num_frames
|
163 |
+
# assert end_frame_ind - start_frame_ind >= self.num_frames
|
164 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
165 |
+
# frame_indice = np.linspace(0, 63, self.num_frames, dtype=int)
|
166 |
+
|
167 |
+
video_data = decord_vr.get_batch(frame_indice).asnumpy()
|
168 |
+
video_data = torch.from_numpy(video_data)
|
169 |
+
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
|
170 |
+
return video_data
|
171 |
+
|
172 |
+
|
173 |
+
def get_vid_cap_list(self):
|
174 |
+
vid_cap_lists = []
|
175 |
+
with open(self.video_data, 'r') as f:
|
176 |
+
folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0]
|
177 |
+
# print(folder_anno)
|
178 |
+
for folder, anno in folder_anno:
|
179 |
+
with open(anno, 'r') as f:
|
180 |
+
vid_cap_list = json.load(f)
|
181 |
+
print(f'Building {anno}...')
|
182 |
+
for i in tqdm(range(len(vid_cap_list))):
|
183 |
+
path = opj(folder, vid_cap_list[i]['path'])
|
184 |
+
if os.path.exists(path.replace('.mp4', '_resize_1080p.mp4')):
|
185 |
+
path = path.replace('.mp4', '_resize_1080p.mp4')
|
186 |
+
vid_cap_list[i]['path'] = path
|
187 |
+
|
188 |
+
vid_cap_lists += vid_cap_list
|
189 |
+
return vid_cap_lists
|
190 |
+
|
191 |
+
def get_img_cap_list(self):
|
192 |
+
img_cap_lists = []
|
193 |
+
with open(self.image_data, 'r') as f:
|
194 |
+
folder_anno = [i.strip().split(',') for i in f.readlines() if len(i.strip()) > 0]
|
195 |
+
for folder, anno in folder_anno:
|
196 |
+
with open(anno, 'r') as f:
|
197 |
+
img_cap_list = json.load(f)
|
198 |
+
print(f'Building {anno}...')
|
199 |
+
for i in tqdm(range(len(img_cap_list))):
|
200 |
+
img_cap_list[i]['path'] = opj(folder, img_cap_list[i]['path'])
|
201 |
+
img_cap_lists += img_cap_list
|
202 |
+
img_cap_lists = [img_cap_lists[i: i+self.use_image_num] for i in range(0, len(img_cap_lists), self.use_image_num)]
|
203 |
+
return img_cap_lists[:-1] # drop last to avoid error length
|
opensora/dataset/transform.py
ADDED
@@ -0,0 +1,573 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import random
|
3 |
+
import numbers
|
4 |
+
from torchvision.transforms import RandomCrop, RandomResizedCrop
|
5 |
+
|
6 |
+
|
7 |
+
def _is_tensor_video_clip(clip):
|
8 |
+
if not torch.is_tensor(clip):
|
9 |
+
raise TypeError("clip should be Tensor. Got %s" % type(clip))
|
10 |
+
|
11 |
+
if not clip.ndimension() == 4:
|
12 |
+
raise ValueError("clip should be 4D. Got %dD" % clip.dim())
|
13 |
+
|
14 |
+
return True
|
15 |
+
|
16 |
+
|
17 |
+
def center_crop_arr(pil_image, image_size):
|
18 |
+
"""
|
19 |
+
Center cropping implementation from ADM.
|
20 |
+
https://github.com/openai/guided-diffusion/blob/8fb3ad9197f16bbc40620447b2742e13458d2831/guided_diffusion/image_datasets.py#L126
|
21 |
+
"""
|
22 |
+
while min(*pil_image.size) >= 2 * image_size:
|
23 |
+
pil_image = pil_image.resize(
|
24 |
+
tuple(x // 2 for x in pil_image.size), resample=Image.BOX
|
25 |
+
)
|
26 |
+
|
27 |
+
scale = image_size / min(*pil_image.size)
|
28 |
+
pil_image = pil_image.resize(
|
29 |
+
tuple(round(x * scale) for x in pil_image.size), resample=Image.BICUBIC
|
30 |
+
)
|
31 |
+
|
32 |
+
arr = np.array(pil_image)
|
33 |
+
crop_y = (arr.shape[0] - image_size) // 2
|
34 |
+
crop_x = (arr.shape[1] - image_size) // 2
|
35 |
+
return Image.fromarray(arr[crop_y: crop_y + image_size, crop_x: crop_x + image_size])
|
36 |
+
|
37 |
+
|
38 |
+
def crop(clip, i, j, h, w):
|
39 |
+
"""
|
40 |
+
Args:
|
41 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
42 |
+
"""
|
43 |
+
if len(clip.size()) != 4:
|
44 |
+
raise ValueError("clip should be a 4D tensor")
|
45 |
+
return clip[..., i: i + h, j: j + w]
|
46 |
+
|
47 |
+
|
48 |
+
def resize(clip, target_size, interpolation_mode):
|
49 |
+
if len(target_size) != 2:
|
50 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
51 |
+
return torch.nn.functional.interpolate(clip, size=target_size, mode=interpolation_mode, align_corners=True, antialias=True)
|
52 |
+
|
53 |
+
|
54 |
+
def resize_scale(clip, target_size, interpolation_mode):
|
55 |
+
if len(target_size) != 2:
|
56 |
+
raise ValueError(f"target size should be tuple (height, width), instead got {target_size}")
|
57 |
+
H, W = clip.size(-2), clip.size(-1)
|
58 |
+
scale_ = target_size[0] / min(H, W)
|
59 |
+
return torch.nn.functional.interpolate(clip, scale_factor=scale_, mode=interpolation_mode, align_corners=True, antialias=True)
|
60 |
+
|
61 |
+
|
62 |
+
def resized_crop(clip, i, j, h, w, size, interpolation_mode="bilinear"):
|
63 |
+
"""
|
64 |
+
Do spatial cropping and resizing to the video clip
|
65 |
+
Args:
|
66 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
67 |
+
i (int): i in (i,j) i.e coordinates of the upper left corner.
|
68 |
+
j (int): j in (i,j) i.e coordinates of the upper left corner.
|
69 |
+
h (int): Height of the cropped region.
|
70 |
+
w (int): Width of the cropped region.
|
71 |
+
size (tuple(int, int)): height and width of resized clip
|
72 |
+
Returns:
|
73 |
+
clip (torch.tensor): Resized and cropped clip. Size is (T, C, H, W)
|
74 |
+
"""
|
75 |
+
if not _is_tensor_video_clip(clip):
|
76 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
77 |
+
clip = crop(clip, i, j, h, w)
|
78 |
+
clip = resize(clip, size, interpolation_mode)
|
79 |
+
return clip
|
80 |
+
|
81 |
+
|
82 |
+
def center_crop(clip, crop_size):
|
83 |
+
if not _is_tensor_video_clip(clip):
|
84 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
85 |
+
h, w = clip.size(-2), clip.size(-1)
|
86 |
+
th, tw = crop_size
|
87 |
+
if h < th or w < tw:
|
88 |
+
raise ValueError("height and width must be no smaller than crop_size")
|
89 |
+
|
90 |
+
i = int(round((h - th) / 2.0))
|
91 |
+
j = int(round((w - tw) / 2.0))
|
92 |
+
return crop(clip, i, j, th, tw)
|
93 |
+
|
94 |
+
|
95 |
+
def center_crop_using_short_edge(clip):
|
96 |
+
if not _is_tensor_video_clip(clip):
|
97 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
98 |
+
h, w = clip.size(-2), clip.size(-1)
|
99 |
+
if h < w:
|
100 |
+
th, tw = h, h
|
101 |
+
i = 0
|
102 |
+
j = int(round((w - tw) / 2.0))
|
103 |
+
else:
|
104 |
+
th, tw = w, w
|
105 |
+
i = int(round((h - th) / 2.0))
|
106 |
+
j = 0
|
107 |
+
return crop(clip, i, j, th, tw)
|
108 |
+
|
109 |
+
|
110 |
+
def random_shift_crop(clip):
|
111 |
+
'''
|
112 |
+
Slide along the long edge, with the short edge as crop size
|
113 |
+
'''
|
114 |
+
if not _is_tensor_video_clip(clip):
|
115 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
116 |
+
h, w = clip.size(-2), clip.size(-1)
|
117 |
+
|
118 |
+
if h <= w:
|
119 |
+
long_edge = w
|
120 |
+
short_edge = h
|
121 |
+
else:
|
122 |
+
long_edge = h
|
123 |
+
short_edge = w
|
124 |
+
|
125 |
+
th, tw = short_edge, short_edge
|
126 |
+
|
127 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
128 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
129 |
+
return crop(clip, i, j, th, tw)
|
130 |
+
|
131 |
+
|
132 |
+
def to_tensor(clip):
|
133 |
+
"""
|
134 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
135 |
+
permute the dimensions of clip tensor
|
136 |
+
Args:
|
137 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
138 |
+
Return:
|
139 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
140 |
+
"""
|
141 |
+
_is_tensor_video_clip(clip)
|
142 |
+
if not clip.dtype == torch.uint8:
|
143 |
+
raise TypeError("clip tensor should have data type uint8. Got %s" % str(clip.dtype))
|
144 |
+
# return clip.float().permute(3, 0, 1, 2) / 255.0
|
145 |
+
return clip.float() / 255.0
|
146 |
+
|
147 |
+
|
148 |
+
def normalize(clip, mean, std, inplace=False):
|
149 |
+
"""
|
150 |
+
Args:
|
151 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
152 |
+
mean (tuple): pixel RGB mean. Size is (3)
|
153 |
+
std (tuple): pixel standard deviation. Size is (3)
|
154 |
+
Returns:
|
155 |
+
normalized clip (torch.tensor): Size is (T, C, H, W)
|
156 |
+
"""
|
157 |
+
if not _is_tensor_video_clip(clip):
|
158 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
159 |
+
if not inplace:
|
160 |
+
clip = clip.clone()
|
161 |
+
mean = torch.as_tensor(mean, dtype=clip.dtype, device=clip.device)
|
162 |
+
# print(mean)
|
163 |
+
std = torch.as_tensor(std, dtype=clip.dtype, device=clip.device)
|
164 |
+
clip.sub_(mean[:, None, None, None]).div_(std[:, None, None, None])
|
165 |
+
return clip
|
166 |
+
|
167 |
+
|
168 |
+
def hflip(clip):
|
169 |
+
"""
|
170 |
+
Args:
|
171 |
+
clip (torch.tensor): Video clip to be normalized. Size is (T, C, H, W)
|
172 |
+
Returns:
|
173 |
+
flipped clip (torch.tensor): Size is (T, C, H, W)
|
174 |
+
"""
|
175 |
+
if not _is_tensor_video_clip(clip):
|
176 |
+
raise ValueError("clip should be a 4D torch.tensor")
|
177 |
+
return clip.flip(-1)
|
178 |
+
|
179 |
+
|
180 |
+
class RandomCropVideo:
|
181 |
+
def __init__(self, size):
|
182 |
+
if isinstance(size, numbers.Number):
|
183 |
+
self.size = (int(size), int(size))
|
184 |
+
else:
|
185 |
+
self.size = size
|
186 |
+
|
187 |
+
def __call__(self, clip):
|
188 |
+
"""
|
189 |
+
Args:
|
190 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
191 |
+
Returns:
|
192 |
+
torch.tensor: randomly cropped video clip.
|
193 |
+
size is (T, C, OH, OW)
|
194 |
+
"""
|
195 |
+
i, j, h, w = self.get_params(clip)
|
196 |
+
return crop(clip, i, j, h, w)
|
197 |
+
|
198 |
+
def get_params(self, clip):
|
199 |
+
h, w = clip.shape[-2:]
|
200 |
+
th, tw = self.size
|
201 |
+
|
202 |
+
if h < th or w < tw:
|
203 |
+
raise ValueError(f"Required crop size {(th, tw)} is larger than input image size {(h, w)}")
|
204 |
+
|
205 |
+
if w == tw and h == th:
|
206 |
+
return 0, 0, h, w
|
207 |
+
|
208 |
+
i = torch.randint(0, h - th + 1, size=(1,)).item()
|
209 |
+
j = torch.randint(0, w - tw + 1, size=(1,)).item()
|
210 |
+
|
211 |
+
return i, j, th, tw
|
212 |
+
|
213 |
+
def __repr__(self) -> str:
|
214 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
215 |
+
|
216 |
+
|
217 |
+
class SpatialStrideCropVideo:
|
218 |
+
def __init__(self, stride):
|
219 |
+
self.stride = stride
|
220 |
+
|
221 |
+
def __call__(self, clip):
|
222 |
+
"""
|
223 |
+
Args:
|
224 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
225 |
+
Returns:
|
226 |
+
torch.tensor: cropped video clip by stride.
|
227 |
+
size is (T, C, OH, OW)
|
228 |
+
"""
|
229 |
+
i, j, h, w = self.get_params(clip)
|
230 |
+
return crop(clip, i, j, h, w)
|
231 |
+
|
232 |
+
def get_params(self, clip):
|
233 |
+
h, w = clip.shape[-2:]
|
234 |
+
|
235 |
+
th, tw = h // self.stride * self.stride, w // self.stride * self.stride
|
236 |
+
|
237 |
+
return 0, 0, th, tw # from top-left
|
238 |
+
|
239 |
+
def __repr__(self) -> str:
|
240 |
+
return f"{self.__class__.__name__}(size={self.size})"
|
241 |
+
|
242 |
+
class LongSideResizeVideo:
|
243 |
+
'''
|
244 |
+
First use the long side,
|
245 |
+
then resize to the specified size
|
246 |
+
'''
|
247 |
+
|
248 |
+
def __init__(
|
249 |
+
self,
|
250 |
+
size,
|
251 |
+
skip_low_resolution=False,
|
252 |
+
interpolation_mode="bilinear",
|
253 |
+
):
|
254 |
+
self.size = size
|
255 |
+
self.skip_low_resolution = skip_low_resolution
|
256 |
+
self.interpolation_mode = interpolation_mode
|
257 |
+
|
258 |
+
def __call__(self, clip):
|
259 |
+
"""
|
260 |
+
Args:
|
261 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
262 |
+
Returns:
|
263 |
+
torch.tensor: scale resized video clip.
|
264 |
+
size is (T, C, 512, *) or (T, C, *, 512)
|
265 |
+
"""
|
266 |
+
_, _, h, w = clip.shape
|
267 |
+
if self.skip_low_resolution and max(h, w) <= self.size:
|
268 |
+
return clip
|
269 |
+
if h > w:
|
270 |
+
w = int(w * self.size / h)
|
271 |
+
h = self.size
|
272 |
+
else:
|
273 |
+
h = int(h * self.size / w)
|
274 |
+
w = self.size
|
275 |
+
resize_clip = resize(clip, target_size=(h, w),
|
276 |
+
interpolation_mode=self.interpolation_mode)
|
277 |
+
return resize_clip
|
278 |
+
|
279 |
+
def __repr__(self) -> str:
|
280 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
281 |
+
|
282 |
+
class CenterCropResizeVideo:
|
283 |
+
'''
|
284 |
+
First use the short side for cropping length,
|
285 |
+
center crop video, then resize to the specified size
|
286 |
+
'''
|
287 |
+
|
288 |
+
def __init__(
|
289 |
+
self,
|
290 |
+
size,
|
291 |
+
interpolation_mode="bilinear",
|
292 |
+
):
|
293 |
+
if isinstance(size, tuple):
|
294 |
+
if len(size) != 2:
|
295 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
296 |
+
self.size = size
|
297 |
+
else:
|
298 |
+
self.size = (size, size)
|
299 |
+
|
300 |
+
self.interpolation_mode = interpolation_mode
|
301 |
+
|
302 |
+
def __call__(self, clip):
|
303 |
+
"""
|
304 |
+
Args:
|
305 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
306 |
+
Returns:
|
307 |
+
torch.tensor: scale resized / center cropped video clip.
|
308 |
+
size is (T, C, crop_size, crop_size)
|
309 |
+
"""
|
310 |
+
clip_center_crop = center_crop_using_short_edge(clip)
|
311 |
+
clip_center_crop_resize = resize(clip_center_crop, target_size=self.size,
|
312 |
+
interpolation_mode=self.interpolation_mode)
|
313 |
+
return clip_center_crop_resize
|
314 |
+
|
315 |
+
def __repr__(self) -> str:
|
316 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
317 |
+
|
318 |
+
|
319 |
+
class UCFCenterCropVideo:
|
320 |
+
'''
|
321 |
+
First scale to the specified size in equal proportion to the short edge,
|
322 |
+
then center cropping
|
323 |
+
'''
|
324 |
+
|
325 |
+
def __init__(
|
326 |
+
self,
|
327 |
+
size,
|
328 |
+
interpolation_mode="bilinear",
|
329 |
+
):
|
330 |
+
if isinstance(size, tuple):
|
331 |
+
if len(size) != 2:
|
332 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
333 |
+
self.size = size
|
334 |
+
else:
|
335 |
+
self.size = (size, size)
|
336 |
+
|
337 |
+
self.interpolation_mode = interpolation_mode
|
338 |
+
|
339 |
+
def __call__(self, clip):
|
340 |
+
"""
|
341 |
+
Args:
|
342 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
343 |
+
Returns:
|
344 |
+
torch.tensor: scale resized / center cropped video clip.
|
345 |
+
size is (T, C, crop_size, crop_size)
|
346 |
+
"""
|
347 |
+
clip_resize = resize_scale(clip=clip, target_size=self.size, interpolation_mode=self.interpolation_mode)
|
348 |
+
clip_center_crop = center_crop(clip_resize, self.size)
|
349 |
+
return clip_center_crop
|
350 |
+
|
351 |
+
def __repr__(self) -> str:
|
352 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
353 |
+
|
354 |
+
|
355 |
+
class KineticsRandomCropResizeVideo:
|
356 |
+
'''
|
357 |
+
Slide along the long edge, with the short edge as crop size. And resie to the desired size.
|
358 |
+
'''
|
359 |
+
|
360 |
+
def __init__(
|
361 |
+
self,
|
362 |
+
size,
|
363 |
+
interpolation_mode="bilinear",
|
364 |
+
):
|
365 |
+
if isinstance(size, tuple):
|
366 |
+
if len(size) != 2:
|
367 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
368 |
+
self.size = size
|
369 |
+
else:
|
370 |
+
self.size = (size, size)
|
371 |
+
|
372 |
+
self.interpolation_mode = interpolation_mode
|
373 |
+
|
374 |
+
def __call__(self, clip):
|
375 |
+
clip_random_crop = random_shift_crop(clip)
|
376 |
+
clip_resize = resize(clip_random_crop, self.size, self.interpolation_mode)
|
377 |
+
return clip_resize
|
378 |
+
|
379 |
+
|
380 |
+
class CenterCropVideo:
|
381 |
+
def __init__(
|
382 |
+
self,
|
383 |
+
size,
|
384 |
+
interpolation_mode="bilinear",
|
385 |
+
):
|
386 |
+
if isinstance(size, tuple):
|
387 |
+
if len(size) != 2:
|
388 |
+
raise ValueError(f"size should be tuple (height, width), instead got {size}")
|
389 |
+
self.size = size
|
390 |
+
else:
|
391 |
+
self.size = (size, size)
|
392 |
+
|
393 |
+
self.interpolation_mode = interpolation_mode
|
394 |
+
|
395 |
+
def __call__(self, clip):
|
396 |
+
"""
|
397 |
+
Args:
|
398 |
+
clip (torch.tensor): Video clip to be cropped. Size is (T, C, H, W)
|
399 |
+
Returns:
|
400 |
+
torch.tensor: center cropped video clip.
|
401 |
+
size is (T, C, crop_size, crop_size)
|
402 |
+
"""
|
403 |
+
clip_center_crop = center_crop(clip, self.size)
|
404 |
+
return clip_center_crop
|
405 |
+
|
406 |
+
def __repr__(self) -> str:
|
407 |
+
return f"{self.__class__.__name__}(size={self.size}, interpolation_mode={self.interpolation_mode}"
|
408 |
+
|
409 |
+
|
410 |
+
class NormalizeVideo:
|
411 |
+
"""
|
412 |
+
Normalize the video clip by mean subtraction and division by standard deviation
|
413 |
+
Args:
|
414 |
+
mean (3-tuple): pixel RGB mean
|
415 |
+
std (3-tuple): pixel RGB standard deviation
|
416 |
+
inplace (boolean): whether do in-place normalization
|
417 |
+
"""
|
418 |
+
|
419 |
+
def __init__(self, mean, std, inplace=False):
|
420 |
+
self.mean = mean
|
421 |
+
self.std = std
|
422 |
+
self.inplace = inplace
|
423 |
+
|
424 |
+
def __call__(self, clip):
|
425 |
+
"""
|
426 |
+
Args:
|
427 |
+
clip (torch.tensor): video clip must be normalized. Size is (C, T, H, W)
|
428 |
+
"""
|
429 |
+
return normalize(clip, self.mean, self.std, self.inplace)
|
430 |
+
|
431 |
+
def __repr__(self) -> str:
|
432 |
+
return f"{self.__class__.__name__}(mean={self.mean}, std={self.std}, inplace={self.inplace})"
|
433 |
+
|
434 |
+
|
435 |
+
class ToTensorVideo:
|
436 |
+
"""
|
437 |
+
Convert tensor data type from uint8 to float, divide value by 255.0 and
|
438 |
+
permute the dimensions of clip tensor
|
439 |
+
"""
|
440 |
+
|
441 |
+
def __init__(self):
|
442 |
+
pass
|
443 |
+
|
444 |
+
def __call__(self, clip):
|
445 |
+
"""
|
446 |
+
Args:
|
447 |
+
clip (torch.tensor, dtype=torch.uint8): Size is (T, C, H, W)
|
448 |
+
Return:
|
449 |
+
clip (torch.tensor, dtype=torch.float): Size is (T, C, H, W)
|
450 |
+
"""
|
451 |
+
return to_tensor(clip)
|
452 |
+
|
453 |
+
def __repr__(self) -> str:
|
454 |
+
return self.__class__.__name__
|
455 |
+
|
456 |
+
|
457 |
+
class RandomHorizontalFlipVideo:
|
458 |
+
"""
|
459 |
+
Flip the video clip along the horizontal direction with a given probability
|
460 |
+
Args:
|
461 |
+
p (float): probability of the clip being flipped. Default value is 0.5
|
462 |
+
"""
|
463 |
+
|
464 |
+
def __init__(self, p=0.5):
|
465 |
+
self.p = p
|
466 |
+
|
467 |
+
def __call__(self, clip):
|
468 |
+
"""
|
469 |
+
Args:
|
470 |
+
clip (torch.tensor): Size is (T, C, H, W)
|
471 |
+
Return:
|
472 |
+
clip (torch.tensor): Size is (T, C, H, W)
|
473 |
+
"""
|
474 |
+
if random.random() < self.p:
|
475 |
+
clip = hflip(clip)
|
476 |
+
return clip
|
477 |
+
|
478 |
+
def __repr__(self) -> str:
|
479 |
+
return f"{self.__class__.__name__}(p={self.p})"
|
480 |
+
|
481 |
+
|
482 |
+
# ------------------------------------------------------------
|
483 |
+
# --------------------- Sampling ---------------------------
|
484 |
+
# ------------------------------------------------------------
|
485 |
+
class TemporalRandomCrop(object):
|
486 |
+
"""Temporally crop the given frame indices at a random location.
|
487 |
+
|
488 |
+
Args:
|
489 |
+
size (int): Desired length of frames will be seen in the model.
|
490 |
+
"""
|
491 |
+
|
492 |
+
def __init__(self, size):
|
493 |
+
self.size = size
|
494 |
+
|
495 |
+
def __call__(self, total_frames):
|
496 |
+
rand_end = max(0, total_frames - self.size - 1)
|
497 |
+
begin_index = random.randint(0, rand_end)
|
498 |
+
end_index = min(begin_index + self.size, total_frames)
|
499 |
+
return begin_index, end_index
|
500 |
+
|
501 |
+
class DynamicSampleDuration(object):
|
502 |
+
"""Temporally crop the given frame indices at a random location.
|
503 |
+
|
504 |
+
Args:
|
505 |
+
size (int): Desired length of frames will be seen in the model.
|
506 |
+
"""
|
507 |
+
|
508 |
+
def __init__(self, t_stride, extra_1):
|
509 |
+
self.t_stride = t_stride
|
510 |
+
self.extra_1 = extra_1
|
511 |
+
|
512 |
+
def __call__(self, t, h, w):
|
513 |
+
if self.extra_1:
|
514 |
+
t = t - 1
|
515 |
+
truncate_t_list = list(range(t+1))[t//2:][::self.t_stride] # need half at least
|
516 |
+
truncate_t = random.choice(truncate_t_list)
|
517 |
+
if self.extra_1:
|
518 |
+
truncate_t = truncate_t + 1
|
519 |
+
return 0, truncate_t
|
520 |
+
|
521 |
+
if __name__ == '__main__':
|
522 |
+
from torchvision import transforms
|
523 |
+
import torchvision.io as io
|
524 |
+
import numpy as np
|
525 |
+
from torchvision.utils import save_image
|
526 |
+
import os
|
527 |
+
|
528 |
+
vframes, aframes, info = io.read_video(
|
529 |
+
filename='./v_Archery_g01_c03.avi',
|
530 |
+
pts_unit='sec',
|
531 |
+
output_format='TCHW'
|
532 |
+
)
|
533 |
+
|
534 |
+
trans = transforms.Compose([
|
535 |
+
ToTensorVideo(),
|
536 |
+
RandomHorizontalFlipVideo(),
|
537 |
+
UCFCenterCropVideo(512),
|
538 |
+
# NormalizeVideo(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True),
|
539 |
+
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], inplace=True)
|
540 |
+
])
|
541 |
+
|
542 |
+
target_video_len = 32
|
543 |
+
frame_interval = 1
|
544 |
+
total_frames = len(vframes)
|
545 |
+
print(total_frames)
|
546 |
+
|
547 |
+
temporal_sample = TemporalRandomCrop(target_video_len * frame_interval)
|
548 |
+
|
549 |
+
# Sampling video frames
|
550 |
+
start_frame_ind, end_frame_ind = temporal_sample(total_frames)
|
551 |
+
# print(start_frame_ind)
|
552 |
+
# print(end_frame_ind)
|
553 |
+
assert end_frame_ind - start_frame_ind >= target_video_len
|
554 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, target_video_len, dtype=int)
|
555 |
+
print(frame_indice)
|
556 |
+
|
557 |
+
select_vframes = vframes[frame_indice]
|
558 |
+
print(select_vframes.shape)
|
559 |
+
print(select_vframes.dtype)
|
560 |
+
|
561 |
+
select_vframes_trans = trans(select_vframes)
|
562 |
+
print(select_vframes_trans.shape)
|
563 |
+
print(select_vframes_trans.dtype)
|
564 |
+
|
565 |
+
select_vframes_trans_int = ((select_vframes_trans * 0.5 + 0.5) * 255).to(dtype=torch.uint8)
|
566 |
+
print(select_vframes_trans_int.dtype)
|
567 |
+
print(select_vframes_trans_int.permute(0, 2, 3, 1).shape)
|
568 |
+
|
569 |
+
io.write_video('./test.avi', select_vframes_trans_int.permute(0, 2, 3, 1), fps=8)
|
570 |
+
|
571 |
+
for i in range(target_video_len):
|
572 |
+
save_image(select_vframes_trans[i], os.path.join('./test000', '%04d.png' % i), normalize=True,
|
573 |
+
value_range=(-1, 1))
|
opensora/dataset/ucf101.py
ADDED
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
|
4 |
+
import decord
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torchvision
|
8 |
+
from decord import VideoReader, cpu
|
9 |
+
from torch.utils.data import Dataset
|
10 |
+
from torchvision.transforms import Compose, Lambda, ToTensor
|
11 |
+
from torchvision.transforms._transforms_video import NormalizeVideo, RandomCropVideo, RandomHorizontalFlipVideo
|
12 |
+
from pytorchvideo.transforms import ApplyTransformToKey, ShortSideScale, UniformTemporalSubsample
|
13 |
+
from torch.nn import functional as F
|
14 |
+
import random
|
15 |
+
|
16 |
+
from opensora.utils.dataset_utils import DecordInit
|
17 |
+
|
18 |
+
|
19 |
+
class UCF101(Dataset):
|
20 |
+
def __init__(self, args, transform, temporal_sample):
|
21 |
+
self.data_path = args.data_path
|
22 |
+
self.num_frames = args.num_frames
|
23 |
+
self.transform = transform
|
24 |
+
self.temporal_sample = temporal_sample
|
25 |
+
self.v_decoder = DecordInit()
|
26 |
+
|
27 |
+
self.classes = sorted(os.listdir(self.data_path))
|
28 |
+
self.class_to_idx = {cls_name: idx for idx, cls_name in enumerate(self.classes)}
|
29 |
+
self.samples = self._make_dataset()
|
30 |
+
|
31 |
+
|
32 |
+
def _make_dataset(self):
|
33 |
+
dataset = []
|
34 |
+
for class_name in self.classes:
|
35 |
+
class_path = os.path.join(self.data_path, class_name)
|
36 |
+
for fname in os.listdir(class_path):
|
37 |
+
if fname.endswith('.avi'):
|
38 |
+
item = (os.path.join(class_path, fname), self.class_to_idx[class_name])
|
39 |
+
dataset.append(item)
|
40 |
+
return dataset
|
41 |
+
|
42 |
+
def __len__(self):
|
43 |
+
return len(self.samples)
|
44 |
+
|
45 |
+
def __getitem__(self, idx):
|
46 |
+
video_path, label = self.samples[idx]
|
47 |
+
try:
|
48 |
+
video = self.tv_read(video_path)
|
49 |
+
video = self.transform(video) # T C H W -> T C H W
|
50 |
+
video = video.transpose(0, 1) # T C H W -> C T H W
|
51 |
+
return video, label
|
52 |
+
except Exception as e:
|
53 |
+
print(f'Error with {e}, {video_path}')
|
54 |
+
return self.__getitem__(random.randint(0, self.__len__()-1))
|
55 |
+
|
56 |
+
def tv_read(self, path):
|
57 |
+
vframes, aframes, info = torchvision.io.read_video(filename=path, pts_unit='sec', output_format='TCHW')
|
58 |
+
total_frames = len(vframes)
|
59 |
+
|
60 |
+
# Sampling video frames
|
61 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
62 |
+
# assert end_frame_ind - start_frame_ind >= self.num_frames
|
63 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
64 |
+
video = vframes[frame_indice] # (T, C, H, W)
|
65 |
+
|
66 |
+
return video
|
67 |
+
|
68 |
+
def decord_read(self, path):
|
69 |
+
decord_vr = self.v_decoder(path)
|
70 |
+
total_frames = len(decord_vr)
|
71 |
+
# Sampling video frames
|
72 |
+
start_frame_ind, end_frame_ind = self.temporal_sample(total_frames)
|
73 |
+
# assert end_frame_ind - start_frame_ind >= self.num_frames
|
74 |
+
frame_indice = np.linspace(start_frame_ind, end_frame_ind - 1, self.num_frames, dtype=int)
|
75 |
+
|
76 |
+
video_data = decord_vr.get_batch(frame_indice).asnumpy()
|
77 |
+
video_data = torch.from_numpy(video_data)
|
78 |
+
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (T C H W)
|
79 |
+
return video_data
|
80 |
+
|
opensora/eval/cal_flolpips.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import math
|
5 |
+
from einops import rearrange
|
6 |
+
import sys
|
7 |
+
sys.path.append(".")
|
8 |
+
from opensora.eval.flolpips.pwcnet import Network as PWCNet
|
9 |
+
from opensora.eval.flolpips.flolpips import FloLPIPS
|
10 |
+
|
11 |
+
loss_fn = FloLPIPS(net='alex', version='0.1').eval().requires_grad_(False)
|
12 |
+
flownet = PWCNet().eval().requires_grad_(False)
|
13 |
+
|
14 |
+
def trans(x):
|
15 |
+
return x
|
16 |
+
|
17 |
+
|
18 |
+
def calculate_flolpips(videos1, videos2, device):
|
19 |
+
global loss_fn, flownet
|
20 |
+
|
21 |
+
print("calculate_flowlpips...")
|
22 |
+
loss_fn = loss_fn.to(device)
|
23 |
+
flownet = flownet.to(device)
|
24 |
+
|
25 |
+
if videos1.shape != videos2.shape:
|
26 |
+
print("Warning: the shape of videos are not equal.")
|
27 |
+
min_frames = min(videos1.shape[1], videos2.shape[1])
|
28 |
+
videos1 = videos1[:, :min_frames]
|
29 |
+
videos2 = videos2[:, :min_frames]
|
30 |
+
|
31 |
+
videos1 = trans(videos1)
|
32 |
+
videos2 = trans(videos2)
|
33 |
+
|
34 |
+
flolpips_results = []
|
35 |
+
for video_num in tqdm(range(videos1.shape[0])):
|
36 |
+
video1 = videos1[video_num].to(device)
|
37 |
+
video2 = videos2[video_num].to(device)
|
38 |
+
frames_rec = video1[:-1]
|
39 |
+
frames_rec_next = video1[1:]
|
40 |
+
frames_gt = video2[:-1]
|
41 |
+
frames_gt_next = video2[1:]
|
42 |
+
t, c, h, w = frames_gt.shape
|
43 |
+
flow_gt = flownet(frames_gt, frames_gt_next)
|
44 |
+
flow_dis = flownet(frames_rec, frames_rec_next)
|
45 |
+
flow_diff = flow_gt - flow_dis
|
46 |
+
flolpips = loss_fn.forward(frames_gt, frames_rec, flow_diff, normalize=True)
|
47 |
+
flolpips_results.append(flolpips.cpu().numpy().tolist())
|
48 |
+
|
49 |
+
flolpips_results = np.array(flolpips_results) # [batch_size, num_frames]
|
50 |
+
flolpips = {}
|
51 |
+
flolpips_std = {}
|
52 |
+
|
53 |
+
for clip_timestamp in range(flolpips_results.shape[1]):
|
54 |
+
flolpips[clip_timestamp] = np.mean(flolpips_results[:,clip_timestamp], axis=-1)
|
55 |
+
flolpips_std[clip_timestamp] = np.std(flolpips_results[:,clip_timestamp], axis=-1)
|
56 |
+
|
57 |
+
result = {
|
58 |
+
"value": flolpips,
|
59 |
+
"value_std": flolpips_std,
|
60 |
+
"video_setting": video1.shape,
|
61 |
+
"video_setting_name": "time, channel, heigth, width",
|
62 |
+
"result": flolpips_results,
|
63 |
+
"details": flolpips_results.tolist()
|
64 |
+
}
|
65 |
+
|
66 |
+
return result
|
67 |
+
|
68 |
+
# test code / using example
|
69 |
+
|
70 |
+
def main():
|
71 |
+
NUMBER_OF_VIDEOS = 8
|
72 |
+
VIDEO_LENGTH = 50
|
73 |
+
CHANNEL = 3
|
74 |
+
SIZE = 64
|
75 |
+
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
76 |
+
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
77 |
+
|
78 |
+
import json
|
79 |
+
result = calculate_flolpips(videos1, videos2, "cuda:0")
|
80 |
+
print(json.dumps(result, indent=4))
|
81 |
+
|
82 |
+
if __name__ == "__main__":
|
83 |
+
main()
|
opensora/eval/cal_fvd.py
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
def trans(x):
|
6 |
+
# if greyscale images add channel
|
7 |
+
if x.shape[-3] == 1:
|
8 |
+
x = x.repeat(1, 1, 3, 1, 1)
|
9 |
+
|
10 |
+
# permute BTCHW -> BCTHW
|
11 |
+
x = x.permute(0, 2, 1, 3, 4)
|
12 |
+
|
13 |
+
return x
|
14 |
+
|
15 |
+
def calculate_fvd(videos1, videos2, device, method='styleganv'):
|
16 |
+
|
17 |
+
if method == 'styleganv':
|
18 |
+
from fvd.styleganv.fvd import get_fvd_feats, frechet_distance, load_i3d_pretrained
|
19 |
+
elif method == 'videogpt':
|
20 |
+
from fvd.videogpt.fvd import load_i3d_pretrained
|
21 |
+
from fvd.videogpt.fvd import get_fvd_logits as get_fvd_feats
|
22 |
+
from fvd.videogpt.fvd import frechet_distance
|
23 |
+
|
24 |
+
print("calculate_fvd...")
|
25 |
+
|
26 |
+
# videos [batch_size, timestamps, channel, h, w]
|
27 |
+
|
28 |
+
assert videos1.shape == videos2.shape
|
29 |
+
|
30 |
+
i3d = load_i3d_pretrained(device=device)
|
31 |
+
fvd_results = []
|
32 |
+
|
33 |
+
# support grayscale input, if grayscale -> channel*3
|
34 |
+
# BTCHW -> BCTHW
|
35 |
+
# videos -> [batch_size, channel, timestamps, h, w]
|
36 |
+
|
37 |
+
videos1 = trans(videos1)
|
38 |
+
videos2 = trans(videos2)
|
39 |
+
|
40 |
+
fvd_results = {}
|
41 |
+
|
42 |
+
# for calculate FVD, each clip_timestamp must >= 10
|
43 |
+
for clip_timestamp in tqdm(range(10, videos1.shape[-3]+1)):
|
44 |
+
|
45 |
+
# get a video clip
|
46 |
+
# videos_clip [batch_size, channel, timestamps[:clip], h, w]
|
47 |
+
videos_clip1 = videos1[:, :, : clip_timestamp]
|
48 |
+
videos_clip2 = videos2[:, :, : clip_timestamp]
|
49 |
+
|
50 |
+
# get FVD features
|
51 |
+
feats1 = get_fvd_feats(videos_clip1, i3d=i3d, device=device)
|
52 |
+
feats2 = get_fvd_feats(videos_clip2, i3d=i3d, device=device)
|
53 |
+
|
54 |
+
# calculate FVD when timestamps[:clip]
|
55 |
+
fvd_results[clip_timestamp] = frechet_distance(feats1, feats2)
|
56 |
+
|
57 |
+
result = {
|
58 |
+
"value": fvd_results,
|
59 |
+
"video_setting": videos1.shape,
|
60 |
+
"video_setting_name": "batch_size, channel, time, heigth, width",
|
61 |
+
}
|
62 |
+
|
63 |
+
return result
|
64 |
+
|
65 |
+
# test code / using example
|
66 |
+
|
67 |
+
def main():
|
68 |
+
NUMBER_OF_VIDEOS = 8
|
69 |
+
VIDEO_LENGTH = 50
|
70 |
+
CHANNEL = 3
|
71 |
+
SIZE = 64
|
72 |
+
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
73 |
+
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
74 |
+
device = torch.device("cuda")
|
75 |
+
# device = torch.device("cpu")
|
76 |
+
|
77 |
+
import json
|
78 |
+
result = calculate_fvd(videos1, videos2, device, method='videogpt')
|
79 |
+
print(json.dumps(result, indent=4))
|
80 |
+
|
81 |
+
result = calculate_fvd(videos1, videos2, device, method='styleganv')
|
82 |
+
print(json.dumps(result, indent=4))
|
83 |
+
|
84 |
+
if __name__ == "__main__":
|
85 |
+
main()
|
opensora/eval/cal_lpips.py
ADDED
@@ -0,0 +1,97 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import math
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import lpips
|
8 |
+
|
9 |
+
spatial = True # Return a spatial map of perceptual distance.
|
10 |
+
|
11 |
+
# Linearly calibrated models (LPIPS)
|
12 |
+
loss_fn = lpips.LPIPS(net='alex', spatial=spatial) # Can also set net = 'squeeze' or 'vgg'
|
13 |
+
# loss_fn = lpips.LPIPS(net='alex', spatial=spatial, lpips=False) # Can also set net = 'squeeze' or 'vgg'
|
14 |
+
|
15 |
+
def trans(x):
|
16 |
+
# if greyscale images add channel
|
17 |
+
if x.shape[-3] == 1:
|
18 |
+
x = x.repeat(1, 1, 3, 1, 1)
|
19 |
+
|
20 |
+
# value range [0, 1] -> [-1, 1]
|
21 |
+
x = x * 2 - 1
|
22 |
+
|
23 |
+
return x
|
24 |
+
|
25 |
+
def calculate_lpips(videos1, videos2, device):
|
26 |
+
# image should be RGB, IMPORTANT: normalized to [-1,1]
|
27 |
+
print("calculate_lpips...")
|
28 |
+
|
29 |
+
assert videos1.shape == videos2.shape
|
30 |
+
|
31 |
+
# videos [batch_size, timestamps, channel, h, w]
|
32 |
+
|
33 |
+
# support grayscale input, if grayscale -> channel*3
|
34 |
+
# value range [0, 1] -> [-1, 1]
|
35 |
+
videos1 = trans(videos1)
|
36 |
+
videos2 = trans(videos2)
|
37 |
+
|
38 |
+
lpips_results = []
|
39 |
+
|
40 |
+
for video_num in tqdm(range(videos1.shape[0])):
|
41 |
+
# get a video
|
42 |
+
# video [timestamps, channel, h, w]
|
43 |
+
video1 = videos1[video_num]
|
44 |
+
video2 = videos2[video_num]
|
45 |
+
|
46 |
+
lpips_results_of_a_video = []
|
47 |
+
for clip_timestamp in range(len(video1)):
|
48 |
+
# get a img
|
49 |
+
# img [timestamps[x], channel, h, w]
|
50 |
+
# img [channel, h, w] tensor
|
51 |
+
|
52 |
+
img1 = video1[clip_timestamp].unsqueeze(0).to(device)
|
53 |
+
img2 = video2[clip_timestamp].unsqueeze(0).to(device)
|
54 |
+
|
55 |
+
loss_fn.to(device)
|
56 |
+
|
57 |
+
# calculate lpips of a video
|
58 |
+
lpips_results_of_a_video.append(loss_fn.forward(img1, img2).mean().detach().cpu().tolist())
|
59 |
+
lpips_results.append(lpips_results_of_a_video)
|
60 |
+
|
61 |
+
lpips_results = np.array(lpips_results)
|
62 |
+
|
63 |
+
lpips = {}
|
64 |
+
lpips_std = {}
|
65 |
+
|
66 |
+
for clip_timestamp in range(len(video1)):
|
67 |
+
lpips[clip_timestamp] = np.mean(lpips_results[:,clip_timestamp])
|
68 |
+
lpips_std[clip_timestamp] = np.std(lpips_results[:,clip_timestamp])
|
69 |
+
|
70 |
+
|
71 |
+
result = {
|
72 |
+
"value": lpips,
|
73 |
+
"value_std": lpips_std,
|
74 |
+
"video_setting": video1.shape,
|
75 |
+
"video_setting_name": "time, channel, heigth, width",
|
76 |
+
}
|
77 |
+
|
78 |
+
return result
|
79 |
+
|
80 |
+
# test code / using example
|
81 |
+
|
82 |
+
def main():
|
83 |
+
NUMBER_OF_VIDEOS = 8
|
84 |
+
VIDEO_LENGTH = 50
|
85 |
+
CHANNEL = 3
|
86 |
+
SIZE = 64
|
87 |
+
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
88 |
+
videos2 = torch.ones(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
89 |
+
device = torch.device("cuda")
|
90 |
+
# device = torch.device("cpu")
|
91 |
+
|
92 |
+
import json
|
93 |
+
result = calculate_lpips(videos1, videos2, device)
|
94 |
+
print(json.dumps(result, indent=4))
|
95 |
+
|
96 |
+
if __name__ == "__main__":
|
97 |
+
main()
|
opensora/eval/cal_psnr.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import math
|
5 |
+
|
6 |
+
def img_psnr(img1, img2):
|
7 |
+
# [0,1]
|
8 |
+
# compute mse
|
9 |
+
# mse = np.mean((img1-img2)**2)
|
10 |
+
mse = np.mean((img1 / 1.0 - img2 / 1.0) ** 2)
|
11 |
+
# compute psnr
|
12 |
+
if mse < 1e-10:
|
13 |
+
return 100
|
14 |
+
psnr = 20 * math.log10(1 / math.sqrt(mse))
|
15 |
+
return psnr
|
16 |
+
|
17 |
+
def trans(x):
|
18 |
+
return x
|
19 |
+
|
20 |
+
def calculate_psnr(videos1, videos2):
|
21 |
+
print("calculate_psnr...")
|
22 |
+
|
23 |
+
# videos [batch_size, timestamps, channel, h, w]
|
24 |
+
|
25 |
+
assert videos1.shape == videos2.shape
|
26 |
+
|
27 |
+
videos1 = trans(videos1)
|
28 |
+
videos2 = trans(videos2)
|
29 |
+
|
30 |
+
psnr_results = []
|
31 |
+
|
32 |
+
for video_num in tqdm(range(videos1.shape[0])):
|
33 |
+
# get a video
|
34 |
+
# video [timestamps, channel, h, w]
|
35 |
+
video1 = videos1[video_num]
|
36 |
+
video2 = videos2[video_num]
|
37 |
+
|
38 |
+
psnr_results_of_a_video = []
|
39 |
+
for clip_timestamp in range(len(video1)):
|
40 |
+
# get a img
|
41 |
+
# img [timestamps[x], channel, h, w]
|
42 |
+
# img [channel, h, w] numpy
|
43 |
+
|
44 |
+
img1 = video1[clip_timestamp].numpy()
|
45 |
+
img2 = video2[clip_timestamp].numpy()
|
46 |
+
|
47 |
+
# calculate psnr of a video
|
48 |
+
psnr_results_of_a_video.append(img_psnr(img1, img2))
|
49 |
+
|
50 |
+
psnr_results.append(psnr_results_of_a_video)
|
51 |
+
|
52 |
+
psnr_results = np.array(psnr_results) # [batch_size, num_frames]
|
53 |
+
psnr = {}
|
54 |
+
psnr_std = {}
|
55 |
+
|
56 |
+
for clip_timestamp in range(len(video1)):
|
57 |
+
psnr[clip_timestamp] = np.mean(psnr_results[:,clip_timestamp])
|
58 |
+
psnr_std[clip_timestamp] = np.std(psnr_results[:,clip_timestamp])
|
59 |
+
|
60 |
+
result = {
|
61 |
+
"value": psnr,
|
62 |
+
"value_std": psnr_std,
|
63 |
+
"video_setting": video1.shape,
|
64 |
+
"video_setting_name": "time, channel, heigth, width",
|
65 |
+
}
|
66 |
+
|
67 |
+
return result
|
68 |
+
|
69 |
+
# test code / using example
|
70 |
+
|
71 |
+
def main():
|
72 |
+
NUMBER_OF_VIDEOS = 8
|
73 |
+
VIDEO_LENGTH = 50
|
74 |
+
CHANNEL = 3
|
75 |
+
SIZE = 64
|
76 |
+
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
77 |
+
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
78 |
+
|
79 |
+
import json
|
80 |
+
result = calculate_psnr(videos1, videos2)
|
81 |
+
print(json.dumps(result, indent=4))
|
82 |
+
|
83 |
+
if __name__ == "__main__":
|
84 |
+
main()
|
opensora/eval/cal_ssim.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from tqdm import tqdm
|
4 |
+
import cv2
|
5 |
+
|
6 |
+
def ssim(img1, img2):
|
7 |
+
C1 = 0.01 ** 2
|
8 |
+
C2 = 0.03 ** 2
|
9 |
+
img1 = img1.astype(np.float64)
|
10 |
+
img2 = img2.astype(np.float64)
|
11 |
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
12 |
+
window = np.outer(kernel, kernel.transpose())
|
13 |
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
14 |
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
15 |
+
mu1_sq = mu1 ** 2
|
16 |
+
mu2_sq = mu2 ** 2
|
17 |
+
mu1_mu2 = mu1 * mu2
|
18 |
+
sigma1_sq = cv2.filter2D(img1 ** 2, -1, window)[5:-5, 5:-5] - mu1_sq
|
19 |
+
sigma2_sq = cv2.filter2D(img2 ** 2, -1, window)[5:-5, 5:-5] - mu2_sq
|
20 |
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
21 |
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
22 |
+
(sigma1_sq + sigma2_sq + C2))
|
23 |
+
return ssim_map.mean()
|
24 |
+
|
25 |
+
|
26 |
+
def calculate_ssim_function(img1, img2):
|
27 |
+
# [0,1]
|
28 |
+
# ssim is the only metric extremely sensitive to gray being compared to b/w
|
29 |
+
if not img1.shape == img2.shape:
|
30 |
+
raise ValueError('Input images must have the same dimensions.')
|
31 |
+
if img1.ndim == 2:
|
32 |
+
return ssim(img1, img2)
|
33 |
+
elif img1.ndim == 3:
|
34 |
+
if img1.shape[0] == 3:
|
35 |
+
ssims = []
|
36 |
+
for i in range(3):
|
37 |
+
ssims.append(ssim(img1[i], img2[i]))
|
38 |
+
return np.array(ssims).mean()
|
39 |
+
elif img1.shape[0] == 1:
|
40 |
+
return ssim(np.squeeze(img1), np.squeeze(img2))
|
41 |
+
else:
|
42 |
+
raise ValueError('Wrong input image dimensions.')
|
43 |
+
|
44 |
+
def trans(x):
|
45 |
+
return x
|
46 |
+
|
47 |
+
def calculate_ssim(videos1, videos2):
|
48 |
+
print("calculate_ssim...")
|
49 |
+
|
50 |
+
# videos [batch_size, timestamps, channel, h, w]
|
51 |
+
|
52 |
+
assert videos1.shape == videos2.shape
|
53 |
+
|
54 |
+
videos1 = trans(videos1)
|
55 |
+
videos2 = trans(videos2)
|
56 |
+
|
57 |
+
ssim_results = []
|
58 |
+
|
59 |
+
for video_num in tqdm(range(videos1.shape[0])):
|
60 |
+
# get a video
|
61 |
+
# video [timestamps, channel, h, w]
|
62 |
+
video1 = videos1[video_num]
|
63 |
+
video2 = videos2[video_num]
|
64 |
+
|
65 |
+
ssim_results_of_a_video = []
|
66 |
+
for clip_timestamp in range(len(video1)):
|
67 |
+
# get a img
|
68 |
+
# img [timestamps[x], channel, h, w]
|
69 |
+
# img [channel, h, w] numpy
|
70 |
+
|
71 |
+
img1 = video1[clip_timestamp].numpy()
|
72 |
+
img2 = video2[clip_timestamp].numpy()
|
73 |
+
|
74 |
+
# calculate ssim of a video
|
75 |
+
ssim_results_of_a_video.append(calculate_ssim_function(img1, img2))
|
76 |
+
|
77 |
+
ssim_results.append(ssim_results_of_a_video)
|
78 |
+
|
79 |
+
ssim_results = np.array(ssim_results)
|
80 |
+
|
81 |
+
ssim = {}
|
82 |
+
ssim_std = {}
|
83 |
+
|
84 |
+
for clip_timestamp in range(len(video1)):
|
85 |
+
ssim[clip_timestamp] = np.mean(ssim_results[:,clip_timestamp])
|
86 |
+
ssim_std[clip_timestamp] = np.std(ssim_results[:,clip_timestamp])
|
87 |
+
|
88 |
+
result = {
|
89 |
+
"value": ssim,
|
90 |
+
"value_std": ssim_std,
|
91 |
+
"video_setting": video1.shape,
|
92 |
+
"video_setting_name": "time, channel, heigth, width",
|
93 |
+
}
|
94 |
+
|
95 |
+
return result
|
96 |
+
|
97 |
+
# test code / using example
|
98 |
+
|
99 |
+
def main():
|
100 |
+
NUMBER_OF_VIDEOS = 8
|
101 |
+
VIDEO_LENGTH = 50
|
102 |
+
CHANNEL = 3
|
103 |
+
SIZE = 64
|
104 |
+
videos1 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
105 |
+
videos2 = torch.zeros(NUMBER_OF_VIDEOS, VIDEO_LENGTH, CHANNEL, SIZE, SIZE, requires_grad=False)
|
106 |
+
device = torch.device("cuda")
|
107 |
+
|
108 |
+
import json
|
109 |
+
result = calculate_ssim(videos1, videos2)
|
110 |
+
print(json.dumps(result, indent=4))
|
111 |
+
|
112 |
+
if __name__ == "__main__":
|
113 |
+
main()
|
opensora/eval/eval_clip_score.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Calculates the CLIP Scores
|
2 |
+
|
3 |
+
The CLIP model is a contrasitively learned language-image model. There is
|
4 |
+
an image encoder and a text encoder. It is believed that the CLIP model could
|
5 |
+
measure the similarity of cross modalities. Please find more information from
|
6 |
+
https://github.com/openai/CLIP.
|
7 |
+
|
8 |
+
The CLIP Score measures the Cosine Similarity between two embedded features.
|
9 |
+
This repository utilizes the pretrained CLIP Model to calculate
|
10 |
+
the mean average of cosine similarities.
|
11 |
+
|
12 |
+
See --help to see further details.
|
13 |
+
|
14 |
+
Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP.
|
15 |
+
|
16 |
+
Copyright 2023 The Hong Kong Polytechnic University
|
17 |
+
|
18 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
19 |
+
you may not use this file except in compliance with the License.
|
20 |
+
You may obtain a copy of the License at
|
21 |
+
|
22 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
23 |
+
|
24 |
+
Unless required by applicable law or agreed to in writing, software
|
25 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
26 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
27 |
+
See the License for the specific language governing permissions and
|
28 |
+
limitations under the License.
|
29 |
+
"""
|
30 |
+
import os
|
31 |
+
import os.path as osp
|
32 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
33 |
+
|
34 |
+
import clip
|
35 |
+
import torch
|
36 |
+
from PIL import Image
|
37 |
+
from torch.utils.data import Dataset, DataLoader
|
38 |
+
|
39 |
+
try:
|
40 |
+
from tqdm import tqdm
|
41 |
+
except ImportError:
|
42 |
+
# If tqdm is not available, provide a mock version of it
|
43 |
+
def tqdm(x):
|
44 |
+
return x
|
45 |
+
|
46 |
+
|
47 |
+
IMAGE_EXTENSIONS = {'bmp', 'jpg', 'jpeg', 'pgm', 'png', 'ppm',
|
48 |
+
'tif', 'tiff', 'webp'}
|
49 |
+
|
50 |
+
TEXT_EXTENSIONS = {'txt'}
|
51 |
+
|
52 |
+
|
53 |
+
class DummyDataset(Dataset):
|
54 |
+
|
55 |
+
FLAGS = ['img', 'txt']
|
56 |
+
def __init__(self, real_path, generated_path,
|
57 |
+
real_flag: str = 'img',
|
58 |
+
generated_flag: str = 'img',
|
59 |
+
transform = None,
|
60 |
+
tokenizer = None) -> None:
|
61 |
+
super().__init__()
|
62 |
+
assert real_flag in self.FLAGS and generated_flag in self.FLAGS, \
|
63 |
+
'CLIP Score only support modality of {}. However, get {} and {}'.format(
|
64 |
+
self.FLAGS, real_flag, generated_flag
|
65 |
+
)
|
66 |
+
self.real_folder = self._combine_without_prefix(real_path)
|
67 |
+
self.real_flag = real_flag
|
68 |
+
self.fake_foler = self._combine_without_prefix(generated_path)
|
69 |
+
self.generated_flag = generated_flag
|
70 |
+
self.transform = transform
|
71 |
+
self.tokenizer = tokenizer
|
72 |
+
# assert self._check()
|
73 |
+
|
74 |
+
def __len__(self):
|
75 |
+
return len(self.real_folder)
|
76 |
+
|
77 |
+
def __getitem__(self, index):
|
78 |
+
if index >= len(self):
|
79 |
+
raise IndexError
|
80 |
+
real_path = self.real_folder[index]
|
81 |
+
generated_path = self.fake_foler[index]
|
82 |
+
real_data = self._load_modality(real_path, self.real_flag)
|
83 |
+
fake_data = self._load_modality(generated_path, self.generated_flag)
|
84 |
+
|
85 |
+
sample = dict(real=real_data, fake=fake_data)
|
86 |
+
return sample
|
87 |
+
|
88 |
+
def _load_modality(self, path, modality):
|
89 |
+
if modality == 'img':
|
90 |
+
data = self._load_img(path)
|
91 |
+
elif modality == 'txt':
|
92 |
+
data = self._load_txt(path)
|
93 |
+
else:
|
94 |
+
raise TypeError("Got unexpected modality: {}".format(modality))
|
95 |
+
return data
|
96 |
+
|
97 |
+
def _load_img(self, path):
|
98 |
+
img = Image.open(path)
|
99 |
+
if self.transform is not None:
|
100 |
+
img = self.transform(img)
|
101 |
+
return img
|
102 |
+
|
103 |
+
def _load_txt(self, path):
|
104 |
+
with open(path, 'r') as fp:
|
105 |
+
data = fp.read()
|
106 |
+
fp.close()
|
107 |
+
if self.tokenizer is not None:
|
108 |
+
data = self.tokenizer(data).squeeze()
|
109 |
+
return data
|
110 |
+
|
111 |
+
def _check(self):
|
112 |
+
for idx in range(len(self)):
|
113 |
+
real_name = self.real_folder[idx].split('.')
|
114 |
+
fake_name = self.fake_folder[idx].split('.')
|
115 |
+
if fake_name != real_name:
|
116 |
+
return False
|
117 |
+
return True
|
118 |
+
|
119 |
+
def _combine_without_prefix(self, folder_path, prefix='.'):
|
120 |
+
folder = []
|
121 |
+
for name in os.listdir(folder_path):
|
122 |
+
if name[0] == prefix:
|
123 |
+
continue
|
124 |
+
folder.append(osp.join(folder_path, name))
|
125 |
+
folder.sort()
|
126 |
+
return folder
|
127 |
+
|
128 |
+
|
129 |
+
@torch.no_grad()
|
130 |
+
def calculate_clip_score(dataloader, model, real_flag, generated_flag):
|
131 |
+
score_acc = 0.
|
132 |
+
sample_num = 0.
|
133 |
+
logit_scale = model.logit_scale.exp()
|
134 |
+
for batch_data in tqdm(dataloader):
|
135 |
+
real = batch_data['real']
|
136 |
+
real_features = forward_modality(model, real, real_flag)
|
137 |
+
fake = batch_data['fake']
|
138 |
+
fake_features = forward_modality(model, fake, generated_flag)
|
139 |
+
|
140 |
+
# normalize features
|
141 |
+
real_features = real_features / real_features.norm(dim=1, keepdim=True).to(torch.float32)
|
142 |
+
fake_features = fake_features / fake_features.norm(dim=1, keepdim=True).to(torch.float32)
|
143 |
+
|
144 |
+
# calculate scores
|
145 |
+
# score = logit_scale * real_features @ fake_features.t()
|
146 |
+
# score_acc += torch.diag(score).sum()
|
147 |
+
score = logit_scale * (fake_features * real_features).sum()
|
148 |
+
score_acc += score
|
149 |
+
sample_num += real.shape[0]
|
150 |
+
|
151 |
+
return score_acc / sample_num
|
152 |
+
|
153 |
+
|
154 |
+
def forward_modality(model, data, flag):
|
155 |
+
device = next(model.parameters()).device
|
156 |
+
if flag == 'img':
|
157 |
+
features = model.encode_image(data.to(device))
|
158 |
+
elif flag == 'txt':
|
159 |
+
features = model.encode_text(data.to(device))
|
160 |
+
else:
|
161 |
+
raise TypeError
|
162 |
+
return features
|
163 |
+
|
164 |
+
|
165 |
+
def main():
|
166 |
+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
167 |
+
parser.add_argument('--batch-size', type=int, default=50,
|
168 |
+
help='Batch size to use')
|
169 |
+
parser.add_argument('--clip-model', type=str, default='ViT-B/32',
|
170 |
+
help='CLIP model to use')
|
171 |
+
parser.add_argument('--num-workers', type=int, default=8,
|
172 |
+
help=('Number of processes to use for data loading. '
|
173 |
+
'Defaults to `min(8, num_cpus)`'))
|
174 |
+
parser.add_argument('--device', type=str, default=None,
|
175 |
+
help='Device to use. Like cuda, cuda:0 or cpu')
|
176 |
+
parser.add_argument('--real_flag', type=str, default='img',
|
177 |
+
help=('The modality of real path. '
|
178 |
+
'Default to img'))
|
179 |
+
parser.add_argument('--generated_flag', type=str, default='txt',
|
180 |
+
help=('The modality of generated path. '
|
181 |
+
'Default to txt'))
|
182 |
+
parser.add_argument('--real_path', type=str,
|
183 |
+
help=('Paths to the real images or '
|
184 |
+
'to .npz statistic files'))
|
185 |
+
parser.add_argument('--generated_path', type=str,
|
186 |
+
help=('Paths to the generated images or '
|
187 |
+
'to .npz statistic files'))
|
188 |
+
args = parser.parse_args()
|
189 |
+
|
190 |
+
if args.device is None:
|
191 |
+
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
192 |
+
else:
|
193 |
+
device = torch.device(args.device)
|
194 |
+
|
195 |
+
if args.num_workers is None:
|
196 |
+
try:
|
197 |
+
num_cpus = len(os.sched_getaffinity(0))
|
198 |
+
except AttributeError:
|
199 |
+
# os.sched_getaffinity is not available under Windows, use
|
200 |
+
# os.cpu_count instead (which may not return the *available* number
|
201 |
+
# of CPUs).
|
202 |
+
num_cpus = os.cpu_count()
|
203 |
+
|
204 |
+
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
|
205 |
+
else:
|
206 |
+
num_workers = args.num_workers
|
207 |
+
|
208 |
+
print('Loading CLIP model: {}'.format(args.clip_model))
|
209 |
+
model, preprocess = clip.load(args.clip_model, device=device)
|
210 |
+
|
211 |
+
dataset = DummyDataset(args.real_path, args.generated_path,
|
212 |
+
args.real_flag, args.generated_flag,
|
213 |
+
transform=preprocess, tokenizer=clip.tokenize)
|
214 |
+
dataloader = DataLoader(dataset, args.batch_size,
|
215 |
+
num_workers=num_workers, pin_memory=True)
|
216 |
+
|
217 |
+
print('Calculating CLIP Score:')
|
218 |
+
clip_score = calculate_clip_score(dataloader, model,
|
219 |
+
args.real_flag, args.generated_flag)
|
220 |
+
clip_score = clip_score.cpu().item()
|
221 |
+
print('CLIP Score: ', clip_score)
|
222 |
+
|
223 |
+
|
224 |
+
if __name__ == '__main__':
|
225 |
+
main()
|
opensora/eval/eval_common_metric.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Calculates the CLIP Scores
|
2 |
+
|
3 |
+
The CLIP model is a contrasitively learned language-image model. There is
|
4 |
+
an image encoder and a text encoder. It is believed that the CLIP model could
|
5 |
+
measure the similarity of cross modalities. Please find more information from
|
6 |
+
https://github.com/openai/CLIP.
|
7 |
+
|
8 |
+
The CLIP Score measures the Cosine Similarity between two embedded features.
|
9 |
+
This repository utilizes the pretrained CLIP Model to calculate
|
10 |
+
the mean average of cosine similarities.
|
11 |
+
|
12 |
+
See --help to see further details.
|
13 |
+
|
14 |
+
Code apapted from https://github.com/mseitzer/pytorch-fid and https://github.com/openai/CLIP.
|
15 |
+
|
16 |
+
Copyright 2023 The Hong Kong Polytechnic University
|
17 |
+
|
18 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
19 |
+
you may not use this file except in compliance with the License.
|
20 |
+
You may obtain a copy of the License at
|
21 |
+
|
22 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
23 |
+
|
24 |
+
Unless required by applicable law or agreed to in writing, software
|
25 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
26 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
27 |
+
See the License for the specific language governing permissions and
|
28 |
+
limitations under the License.
|
29 |
+
"""
|
30 |
+
|
31 |
+
import os
|
32 |
+
import os.path as osp
|
33 |
+
from argparse import ArgumentDefaultsHelpFormatter, ArgumentParser
|
34 |
+
import numpy as np
|
35 |
+
import torch
|
36 |
+
from torch.utils.data import Dataset, DataLoader, Subset
|
37 |
+
from decord import VideoReader, cpu
|
38 |
+
import random
|
39 |
+
from pytorchvideo.transforms import ShortSideScale
|
40 |
+
from torchvision.io import read_video
|
41 |
+
from torchvision.transforms import Lambda, Compose
|
42 |
+
from torchvision.transforms._transforms_video import CenterCropVideo
|
43 |
+
import sys
|
44 |
+
sys.path.append(".")
|
45 |
+
from opensora.eval.cal_lpips import calculate_lpips
|
46 |
+
from opensora.eval.cal_fvd import calculate_fvd
|
47 |
+
from opensora.eval.cal_psnr import calculate_psnr
|
48 |
+
from opensora.eval.cal_flolpips import calculate_flolpips
|
49 |
+
from opensora.eval.cal_ssim import calculate_ssim
|
50 |
+
|
51 |
+
try:
|
52 |
+
from tqdm import tqdm
|
53 |
+
except ImportError:
|
54 |
+
# If tqdm is not available, provide a mock version of it
|
55 |
+
def tqdm(x):
|
56 |
+
return x
|
57 |
+
|
58 |
+
class VideoDataset(Dataset):
|
59 |
+
def __init__(self,
|
60 |
+
real_video_dir,
|
61 |
+
generated_video_dir,
|
62 |
+
num_frames,
|
63 |
+
sample_rate = 1,
|
64 |
+
crop_size=None,
|
65 |
+
resolution=128,
|
66 |
+
) -> None:
|
67 |
+
super().__init__()
|
68 |
+
self.real_video_files = self._combine_without_prefix(real_video_dir)
|
69 |
+
self.generated_video_files = self._combine_without_prefix(generated_video_dir)
|
70 |
+
self.num_frames = num_frames
|
71 |
+
self.sample_rate = sample_rate
|
72 |
+
self.crop_size = crop_size
|
73 |
+
self.short_size = resolution
|
74 |
+
|
75 |
+
|
76 |
+
def __len__(self):
|
77 |
+
return len(self.real_video_files)
|
78 |
+
|
79 |
+
def __getitem__(self, index):
|
80 |
+
if index >= len(self):
|
81 |
+
raise IndexError
|
82 |
+
real_video_file = self.real_video_files[index]
|
83 |
+
generated_video_file = self.generated_video_files[index]
|
84 |
+
print(real_video_file, generated_video_file)
|
85 |
+
real_video_tensor = self._load_video(real_video_file)
|
86 |
+
generated_video_tensor = self._load_video(generated_video_file)
|
87 |
+
return {'real': real_video_tensor, 'generated':generated_video_tensor }
|
88 |
+
|
89 |
+
|
90 |
+
def _load_video(self, video_path):
|
91 |
+
num_frames = self.num_frames
|
92 |
+
sample_rate = self.sample_rate
|
93 |
+
decord_vr = VideoReader(video_path, ctx=cpu(0))
|
94 |
+
total_frames = len(decord_vr)
|
95 |
+
sample_frames_len = sample_rate * num_frames
|
96 |
+
|
97 |
+
if total_frames >= sample_frames_len:
|
98 |
+
s = 0
|
99 |
+
e = s + sample_frames_len
|
100 |
+
num_frames = num_frames
|
101 |
+
else:
|
102 |
+
s = 0
|
103 |
+
e = total_frames
|
104 |
+
num_frames = int(total_frames / sample_frames_len * num_frames)
|
105 |
+
print(f'sample_frames_len {sample_frames_len}, only can sample {num_frames * sample_rate}', video_path,
|
106 |
+
total_frames)
|
107 |
+
|
108 |
+
|
109 |
+
frame_id_list = np.linspace(s, e - 1, num_frames, dtype=int)
|
110 |
+
video_data = decord_vr.get_batch(frame_id_list).asnumpy()
|
111 |
+
video_data = torch.from_numpy(video_data)
|
112 |
+
video_data = video_data.permute(0, 3, 1, 2) # (T, H, W, C) -> (C, T, H, W)
|
113 |
+
return _preprocess(video_data, short_size=self.short_size, crop_size = self.crop_size)
|
114 |
+
|
115 |
+
|
116 |
+
def _combine_without_prefix(self, folder_path, prefix='.'):
|
117 |
+
folder = []
|
118 |
+
os.makedirs(folder_path, exist_ok=True)
|
119 |
+
for name in os.listdir(folder_path):
|
120 |
+
if name[0] == prefix:
|
121 |
+
continue
|
122 |
+
if osp.isfile(osp.join(folder_path, name)):
|
123 |
+
folder.append(osp.join(folder_path, name))
|
124 |
+
folder.sort()
|
125 |
+
return folder
|
126 |
+
|
127 |
+
def _preprocess(video_data, short_size=128, crop_size=None):
|
128 |
+
transform = Compose(
|
129 |
+
[
|
130 |
+
Lambda(lambda x: x / 255.0),
|
131 |
+
ShortSideScale(size=short_size),
|
132 |
+
CenterCropVideo(crop_size=crop_size),
|
133 |
+
]
|
134 |
+
)
|
135 |
+
video_outputs = transform(video_data)
|
136 |
+
# video_outputs = torch.unsqueeze(video_outputs, 0) # (bz,c,t,h,w)
|
137 |
+
return video_outputs
|
138 |
+
|
139 |
+
|
140 |
+
def calculate_common_metric(args, dataloader, device):
|
141 |
+
|
142 |
+
score_list = []
|
143 |
+
for batch_data in tqdm(dataloader): # {'real': real_video_tensor, 'generated':generated_video_tensor }
|
144 |
+
real_videos = batch_data['real']
|
145 |
+
generated_videos = batch_data['generated']
|
146 |
+
assert real_videos.shape[2] == generated_videos.shape[2]
|
147 |
+
if args.metric == 'fvd':
|
148 |
+
tmp_list = list(calculate_fvd(real_videos, generated_videos, args.device, method=args.fvd_method)['value'].values())
|
149 |
+
elif args.metric == 'ssim':
|
150 |
+
tmp_list = list(calculate_ssim(real_videos, generated_videos)['value'].values())
|
151 |
+
elif args.metric == 'psnr':
|
152 |
+
tmp_list = list(calculate_psnr(real_videos, generated_videos)['value'].values())
|
153 |
+
elif args.metric == 'flolpips':
|
154 |
+
result = calculate_flolpips(real_videos, generated_videos, args.device)
|
155 |
+
tmp_list = list(result['value'].values())
|
156 |
+
else:
|
157 |
+
tmp_list = list(calculate_lpips(real_videos, generated_videos, args.device)['value'].values())
|
158 |
+
score_list += tmp_list
|
159 |
+
return np.mean(score_list)
|
160 |
+
|
161 |
+
def main():
|
162 |
+
parser = ArgumentParser(formatter_class=ArgumentDefaultsHelpFormatter)
|
163 |
+
parser.add_argument('--batch_size', type=int, default=2,
|
164 |
+
help='Batch size to use')
|
165 |
+
parser.add_argument('--real_video_dir', type=str,
|
166 |
+
help=('the path of real videos`'))
|
167 |
+
parser.add_argument('--generated_video_dir', type=str,
|
168 |
+
help=('the path of generated videos`'))
|
169 |
+
parser.add_argument('--device', type=str, default=None,
|
170 |
+
help='Device to use. Like cuda, cuda:0 or cpu')
|
171 |
+
parser.add_argument('--num_workers', type=int, default=8,
|
172 |
+
help=('Number of processes to use for data loading. '
|
173 |
+
'Defaults to `min(8, num_cpus)`'))
|
174 |
+
parser.add_argument('--sample_fps', type=int, default=30)
|
175 |
+
parser.add_argument('--resolution', type=int, default=336)
|
176 |
+
parser.add_argument('--crop_size', type=int, default=None)
|
177 |
+
parser.add_argument('--num_frames', type=int, default=100)
|
178 |
+
parser.add_argument('--sample_rate', type=int, default=1)
|
179 |
+
parser.add_argument('--subset_size', type=int, default=None)
|
180 |
+
parser.add_argument("--metric", type=str, default="fvd",choices=['fvd','psnr','ssim','lpips', 'flolpips'])
|
181 |
+
parser.add_argument("--fvd_method", type=str, default='styleganv',choices=['styleganv','videogpt'])
|
182 |
+
|
183 |
+
|
184 |
+
args = parser.parse_args()
|
185 |
+
|
186 |
+
if args.device is None:
|
187 |
+
device = torch.device('cuda' if (torch.cuda.is_available()) else 'cpu')
|
188 |
+
else:
|
189 |
+
device = torch.device(args.device)
|
190 |
+
|
191 |
+
if args.num_workers is None:
|
192 |
+
try:
|
193 |
+
num_cpus = len(os.sched_getaffinity(0))
|
194 |
+
except AttributeError:
|
195 |
+
# os.sched_getaffinity is not available under Windows, use
|
196 |
+
# os.cpu_count instead (which may not return the *available* number
|
197 |
+
# of CPUs).
|
198 |
+
num_cpus = os.cpu_count()
|
199 |
+
|
200 |
+
num_workers = min(num_cpus, 8) if num_cpus is not None else 0
|
201 |
+
else:
|
202 |
+
num_workers = args.num_workers
|
203 |
+
|
204 |
+
|
205 |
+
dataset = VideoDataset(args.real_video_dir,
|
206 |
+
args.generated_video_dir,
|
207 |
+
num_frames = args.num_frames,
|
208 |
+
sample_rate = args.sample_rate,
|
209 |
+
crop_size=args.crop_size,
|
210 |
+
resolution=args.resolution)
|
211 |
+
|
212 |
+
if args.subset_size:
|
213 |
+
indices = range(args.subset_size)
|
214 |
+
dataset = Subset(dataset, indices=indices)
|
215 |
+
|
216 |
+
dataloader = DataLoader(dataset, args.batch_size,
|
217 |
+
num_workers=num_workers, pin_memory=True)
|
218 |
+
|
219 |
+
|
220 |
+
metric_score = calculate_common_metric(args, dataloader,device)
|
221 |
+
print('metric: ', args.metric, " ",metric_score)
|
222 |
+
|
223 |
+
if __name__ == '__main__':
|
224 |
+
main()
|
opensora/eval/flolpips/correlation/correlation.py
ADDED
@@ -0,0 +1,397 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import cupy
|
6 |
+
import re
|
7 |
+
|
8 |
+
kernel_Correlation_rearrange = '''
|
9 |
+
extern "C" __global__ void kernel_Correlation_rearrange(
|
10 |
+
const int n,
|
11 |
+
const float* input,
|
12 |
+
float* output
|
13 |
+
) {
|
14 |
+
int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
|
15 |
+
|
16 |
+
if (intIndex >= n) {
|
17 |
+
return;
|
18 |
+
}
|
19 |
+
|
20 |
+
int intSample = blockIdx.z;
|
21 |
+
int intChannel = blockIdx.y;
|
22 |
+
|
23 |
+
float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
|
24 |
+
|
25 |
+
__syncthreads();
|
26 |
+
|
27 |
+
int intPaddedY = (intIndex / SIZE_3(input)) + 4;
|
28 |
+
int intPaddedX = (intIndex % SIZE_3(input)) + 4;
|
29 |
+
int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
|
30 |
+
|
31 |
+
output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
|
32 |
+
}
|
33 |
+
'''
|
34 |
+
|
35 |
+
kernel_Correlation_updateOutput = '''
|
36 |
+
extern "C" __global__ void kernel_Correlation_updateOutput(
|
37 |
+
const int n,
|
38 |
+
const float* rbot0,
|
39 |
+
const float* rbot1,
|
40 |
+
float* top
|
41 |
+
) {
|
42 |
+
extern __shared__ char patch_data_char[];
|
43 |
+
|
44 |
+
float *patch_data = (float *)patch_data_char;
|
45 |
+
|
46 |
+
// First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
|
47 |
+
int x1 = blockIdx.x + 4;
|
48 |
+
int y1 = blockIdx.y + 4;
|
49 |
+
int item = blockIdx.z;
|
50 |
+
int ch_off = threadIdx.x;
|
51 |
+
|
52 |
+
// Load 3D patch into shared shared memory
|
53 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
54 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
55 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
56 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
57 |
+
int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
|
58 |
+
int idxPatchData = ji_off + ch;
|
59 |
+
patch_data[idxPatchData] = rbot0[idx1];
|
60 |
+
}
|
61 |
+
}
|
62 |
+
}
|
63 |
+
|
64 |
+
__syncthreads();
|
65 |
+
|
66 |
+
__shared__ float sum[32];
|
67 |
+
|
68 |
+
// Compute correlation
|
69 |
+
for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
|
70 |
+
sum[ch_off] = 0;
|
71 |
+
|
72 |
+
int s2o = top_channel % 9 - 4;
|
73 |
+
int s2p = top_channel / 9 - 4;
|
74 |
+
|
75 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
76 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
77 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
78 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
79 |
+
int x2 = x1 + s2o;
|
80 |
+
int y2 = y1 + s2p;
|
81 |
+
|
82 |
+
int idxPatchData = ji_off + ch;
|
83 |
+
int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
|
84 |
+
|
85 |
+
sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
|
86 |
+
}
|
87 |
+
}
|
88 |
+
}
|
89 |
+
|
90 |
+
__syncthreads();
|
91 |
+
|
92 |
+
if (ch_off == 0) {
|
93 |
+
float total_sum = 0;
|
94 |
+
for (int idx = 0; idx < 32; idx++) {
|
95 |
+
total_sum += sum[idx];
|
96 |
+
}
|
97 |
+
const int sumelems = SIZE_3(rbot0);
|
98 |
+
const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
|
99 |
+
top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
|
100 |
+
}
|
101 |
+
}
|
102 |
+
}
|
103 |
+
'''
|
104 |
+
|
105 |
+
kernel_Correlation_updateGradFirst = '''
|
106 |
+
#define ROUND_OFF 50000
|
107 |
+
|
108 |
+
extern "C" __global__ void kernel_Correlation_updateGradFirst(
|
109 |
+
const int n,
|
110 |
+
const int intSample,
|
111 |
+
const float* rbot0,
|
112 |
+
const float* rbot1,
|
113 |
+
const float* gradOutput,
|
114 |
+
float* gradFirst,
|
115 |
+
float* gradSecond
|
116 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
117 |
+
int n = intIndex % SIZE_1(gradFirst); // channels
|
118 |
+
int l = (intIndex / SIZE_1(gradFirst)) % SIZE_3(gradFirst) + 4; // w-pos
|
119 |
+
int m = (intIndex / SIZE_1(gradFirst) / SIZE_3(gradFirst)) % SIZE_2(gradFirst) + 4; // h-pos
|
120 |
+
|
121 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
122 |
+
// We use a large offset, for the inner part not to become negative.
|
123 |
+
const int round_off = ROUND_OFF;
|
124 |
+
const int round_off_s1 = round_off;
|
125 |
+
|
126 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
127 |
+
int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
|
128 |
+
int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
|
129 |
+
|
130 |
+
// Same here:
|
131 |
+
int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
|
132 |
+
int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
|
133 |
+
|
134 |
+
float sum = 0;
|
135 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
136 |
+
xmin = max(0,xmin);
|
137 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
138 |
+
|
139 |
+
ymin = max(0,ymin);
|
140 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
141 |
+
|
142 |
+
for (int p = -4; p <= 4; p++) {
|
143 |
+
for (int o = -4; o <= 4; o++) {
|
144 |
+
// Get rbot1 data:
|
145 |
+
int s2o = o;
|
146 |
+
int s2p = p;
|
147 |
+
int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
|
148 |
+
float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
|
149 |
+
|
150 |
+
// Index offset for gradOutput in following loops:
|
151 |
+
int op = (p+4) * 9 + (o+4); // index[o,p]
|
152 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
153 |
+
|
154 |
+
for (int y = ymin; y <= ymax; y++) {
|
155 |
+
for (int x = xmin; x <= xmax; x++) {
|
156 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
157 |
+
sum += gradOutput[idxgradOutput] * bot1tmp;
|
158 |
+
}
|
159 |
+
}
|
160 |
+
}
|
161 |
+
}
|
162 |
+
}
|
163 |
+
const int sumelems = SIZE_1(gradFirst);
|
164 |
+
const int bot0index = ((n * SIZE_2(gradFirst)) + (m-4)) * SIZE_3(gradFirst) + (l-4);
|
165 |
+
gradFirst[bot0index + intSample*SIZE_1(gradFirst)*SIZE_2(gradFirst)*SIZE_3(gradFirst)] = sum / (float)sumelems;
|
166 |
+
} }
|
167 |
+
'''
|
168 |
+
|
169 |
+
kernel_Correlation_updateGradSecond = '''
|
170 |
+
#define ROUND_OFF 50000
|
171 |
+
|
172 |
+
extern "C" __global__ void kernel_Correlation_updateGradSecond(
|
173 |
+
const int n,
|
174 |
+
const int intSample,
|
175 |
+
const float* rbot0,
|
176 |
+
const float* rbot1,
|
177 |
+
const float* gradOutput,
|
178 |
+
float* gradFirst,
|
179 |
+
float* gradSecond
|
180 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
181 |
+
int n = intIndex % SIZE_1(gradSecond); // channels
|
182 |
+
int l = (intIndex / SIZE_1(gradSecond)) % SIZE_3(gradSecond) + 4; // w-pos
|
183 |
+
int m = (intIndex / SIZE_1(gradSecond) / SIZE_3(gradSecond)) % SIZE_2(gradSecond) + 4; // h-pos
|
184 |
+
|
185 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
186 |
+
// We use a large offset, for the inner part not to become negative.
|
187 |
+
const int round_off = ROUND_OFF;
|
188 |
+
const int round_off_s1 = round_off;
|
189 |
+
|
190 |
+
float sum = 0;
|
191 |
+
for (int p = -4; p <= 4; p++) {
|
192 |
+
for (int o = -4; o <= 4; o++) {
|
193 |
+
int s2o = o;
|
194 |
+
int s2p = p;
|
195 |
+
|
196 |
+
//Get X,Y ranges and clamp
|
197 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
198 |
+
int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
|
199 |
+
int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
|
200 |
+
|
201 |
+
// Same here:
|
202 |
+
int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
|
203 |
+
int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
|
204 |
+
|
205 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
206 |
+
xmin = max(0,xmin);
|
207 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
208 |
+
|
209 |
+
ymin = max(0,ymin);
|
210 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
211 |
+
|
212 |
+
// Get rbot0 data:
|
213 |
+
int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
|
214 |
+
float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
|
215 |
+
|
216 |
+
// Index offset for gradOutput in following loops:
|
217 |
+
int op = (p+4) * 9 + (o+4); // index[o,p]
|
218 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
219 |
+
|
220 |
+
for (int y = ymin; y <= ymax; y++) {
|
221 |
+
for (int x = xmin; x <= xmax; x++) {
|
222 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
223 |
+
sum += gradOutput[idxgradOutput] * bot0tmp;
|
224 |
+
}
|
225 |
+
}
|
226 |
+
}
|
227 |
+
}
|
228 |
+
}
|
229 |
+
const int sumelems = SIZE_1(gradSecond);
|
230 |
+
const int bot1index = ((n * SIZE_2(gradSecond)) + (m-4)) * SIZE_3(gradSecond) + (l-4);
|
231 |
+
gradSecond[bot1index + intSample*SIZE_1(gradSecond)*SIZE_2(gradSecond)*SIZE_3(gradSecond)] = sum / (float)sumelems;
|
232 |
+
} }
|
233 |
+
'''
|
234 |
+
|
235 |
+
def cupy_kernel(strFunction, objVariables):
|
236 |
+
strKernel = globals()[strFunction]
|
237 |
+
|
238 |
+
while True:
|
239 |
+
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
240 |
+
|
241 |
+
if objMatch is None:
|
242 |
+
break
|
243 |
+
# end
|
244 |
+
|
245 |
+
intArg = int(objMatch.group(2))
|
246 |
+
|
247 |
+
strTensor = objMatch.group(4)
|
248 |
+
intSizes = objVariables[strTensor].size()
|
249 |
+
|
250 |
+
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg]))
|
251 |
+
# end
|
252 |
+
|
253 |
+
while True:
|
254 |
+
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
|
255 |
+
|
256 |
+
if objMatch is None:
|
257 |
+
break
|
258 |
+
# end
|
259 |
+
|
260 |
+
intArgs = int(objMatch.group(2))
|
261 |
+
strArgs = objMatch.group(4).split(',')
|
262 |
+
|
263 |
+
strTensor = strArgs[0]
|
264 |
+
intStrides = objVariables[strTensor].stride()
|
265 |
+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg]) + ')' for intArg in range(intArgs) ]
|
266 |
+
|
267 |
+
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
|
268 |
+
# end
|
269 |
+
|
270 |
+
return strKernel
|
271 |
+
# end
|
272 |
+
|
273 |
+
@cupy.memoize(for_each_device=True)
|
274 |
+
def cupy_launch(strFunction, strKernel):
|
275 |
+
return cupy.RawKernel(strKernel, strFunction)
|
276 |
+
# end
|
277 |
+
|
278 |
+
class _FunctionCorrelation(torch.autograd.Function):
|
279 |
+
@staticmethod
|
280 |
+
def forward(self, first, second):
|
281 |
+
rbot0 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ])
|
282 |
+
rbot1 = first.new_zeros([ first.shape[0], first.shape[2] + 8, first.shape[3] + 8, first.shape[1] ])
|
283 |
+
|
284 |
+
self.save_for_backward(first, second, rbot0, rbot1)
|
285 |
+
|
286 |
+
first = first.contiguous(); assert(first.is_cuda == True)
|
287 |
+
second = second.contiguous(); assert(second.is_cuda == True)
|
288 |
+
|
289 |
+
output = first.new_zeros([ first.shape[0], 81, first.shape[2], first.shape[3] ])
|
290 |
+
|
291 |
+
if first.is_cuda == True:
|
292 |
+
n = first.shape[2] * first.shape[3]
|
293 |
+
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
294 |
+
'input': first,
|
295 |
+
'output': rbot0
|
296 |
+
}))(
|
297 |
+
grid=tuple([ int((n + 16 - 1) / 16), first.shape[1], first.shape[0] ]),
|
298 |
+
block=tuple([ 16, 1, 1 ]),
|
299 |
+
args=[ n, first.data_ptr(), rbot0.data_ptr() ]
|
300 |
+
)
|
301 |
+
|
302 |
+
n = second.shape[2] * second.shape[3]
|
303 |
+
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
304 |
+
'input': second,
|
305 |
+
'output': rbot1
|
306 |
+
}))(
|
307 |
+
grid=tuple([ int((n + 16 - 1) / 16), second.shape[1], second.shape[0] ]),
|
308 |
+
block=tuple([ 16, 1, 1 ]),
|
309 |
+
args=[ n, second.data_ptr(), rbot1.data_ptr() ]
|
310 |
+
)
|
311 |
+
|
312 |
+
n = output.shape[1] * output.shape[2] * output.shape[3]
|
313 |
+
cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
|
314 |
+
'rbot0': rbot0,
|
315 |
+
'rbot1': rbot1,
|
316 |
+
'top': output
|
317 |
+
}))(
|
318 |
+
grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
|
319 |
+
block=tuple([ 32, 1, 1 ]),
|
320 |
+
shared_mem=first.shape[1] * 4,
|
321 |
+
args=[ n, rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
|
322 |
+
)
|
323 |
+
|
324 |
+
elif first.is_cuda == False:
|
325 |
+
raise NotImplementedError()
|
326 |
+
|
327 |
+
# end
|
328 |
+
|
329 |
+
return output
|
330 |
+
# end
|
331 |
+
|
332 |
+
@staticmethod
|
333 |
+
def backward(self, gradOutput):
|
334 |
+
first, second, rbot0, rbot1 = self.saved_tensors
|
335 |
+
|
336 |
+
gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True)
|
337 |
+
|
338 |
+
gradFirst = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[0] == True else None
|
339 |
+
gradSecond = first.new_zeros([ first.shape[0], first.shape[1], first.shape[2], first.shape[3] ]) if self.needs_input_grad[1] == True else None
|
340 |
+
|
341 |
+
if first.is_cuda == True:
|
342 |
+
if gradFirst is not None:
|
343 |
+
for intSample in range(first.shape[0]):
|
344 |
+
n = first.shape[1] * first.shape[2] * first.shape[3]
|
345 |
+
cupy_launch('kernel_Correlation_updateGradFirst', cupy_kernel('kernel_Correlation_updateGradFirst', {
|
346 |
+
'rbot0': rbot0,
|
347 |
+
'rbot1': rbot1,
|
348 |
+
'gradOutput': gradOutput,
|
349 |
+
'gradFirst': gradFirst,
|
350 |
+
'gradSecond': None
|
351 |
+
}))(
|
352 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
353 |
+
block=tuple([ 512, 1, 1 ]),
|
354 |
+
args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradFirst.data_ptr(), None ]
|
355 |
+
)
|
356 |
+
# end
|
357 |
+
# end
|
358 |
+
|
359 |
+
if gradSecond is not None:
|
360 |
+
for intSample in range(first.shape[0]):
|
361 |
+
n = first.shape[1] * first.shape[2] * first.shape[3]
|
362 |
+
cupy_launch('kernel_Correlation_updateGradSecond', cupy_kernel('kernel_Correlation_updateGradSecond', {
|
363 |
+
'rbot0': rbot0,
|
364 |
+
'rbot1': rbot1,
|
365 |
+
'gradOutput': gradOutput,
|
366 |
+
'gradFirst': None,
|
367 |
+
'gradSecond': gradSecond
|
368 |
+
}))(
|
369 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
370 |
+
block=tuple([ 512, 1, 1 ]),
|
371 |
+
args=[ n, intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradSecond.data_ptr() ]
|
372 |
+
)
|
373 |
+
# end
|
374 |
+
# end
|
375 |
+
|
376 |
+
elif first.is_cuda == False:
|
377 |
+
raise NotImplementedError()
|
378 |
+
|
379 |
+
# end
|
380 |
+
|
381 |
+
return gradFirst, gradSecond
|
382 |
+
# end
|
383 |
+
# end
|
384 |
+
|
385 |
+
def FunctionCorrelation(tenFirst, tenSecond):
|
386 |
+
return _FunctionCorrelation.apply(tenFirst, tenSecond)
|
387 |
+
# end
|
388 |
+
|
389 |
+
class ModuleCorrelation(torch.nn.Module):
|
390 |
+
def __init__(self):
|
391 |
+
super(ModuleCorrelation, self).__init__()
|
392 |
+
# end
|
393 |
+
|
394 |
+
def forward(self, tenFirst, tenSecond):
|
395 |
+
return _FunctionCorrelation.apply(tenFirst, tenSecond)
|
396 |
+
# end
|
397 |
+
# end
|
opensora/eval/flolpips/flolpips.py
ADDED
@@ -0,0 +1,308 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
from __future__ import absolute_import
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.autograd import Variable
|
8 |
+
from .pretrained_networks import vgg16, alexnet, squeezenet
|
9 |
+
import torch.nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
import torchvision.transforms.functional as TF
|
12 |
+
import cv2
|
13 |
+
|
14 |
+
from .pwcnet import Network as PWCNet
|
15 |
+
from .utils import *
|
16 |
+
|
17 |
+
def spatial_average(in_tens, keepdim=True):
|
18 |
+
return in_tens.mean([2,3],keepdim=keepdim)
|
19 |
+
|
20 |
+
def mw_spatial_average(in_tens, flow, keepdim=True):
|
21 |
+
_,_,h,w = in_tens.shape
|
22 |
+
flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear')
|
23 |
+
flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2)
|
24 |
+
flow_mag = flow_mag / torch.sum(flow_mag, dim=[1,2,3], keepdim=True)
|
25 |
+
return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim)
|
26 |
+
|
27 |
+
|
28 |
+
def mtw_spatial_average(in_tens, flow, texture, keepdim=True):
|
29 |
+
_,_,h,w = in_tens.shape
|
30 |
+
flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear')
|
31 |
+
texture = F.interpolate(texture, (h,w), align_corners=False, mode='bilinear')
|
32 |
+
flow_mag = torch.sqrt(flow[:,0:1]**2 + flow[:,1:2]**2)
|
33 |
+
flow_mag = (flow_mag - flow_mag.min()) / (flow_mag.max() - flow_mag.min()) + 1e-6
|
34 |
+
texture = (texture - texture.min()) / (texture.max() - texture.min()) + 1e-6
|
35 |
+
weight = flow_mag / texture
|
36 |
+
weight /= torch.sum(weight)
|
37 |
+
return torch.sum(in_tens*weight, dim=[2,3],keepdim=keepdim)
|
38 |
+
|
39 |
+
|
40 |
+
|
41 |
+
def m2w_spatial_average(in_tens, flow, keepdim=True):
|
42 |
+
_,_,h,w = in_tens.shape
|
43 |
+
flow = F.interpolate(flow, (h,w), align_corners=False, mode='bilinear')
|
44 |
+
flow_mag = flow[:,0:1]**2 + flow[:,1:2]**2 # B,1,H,W
|
45 |
+
flow_mag = flow_mag / torch.sum(flow_mag)
|
46 |
+
return torch.sum(in_tens*flow_mag, dim=[2,3],keepdim=keepdim)
|
47 |
+
|
48 |
+
def upsample(in_tens, out_HW=(64,64)): # assumes scale factor is same for H and W
|
49 |
+
in_H, in_W = in_tens.shape[2], in_tens.shape[3]
|
50 |
+
return nn.Upsample(size=out_HW, mode='bilinear', align_corners=False)(in_tens)
|
51 |
+
|
52 |
+
# Learned perceptual metric
|
53 |
+
class LPIPS(nn.Module):
|
54 |
+
def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False,
|
55 |
+
pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False):
|
56 |
+
# lpips - [True] means with linear calibration on top of base network
|
57 |
+
# pretrained - [True] means load linear weights
|
58 |
+
|
59 |
+
super(LPIPS, self).__init__()
|
60 |
+
if(verbose):
|
61 |
+
print('Setting up [%s] perceptual loss: trunk [%s], v[%s], spatial [%s]'%
|
62 |
+
('LPIPS' if lpips else 'baseline', net, version, 'on' if spatial else 'off'))
|
63 |
+
|
64 |
+
self.pnet_type = net
|
65 |
+
self.pnet_tune = pnet_tune
|
66 |
+
self.pnet_rand = pnet_rand
|
67 |
+
self.spatial = spatial
|
68 |
+
self.lpips = lpips # false means baseline of just averaging all layers
|
69 |
+
self.version = version
|
70 |
+
self.scaling_layer = ScalingLayer()
|
71 |
+
|
72 |
+
if(self.pnet_type in ['vgg','vgg16']):
|
73 |
+
net_type = vgg16
|
74 |
+
self.chns = [64,128,256,512,512]
|
75 |
+
elif(self.pnet_type=='alex'):
|
76 |
+
net_type = alexnet
|
77 |
+
self.chns = [64,192,384,256,256]
|
78 |
+
elif(self.pnet_type=='squeeze'):
|
79 |
+
net_type = squeezenet
|
80 |
+
self.chns = [64,128,256,384,384,512,512]
|
81 |
+
self.L = len(self.chns)
|
82 |
+
|
83 |
+
self.net = net_type(pretrained=not self.pnet_rand, requires_grad=self.pnet_tune)
|
84 |
+
|
85 |
+
if(lpips):
|
86 |
+
self.lin0 = NetLinLayer(self.chns[0], use_dropout=use_dropout)
|
87 |
+
self.lin1 = NetLinLayer(self.chns[1], use_dropout=use_dropout)
|
88 |
+
self.lin2 = NetLinLayer(self.chns[2], use_dropout=use_dropout)
|
89 |
+
self.lin3 = NetLinLayer(self.chns[3], use_dropout=use_dropout)
|
90 |
+
self.lin4 = NetLinLayer(self.chns[4], use_dropout=use_dropout)
|
91 |
+
self.lins = [self.lin0,self.lin1,self.lin2,self.lin3,self.lin4]
|
92 |
+
if(self.pnet_type=='squeeze'): # 7 layers for squeezenet
|
93 |
+
self.lin5 = NetLinLayer(self.chns[5], use_dropout=use_dropout)
|
94 |
+
self.lin6 = NetLinLayer(self.chns[6], use_dropout=use_dropout)
|
95 |
+
self.lins+=[self.lin5,self.lin6]
|
96 |
+
self.lins = nn.ModuleList(self.lins)
|
97 |
+
|
98 |
+
if(pretrained):
|
99 |
+
if(model_path is None):
|
100 |
+
import inspect
|
101 |
+
import os
|
102 |
+
model_path = os.path.abspath(os.path.join(inspect.getfile(self.__init__), '..', 'weights/v%s/%s.pth'%(version,net)))
|
103 |
+
|
104 |
+
if(verbose):
|
105 |
+
print('Loading model from: %s'%model_path)
|
106 |
+
self.load_state_dict(torch.load(model_path, map_location='cpu'), strict=False)
|
107 |
+
|
108 |
+
if(eval_mode):
|
109 |
+
self.eval()
|
110 |
+
|
111 |
+
def forward(self, in0, in1, retPerLayer=False, normalize=False):
|
112 |
+
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
|
113 |
+
in0 = 2 * in0 - 1
|
114 |
+
in1 = 2 * in1 - 1
|
115 |
+
|
116 |
+
# v0.0 - original release had a bug, where input was not scaled
|
117 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
|
118 |
+
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
119 |
+
feats0, feats1, diffs = {}, {}, {}
|
120 |
+
|
121 |
+
for kk in range(self.L):
|
122 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
123 |
+
diffs[kk] = (feats0[kk]-feats1[kk])**2
|
124 |
+
|
125 |
+
if(self.lpips):
|
126 |
+
if(self.spatial):
|
127 |
+
res = [upsample(self.lins[kk](diffs[kk]), out_HW=in0.shape[2:]) for kk in range(self.L)]
|
128 |
+
else:
|
129 |
+
res = [spatial_average(self.lins[kk](diffs[kk]), keepdim=True) for kk in range(self.L)]
|
130 |
+
else:
|
131 |
+
if(self.spatial):
|
132 |
+
res = [upsample(diffs[kk].sum(dim=1,keepdim=True), out_HW=in0.shape[2:]) for kk in range(self.L)]
|
133 |
+
else:
|
134 |
+
res = [spatial_average(diffs[kk].sum(dim=1,keepdim=True), keepdim=True) for kk in range(self.L)]
|
135 |
+
|
136 |
+
# val = res[0]
|
137 |
+
# for l in range(1,self.L):
|
138 |
+
# val += res[l]
|
139 |
+
# print(val)
|
140 |
+
|
141 |
+
# a = spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
|
142 |
+
# b = torch.max(self.lins[kk](feats0[kk]**2))
|
143 |
+
# for kk in range(self.L):
|
144 |
+
# a += spatial_average(self.lins[kk](diffs[kk]), keepdim=True)
|
145 |
+
# b = torch.max(b,torch.max(self.lins[kk](feats0[kk]**2)))
|
146 |
+
# a = a/self.L
|
147 |
+
# from IPython import embed
|
148 |
+
# embed()
|
149 |
+
# return 10*torch.log10(b/a)
|
150 |
+
|
151 |
+
# if(retPerLayer):
|
152 |
+
# return (val, res)
|
153 |
+
# else:
|
154 |
+
return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False)
|
155 |
+
|
156 |
+
|
157 |
+
class ScalingLayer(nn.Module):
|
158 |
+
def __init__(self):
|
159 |
+
super(ScalingLayer, self).__init__()
|
160 |
+
self.register_buffer('shift', torch.Tensor([-.030,-.088,-.188])[None,:,None,None])
|
161 |
+
self.register_buffer('scale', torch.Tensor([.458,.448,.450])[None,:,None,None])
|
162 |
+
|
163 |
+
def forward(self, inp):
|
164 |
+
return (inp - self.shift) / self.scale
|
165 |
+
|
166 |
+
|
167 |
+
class NetLinLayer(nn.Module):
|
168 |
+
''' A single linear layer which does a 1x1 conv '''
|
169 |
+
def __init__(self, chn_in, chn_out=1, use_dropout=False):
|
170 |
+
super(NetLinLayer, self).__init__()
|
171 |
+
|
172 |
+
layers = [nn.Dropout(),] if(use_dropout) else []
|
173 |
+
layers += [nn.Conv2d(chn_in, chn_out, 1, stride=1, padding=0, bias=False),]
|
174 |
+
self.model = nn.Sequential(*layers)
|
175 |
+
|
176 |
+
def forward(self, x):
|
177 |
+
return self.model(x)
|
178 |
+
|
179 |
+
class Dist2LogitLayer(nn.Module):
|
180 |
+
''' takes 2 distances, puts through fc layers, spits out value between [0,1] (if use_sigmoid is True) '''
|
181 |
+
def __init__(self, chn_mid=32, use_sigmoid=True):
|
182 |
+
super(Dist2LogitLayer, self).__init__()
|
183 |
+
|
184 |
+
layers = [nn.Conv2d(5, chn_mid, 1, stride=1, padding=0, bias=True),]
|
185 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
186 |
+
layers += [nn.Conv2d(chn_mid, chn_mid, 1, stride=1, padding=0, bias=True),]
|
187 |
+
layers += [nn.LeakyReLU(0.2,True),]
|
188 |
+
layers += [nn.Conv2d(chn_mid, 1, 1, stride=1, padding=0, bias=True),]
|
189 |
+
if(use_sigmoid):
|
190 |
+
layers += [nn.Sigmoid(),]
|
191 |
+
self.model = nn.Sequential(*layers)
|
192 |
+
|
193 |
+
def forward(self,d0,d1,eps=0.1):
|
194 |
+
return self.model.forward(torch.cat((d0,d1,d0-d1,d0/(d1+eps),d1/(d0+eps)),dim=1))
|
195 |
+
|
196 |
+
class BCERankingLoss(nn.Module):
|
197 |
+
def __init__(self, chn_mid=32):
|
198 |
+
super(BCERankingLoss, self).__init__()
|
199 |
+
self.net = Dist2LogitLayer(chn_mid=chn_mid)
|
200 |
+
# self.parameters = list(self.net.parameters())
|
201 |
+
self.loss = torch.nn.BCELoss()
|
202 |
+
|
203 |
+
def forward(self, d0, d1, judge):
|
204 |
+
per = (judge+1.)/2.
|
205 |
+
self.logit = self.net.forward(d0,d1)
|
206 |
+
return self.loss(self.logit, per)
|
207 |
+
|
208 |
+
# L2, DSSIM metrics
|
209 |
+
class FakeNet(nn.Module):
|
210 |
+
def __init__(self, use_gpu=True, colorspace='Lab'):
|
211 |
+
super(FakeNet, self).__init__()
|
212 |
+
self.use_gpu = use_gpu
|
213 |
+
self.colorspace = colorspace
|
214 |
+
|
215 |
+
class L2(FakeNet):
|
216 |
+
def forward(self, in0, in1, retPerLayer=None):
|
217 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
218 |
+
|
219 |
+
if(self.colorspace=='RGB'):
|
220 |
+
(N,C,X,Y) = in0.size()
|
221 |
+
value = torch.mean(torch.mean(torch.mean((in0-in1)**2,dim=1).view(N,1,X,Y),dim=2).view(N,1,1,Y),dim=3).view(N)
|
222 |
+
return value
|
223 |
+
elif(self.colorspace=='Lab'):
|
224 |
+
value = l2(tensor2np(tensor2tensorlab(in0.data,to_norm=False)),
|
225 |
+
tensor2np(tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
226 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
227 |
+
if(self.use_gpu):
|
228 |
+
ret_var = ret_var.cuda()
|
229 |
+
return ret_var
|
230 |
+
|
231 |
+
class DSSIM(FakeNet):
|
232 |
+
|
233 |
+
def forward(self, in0, in1, retPerLayer=None):
|
234 |
+
assert(in0.size()[0]==1) # currently only supports batchSize 1
|
235 |
+
|
236 |
+
if(self.colorspace=='RGB'):
|
237 |
+
value = dssim(1.*tensor2im(in0.data), 1.*tensor2im(in1.data), range=255.).astype('float')
|
238 |
+
elif(self.colorspace=='Lab'):
|
239 |
+
value = dssim(tensor2np(tensor2tensorlab(in0.data,to_norm=False)),
|
240 |
+
tensor2np(tensor2tensorlab(in1.data,to_norm=False)), range=100.).astype('float')
|
241 |
+
ret_var = Variable( torch.Tensor((value,) ) )
|
242 |
+
if(self.use_gpu):
|
243 |
+
ret_var = ret_var.cuda()
|
244 |
+
return ret_var
|
245 |
+
|
246 |
+
def print_network(net):
|
247 |
+
num_params = 0
|
248 |
+
for param in net.parameters():
|
249 |
+
num_params += param.numel()
|
250 |
+
print('Network',net)
|
251 |
+
print('Total number of parameters: %d' % num_params)
|
252 |
+
|
253 |
+
|
254 |
+
class FloLPIPS(LPIPS):
|
255 |
+
def __init__(self, pretrained=True, net='alex', version='0.1', lpips=True, spatial=False, pnet_rand=False, pnet_tune=False, use_dropout=True, model_path=None, eval_mode=True, verbose=False):
|
256 |
+
super(FloLPIPS, self).__init__(pretrained, net, version, lpips, spatial, pnet_rand, pnet_tune, use_dropout, model_path, eval_mode, verbose)
|
257 |
+
|
258 |
+
def forward(self, in0, in1, flow, retPerLayer=False, normalize=False):
|
259 |
+
if normalize: # turn on this flag if input is [0,1] so it can be adjusted to [-1, +1]
|
260 |
+
in0 = 2 * in0 - 1
|
261 |
+
in1 = 2 * in1 - 1
|
262 |
+
|
263 |
+
in0_input, in1_input = (self.scaling_layer(in0), self.scaling_layer(in1)) if self.version=='0.1' else (in0, in1)
|
264 |
+
outs0, outs1 = self.net.forward(in0_input), self.net.forward(in1_input)
|
265 |
+
feats0, feats1, diffs = {}, {}, {}
|
266 |
+
|
267 |
+
for kk in range(self.L):
|
268 |
+
feats0[kk], feats1[kk] = normalize_tensor(outs0[kk]), normalize_tensor(outs1[kk])
|
269 |
+
diffs[kk] = (feats0[kk]-feats1[kk])**2
|
270 |
+
|
271 |
+
res = [mw_spatial_average(self.lins[kk](diffs[kk]), flow, keepdim=True) for kk in range(self.L)]
|
272 |
+
|
273 |
+
return torch.sum(torch.cat(res, 1), dim=(1,2,3), keepdims=False)
|
274 |
+
|
275 |
+
|
276 |
+
|
277 |
+
|
278 |
+
|
279 |
+
class Flolpips(nn.Module):
|
280 |
+
def __init__(self):
|
281 |
+
super(Flolpips, self).__init__()
|
282 |
+
self.loss_fn = FloLPIPS(net='alex',version='0.1')
|
283 |
+
self.flownet = PWCNet()
|
284 |
+
|
285 |
+
@torch.no_grad()
|
286 |
+
def forward(self, I0, I1, frame_dis, frame_ref):
|
287 |
+
"""
|
288 |
+
args:
|
289 |
+
I0: first frame of the triplet, shape: [B, C, H, W]
|
290 |
+
I1: third frame of the triplet, shape: [B, C, H, W]
|
291 |
+
frame_dis: prediction of the intermediate frame, shape: [B, C, H, W]
|
292 |
+
frame_ref: ground-truth of the intermediate frame, shape: [B, C, H, W]
|
293 |
+
"""
|
294 |
+
assert I0.size() == I1.size() == frame_dis.size() == frame_ref.size(), \
|
295 |
+
"the 4 input tensors should have same size"
|
296 |
+
|
297 |
+
flow_ref = self.flownet(frame_ref, I0)
|
298 |
+
flow_dis = self.flownet(frame_dis, I0)
|
299 |
+
flow_diff = flow_ref - flow_dis
|
300 |
+
flolpips_wrt_I0 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True)
|
301 |
+
|
302 |
+
flow_ref = self.flownet(frame_ref, I1)
|
303 |
+
flow_dis = self.flownet(frame_dis, I1)
|
304 |
+
flow_diff = flow_ref - flow_dis
|
305 |
+
flolpips_wrt_I1 = self.loss_fn.forward(frame_ref, frame_dis, flow_diff, normalize=True)
|
306 |
+
|
307 |
+
flolpips = (flolpips_wrt_I0 + flolpips_wrt_I1) / 2
|
308 |
+
return flolpips
|
opensora/eval/flolpips/pretrained_networks.py
ADDED
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import namedtuple
|
2 |
+
import torch
|
3 |
+
from torchvision import models as tv
|
4 |
+
|
5 |
+
class squeezenet(torch.nn.Module):
|
6 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
7 |
+
super(squeezenet, self).__init__()
|
8 |
+
pretrained_features = tv.squeezenet1_1(pretrained=pretrained).features
|
9 |
+
self.slice1 = torch.nn.Sequential()
|
10 |
+
self.slice2 = torch.nn.Sequential()
|
11 |
+
self.slice3 = torch.nn.Sequential()
|
12 |
+
self.slice4 = torch.nn.Sequential()
|
13 |
+
self.slice5 = torch.nn.Sequential()
|
14 |
+
self.slice6 = torch.nn.Sequential()
|
15 |
+
self.slice7 = torch.nn.Sequential()
|
16 |
+
self.N_slices = 7
|
17 |
+
for x in range(2):
|
18 |
+
self.slice1.add_module(str(x), pretrained_features[x])
|
19 |
+
for x in range(2,5):
|
20 |
+
self.slice2.add_module(str(x), pretrained_features[x])
|
21 |
+
for x in range(5, 8):
|
22 |
+
self.slice3.add_module(str(x), pretrained_features[x])
|
23 |
+
for x in range(8, 10):
|
24 |
+
self.slice4.add_module(str(x), pretrained_features[x])
|
25 |
+
for x in range(10, 11):
|
26 |
+
self.slice5.add_module(str(x), pretrained_features[x])
|
27 |
+
for x in range(11, 12):
|
28 |
+
self.slice6.add_module(str(x), pretrained_features[x])
|
29 |
+
for x in range(12, 13):
|
30 |
+
self.slice7.add_module(str(x), pretrained_features[x])
|
31 |
+
if not requires_grad:
|
32 |
+
for param in self.parameters():
|
33 |
+
param.requires_grad = False
|
34 |
+
|
35 |
+
def forward(self, X):
|
36 |
+
h = self.slice1(X)
|
37 |
+
h_relu1 = h
|
38 |
+
h = self.slice2(h)
|
39 |
+
h_relu2 = h
|
40 |
+
h = self.slice3(h)
|
41 |
+
h_relu3 = h
|
42 |
+
h = self.slice4(h)
|
43 |
+
h_relu4 = h
|
44 |
+
h = self.slice5(h)
|
45 |
+
h_relu5 = h
|
46 |
+
h = self.slice6(h)
|
47 |
+
h_relu6 = h
|
48 |
+
h = self.slice7(h)
|
49 |
+
h_relu7 = h
|
50 |
+
vgg_outputs = namedtuple("SqueezeOutputs", ['relu1','relu2','relu3','relu4','relu5','relu6','relu7'])
|
51 |
+
out = vgg_outputs(h_relu1,h_relu2,h_relu3,h_relu4,h_relu5,h_relu6,h_relu7)
|
52 |
+
|
53 |
+
return out
|
54 |
+
|
55 |
+
|
56 |
+
class alexnet(torch.nn.Module):
|
57 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
58 |
+
super(alexnet, self).__init__()
|
59 |
+
alexnet_pretrained_features = tv.alexnet(pretrained=pretrained).features
|
60 |
+
self.slice1 = torch.nn.Sequential()
|
61 |
+
self.slice2 = torch.nn.Sequential()
|
62 |
+
self.slice3 = torch.nn.Sequential()
|
63 |
+
self.slice4 = torch.nn.Sequential()
|
64 |
+
self.slice5 = torch.nn.Sequential()
|
65 |
+
self.N_slices = 5
|
66 |
+
for x in range(2):
|
67 |
+
self.slice1.add_module(str(x), alexnet_pretrained_features[x])
|
68 |
+
for x in range(2, 5):
|
69 |
+
self.slice2.add_module(str(x), alexnet_pretrained_features[x])
|
70 |
+
for x in range(5, 8):
|
71 |
+
self.slice3.add_module(str(x), alexnet_pretrained_features[x])
|
72 |
+
for x in range(8, 10):
|
73 |
+
self.slice4.add_module(str(x), alexnet_pretrained_features[x])
|
74 |
+
for x in range(10, 12):
|
75 |
+
self.slice5.add_module(str(x), alexnet_pretrained_features[x])
|
76 |
+
if not requires_grad:
|
77 |
+
for param in self.parameters():
|
78 |
+
param.requires_grad = False
|
79 |
+
|
80 |
+
def forward(self, X):
|
81 |
+
h = self.slice1(X)
|
82 |
+
h_relu1 = h
|
83 |
+
h = self.slice2(h)
|
84 |
+
h_relu2 = h
|
85 |
+
h = self.slice3(h)
|
86 |
+
h_relu3 = h
|
87 |
+
h = self.slice4(h)
|
88 |
+
h_relu4 = h
|
89 |
+
h = self.slice5(h)
|
90 |
+
h_relu5 = h
|
91 |
+
alexnet_outputs = namedtuple("AlexnetOutputs", ['relu1', 'relu2', 'relu3', 'relu4', 'relu5'])
|
92 |
+
out = alexnet_outputs(h_relu1, h_relu2, h_relu3, h_relu4, h_relu5)
|
93 |
+
|
94 |
+
return out
|
95 |
+
|
96 |
+
class vgg16(torch.nn.Module):
|
97 |
+
def __init__(self, requires_grad=False, pretrained=True):
|
98 |
+
super(vgg16, self).__init__()
|
99 |
+
vgg_pretrained_features = tv.vgg16(pretrained=pretrained).features
|
100 |
+
self.slice1 = torch.nn.Sequential()
|
101 |
+
self.slice2 = torch.nn.Sequential()
|
102 |
+
self.slice3 = torch.nn.Sequential()
|
103 |
+
self.slice4 = torch.nn.Sequential()
|
104 |
+
self.slice5 = torch.nn.Sequential()
|
105 |
+
self.N_slices = 5
|
106 |
+
for x in range(4):
|
107 |
+
self.slice1.add_module(str(x), vgg_pretrained_features[x])
|
108 |
+
for x in range(4, 9):
|
109 |
+
self.slice2.add_module(str(x), vgg_pretrained_features[x])
|
110 |
+
for x in range(9, 16):
|
111 |
+
self.slice3.add_module(str(x), vgg_pretrained_features[x])
|
112 |
+
for x in range(16, 23):
|
113 |
+
self.slice4.add_module(str(x), vgg_pretrained_features[x])
|
114 |
+
for x in range(23, 30):
|
115 |
+
self.slice5.add_module(str(x), vgg_pretrained_features[x])
|
116 |
+
if not requires_grad:
|
117 |
+
for param in self.parameters():
|
118 |
+
param.requires_grad = False
|
119 |
+
|
120 |
+
def forward(self, X):
|
121 |
+
h = self.slice1(X)
|
122 |
+
h_relu1_2 = h
|
123 |
+
h = self.slice2(h)
|
124 |
+
h_relu2_2 = h
|
125 |
+
h = self.slice3(h)
|
126 |
+
h_relu3_3 = h
|
127 |
+
h = self.slice4(h)
|
128 |
+
h_relu4_3 = h
|
129 |
+
h = self.slice5(h)
|
130 |
+
h_relu5_3 = h
|
131 |
+
vgg_outputs = namedtuple("VggOutputs", ['relu1_2', 'relu2_2', 'relu3_3', 'relu4_3', 'relu5_3'])
|
132 |
+
out = vgg_outputs(h_relu1_2, h_relu2_2, h_relu3_3, h_relu4_3, h_relu5_3)
|
133 |
+
|
134 |
+
return out
|
135 |
+
|
136 |
+
|
137 |
+
|
138 |
+
class resnet(torch.nn.Module):
|
139 |
+
def __init__(self, requires_grad=False, pretrained=True, num=18):
|
140 |
+
super(resnet, self).__init__()
|
141 |
+
if(num==18):
|
142 |
+
self.net = tv.resnet18(pretrained=pretrained)
|
143 |
+
elif(num==34):
|
144 |
+
self.net = tv.resnet34(pretrained=pretrained)
|
145 |
+
elif(num==50):
|
146 |
+
self.net = tv.resnet50(pretrained=pretrained)
|
147 |
+
elif(num==101):
|
148 |
+
self.net = tv.resnet101(pretrained=pretrained)
|
149 |
+
elif(num==152):
|
150 |
+
self.net = tv.resnet152(pretrained=pretrained)
|
151 |
+
self.N_slices = 5
|
152 |
+
|
153 |
+
self.conv1 = self.net.conv1
|
154 |
+
self.bn1 = self.net.bn1
|
155 |
+
self.relu = self.net.relu
|
156 |
+
self.maxpool = self.net.maxpool
|
157 |
+
self.layer1 = self.net.layer1
|
158 |
+
self.layer2 = self.net.layer2
|
159 |
+
self.layer3 = self.net.layer3
|
160 |
+
self.layer4 = self.net.layer4
|
161 |
+
|
162 |
+
def forward(self, X):
|
163 |
+
h = self.conv1(X)
|
164 |
+
h = self.bn1(h)
|
165 |
+
h = self.relu(h)
|
166 |
+
h_relu1 = h
|
167 |
+
h = self.maxpool(h)
|
168 |
+
h = self.layer1(h)
|
169 |
+
h_conv2 = h
|
170 |
+
h = self.layer2(h)
|
171 |
+
h_conv3 = h
|
172 |
+
h = self.layer3(h)
|
173 |
+
h_conv4 = h
|
174 |
+
h = self.layer4(h)
|
175 |
+
h_conv5 = h
|
176 |
+
|
177 |
+
outputs = namedtuple("Outputs", ['relu1','conv2','conv3','conv4','conv5'])
|
178 |
+
out = outputs(h_relu1, h_conv2, h_conv3, h_conv4, h_conv5)
|
179 |
+
|
180 |
+
return out
|
opensora/eval/flolpips/pwcnet.py
ADDED
@@ -0,0 +1,344 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import torch
|
4 |
+
|
5 |
+
import getopt
|
6 |
+
import math
|
7 |
+
import numpy
|
8 |
+
import os
|
9 |
+
import PIL
|
10 |
+
import PIL.Image
|
11 |
+
import sys
|
12 |
+
|
13 |
+
# try:
|
14 |
+
from .correlation import correlation # the custom cost volume layer
|
15 |
+
# except:
|
16 |
+
# sys.path.insert(0, './correlation'); import correlation # you should consider upgrading python
|
17 |
+
# end
|
18 |
+
|
19 |
+
##########################################################
|
20 |
+
|
21 |
+
# assert(int(str('').join(torch.__version__.split('.')[0:2])) >= 13) # requires at least pytorch version 1.3.0
|
22 |
+
|
23 |
+
# torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance
|
24 |
+
|
25 |
+
# torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
|
26 |
+
|
27 |
+
# ##########################################################
|
28 |
+
|
29 |
+
# arguments_strModel = 'default' # 'default', or 'chairs-things'
|
30 |
+
# arguments_strFirst = './images/first.png'
|
31 |
+
# arguments_strSecond = './images/second.png'
|
32 |
+
# arguments_strOut = './out.flo'
|
33 |
+
|
34 |
+
# for strOption, strArgument in getopt.getopt(sys.argv[1:], '', [ strParameter[2:] + '=' for strParameter in sys.argv[1::2] ])[0]:
|
35 |
+
# if strOption == '--model' and strArgument != '': arguments_strModel = strArgument # which model to use
|
36 |
+
# if strOption == '--first' and strArgument != '': arguments_strFirst = strArgument # path to the first frame
|
37 |
+
# if strOption == '--second' and strArgument != '': arguments_strSecond = strArgument # path to the second frame
|
38 |
+
# if strOption == '--out' and strArgument != '': arguments_strOut = strArgument # path to where the output should be stored
|
39 |
+
# end
|
40 |
+
|
41 |
+
##########################################################
|
42 |
+
|
43 |
+
|
44 |
+
|
45 |
+
def backwarp(tenInput, tenFlow):
|
46 |
+
backwarp_tenGrid = {}
|
47 |
+
backwarp_tenPartial = {}
|
48 |
+
if str(tenFlow.shape) not in backwarp_tenGrid:
|
49 |
+
tenHor = torch.linspace(-1.0 + (1.0 / tenFlow.shape[3]), 1.0 - (1.0 / tenFlow.shape[3]), tenFlow.shape[3]).view(1, 1, 1, -1).expand(-1, -1, tenFlow.shape[2], -1)
|
50 |
+
tenVer = torch.linspace(-1.0 + (1.0 / tenFlow.shape[2]), 1.0 - (1.0 / tenFlow.shape[2]), tenFlow.shape[2]).view(1, 1, -1, 1).expand(-1, -1, -1, tenFlow.shape[3])
|
51 |
+
|
52 |
+
backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([ tenHor, tenVer ], 1).cuda()
|
53 |
+
# end
|
54 |
+
|
55 |
+
if str(tenFlow.shape) not in backwarp_tenPartial:
|
56 |
+
backwarp_tenPartial[str(tenFlow.shape)] = tenFlow.new_ones([ tenFlow.shape[0], 1, tenFlow.shape[2], tenFlow.shape[3] ])
|
57 |
+
# end
|
58 |
+
|
59 |
+
tenFlow = torch.cat([ tenFlow[:, 0:1, :, :] / ((tenInput.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenInput.shape[2] - 1.0) / 2.0) ], 1)
|
60 |
+
tenInput = torch.cat([ tenInput, backwarp_tenPartial[str(tenFlow.shape)] ], 1)
|
61 |
+
|
62 |
+
tenOutput = torch.nn.functional.grid_sample(input=tenInput, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=False)
|
63 |
+
|
64 |
+
tenMask = tenOutput[:, -1:, :, :]; tenMask[tenMask > 0.999] = 1.0; tenMask[tenMask < 1.0] = 0.0
|
65 |
+
|
66 |
+
return tenOutput[:, :-1, :, :] * tenMask
|
67 |
+
# end
|
68 |
+
|
69 |
+
##########################################################
|
70 |
+
|
71 |
+
class Network(torch.nn.Module):
|
72 |
+
def __init__(self):
|
73 |
+
super(Network, self).__init__()
|
74 |
+
|
75 |
+
class Extractor(torch.nn.Module):
|
76 |
+
def __init__(self):
|
77 |
+
super(Extractor, self).__init__()
|
78 |
+
|
79 |
+
self.netOne = torch.nn.Sequential(
|
80 |
+
torch.nn.Conv2d(in_channels=3, out_channels=16, kernel_size=3, stride=2, padding=1),
|
81 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
82 |
+
torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
|
83 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
84 |
+
torch.nn.Conv2d(in_channels=16, out_channels=16, kernel_size=3, stride=1, padding=1),
|
85 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
86 |
+
)
|
87 |
+
|
88 |
+
self.netTwo = torch.nn.Sequential(
|
89 |
+
torch.nn.Conv2d(in_channels=16, out_channels=32, kernel_size=3, stride=2, padding=1),
|
90 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
91 |
+
torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
|
92 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
93 |
+
torch.nn.Conv2d(in_channels=32, out_channels=32, kernel_size=3, stride=1, padding=1),
|
94 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
95 |
+
)
|
96 |
+
|
97 |
+
self.netThr = torch.nn.Sequential(
|
98 |
+
torch.nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, stride=2, padding=1),
|
99 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
100 |
+
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
101 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
102 |
+
torch.nn.Conv2d(in_channels=64, out_channels=64, kernel_size=3, stride=1, padding=1),
|
103 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
104 |
+
)
|
105 |
+
|
106 |
+
self.netFou = torch.nn.Sequential(
|
107 |
+
torch.nn.Conv2d(in_channels=64, out_channels=96, kernel_size=3, stride=2, padding=1),
|
108 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
109 |
+
torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
|
110 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
111 |
+
torch.nn.Conv2d(in_channels=96, out_channels=96, kernel_size=3, stride=1, padding=1),
|
112 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
113 |
+
)
|
114 |
+
|
115 |
+
self.netFiv = torch.nn.Sequential(
|
116 |
+
torch.nn.Conv2d(in_channels=96, out_channels=128, kernel_size=3, stride=2, padding=1),
|
117 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
118 |
+
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
119 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
120 |
+
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
121 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
122 |
+
)
|
123 |
+
|
124 |
+
self.netSix = torch.nn.Sequential(
|
125 |
+
torch.nn.Conv2d(in_channels=128, out_channels=196, kernel_size=3, stride=2, padding=1),
|
126 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
127 |
+
torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
|
128 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
129 |
+
torch.nn.Conv2d(in_channels=196, out_channels=196, kernel_size=3, stride=1, padding=1),
|
130 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
131 |
+
)
|
132 |
+
# end
|
133 |
+
|
134 |
+
def forward(self, tenInput):
|
135 |
+
tenOne = self.netOne(tenInput)
|
136 |
+
tenTwo = self.netTwo(tenOne)
|
137 |
+
tenThr = self.netThr(tenTwo)
|
138 |
+
tenFou = self.netFou(tenThr)
|
139 |
+
tenFiv = self.netFiv(tenFou)
|
140 |
+
tenSix = self.netSix(tenFiv)
|
141 |
+
|
142 |
+
return [ tenOne, tenTwo, tenThr, tenFou, tenFiv, tenSix ]
|
143 |
+
# end
|
144 |
+
# end
|
145 |
+
|
146 |
+
class Decoder(torch.nn.Module):
|
147 |
+
def __init__(self, intLevel):
|
148 |
+
super(Decoder, self).__init__()
|
149 |
+
|
150 |
+
intPrevious = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 1]
|
151 |
+
intCurrent = [ None, None, 81 + 32 + 2 + 2, 81 + 64 + 2 + 2, 81 + 96 + 2 + 2, 81 + 128 + 2 + 2, 81, None ][intLevel + 0]
|
152 |
+
|
153 |
+
if intLevel < 6: self.netUpflow = torch.nn.ConvTranspose2d(in_channels=2, out_channels=2, kernel_size=4, stride=2, padding=1)
|
154 |
+
if intLevel < 6: self.netUpfeat = torch.nn.ConvTranspose2d(in_channels=intPrevious + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=4, stride=2, padding=1)
|
155 |
+
if intLevel < 6: self.fltBackwarp = [ None, None, None, 5.0, 2.5, 1.25, 0.625, None ][intLevel + 1]
|
156 |
+
|
157 |
+
self.netOne = torch.nn.Sequential(
|
158 |
+
torch.nn.Conv2d(in_channels=intCurrent, out_channels=128, kernel_size=3, stride=1, padding=1),
|
159 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
160 |
+
)
|
161 |
+
|
162 |
+
self.netTwo = torch.nn.Sequential(
|
163 |
+
torch.nn.Conv2d(in_channels=intCurrent + 128, out_channels=128, kernel_size=3, stride=1, padding=1),
|
164 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
165 |
+
)
|
166 |
+
|
167 |
+
self.netThr = torch.nn.Sequential(
|
168 |
+
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128, out_channels=96, kernel_size=3, stride=1, padding=1),
|
169 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
170 |
+
)
|
171 |
+
|
172 |
+
self.netFou = torch.nn.Sequential(
|
173 |
+
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96, out_channels=64, kernel_size=3, stride=1, padding=1),
|
174 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
175 |
+
)
|
176 |
+
|
177 |
+
self.netFiv = torch.nn.Sequential(
|
178 |
+
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64, out_channels=32, kernel_size=3, stride=1, padding=1),
|
179 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1)
|
180 |
+
)
|
181 |
+
|
182 |
+
self.netSix = torch.nn.Sequential(
|
183 |
+
torch.nn.Conv2d(in_channels=intCurrent + 128 + 128 + 96 + 64 + 32, out_channels=2, kernel_size=3, stride=1, padding=1)
|
184 |
+
)
|
185 |
+
# end
|
186 |
+
|
187 |
+
def forward(self, tenFirst, tenSecond, objPrevious):
|
188 |
+
tenFlow = None
|
189 |
+
tenFeat = None
|
190 |
+
|
191 |
+
if objPrevious is None:
|
192 |
+
tenFlow = None
|
193 |
+
tenFeat = None
|
194 |
+
|
195 |
+
tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=tenSecond), negative_slope=0.1, inplace=False)
|
196 |
+
|
197 |
+
tenFeat = torch.cat([ tenVolume ], 1)
|
198 |
+
|
199 |
+
elif objPrevious is not None:
|
200 |
+
tenFlow = self.netUpflow(objPrevious['tenFlow'])
|
201 |
+
tenFeat = self.netUpfeat(objPrevious['tenFeat'])
|
202 |
+
|
203 |
+
tenVolume = torch.nn.functional.leaky_relu(input=correlation.FunctionCorrelation(tenFirst=tenFirst, tenSecond=backwarp(tenInput=tenSecond, tenFlow=tenFlow * self.fltBackwarp)), negative_slope=0.1, inplace=False)
|
204 |
+
|
205 |
+
tenFeat = torch.cat([ tenVolume, tenFirst, tenFlow, tenFeat ], 1)
|
206 |
+
|
207 |
+
# end
|
208 |
+
|
209 |
+
tenFeat = torch.cat([ self.netOne(tenFeat), tenFeat ], 1)
|
210 |
+
tenFeat = torch.cat([ self.netTwo(tenFeat), tenFeat ], 1)
|
211 |
+
tenFeat = torch.cat([ self.netThr(tenFeat), tenFeat ], 1)
|
212 |
+
tenFeat = torch.cat([ self.netFou(tenFeat), tenFeat ], 1)
|
213 |
+
tenFeat = torch.cat([ self.netFiv(tenFeat), tenFeat ], 1)
|
214 |
+
|
215 |
+
tenFlow = self.netSix(tenFeat)
|
216 |
+
|
217 |
+
return {
|
218 |
+
'tenFlow': tenFlow,
|
219 |
+
'tenFeat': tenFeat
|
220 |
+
}
|
221 |
+
# end
|
222 |
+
# end
|
223 |
+
|
224 |
+
class Refiner(torch.nn.Module):
|
225 |
+
def __init__(self):
|
226 |
+
super(Refiner, self).__init__()
|
227 |
+
|
228 |
+
self.netMain = torch.nn.Sequential(
|
229 |
+
torch.nn.Conv2d(in_channels=81 + 32 + 2 + 2 + 128 + 128 + 96 + 64 + 32, out_channels=128, kernel_size=3, stride=1, padding=1, dilation=1),
|
230 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
231 |
+
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=2, dilation=2),
|
232 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
233 |
+
torch.nn.Conv2d(in_channels=128, out_channels=128, kernel_size=3, stride=1, padding=4, dilation=4),
|
234 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
235 |
+
torch.nn.Conv2d(in_channels=128, out_channels=96, kernel_size=3, stride=1, padding=8, dilation=8),
|
236 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
237 |
+
torch.nn.Conv2d(in_channels=96, out_channels=64, kernel_size=3, stride=1, padding=16, dilation=16),
|
238 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
239 |
+
torch.nn.Conv2d(in_channels=64, out_channels=32, kernel_size=3, stride=1, padding=1, dilation=1),
|
240 |
+
torch.nn.LeakyReLU(inplace=False, negative_slope=0.1),
|
241 |
+
torch.nn.Conv2d(in_channels=32, out_channels=2, kernel_size=3, stride=1, padding=1, dilation=1)
|
242 |
+
)
|
243 |
+
# end
|
244 |
+
|
245 |
+
def forward(self, tenInput):
|
246 |
+
return self.netMain(tenInput)
|
247 |
+
# end
|
248 |
+
# end
|
249 |
+
|
250 |
+
self.netExtractor = Extractor()
|
251 |
+
|
252 |
+
self.netTwo = Decoder(2)
|
253 |
+
self.netThr = Decoder(3)
|
254 |
+
self.netFou = Decoder(4)
|
255 |
+
self.netFiv = Decoder(5)
|
256 |
+
self.netSix = Decoder(6)
|
257 |
+
|
258 |
+
self.netRefiner = Refiner()
|
259 |
+
|
260 |
+
self.load_state_dict({ strKey.replace('module', 'net'): tenWeight for strKey, tenWeight in torch.hub.load_state_dict_from_url(url='http://content.sniklaus.com/github/pytorch-pwc/network-' + 'default' + '.pytorch').items() })
|
261 |
+
# end
|
262 |
+
|
263 |
+
def forward(self, tenFirst, tenSecond):
|
264 |
+
intWidth = tenFirst.shape[3]
|
265 |
+
intHeight = tenFirst.shape[2]
|
266 |
+
|
267 |
+
intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0))
|
268 |
+
intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0))
|
269 |
+
|
270 |
+
tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
|
271 |
+
tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
|
272 |
+
|
273 |
+
tenFirst = self.netExtractor(tenPreprocessedFirst)
|
274 |
+
tenSecond = self.netExtractor(tenPreprocessedSecond)
|
275 |
+
|
276 |
+
|
277 |
+
objEstimate = self.netSix(tenFirst[-1], tenSecond[-1], None)
|
278 |
+
objEstimate = self.netFiv(tenFirst[-2], tenSecond[-2], objEstimate)
|
279 |
+
objEstimate = self.netFou(tenFirst[-3], tenSecond[-3], objEstimate)
|
280 |
+
objEstimate = self.netThr(tenFirst[-4], tenSecond[-4], objEstimate)
|
281 |
+
objEstimate = self.netTwo(tenFirst[-5], tenSecond[-5], objEstimate)
|
282 |
+
|
283 |
+
tenFlow = objEstimate['tenFlow'] + self.netRefiner(objEstimate['tenFeat'])
|
284 |
+
tenFlow = 20.0 * torch.nn.functional.interpolate(input=tenFlow, size=(intHeight, intWidth), mode='bilinear', align_corners=False)
|
285 |
+
tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
|
286 |
+
tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
|
287 |
+
|
288 |
+
return tenFlow
|
289 |
+
# end
|
290 |
+
# end
|
291 |
+
|
292 |
+
netNetwork = None
|
293 |
+
|
294 |
+
##########################################################
|
295 |
+
|
296 |
+
def estimate(tenFirst, tenSecond):
|
297 |
+
global netNetwork
|
298 |
+
|
299 |
+
if netNetwork is None:
|
300 |
+
netNetwork = Network().cuda().eval()
|
301 |
+
# end
|
302 |
+
|
303 |
+
assert(tenFirst.shape[1] == tenSecond.shape[1])
|
304 |
+
assert(tenFirst.shape[2] == tenSecond.shape[2])
|
305 |
+
|
306 |
+
intWidth = tenFirst.shape[2]
|
307 |
+
intHeight = tenFirst.shape[1]
|
308 |
+
|
309 |
+
assert(intWidth == 1024) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
|
310 |
+
assert(intHeight == 436) # remember that there is no guarantee for correctness, comment this line out if you acknowledge this and want to continue
|
311 |
+
|
312 |
+
tenPreprocessedFirst = tenFirst.cuda().view(1, 3, intHeight, intWidth)
|
313 |
+
tenPreprocessedSecond = tenSecond.cuda().view(1, 3, intHeight, intWidth)
|
314 |
+
|
315 |
+
intPreprocessedWidth = int(math.floor(math.ceil(intWidth / 64.0) * 64.0))
|
316 |
+
intPreprocessedHeight = int(math.floor(math.ceil(intHeight / 64.0) * 64.0))
|
317 |
+
|
318 |
+
tenPreprocessedFirst = torch.nn.functional.interpolate(input=tenPreprocessedFirst, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
|
319 |
+
tenPreprocessedSecond = torch.nn.functional.interpolate(input=tenPreprocessedSecond, size=(intPreprocessedHeight, intPreprocessedWidth), mode='bilinear', align_corners=False)
|
320 |
+
|
321 |
+
tenFlow = 20.0 * torch.nn.functional.interpolate(input=netNetwork(tenPreprocessedFirst, tenPreprocessedSecond), size=(intHeight, intWidth), mode='bilinear', align_corners=False)
|
322 |
+
|
323 |
+
tenFlow[:, 0, :, :] *= float(intWidth) / float(intPreprocessedWidth)
|
324 |
+
tenFlow[:, 1, :, :] *= float(intHeight) / float(intPreprocessedHeight)
|
325 |
+
|
326 |
+
return tenFlow[0, :, :, :].cpu()
|
327 |
+
# end
|
328 |
+
|
329 |
+
##########################################################
|
330 |
+
|
331 |
+
# if __name__ == '__main__':
|
332 |
+
# tenFirst = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strFirst))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
|
333 |
+
# tenSecond = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(arguments_strSecond))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
|
334 |
+
|
335 |
+
# tenOutput = estimate(tenFirst, tenSecond)
|
336 |
+
|
337 |
+
# objOutput = open(arguments_strOut, 'wb')
|
338 |
+
|
339 |
+
# numpy.array([ 80, 73, 69, 72 ], numpy.uint8).tofile(objOutput)
|
340 |
+
# numpy.array([ tenOutput.shape[2], tenOutput.shape[1] ], numpy.int32).tofile(objOutput)
|
341 |
+
# numpy.array(tenOutput.numpy().transpose(1, 2, 0), numpy.float32).tofile(objOutput)
|
342 |
+
|
343 |
+
# objOutput.close()
|
344 |
+
# end
|
opensora/eval/flolpips/utils.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import cv2
|
3 |
+
import torch
|
4 |
+
|
5 |
+
|
6 |
+
def normalize_tensor(in_feat,eps=1e-10):
|
7 |
+
norm_factor = torch.sqrt(torch.sum(in_feat**2,dim=1,keepdim=True))
|
8 |
+
return in_feat/(norm_factor+eps)
|
9 |
+
|
10 |
+
def l2(p0, p1, range=255.):
|
11 |
+
return .5*np.mean((p0 / range - p1 / range)**2)
|
12 |
+
|
13 |
+
def dssim(p0, p1, range=255.):
|
14 |
+
from skimage.measure import compare_ssim
|
15 |
+
return (1 - compare_ssim(p0, p1, data_range=range, multichannel=True)) / 2.
|
16 |
+
|
17 |
+
def tensor2im(image_tensor, imtype=np.uint8, cent=1., factor=255./2.):
|
18 |
+
image_numpy = image_tensor[0].cpu().float().numpy()
|
19 |
+
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + cent) * factor
|
20 |
+
return image_numpy.astype(imtype)
|
21 |
+
|
22 |
+
def tensor2np(tensor_obj):
|
23 |
+
# change dimension of a tensor object into a numpy array
|
24 |
+
return tensor_obj[0].cpu().float().numpy().transpose((1,2,0))
|
25 |
+
|
26 |
+
def np2tensor(np_obj):
|
27 |
+
# change dimenion of np array into tensor array
|
28 |
+
return torch.Tensor(np_obj[:, :, :, np.newaxis].transpose((3, 2, 0, 1)))
|
29 |
+
|
30 |
+
def tensor2tensorlab(image_tensor,to_norm=True,mc_only=False):
|
31 |
+
# image tensor to lab tensor
|
32 |
+
from skimage import color
|
33 |
+
|
34 |
+
img = tensor2im(image_tensor)
|
35 |
+
img_lab = color.rgb2lab(img)
|
36 |
+
if(mc_only):
|
37 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
38 |
+
if(to_norm and not mc_only):
|
39 |
+
img_lab[:,:,0] = img_lab[:,:,0]-50
|
40 |
+
img_lab = img_lab/100.
|
41 |
+
|
42 |
+
return np2tensor(img_lab)
|
43 |
+
|
44 |
+
def read_frame_yuv2rgb(stream, width, height, iFrame, bit_depth, pix_fmt='420'):
|
45 |
+
if pix_fmt == '420':
|
46 |
+
multiplier = 1
|
47 |
+
uv_factor = 2
|
48 |
+
elif pix_fmt == '444':
|
49 |
+
multiplier = 2
|
50 |
+
uv_factor = 1
|
51 |
+
else:
|
52 |
+
print('Pixel format {} is not supported'.format(pix_fmt))
|
53 |
+
return
|
54 |
+
|
55 |
+
if bit_depth == 8:
|
56 |
+
datatype = np.uint8
|
57 |
+
stream.seek(iFrame*1.5*width*height*multiplier)
|
58 |
+
Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width))
|
59 |
+
|
60 |
+
# read chroma samples and upsample since original is 4:2:0 sampling
|
61 |
+
U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
|
62 |
+
reshape((height//uv_factor, width//uv_factor))
|
63 |
+
V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
|
64 |
+
reshape((height//uv_factor, width//uv_factor))
|
65 |
+
|
66 |
+
else:
|
67 |
+
datatype = np.uint16
|
68 |
+
stream.seek(iFrame*3*width*height*multiplier)
|
69 |
+
Y = np.fromfile(stream, dtype=datatype, count=width*height).reshape((height, width))
|
70 |
+
|
71 |
+
U = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
|
72 |
+
reshape((height//uv_factor, width//uv_factor))
|
73 |
+
V = np.fromfile(stream, dtype=datatype, count=(width//uv_factor)*(height//uv_factor)).\
|
74 |
+
reshape((height//uv_factor, width//uv_factor))
|
75 |
+
|
76 |
+
if pix_fmt == '420':
|
77 |
+
yuv = np.empty((height*3//2, width), dtype=datatype)
|
78 |
+
yuv[0:height,:] = Y
|
79 |
+
|
80 |
+
yuv[height:height+height//4,:] = U.reshape(-1, width)
|
81 |
+
yuv[height+height//4:,:] = V.reshape(-1, width)
|
82 |
+
|
83 |
+
if bit_depth != 8:
|
84 |
+
yuv = (yuv/(2**bit_depth-1)*255).astype(np.uint8)
|
85 |
+
|
86 |
+
#convert to rgb
|
87 |
+
rgb = cv2.cvtColor(yuv, cv2.COLOR_YUV2RGB_I420)
|
88 |
+
|
89 |
+
else:
|
90 |
+
yvu = np.stack([Y,V,U],axis=2)
|
91 |
+
if bit_depth != 8:
|
92 |
+
yvu = (yvu/(2**bit_depth-1)*255).astype(np.uint8)
|
93 |
+
rgb = cv2.cvtColor(yvu, cv2.COLOR_YCrCb2RGB)
|
94 |
+
|
95 |
+
return rgb
|
opensora/eval/fvd/styleganv/fvd.py
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import torch.nn.functional as F
|
5 |
+
|
6 |
+
# https://github.com/universome/fvd-comparison
|
7 |
+
|
8 |
+
|
9 |
+
def load_i3d_pretrained(device=torch.device('cpu')):
|
10 |
+
i3D_WEIGHTS_URL = "https://www.dropbox.com/s/ge9e5ujwgetktms/i3d_torchscript.pt"
|
11 |
+
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_torchscript.pt')
|
12 |
+
print(filepath)
|
13 |
+
if not os.path.exists(filepath):
|
14 |
+
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
|
15 |
+
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
|
16 |
+
i3d = torch.jit.load(filepath).eval().to(device)
|
17 |
+
i3d = torch.nn.DataParallel(i3d)
|
18 |
+
return i3d
|
19 |
+
|
20 |
+
|
21 |
+
def get_feats(videos, detector, device, bs=10):
|
22 |
+
# videos : torch.tensor BCTHW [0, 1]
|
23 |
+
detector_kwargs = dict(rescale=False, resize=False, return_features=True) # Return raw features before the softmax layer.
|
24 |
+
feats = np.empty((0, 400))
|
25 |
+
with torch.no_grad():
|
26 |
+
for i in range((len(videos)-1)//bs + 1):
|
27 |
+
feats = np.vstack([feats, detector(torch.stack([preprocess_single(video) for video in videos[i*bs:(i+1)*bs]]).to(device), **detector_kwargs).detach().cpu().numpy()])
|
28 |
+
return feats
|
29 |
+
|
30 |
+
|
31 |
+
def get_fvd_feats(videos, i3d, device, bs=10):
|
32 |
+
# videos in [0, 1] as torch tensor BCTHW
|
33 |
+
# videos = [preprocess_single(video) for video in videos]
|
34 |
+
embeddings = get_feats(videos, i3d, device, bs)
|
35 |
+
return embeddings
|
36 |
+
|
37 |
+
|
38 |
+
def preprocess_single(video, resolution=224, sequence_length=None):
|
39 |
+
# video: CTHW, [0, 1]
|
40 |
+
c, t, h, w = video.shape
|
41 |
+
|
42 |
+
# temporal crop
|
43 |
+
if sequence_length is not None:
|
44 |
+
assert sequence_length <= t
|
45 |
+
video = video[:, :sequence_length]
|
46 |
+
|
47 |
+
# scale shorter side to resolution
|
48 |
+
scale = resolution / min(h, w)
|
49 |
+
if h < w:
|
50 |
+
target_size = (resolution, math.ceil(w * scale))
|
51 |
+
else:
|
52 |
+
target_size = (math.ceil(h * scale), resolution)
|
53 |
+
video = F.interpolate(video, size=target_size, mode='bilinear', align_corners=False)
|
54 |
+
|
55 |
+
# center crop
|
56 |
+
c, t, h, w = video.shape
|
57 |
+
w_start = (w - resolution) // 2
|
58 |
+
h_start = (h - resolution) // 2
|
59 |
+
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
|
60 |
+
|
61 |
+
# [0, 1] -> [-1, 1]
|
62 |
+
video = (video - 0.5) * 2
|
63 |
+
|
64 |
+
return video.contiguous()
|
65 |
+
|
66 |
+
|
67 |
+
"""
|
68 |
+
Copy-pasted from https://github.com/cvpr2022-stylegan-v/stylegan-v/blob/main/src/metrics/frechet_video_distance.py
|
69 |
+
"""
|
70 |
+
from typing import Tuple
|
71 |
+
from scipy.linalg import sqrtm
|
72 |
+
import numpy as np
|
73 |
+
|
74 |
+
|
75 |
+
def compute_stats(feats: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
76 |
+
mu = feats.mean(axis=0) # [d]
|
77 |
+
sigma = np.cov(feats, rowvar=False) # [d, d]
|
78 |
+
return mu, sigma
|
79 |
+
|
80 |
+
|
81 |
+
def frechet_distance(feats_fake: np.ndarray, feats_real: np.ndarray) -> float:
|
82 |
+
mu_gen, sigma_gen = compute_stats(feats_fake)
|
83 |
+
mu_real, sigma_real = compute_stats(feats_real)
|
84 |
+
m = np.square(mu_gen - mu_real).sum()
|
85 |
+
if feats_fake.shape[0]>1:
|
86 |
+
s, _ = sqrtm(np.dot(sigma_gen, sigma_real), disp=False) # pylint: disable=no-member
|
87 |
+
fid = np.real(m + np.trace(sigma_gen + sigma_real - s * 2))
|
88 |
+
else:
|
89 |
+
fid = np.real(m)
|
90 |
+
return float(fid)
|
opensora/eval/fvd/videogpt/fvd.py
ADDED
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import os
|
3 |
+
import math
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
import einops
|
7 |
+
|
8 |
+
def load_i3d_pretrained(device=torch.device('cpu')):
|
9 |
+
i3D_WEIGHTS_URL = "https://onedrive.live.com/download?cid=78EEF3EB6AE7DBCB&resid=78EEF3EB6AE7DBCB%21199&authkey=AApKdFHPXzWLNyI"
|
10 |
+
filepath = os.path.join(os.path.dirname(os.path.abspath(__file__)), 'i3d_pretrained_400.pt')
|
11 |
+
print(filepath)
|
12 |
+
if not os.path.exists(filepath):
|
13 |
+
print(f"preparing for download {i3D_WEIGHTS_URL}, you can download it by yourself.")
|
14 |
+
os.system(f"wget {i3D_WEIGHTS_URL} -O {filepath}")
|
15 |
+
from .pytorch_i3d import InceptionI3d
|
16 |
+
i3d = InceptionI3d(400, in_channels=3).eval().to(device)
|
17 |
+
i3d.load_state_dict(torch.load(filepath, map_location=device))
|
18 |
+
i3d = torch.nn.DataParallel(i3d)
|
19 |
+
return i3d
|
20 |
+
|
21 |
+
def preprocess_single(video, resolution, sequence_length=None):
|
22 |
+
# video: THWC, {0, ..., 255}
|
23 |
+
video = video.permute(0, 3, 1, 2).float() / 255. # TCHW
|
24 |
+
t, c, h, w = video.shape
|
25 |
+
|
26 |
+
# temporal crop
|
27 |
+
if sequence_length is not None:
|
28 |
+
assert sequence_length <= t
|
29 |
+
video = video[:sequence_length]
|
30 |
+
|
31 |
+
# scale shorter side to resolution
|
32 |
+
scale = resolution / min(h, w)
|
33 |
+
if h < w:
|
34 |
+
target_size = (resolution, math.ceil(w * scale))
|
35 |
+
else:
|
36 |
+
target_size = (math.ceil(h * scale), resolution)
|
37 |
+
video = F.interpolate(video, size=target_size, mode='bilinear',
|
38 |
+
align_corners=False)
|
39 |
+
|
40 |
+
# center crop
|
41 |
+
t, c, h, w = video.shape
|
42 |
+
w_start = (w - resolution) // 2
|
43 |
+
h_start = (h - resolution) // 2
|
44 |
+
video = video[:, :, h_start:h_start + resolution, w_start:w_start + resolution]
|
45 |
+
video = video.permute(1, 0, 2, 3).contiguous() # CTHW
|
46 |
+
|
47 |
+
video -= 0.5
|
48 |
+
|
49 |
+
return video
|
50 |
+
|
51 |
+
def preprocess(videos, target_resolution=224):
|
52 |
+
# we should tras videos in [0-1] [b c t h w] as th.float
|
53 |
+
# -> videos in {0, ..., 255} [b t h w c] as np.uint8 array
|
54 |
+
videos = einops.rearrange(videos, 'b c t h w -> b t h w c')
|
55 |
+
videos = (videos*255).numpy().astype(np.uint8)
|
56 |
+
|
57 |
+
b, t, h, w, c = videos.shape
|
58 |
+
videos = torch.from_numpy(videos)
|
59 |
+
videos = torch.stack([preprocess_single(video, target_resolution) for video in videos])
|
60 |
+
return videos * 2 # [-0.5, 0.5] -> [-1, 1]
|
61 |
+
|
62 |
+
def get_fvd_logits(videos, i3d, device, bs=10):
|
63 |
+
videos = preprocess(videos)
|
64 |
+
embeddings = get_logits(i3d, videos, device, bs=10)
|
65 |
+
return embeddings
|
66 |
+
|
67 |
+
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L161
|
68 |
+
def _symmetric_matrix_square_root(mat, eps=1e-10):
|
69 |
+
u, s, v = torch.svd(mat)
|
70 |
+
si = torch.where(s < eps, s, torch.sqrt(s))
|
71 |
+
return torch.matmul(torch.matmul(u, torch.diag(si)), v.t())
|
72 |
+
|
73 |
+
# https://github.com/tensorflow/gan/blob/de4b8da3853058ea380a6152bd3bd454013bf619/tensorflow_gan/python/eval/classifier_metrics.py#L400
|
74 |
+
def trace_sqrt_product(sigma, sigma_v):
|
75 |
+
sqrt_sigma = _symmetric_matrix_square_root(sigma)
|
76 |
+
sqrt_a_sigmav_a = torch.matmul(sqrt_sigma, torch.matmul(sigma_v, sqrt_sigma))
|
77 |
+
return torch.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
|
78 |
+
|
79 |
+
# https://discuss.pytorch.org/t/covariance-and-gradient-support/16217/2
|
80 |
+
def cov(m, rowvar=False):
|
81 |
+
'''Estimate a covariance matrix given data.
|
82 |
+
|
83 |
+
Covariance indicates the level to which two variables vary together.
|
84 |
+
If we examine N-dimensional samples, `X = [x_1, x_2, ... x_N]^T`,
|
85 |
+
then the covariance matrix element `C_{ij}` is the covariance of
|
86 |
+
`x_i` and `x_j`. The element `C_{ii}` is the variance of `x_i`.
|
87 |
+
|
88 |
+
Args:
|
89 |
+
m: A 1-D or 2-D array containing multiple variables and observations.
|
90 |
+
Each row of `m` represents a variable, and each column a single
|
91 |
+
observation of all those variables.
|
92 |
+
rowvar: If `rowvar` is True, then each row represents a
|
93 |
+
variable, with observations in the columns. Otherwise, the
|
94 |
+
relationship is transposed: each column represents a variable,
|
95 |
+
while the rows contain observations.
|
96 |
+
|
97 |
+
Returns:
|
98 |
+
The covariance matrix of the variables.
|
99 |
+
'''
|
100 |
+
if m.dim() > 2:
|
101 |
+
raise ValueError('m has more than 2 dimensions')
|
102 |
+
if m.dim() < 2:
|
103 |
+
m = m.view(1, -1)
|
104 |
+
if not rowvar and m.size(0) != 1:
|
105 |
+
m = m.t()
|
106 |
+
|
107 |
+
fact = 1.0 / (m.size(1) - 1) # unbiased estimate
|
108 |
+
m -= torch.mean(m, dim=1, keepdim=True)
|
109 |
+
mt = m.t() # if complex: mt = m.t().conj()
|
110 |
+
return fact * m.matmul(mt).squeeze()
|
111 |
+
|
112 |
+
|
113 |
+
def frechet_distance(x1, x2):
|
114 |
+
x1 = x1.flatten(start_dim=1)
|
115 |
+
x2 = x2.flatten(start_dim=1)
|
116 |
+
m, m_w = x1.mean(dim=0), x2.mean(dim=0)
|
117 |
+
sigma, sigma_w = cov(x1, rowvar=False), cov(x2, rowvar=False)
|
118 |
+
mean = torch.sum((m - m_w) ** 2)
|
119 |
+
if x1.shape[0]>1:
|
120 |
+
sqrt_trace_component = trace_sqrt_product(sigma, sigma_w)
|
121 |
+
trace = torch.trace(sigma + sigma_w) - 2.0 * sqrt_trace_component
|
122 |
+
fd = trace + mean
|
123 |
+
else:
|
124 |
+
fd = np.real(mean)
|
125 |
+
return float(fd)
|
126 |
+
|
127 |
+
|
128 |
+
def get_logits(i3d, videos, device, bs=10):
|
129 |
+
# assert videos.shape[0] % 16 == 0
|
130 |
+
with torch.no_grad():
|
131 |
+
logits = []
|
132 |
+
for i in range(0, videos.shape[0], bs):
|
133 |
+
batch = videos[i:i + bs].to(device)
|
134 |
+
# logits.append(i3d.module.extract_features(batch)) # wrong
|
135 |
+
logits.append(i3d(batch)) # right
|
136 |
+
logits = torch.cat(logits, dim=0)
|
137 |
+
return logits
|
opensora/eval/fvd/videogpt/pytorch_i3d.py
ADDED
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Original code from https://github.com/piergiaj/pytorch-i3d
|
2 |
+
import torch
|
3 |
+
import torch.nn as nn
|
4 |
+
import torch.nn.functional as F
|
5 |
+
import numpy as np
|
6 |
+
|
7 |
+
class MaxPool3dSamePadding(nn.MaxPool3d):
|
8 |
+
|
9 |
+
def compute_pad(self, dim, s):
|
10 |
+
if s % self.stride[dim] == 0:
|
11 |
+
return max(self.kernel_size[dim] - self.stride[dim], 0)
|
12 |
+
else:
|
13 |
+
return max(self.kernel_size[dim] - (s % self.stride[dim]), 0)
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
# compute 'same' padding
|
17 |
+
(batch, channel, t, h, w) = x.size()
|
18 |
+
out_t = np.ceil(float(t) / float(self.stride[0]))
|
19 |
+
out_h = np.ceil(float(h) / float(self.stride[1]))
|
20 |
+
out_w = np.ceil(float(w) / float(self.stride[2]))
|
21 |
+
pad_t = self.compute_pad(0, t)
|
22 |
+
pad_h = self.compute_pad(1, h)
|
23 |
+
pad_w = self.compute_pad(2, w)
|
24 |
+
|
25 |
+
pad_t_f = pad_t // 2
|
26 |
+
pad_t_b = pad_t - pad_t_f
|
27 |
+
pad_h_f = pad_h // 2
|
28 |
+
pad_h_b = pad_h - pad_h_f
|
29 |
+
pad_w_f = pad_w // 2
|
30 |
+
pad_w_b = pad_w - pad_w_f
|
31 |
+
|
32 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
33 |
+
x = F.pad(x, pad)
|
34 |
+
return super(MaxPool3dSamePadding, self).forward(x)
|
35 |
+
|
36 |
+
|
37 |
+
class Unit3D(nn.Module):
|
38 |
+
|
39 |
+
def __init__(self, in_channels,
|
40 |
+
output_channels,
|
41 |
+
kernel_shape=(1, 1, 1),
|
42 |
+
stride=(1, 1, 1),
|
43 |
+
padding=0,
|
44 |
+
activation_fn=F.relu,
|
45 |
+
use_batch_norm=True,
|
46 |
+
use_bias=False,
|
47 |
+
name='unit_3d'):
|
48 |
+
|
49 |
+
"""Initializes Unit3D module."""
|
50 |
+
super(Unit3D, self).__init__()
|
51 |
+
|
52 |
+
self._output_channels = output_channels
|
53 |
+
self._kernel_shape = kernel_shape
|
54 |
+
self._stride = stride
|
55 |
+
self._use_batch_norm = use_batch_norm
|
56 |
+
self._activation_fn = activation_fn
|
57 |
+
self._use_bias = use_bias
|
58 |
+
self.name = name
|
59 |
+
self.padding = padding
|
60 |
+
|
61 |
+
self.conv3d = nn.Conv3d(in_channels=in_channels,
|
62 |
+
out_channels=self._output_channels,
|
63 |
+
kernel_size=self._kernel_shape,
|
64 |
+
stride=self._stride,
|
65 |
+
padding=0, # we always want padding to be 0 here. We will dynamically pad based on input size in forward function
|
66 |
+
bias=self._use_bias)
|
67 |
+
|
68 |
+
if self._use_batch_norm:
|
69 |
+
self.bn = nn.BatchNorm3d(self._output_channels, eps=1e-5, momentum=0.001)
|
70 |
+
|
71 |
+
def compute_pad(self, dim, s):
|
72 |
+
if s % self._stride[dim] == 0:
|
73 |
+
return max(self._kernel_shape[dim] - self._stride[dim], 0)
|
74 |
+
else:
|
75 |
+
return max(self._kernel_shape[dim] - (s % self._stride[dim]), 0)
|
76 |
+
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
# compute 'same' padding
|
80 |
+
(batch, channel, t, h, w) = x.size()
|
81 |
+
out_t = np.ceil(float(t) / float(self._stride[0]))
|
82 |
+
out_h = np.ceil(float(h) / float(self._stride[1]))
|
83 |
+
out_w = np.ceil(float(w) / float(self._stride[2]))
|
84 |
+
pad_t = self.compute_pad(0, t)
|
85 |
+
pad_h = self.compute_pad(1, h)
|
86 |
+
pad_w = self.compute_pad(2, w)
|
87 |
+
|
88 |
+
pad_t_f = pad_t // 2
|
89 |
+
pad_t_b = pad_t - pad_t_f
|
90 |
+
pad_h_f = pad_h // 2
|
91 |
+
pad_h_b = pad_h - pad_h_f
|
92 |
+
pad_w_f = pad_w // 2
|
93 |
+
pad_w_b = pad_w - pad_w_f
|
94 |
+
|
95 |
+
pad = (pad_w_f, pad_w_b, pad_h_f, pad_h_b, pad_t_f, pad_t_b)
|
96 |
+
x = F.pad(x, pad)
|
97 |
+
|
98 |
+
x = self.conv3d(x)
|
99 |
+
if self._use_batch_norm:
|
100 |
+
x = self.bn(x)
|
101 |
+
if self._activation_fn is not None:
|
102 |
+
x = self._activation_fn(x)
|
103 |
+
return x
|
104 |
+
|
105 |
+
|
106 |
+
|
107 |
+
class InceptionModule(nn.Module):
|
108 |
+
def __init__(self, in_channels, out_channels, name):
|
109 |
+
super(InceptionModule, self).__init__()
|
110 |
+
|
111 |
+
self.b0 = Unit3D(in_channels=in_channels, output_channels=out_channels[0], kernel_shape=[1, 1, 1], padding=0,
|
112 |
+
name=name+'/Branch_0/Conv3d_0a_1x1')
|
113 |
+
self.b1a = Unit3D(in_channels=in_channels, output_channels=out_channels[1], kernel_shape=[1, 1, 1], padding=0,
|
114 |
+
name=name+'/Branch_1/Conv3d_0a_1x1')
|
115 |
+
self.b1b = Unit3D(in_channels=out_channels[1], output_channels=out_channels[2], kernel_shape=[3, 3, 3],
|
116 |
+
name=name+'/Branch_1/Conv3d_0b_3x3')
|
117 |
+
self.b2a = Unit3D(in_channels=in_channels, output_channels=out_channels[3], kernel_shape=[1, 1, 1], padding=0,
|
118 |
+
name=name+'/Branch_2/Conv3d_0a_1x1')
|
119 |
+
self.b2b = Unit3D(in_channels=out_channels[3], output_channels=out_channels[4], kernel_shape=[3, 3, 3],
|
120 |
+
name=name+'/Branch_2/Conv3d_0b_3x3')
|
121 |
+
self.b3a = MaxPool3dSamePadding(kernel_size=[3, 3, 3],
|
122 |
+
stride=(1, 1, 1), padding=0)
|
123 |
+
self.b3b = Unit3D(in_channels=in_channels, output_channels=out_channels[5], kernel_shape=[1, 1, 1], padding=0,
|
124 |
+
name=name+'/Branch_3/Conv3d_0b_1x1')
|
125 |
+
self.name = name
|
126 |
+
|
127 |
+
def forward(self, x):
|
128 |
+
b0 = self.b0(x)
|
129 |
+
b1 = self.b1b(self.b1a(x))
|
130 |
+
b2 = self.b2b(self.b2a(x))
|
131 |
+
b3 = self.b3b(self.b3a(x))
|
132 |
+
return torch.cat([b0,b1,b2,b3], dim=1)
|
133 |
+
|
134 |
+
|
135 |
+
class InceptionI3d(nn.Module):
|
136 |
+
"""Inception-v1 I3D architecture.
|
137 |
+
The model is introduced in:
|
138 |
+
Quo Vadis, Action Recognition? A New Model and the Kinetics Dataset
|
139 |
+
Joao Carreira, Andrew Zisserman
|
140 |
+
https://arxiv.org/pdf/1705.07750v1.pdf.
|
141 |
+
See also the Inception architecture, introduced in:
|
142 |
+
Going deeper with convolutions
|
143 |
+
Christian Szegedy, Wei Liu, Yangqing Jia, Pierre Sermanet, Scott Reed,
|
144 |
+
Dragomir Anguelov, Dumitru Erhan, Vincent Vanhoucke, Andrew Rabinovich.
|
145 |
+
http://arxiv.org/pdf/1409.4842v1.pdf.
|
146 |
+
"""
|
147 |
+
|
148 |
+
# Endpoints of the model in order. During construction, all the endpoints up
|
149 |
+
# to a designated `final_endpoint` are returned in a dictionary as the
|
150 |
+
# second return value.
|
151 |
+
VALID_ENDPOINTS = (
|
152 |
+
'Conv3d_1a_7x7',
|
153 |
+
'MaxPool3d_2a_3x3',
|
154 |
+
'Conv3d_2b_1x1',
|
155 |
+
'Conv3d_2c_3x3',
|
156 |
+
'MaxPool3d_3a_3x3',
|
157 |
+
'Mixed_3b',
|
158 |
+
'Mixed_3c',
|
159 |
+
'MaxPool3d_4a_3x3',
|
160 |
+
'Mixed_4b',
|
161 |
+
'Mixed_4c',
|
162 |
+
'Mixed_4d',
|
163 |
+
'Mixed_4e',
|
164 |
+
'Mixed_4f',
|
165 |
+
'MaxPool3d_5a_2x2',
|
166 |
+
'Mixed_5b',
|
167 |
+
'Mixed_5c',
|
168 |
+
'Logits',
|
169 |
+
'Predictions',
|
170 |
+
)
|
171 |
+
|
172 |
+
def __init__(self, num_classes=400, spatial_squeeze=True,
|
173 |
+
final_endpoint='Logits', name='inception_i3d', in_channels=3, dropout_keep_prob=0.5):
|
174 |
+
"""Initializes I3D model instance.
|
175 |
+
Args:
|
176 |
+
num_classes: The number of outputs in the logit layer (default 400, which
|
177 |
+
matches the Kinetics dataset).
|
178 |
+
spatial_squeeze: Whether to squeeze the spatial dimensions for the logits
|
179 |
+
before returning (default True).
|
180 |
+
final_endpoint: The model contains many possible endpoints.
|
181 |
+
`final_endpoint` specifies the last endpoint for the model to be built
|
182 |
+
up to. In addition to the output at `final_endpoint`, all the outputs
|
183 |
+
at endpoints up to `final_endpoint` will also be returned, in a
|
184 |
+
dictionary. `final_endpoint` must be one of
|
185 |
+
InceptionI3d.VALID_ENDPOINTS (default 'Logits').
|
186 |
+
name: A string (optional). The name of this module.
|
187 |
+
Raises:
|
188 |
+
ValueError: if `final_endpoint` is not recognized.
|
189 |
+
"""
|
190 |
+
|
191 |
+
if final_endpoint not in self.VALID_ENDPOINTS:
|
192 |
+
raise ValueError('Unknown final endpoint %s' % final_endpoint)
|
193 |
+
|
194 |
+
super(InceptionI3d, self).__init__()
|
195 |
+
self._num_classes = num_classes
|
196 |
+
self._spatial_squeeze = spatial_squeeze
|
197 |
+
self._final_endpoint = final_endpoint
|
198 |
+
self.logits = None
|
199 |
+
|
200 |
+
if self._final_endpoint not in self.VALID_ENDPOINTS:
|
201 |
+
raise ValueError('Unknown final endpoint %s' % self._final_endpoint)
|
202 |
+
|
203 |
+
self.end_points = {}
|
204 |
+
end_point = 'Conv3d_1a_7x7'
|
205 |
+
self.end_points[end_point] = Unit3D(in_channels=in_channels, output_channels=64, kernel_shape=[7, 7, 7],
|
206 |
+
stride=(2, 2, 2), padding=(3,3,3), name=name+end_point)
|
207 |
+
if self._final_endpoint == end_point: return
|
208 |
+
|
209 |
+
end_point = 'MaxPool3d_2a_3x3'
|
210 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
|
211 |
+
padding=0)
|
212 |
+
if self._final_endpoint == end_point: return
|
213 |
+
|
214 |
+
end_point = 'Conv3d_2b_1x1'
|
215 |
+
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=64, kernel_shape=[1, 1, 1], padding=0,
|
216 |
+
name=name+end_point)
|
217 |
+
if self._final_endpoint == end_point: return
|
218 |
+
|
219 |
+
end_point = 'Conv3d_2c_3x3'
|
220 |
+
self.end_points[end_point] = Unit3D(in_channels=64, output_channels=192, kernel_shape=[3, 3, 3], padding=1,
|
221 |
+
name=name+end_point)
|
222 |
+
if self._final_endpoint == end_point: return
|
223 |
+
|
224 |
+
end_point = 'MaxPool3d_3a_3x3'
|
225 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[1, 3, 3], stride=(1, 2, 2),
|
226 |
+
padding=0)
|
227 |
+
if self._final_endpoint == end_point: return
|
228 |
+
|
229 |
+
end_point = 'Mixed_3b'
|
230 |
+
self.end_points[end_point] = InceptionModule(192, [64,96,128,16,32,32], name+end_point)
|
231 |
+
if self._final_endpoint == end_point: return
|
232 |
+
|
233 |
+
end_point = 'Mixed_3c'
|
234 |
+
self.end_points[end_point] = InceptionModule(256, [128,128,192,32,96,64], name+end_point)
|
235 |
+
if self._final_endpoint == end_point: return
|
236 |
+
|
237 |
+
end_point = 'MaxPool3d_4a_3x3'
|
238 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[3, 3, 3], stride=(2, 2, 2),
|
239 |
+
padding=0)
|
240 |
+
if self._final_endpoint == end_point: return
|
241 |
+
|
242 |
+
end_point = 'Mixed_4b'
|
243 |
+
self.end_points[end_point] = InceptionModule(128+192+96+64, [192,96,208,16,48,64], name+end_point)
|
244 |
+
if self._final_endpoint == end_point: return
|
245 |
+
|
246 |
+
end_point = 'Mixed_4c'
|
247 |
+
self.end_points[end_point] = InceptionModule(192+208+48+64, [160,112,224,24,64,64], name+end_point)
|
248 |
+
if self._final_endpoint == end_point: return
|
249 |
+
|
250 |
+
end_point = 'Mixed_4d'
|
251 |
+
self.end_points[end_point] = InceptionModule(160+224+64+64, [128,128,256,24,64,64], name+end_point)
|
252 |
+
if self._final_endpoint == end_point: return
|
253 |
+
|
254 |
+
end_point = 'Mixed_4e'
|
255 |
+
self.end_points[end_point] = InceptionModule(128+256+64+64, [112,144,288,32,64,64], name+end_point)
|
256 |
+
if self._final_endpoint == end_point: return
|
257 |
+
|
258 |
+
end_point = 'Mixed_4f'
|
259 |
+
self.end_points[end_point] = InceptionModule(112+288+64+64, [256,160,320,32,128,128], name+end_point)
|
260 |
+
if self._final_endpoint == end_point: return
|
261 |
+
|
262 |
+
end_point = 'MaxPool3d_5a_2x2'
|
263 |
+
self.end_points[end_point] = MaxPool3dSamePadding(kernel_size=[2, 2, 2], stride=(2, 2, 2),
|
264 |
+
padding=0)
|
265 |
+
if self._final_endpoint == end_point: return
|
266 |
+
|
267 |
+
end_point = 'Mixed_5b'
|
268 |
+
self.end_points[end_point] = InceptionModule(256+320+128+128, [256,160,320,32,128,128], name+end_point)
|
269 |
+
if self._final_endpoint == end_point: return
|
270 |
+
|
271 |
+
end_point = 'Mixed_5c'
|
272 |
+
self.end_points[end_point] = InceptionModule(256+320+128+128, [384,192,384,48,128,128], name+end_point)
|
273 |
+
if self._final_endpoint == end_point: return
|
274 |
+
|
275 |
+
end_point = 'Logits'
|
276 |
+
self.avg_pool = nn.AvgPool3d(kernel_size=[2, 7, 7],
|
277 |
+
stride=(1, 1, 1))
|
278 |
+
self.dropout = nn.Dropout(dropout_keep_prob)
|
279 |
+
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
|
280 |
+
kernel_shape=[1, 1, 1],
|
281 |
+
padding=0,
|
282 |
+
activation_fn=None,
|
283 |
+
use_batch_norm=False,
|
284 |
+
use_bias=True,
|
285 |
+
name='logits')
|
286 |
+
|
287 |
+
self.build()
|
288 |
+
|
289 |
+
|
290 |
+
def replace_logits(self, num_classes):
|
291 |
+
self._num_classes = num_classes
|
292 |
+
self.logits = Unit3D(in_channels=384+384+128+128, output_channels=self._num_classes,
|
293 |
+
kernel_shape=[1, 1, 1],
|
294 |
+
padding=0,
|
295 |
+
activation_fn=None,
|
296 |
+
use_batch_norm=False,
|
297 |
+
use_bias=True,
|
298 |
+
name='logits')
|
299 |
+
|
300 |
+
|
301 |
+
def build(self):
|
302 |
+
for k in self.end_points.keys():
|
303 |
+
self.add_module(k, self.end_points[k])
|
304 |
+
|
305 |
+
def forward(self, x):
|
306 |
+
for end_point in self.VALID_ENDPOINTS:
|
307 |
+
if end_point in self.end_points:
|
308 |
+
x = self._modules[end_point](x) # use _modules to work with dataparallel
|
309 |
+
|
310 |
+
x = self.logits(self.dropout(self.avg_pool(x)))
|
311 |
+
if self._spatial_squeeze:
|
312 |
+
logits = x.squeeze(3).squeeze(3)
|
313 |
+
logits = logits.mean(dim=2)
|
314 |
+
# logits is batch X time X classes, which is what we want to work with
|
315 |
+
return logits
|
316 |
+
|
317 |
+
|
318 |
+
def extract_features(self, x):
|
319 |
+
for end_point in self.VALID_ENDPOINTS:
|
320 |
+
if end_point in self.end_points:
|
321 |
+
x = self._modules[end_point](x)
|
322 |
+
return self.avg_pool(x)
|
opensora/eval/script/cal_clip_score.sh
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# clip_score cross modality
|
2 |
+
python eval_clip_score.py \
|
3 |
+
--real_path path/to/image \
|
4 |
+
--generated_path path/to/text \
|
5 |
+
--batch-size 50 \
|
6 |
+
--device "cuda"
|
7 |
+
|
8 |
+
# clip_score within the same modality
|
9 |
+
python eval_clip_score.py \
|
10 |
+
--real_path path/to/textA \
|
11 |
+
--generated_path path/to/textB \
|
12 |
+
--real_flag txt \
|
13 |
+
--generated_flag txt \
|
14 |
+
--batch-size 50 \
|
15 |
+
--device "cuda"
|
16 |
+
|
17 |
+
python eval_clip_score.py \
|
18 |
+
--real_path path/to/imageA \
|
19 |
+
--generated_path path/to/imageB \
|
20 |
+
--real_flag img \
|
21 |
+
--generated_flag img \
|
22 |
+
--batch-size 50 \
|
23 |
+
--device "cuda"
|
opensora/eval/script/cal_fvd.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python eval_common_metric.py \
|
2 |
+
--real_video_dir path/to/imageA\
|
3 |
+
--generated_video_dir path/to/imageB \
|
4 |
+
--batch_size 10 \
|
5 |
+
--crop_size 64 \
|
6 |
+
--num_frames 20 \
|
7 |
+
--device 'cuda' \
|
8 |
+
--metric 'fvd' \
|
9 |
+
--fvd_method 'styleganv'
|
opensora/eval/script/cal_lpips.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python eval_common_metric.py \
|
2 |
+
--real_video_dir path/to/imageA\
|
3 |
+
--generated_video_dir path/to/imageB \
|
4 |
+
--batch_size 10 \
|
5 |
+
--num_frames 20 \
|
6 |
+
--crop_size 64 \
|
7 |
+
--device 'cuda' \
|
8 |
+
--metric 'lpips'
|
opensora/eval/script/cal_psnr.sh
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
python eval_common_metric.py \
|
3 |
+
--real_video_dir /data/xiaogeng_liu/data/video1 \
|
4 |
+
--generated_video_dir /data/xiaogeng_liu/data/video2 \
|
5 |
+
--batch_size 10 \
|
6 |
+
--num_frames 20 \
|
7 |
+
--crop_size 64 \
|
8 |
+
--device 'cuda' \
|
9 |
+
--metric 'psnr'
|
opensora/eval/script/cal_ssim.sh
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
python eval_common_metric.py \
|
2 |
+
--real_video_dir /data/xiaogeng_liu/data/video1 \
|
3 |
+
--generated_video_dir /data/xiaogeng_liu/data/video2 \
|
4 |
+
--batch_size 10 \
|
5 |
+
--num_frames 20 \
|
6 |
+
--crop_size 64 \
|
7 |
+
--device 'cuda' \
|
8 |
+
--metric 'ssim'
|