ericmagalhaes
commited on
Commit
•
ed72c0e
1
Parent(s):
49a817b
initial commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +174 -0
- LICENSE +209 -0
- __init__.py +1 -0
- app.py +287 -0
- app_sadtalker.py +111 -0
- cog.yaml +35 -0
- inference.py +145 -0
- launcher.py +204 -0
- predict.py +192 -0
- quick_demo.ipynb +221 -0
- req.txt +22 -0
- requirements.txt +30 -0
- requirements3d.txt +21 -0
- scripts/download_models.sh +15 -0
- src/audio2exp_models/audio2exp.py +41 -0
- src/audio2exp_models/networks.py +74 -0
- src/audio2pose_models/audio2pose.py +94 -0
- src/audio2pose_models/audio_encoder.py +64 -0
- src/audio2pose_models/cvae.py +149 -0
- src/audio2pose_models/discriminator.py +76 -0
- src/audio2pose_models/networks.py +140 -0
- src/audio2pose_models/res_unet.py +65 -0
- src/config/auido2exp.yaml +58 -0
- src/config/auido2pose.yaml +49 -0
- src/config/facerender.yaml +45 -0
- src/config/facerender_still.yaml +45 -0
- src/config/similarity_Lm3D_all.mat +0 -0
- src/face3d/data/__init__.py +116 -0
- src/face3d/data/base_dataset.py +125 -0
- src/face3d/data/flist_dataset.py +125 -0
- src/face3d/data/image_folder.py +66 -0
- src/face3d/data/template_dataset.py +75 -0
- src/face3d/extract_kp_videos.py +108 -0
- src/face3d/extract_kp_videos_safe.py +151 -0
- src/face3d/models/__init__.py +67 -0
- src/face3d/models/arcface_torch/README.md +164 -0
- src/face3d/models/arcface_torch/backbones/__init__.py +25 -0
- src/face3d/models/arcface_torch/backbones/iresnet.py +187 -0
- src/face3d/models/arcface_torch/backbones/iresnet2060.py +176 -0
- src/face3d/models/arcface_torch/backbones/mobilefacenet.py +130 -0
- src/face3d/models/arcface_torch/configs/3millions.py +23 -0
- src/face3d/models/arcface_torch/configs/3millions_pfc.py +23 -0
- src/face3d/models/arcface_torch/configs/__init__.py +0 -0
- src/face3d/models/arcface_torch/configs/base.py +56 -0
- src/face3d/models/arcface_torch/configs/glint360k_mbf.py +26 -0
- src/face3d/models/arcface_torch/configs/glint360k_r100.py +26 -0
- src/face3d/models/arcface_torch/configs/glint360k_r18.py +26 -0
- src/face3d/models/arcface_torch/configs/glint360k_r34.py +26 -0
- src/face3d/models/arcface_torch/configs/glint360k_r50.py +26 -0
- src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py +26 -0
.gitignore
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
.idea/
|
161 |
+
|
162 |
+
examples/results/*
|
163 |
+
gfpgan/*
|
164 |
+
checkpoints/*
|
165 |
+
assets/*
|
166 |
+
results/*
|
167 |
+
Dockerfile
|
168 |
+
start_docker.sh
|
169 |
+
start.sh
|
170 |
+
|
171 |
+
checkpoints
|
172 |
+
|
173 |
+
# Mac
|
174 |
+
.DS_Store
|
LICENSE
ADDED
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Tencent is pleased to support the open source community by making SadTalker available.
|
2 |
+
|
3 |
+
Copyright (C), a Tencent company. All rights reserved.
|
4 |
+
|
5 |
+
SadTalker is licensed under the Apache 2.0 License, except for the third-party components listed below.
|
6 |
+
|
7 |
+
Terms of the Apache License Version 2.0:
|
8 |
+
---------------------------------------------
|
9 |
+
Apache License
|
10 |
+
Version 2.0, January 2004
|
11 |
+
http://www.apache.org/licenses/
|
12 |
+
|
13 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
14 |
+
|
15 |
+
1. Definitions.
|
16 |
+
|
17 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
18 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
19 |
+
|
20 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
21 |
+
the copyright owner that is granting the License.
|
22 |
+
|
23 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
24 |
+
other entities that control, are controlled by, or are under common
|
25 |
+
control with that entity. For the purposes of this definition,
|
26 |
+
"control" means (i) the power, direct or indirect, to cause the
|
27 |
+
direction or management of such entity, whether by contract or
|
28 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
29 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
30 |
+
|
31 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
32 |
+
exercising permissions granted by this License.
|
33 |
+
|
34 |
+
"Source" form shall mean the preferred form for making modifications,
|
35 |
+
including but not limited to software source code, documentation
|
36 |
+
source, and configuration files.
|
37 |
+
|
38 |
+
"Object" form shall mean any form resulting from mechanical
|
39 |
+
transformation or translation of a Source form, including but
|
40 |
+
not limited to compiled object code, generated documentation,
|
41 |
+
and conversions to other media types.
|
42 |
+
|
43 |
+
"Work" shall mean the work of authorship, whether in Source or
|
44 |
+
Object form, made available under the License, as indicated by a
|
45 |
+
copyright notice that is included in or attached to the work
|
46 |
+
(an example is provided in the Appendix below).
|
47 |
+
|
48 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
49 |
+
form, that is based on (or derived from) the Work and for which the
|
50 |
+
editorial revisions, annotations, elaborations, or other modifications
|
51 |
+
represent, as a whole, an original work of authorship. For the purposes
|
52 |
+
of this License, Derivative Works shall not include works that remain
|
53 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
54 |
+
the Work and Derivative Works thereof.
|
55 |
+
|
56 |
+
"Contribution" shall mean any work of authorship, including
|
57 |
+
the original version of the Work and any modifications or additions
|
58 |
+
to that Work or Derivative Works thereof, that is intentionally
|
59 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
60 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
61 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
62 |
+
means any form of electronic, verbal, or written communication sent
|
63 |
+
to the Licensor or its representatives, including but not limited to
|
64 |
+
communication on electronic mailing lists, source code control systems,
|
65 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
66 |
+
Licensor for the purpose of discussing and improving the Work, but
|
67 |
+
excluding communication that is conspicuously marked or otherwise
|
68 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
69 |
+
|
70 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
71 |
+
on behalf of whom a Contribution has been received by Licensor and
|
72 |
+
subsequently incorporated within the Work.
|
73 |
+
|
74 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
77 |
+
copyright license to reproduce, prepare Derivative Works of,
|
78 |
+
publicly display, publicly perform, sublicense, and distribute the
|
79 |
+
Work and such Derivative Works in Source or Object form.
|
80 |
+
|
81 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
82 |
+
this License, each Contributor hereby grants to You a perpetual,
|
83 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
84 |
+
(except as stated in this section) patent license to make, have made,
|
85 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
86 |
+
where such license applies only to those patent claims licensable
|
87 |
+
by such Contributor that are necessarily infringed by their
|
88 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
89 |
+
with the Work to which such Contribution(s) was submitted. If You
|
90 |
+
institute patent litigation against any entity (including a
|
91 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
92 |
+
or a Contribution incorporated within the Work constitutes direct
|
93 |
+
or contributory patent infringement, then any patent licenses
|
94 |
+
granted to You under this License for that Work shall terminate
|
95 |
+
as of the date such litigation is filed.
|
96 |
+
|
97 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
98 |
+
Work or Derivative Works thereof in any medium, with or without
|
99 |
+
modifications, and in Source or Object form, provided that You
|
100 |
+
meet the following conditions:
|
101 |
+
|
102 |
+
(a) You must give any other recipients of the Work or
|
103 |
+
Derivative Works a copy of this License; and
|
104 |
+
|
105 |
+
(b) You must cause any modified files to carry prominent notices
|
106 |
+
stating that You changed the files; and
|
107 |
+
|
108 |
+
(c) You must retain, in the Source form of any Derivative Works
|
109 |
+
that You distribute, all copyright, patent, trademark, and
|
110 |
+
attribution notices from the Source form of the Work,
|
111 |
+
excluding those notices that do not pertain to any part of
|
112 |
+
the Derivative Works; and
|
113 |
+
|
114 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
115 |
+
distribution, then any Derivative Works that You distribute must
|
116 |
+
include a readable copy of the attribution notices contained
|
117 |
+
within such NOTICE file, excluding those notices that do not
|
118 |
+
pertain to any part of the Derivative Works, in at least one
|
119 |
+
of the following places: within a NOTICE text file distributed
|
120 |
+
as part of the Derivative Works; within the Source form or
|
121 |
+
documentation, if provided along with the Derivative Works; or,
|
122 |
+
within a display generated by the Derivative Works, if and
|
123 |
+
wherever such third-party notices normally appear. The contents
|
124 |
+
of the NOTICE file are for informational purposes only and
|
125 |
+
do not modify the License. You may add Your own attribution
|
126 |
+
notices within Derivative Works that You distribute, alongside
|
127 |
+
or as an addendum to the NOTICE text from the Work, provided
|
128 |
+
that such additional attribution notices cannot be construed
|
129 |
+
as modifying the License.
|
130 |
+
|
131 |
+
You may add Your own copyright statement to Your modifications and
|
132 |
+
may provide additional or different license terms and conditions
|
133 |
+
for use, reproduction, or distribution of Your modifications, or
|
134 |
+
for any such Derivative Works as a whole, provided Your use,
|
135 |
+
reproduction, and distribution of the Work otherwise complies with
|
136 |
+
the conditions stated in this License.
|
137 |
+
|
138 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
139 |
+
any Contribution intentionally submitted for inclusion in the Work
|
140 |
+
by You to the Licensor shall be under the terms and conditions of
|
141 |
+
this License, without any additional terms or conditions.
|
142 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
143 |
+
the terms of any separate license agreement you may have executed
|
144 |
+
with Licensor regarding such Contributions.
|
145 |
+
|
146 |
+
6. Trademarks. This License does not grant permission to use the trade
|
147 |
+
names, trademarks, service marks, or product names of the Licensor,
|
148 |
+
except as required for reasonable and customary use in describing the
|
149 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
150 |
+
|
151 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
152 |
+
agreed to in writing, Licensor provides the Work (and each
|
153 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
154 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
155 |
+
implied, including, without limitation, any warranties or conditions
|
156 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
157 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
158 |
+
appropriateness of using or redistributing the Work and assume any
|
159 |
+
risks associated with Your exercise of permissions under this License.
|
160 |
+
|
161 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
162 |
+
whether in tort (including negligence), contract, or otherwise,
|
163 |
+
unless required by applicable law (such as deliberate and grossly
|
164 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
165 |
+
liable to You for damages, including any direct, indirect, special,
|
166 |
+
incidental, or consequential damages of any character arising as a
|
167 |
+
result of this License or out of the use or inability to use the
|
168 |
+
Work (including but not limited to damages for loss of goodwill,
|
169 |
+
work stoppage, computer failure or malfunction, or any and all
|
170 |
+
other commercial damages or losses), even if such Contributor
|
171 |
+
has been advised of the possibility of such damages.
|
172 |
+
|
173 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
174 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
175 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
176 |
+
or other liability obligations and/or rights consistent with this
|
177 |
+
License. However, in accepting such obligations, You may act only
|
178 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
179 |
+
of any other Contributor, and only if You agree to indemnify,
|
180 |
+
defend, and hold each Contributor harmless for any liability
|
181 |
+
incurred by, or claims asserted against, such Contributor by reason
|
182 |
+
of your accepting any such warranty or additional liability.
|
183 |
+
|
184 |
+
END OF TERMS AND CONDITIONS
|
185 |
+
|
186 |
+
APPENDIX: How to apply the Apache License to your work.
|
187 |
+
|
188 |
+
To apply the Apache License to your work, attach the following
|
189 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
190 |
+
replaced with your own identifying information. (Don't include
|
191 |
+
the brackets!) The text should be enclosed in the appropriate
|
192 |
+
comment syntax for the file format. We also recommend that a
|
193 |
+
file or class name and description of purpose be included on the
|
194 |
+
same "printed page" as the copyright notice for easier
|
195 |
+
identification within third-party archives.
|
196 |
+
|
197 |
+
Copyright [yyyy] [name of copyright owner]
|
198 |
+
|
199 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
200 |
+
you may not use this file except in compliance with the License.
|
201 |
+
You may obtain a copy of the License at
|
202 |
+
|
203 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
204 |
+
|
205 |
+
Unless required by applicable law or agreed to in writing, software
|
206 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
207 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
208 |
+
See the License for the specific language governing permissions and
|
209 |
+
limitations under the License.
|
__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
|
app.py
ADDED
@@ -0,0 +1,287 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from flask import Flask, request, jsonify
|
2 |
+
import torch
|
3 |
+
import shutil
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
from time import strftime
|
8 |
+
from argparse import Namespace
|
9 |
+
from src.utils.preprocess import CropAndExtract
|
10 |
+
from src.test_audio2coeff import Audio2Coeff
|
11 |
+
from src.facerender.animate import AnimateFromCoeff
|
12 |
+
from src.generate_batch import get_data
|
13 |
+
from src.generate_facerender_batch import get_facerender_data
|
14 |
+
from src.utils.init_path import init_path
|
15 |
+
import tempfile
|
16 |
+
from openai import OpenAI
|
17 |
+
import threading
|
18 |
+
import elevenlabs
|
19 |
+
from elevenlabs import set_api_key, generate, play, clone
|
20 |
+
from flask_cors import CORS, cross_origin
|
21 |
+
from flask_swagger_ui import get_swaggerui_blueprint
|
22 |
+
import uuid
|
23 |
+
|
24 |
+
class AnimationConfig:
|
25 |
+
def __init__(self, driven_audio_path, source_image_path, result_folder):
|
26 |
+
self.driven_audio = driven_audio_path
|
27 |
+
self.source_image = source_image_path
|
28 |
+
self.ref_eyeblink = None
|
29 |
+
self.ref_pose = None
|
30 |
+
self.checkpoint_dir = './checkpoints'
|
31 |
+
self.result_dir = result_folder
|
32 |
+
self.pose_style = 1
|
33 |
+
self.batch_size = 1
|
34 |
+
self.size = 256
|
35 |
+
self.expression_scale = 1
|
36 |
+
self.input_yaw = None
|
37 |
+
self.input_pitch = None
|
38 |
+
self.input_roll = None
|
39 |
+
self.enhancer = 'gfpgan'
|
40 |
+
self.background_enhancer = None
|
41 |
+
self.cpu = False
|
42 |
+
self.face3dvis = False
|
43 |
+
self.still = False
|
44 |
+
self.preprocess = 'crop'
|
45 |
+
self.verbose = False
|
46 |
+
self.old_version = False
|
47 |
+
self.net_recon = 'resnet50'
|
48 |
+
self.init_path = None
|
49 |
+
self.use_last_fc = False
|
50 |
+
self.bfm_folder = './checkpoints/BFM_Fitting/'
|
51 |
+
self.bfm_model = 'BFM_model_front.mat'
|
52 |
+
self.focal = 1015.
|
53 |
+
self.center = 112.
|
54 |
+
self.camera_d = 10.
|
55 |
+
self.z_near = 5.
|
56 |
+
self.z_far = 15.
|
57 |
+
self.device = 'cpu'
|
58 |
+
|
59 |
+
# Define the blueprint
|
60 |
+
SWAGGER_URL="/swagger"
|
61 |
+
API_URL="/static/swagger.json"
|
62 |
+
|
63 |
+
swagger_ui_blueprint = get_swaggerui_blueprint(
|
64 |
+
SWAGGER_URL,
|
65 |
+
API_URL,
|
66 |
+
config={
|
67 |
+
'app_name': 'Access API'
|
68 |
+
}
|
69 |
+
)
|
70 |
+
|
71 |
+
app = Flask(__name__)
|
72 |
+
CORS(app)
|
73 |
+
app.register_blueprint(swagger_ui_blueprint, url_prefix=SWAGGER_URL)
|
74 |
+
|
75 |
+
app.config['temp_response'] = None
|
76 |
+
app.config['generation_thread'] = None
|
77 |
+
|
78 |
+
TEMP_DIR = tempfile.TemporaryDirectory()
|
79 |
+
|
80 |
+
|
81 |
+
def main(args):
|
82 |
+
pic_path = args.source_image
|
83 |
+
audio_path = args.driven_audio
|
84 |
+
save_dir = args.result_dir
|
85 |
+
# save_dir = os.path.join(args.result_folder, strftime("%Y_%m_%d_%H.%M.%S"))
|
86 |
+
# os.makedirs(save_dir, exist_ok=True)
|
87 |
+
print('save_dir',save_dir)
|
88 |
+
pose_style = args.pose_style
|
89 |
+
device = args.device
|
90 |
+
batch_size = args.batch_size
|
91 |
+
input_yaw_list = args.input_yaw
|
92 |
+
input_pitch_list = args.input_pitch
|
93 |
+
input_roll_list = args.input_roll
|
94 |
+
ref_eyeblink = args.ref_eyeblink
|
95 |
+
ref_pose = args.ref_pose
|
96 |
+
|
97 |
+
dir_path = os.path.dirname(os.path.realpath(__file__))
|
98 |
+
current_root_path = dir_path
|
99 |
+
print('current_root_path ',current_root_path)
|
100 |
+
|
101 |
+
sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
|
102 |
+
print('sadtalker_paths ',sadtalker_paths)
|
103 |
+
|
104 |
+
|
105 |
+
|
106 |
+
preprocess_model = CropAndExtract(sadtalker_paths, device)
|
107 |
+
|
108 |
+
audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
|
109 |
+
animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
|
110 |
+
|
111 |
+
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
|
112 |
+
os.makedirs(first_frame_dir, exist_ok=True)
|
113 |
+
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
|
114 |
+
source_image_flag=True, pic_size=args.size)
|
115 |
+
|
116 |
+
print('first_coeff_path ',first_coeff_path)
|
117 |
+
print('crop_pic_path ',crop_pic_path)
|
118 |
+
|
119 |
+
|
120 |
+
if first_coeff_path is None:
|
121 |
+
print("Can't get the coeffs of the input")
|
122 |
+
return
|
123 |
+
|
124 |
+
if ref_eyeblink is not None:
|
125 |
+
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
|
126 |
+
ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
|
127 |
+
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
|
128 |
+
ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
|
129 |
+
else:
|
130 |
+
ref_eyeblink_coeff_path=None
|
131 |
+
print('ref_eyeblink_coeff_path',ref_eyeblink_coeff_path)
|
132 |
+
|
133 |
+
if ref_pose is not None:
|
134 |
+
if ref_pose == ref_eyeblink:
|
135 |
+
ref_pose_coeff_path = ref_eyeblink_coeff_path
|
136 |
+
else:
|
137 |
+
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
|
138 |
+
ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
|
139 |
+
os.makedirs(ref_pose_frame_dir, exist_ok=True)
|
140 |
+
ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
|
141 |
+
else:
|
142 |
+
ref_pose_coeff_path=None
|
143 |
+
print('ref_eyeblink_coeff_path',ref_pose_coeff_path)
|
144 |
+
|
145 |
+
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
|
146 |
+
print('batch',batch)
|
147 |
+
coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
|
148 |
+
|
149 |
+
if args.face3dvis:
|
150 |
+
from src.face3d.visualize import gen_composed_video
|
151 |
+
gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
|
152 |
+
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
|
153 |
+
batch_size, input_yaw_list, input_pitch_list, input_roll_list,
|
154 |
+
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
|
155 |
+
|
156 |
+
print('data ',data)
|
157 |
+
print('save_dir ', save_dir)
|
158 |
+
print('pic_path ',pic_path)
|
159 |
+
print('crop ',crop_info)
|
160 |
+
|
161 |
+
result, base64_video = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
|
162 |
+
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
|
163 |
+
|
164 |
+
|
165 |
+
print('The generated video is named:')
|
166 |
+
app.config['temp_response'] = base64_video
|
167 |
+
|
168 |
+
return base64_video
|
169 |
+
|
170 |
+
# shutil.move(result, save_dir+'.mp4')
|
171 |
+
|
172 |
+
|
173 |
+
if not args.verbose:
|
174 |
+
shutil.rmtree(save_dir)
|
175 |
+
|
176 |
+
def save_uploaded_file(file, filename):
|
177 |
+
unique_filename = str(uuid.uuid4()) + "_" + filename
|
178 |
+
file_path = os.path.join(TEMP_DIR.name, unique_filename)
|
179 |
+
file.save(file_path)
|
180 |
+
return file_path
|
181 |
+
|
182 |
+
client = OpenAI(api_key="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")
|
183 |
+
|
184 |
+
def translate_text(text, target_language):
|
185 |
+
response = client.chat.completions.create(
|
186 |
+
model="gpt-4-0125-preview",
|
187 |
+
messages=[
|
188 |
+
{"role": "system", "content": "You are a helpful assistant."},
|
189 |
+
{"role": "user", "content": f"Translate the following text into {target_language} Completely: {text}\n"},
|
190 |
+
],
|
191 |
+
max_tokens=len(text),
|
192 |
+
temperature=0.3,
|
193 |
+
)
|
194 |
+
return response
|
195 |
+
|
196 |
+
@app.route("/run", methods=['POST'])
|
197 |
+
def generate_video():
|
198 |
+
if request.method == 'POST':
|
199 |
+
source_image = request.files['source_image']
|
200 |
+
text_prompt = request.form['text_prompt']
|
201 |
+
voice_cloning = request.form.get('voice_cloning', 'no')
|
202 |
+
target_language = request.form.get('target_language', 'English')
|
203 |
+
|
204 |
+
if target_language != 'English':
|
205 |
+
response = translate_text(text_prompt, target_language)
|
206 |
+
text_prompt = response.choices[0].message.content.strip()
|
207 |
+
print('text_prompt',text_prompt)
|
208 |
+
|
209 |
+
source_image_path = save_uploaded_file(source_image, 'source_image.png')
|
210 |
+
print(source_image_path)
|
211 |
+
|
212 |
+
if voice_cloning == 'no':
|
213 |
+
response = client.audio.speech.create(model="tts-1-hd",
|
214 |
+
voice="onyx",
|
215 |
+
input = text_prompt)
|
216 |
+
|
217 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", prefix="text_to_speech_", delete=False) as temp_file:
|
218 |
+
driven_audio_path = temp_file.name
|
219 |
+
|
220 |
+
response.write_to_file(driven_audio_path)
|
221 |
+
|
222 |
+
elif voice_cloning == 'yes':
|
223 |
+
user_voice = request.files['user_voice']
|
224 |
+
|
225 |
+
with tempfile.NamedTemporaryFile(suffix=".wav", prefix="user_voice_", delete=False) as temp_file:
|
226 |
+
user_voice_path = temp_file.name
|
227 |
+
user_voice.save(user_voice_path)
|
228 |
+
print('user_voice_path',user_voice_path)
|
229 |
+
|
230 |
+
set_api_key("92e149985ea2732b4359c74346c3daee")
|
231 |
+
voice = clone(name = "User Cloned Voice", # Not Required used to store with this name on elevenlabs server
|
232 |
+
files = [user_voice_path] )
|
233 |
+
|
234 |
+
audio = generate(text = text_prompt, voice = voice, model = "eleven_multilingual_v2")
|
235 |
+
with tempfile.NamedTemporaryFile(suffix=".mp3", prefix="cloned_audio_", delete=False) as temp_file:
|
236 |
+
driven_audio_path = temp_file.name
|
237 |
+
elevenlabs.save(audio, driven_audio_path)
|
238 |
+
|
239 |
+
save_dir = tempfile.mkdtemp()
|
240 |
+
result_folder = os.path.join(save_dir, "results")
|
241 |
+
os.makedirs(result_folder, exist_ok=True)
|
242 |
+
|
243 |
+
# Example of using the class with some hypothetical paths
|
244 |
+
args = AnimationConfig(driven_audio_path=driven_audio_path, source_image_path=source_image_path, result_folder=result_folder)
|
245 |
+
|
246 |
+
if torch.cuda.is_available() and not args.cpu:
|
247 |
+
args.device = "cuda"
|
248 |
+
else:
|
249 |
+
args.device = "cpu"
|
250 |
+
|
251 |
+
generation_thread = threading.Thread(target=main, args=(args,))
|
252 |
+
app.config['generation_thread'] = generation_thread
|
253 |
+
generation_thread.start()
|
254 |
+
response_data = {"message": "Video generation started",
|
255 |
+
"process_id": generation_thread.ident}
|
256 |
+
|
257 |
+
return jsonify(response_data)
|
258 |
+
# base64_video = main(args)
|
259 |
+
# return jsonify({"base64_video": base64_video})
|
260 |
+
|
261 |
+
#else:
|
262 |
+
# return 'Unsupported HTTP method', 405
|
263 |
+
|
264 |
+
@app.route("/status", methods=["GET"])
|
265 |
+
def check_generation_status():
|
266 |
+
response = {"base64_video": "", "status": ""}
|
267 |
+
process_id = request.args.get('process_id', None)
|
268 |
+
|
269 |
+
# process_id is required to check the status for that specific process
|
270 |
+
if process_id:
|
271 |
+
generation_thread = app.config.get('generation_thread')
|
272 |
+
if generation_thread and generation_thread.ident == int(process_id) and generation_thread.is_alive():
|
273 |
+
return jsonify({"status": "in_progress"}), 200
|
274 |
+
elif app.config.get('temp_response'):
|
275 |
+
# app.config['temp_response']['status'] = 'completed'
|
276 |
+
final_response = app.config['temp_response']
|
277 |
+
response["base64_video"] = final_response
|
278 |
+
response["status"] = "completed"
|
279 |
+
return jsonify(response)
|
280 |
+
return jsonify({"error":"No process id provided"})
|
281 |
+
|
282 |
+
@app.route("/health", methods=["GET"])
|
283 |
+
def health_status():
|
284 |
+
response = {"online": "true"}
|
285 |
+
return jsonify(response)
|
286 |
+
if __name__ == '__main__':
|
287 |
+
app.run(debug=True)
|
app_sadtalker.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os, sys
|
2 |
+
import gradio as gr
|
3 |
+
from src.gradio_demo import SadTalker
|
4 |
+
|
5 |
+
|
6 |
+
try:
|
7 |
+
import webui # in webui
|
8 |
+
in_webui = True
|
9 |
+
except:
|
10 |
+
in_webui = False
|
11 |
+
|
12 |
+
|
13 |
+
def toggle_audio_file(choice):
|
14 |
+
if choice == False:
|
15 |
+
return gr.update(visible=True), gr.update(visible=False)
|
16 |
+
else:
|
17 |
+
return gr.update(visible=False), gr.update(visible=True)
|
18 |
+
|
19 |
+
def ref_video_fn(path_of_ref_video):
|
20 |
+
if path_of_ref_video is not None:
|
21 |
+
return gr.update(value=True)
|
22 |
+
else:
|
23 |
+
return gr.update(value=False)
|
24 |
+
|
25 |
+
def sadtalker_demo(checkpoint_path='checkpoints', config_path='src/config', warpfn=None):
|
26 |
+
|
27 |
+
sad_talker = SadTalker(checkpoint_path, config_path, lazy_load=True)
|
28 |
+
|
29 |
+
with gr.Blocks(analytics_enabled=False) as sadtalker_interface:
|
30 |
+
gr.Markdown("<div align='center'> <h2> 😭 SadTalker: Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation (CVPR 2023) </span> </h2> \
|
31 |
+
<a style='font-size:18px;color: #efefef' href='https://arxiv.org/abs/2211.12194'>Arxiv</a> \
|
32 |
+
<a style='font-size:18px;color: #efefef' href='https://sadtalker.github.io'>Homepage</a> \
|
33 |
+
<a style='font-size:18px;color: #efefef' href='https://github.com/Winfredy/SadTalker'> Github </div>")
|
34 |
+
|
35 |
+
with gr.Row().style(equal_height=False):
|
36 |
+
with gr.Column(variant='panel'):
|
37 |
+
with gr.Tabs(elem_id="sadtalker_source_image"):
|
38 |
+
with gr.TabItem('Upload image'):
|
39 |
+
with gr.Row():
|
40 |
+
source_image = gr.Image(label="Source image", source="upload", type="filepath", elem_id="img2img_image").style(width=512)
|
41 |
+
|
42 |
+
with gr.Tabs(elem_id="sadtalker_driven_audio"):
|
43 |
+
with gr.TabItem('Upload OR TTS'):
|
44 |
+
with gr.Column(variant='panel'):
|
45 |
+
driven_audio = gr.Audio(label="Input audio", source="upload", type="filepath")
|
46 |
+
|
47 |
+
if sys.platform != 'win32' and not in_webui:
|
48 |
+
from src.utils.text2speech import TTSTalker
|
49 |
+
tts_talker = TTSTalker()
|
50 |
+
with gr.Column(variant='panel'):
|
51 |
+
input_text = gr.Textbox(label="Generating audio from text", lines=5, placeholder="please enter some text here, we genreate the audio from text using @Coqui.ai TTS.")
|
52 |
+
tts = gr.Button('Generate audio',elem_id="sadtalker_audio_generate", variant='primary')
|
53 |
+
tts.click(fn=tts_talker.test, inputs=[input_text], outputs=[driven_audio])
|
54 |
+
|
55 |
+
with gr.Column(variant='panel'):
|
56 |
+
with gr.Tabs(elem_id="sadtalker_checkbox"):
|
57 |
+
with gr.TabItem('Settings'):
|
58 |
+
gr.Markdown("need help? please visit our [best practice page](https://github.com/OpenTalker/SadTalker/blob/main/docs/best_practice.md) for more detials")
|
59 |
+
with gr.Column(variant='panel'):
|
60 |
+
# width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
|
61 |
+
# height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
|
62 |
+
pose_style = gr.Slider(minimum=0, maximum=46, step=1, label="Pose style", value=0) #
|
63 |
+
size_of_image = gr.Radio([256, 512], value=256, label='face model resolution', info="use 256/512 model?") #
|
64 |
+
preprocess_type = gr.Radio(['crop', 'resize','full', 'extcrop', 'extfull'], value='crop', label='preprocess', info="How to handle input image?")
|
65 |
+
is_still_mode = gr.Checkbox(label="Still Mode (fewer head motion, works with preprocess `full`)")
|
66 |
+
batch_size = gr.Slider(label="batch size in generation", step=1, maximum=10, value=2)
|
67 |
+
enhancer = gr.Checkbox(label="GFPGAN as Face enhancer")
|
68 |
+
submit = gr.Button('Generate', elem_id="sadtalker_generate", variant='primary')
|
69 |
+
|
70 |
+
with gr.Tabs(elem_id="sadtalker_genearted"):
|
71 |
+
gen_video = gr.Video(label="Generated video", format="mp4").style(width=256)
|
72 |
+
|
73 |
+
if warpfn:
|
74 |
+
submit.click(
|
75 |
+
fn=warpfn(sad_talker.test),
|
76 |
+
inputs=[source_image,
|
77 |
+
driven_audio,
|
78 |
+
preprocess_type,
|
79 |
+
is_still_mode,
|
80 |
+
enhancer,
|
81 |
+
batch_size,
|
82 |
+
size_of_image,
|
83 |
+
pose_style
|
84 |
+
],
|
85 |
+
outputs=[gen_video]
|
86 |
+
)
|
87 |
+
else:
|
88 |
+
submit.click(
|
89 |
+
fn=sad_talker.test,
|
90 |
+
inputs=[source_image,
|
91 |
+
driven_audio,
|
92 |
+
preprocess_type,
|
93 |
+
is_still_mode,
|
94 |
+
enhancer,
|
95 |
+
batch_size,
|
96 |
+
size_of_image,
|
97 |
+
pose_style
|
98 |
+
],
|
99 |
+
outputs=[gen_video]
|
100 |
+
)
|
101 |
+
|
102 |
+
return sadtalker_interface
|
103 |
+
|
104 |
+
|
105 |
+
if __name__ == "__main__":
|
106 |
+
|
107 |
+
demo = sadtalker_demo()
|
108 |
+
demo.queue()
|
109 |
+
demo.launch()
|
110 |
+
|
111 |
+
|
cog.yaml
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
build:
|
2 |
+
gpu: true
|
3 |
+
cuda: "11.3"
|
4 |
+
python_version: "3.8"
|
5 |
+
system_packages:
|
6 |
+
- "ffmpeg"
|
7 |
+
- "libgl1-mesa-glx"
|
8 |
+
- "libglib2.0-0"
|
9 |
+
python_packages:
|
10 |
+
- "torch==1.12.1"
|
11 |
+
- "torchvision==0.13.1"
|
12 |
+
- "torchaudio==0.12.1"
|
13 |
+
- "joblib==1.1.0"
|
14 |
+
- "scikit-image==0.19.3"
|
15 |
+
- "basicsr==1.4.2"
|
16 |
+
- "facexlib==0.3.0"
|
17 |
+
- "resampy==0.3.1"
|
18 |
+
- "pydub==0.25.1"
|
19 |
+
- "scipy==1.10.1"
|
20 |
+
- "kornia==0.6.8"
|
21 |
+
- "face_alignment==1.3.5"
|
22 |
+
- "imageio==2.19.3"
|
23 |
+
- "imageio-ffmpeg==0.4.7"
|
24 |
+
- "librosa==0.9.2" #
|
25 |
+
- "tqdm==4.65.0"
|
26 |
+
- "yacs==0.1.8"
|
27 |
+
- "gfpgan==1.3.8"
|
28 |
+
- "dlib-bin==19.24.1"
|
29 |
+
- "av==10.0.0"
|
30 |
+
- "trimesh==3.9.20"
|
31 |
+
run:
|
32 |
+
- mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/s3fd-619a316812.pth" "https://www.adrianbulat.com/downloads/python-fan/s3fd-619a316812.pth"
|
33 |
+
- mkdir -p /root/.cache/torch/hub/checkpoints/ && wget --output-document "/root/.cache/torch/hub/checkpoints/2DFAN4-cd938726ad.zip" "https://www.adrianbulat.com/downloads/python-fan/2DFAN4-cd938726ad.zip"
|
34 |
+
|
35 |
+
predict: "predict.py:Predictor"
|
inference.py
ADDED
@@ -0,0 +1,145 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from glob import glob
|
2 |
+
import shutil
|
3 |
+
import torch
|
4 |
+
from time import strftime
|
5 |
+
import os, sys, time
|
6 |
+
from argparse import ArgumentParser
|
7 |
+
|
8 |
+
from src.utils.preprocess import CropAndExtract
|
9 |
+
from src.test_audio2coeff import Audio2Coeff
|
10 |
+
from src.facerender.animate import AnimateFromCoeff
|
11 |
+
from src.generate_batch import get_data
|
12 |
+
from src.generate_facerender_batch import get_facerender_data
|
13 |
+
from src.utils.init_path import init_path
|
14 |
+
|
15 |
+
def main(args):
|
16 |
+
#torch.backends.cudnn.enabled = False
|
17 |
+
|
18 |
+
pic_path = args.source_image
|
19 |
+
audio_path = args.driven_audio
|
20 |
+
save_dir = os.path.join(args.result_dir, strftime("%Y_%m_%d_%H.%M.%S"))
|
21 |
+
os.makedirs(save_dir, exist_ok=True)
|
22 |
+
pose_style = args.pose_style
|
23 |
+
device = args.device
|
24 |
+
batch_size = args.batch_size
|
25 |
+
input_yaw_list = args.input_yaw
|
26 |
+
input_pitch_list = args.input_pitch
|
27 |
+
input_roll_list = args.input_roll
|
28 |
+
ref_eyeblink = args.ref_eyeblink
|
29 |
+
ref_pose = args.ref_pose
|
30 |
+
|
31 |
+
current_root_path = os.path.split(sys.argv[0])[0]
|
32 |
+
|
33 |
+
sadtalker_paths = init_path(args.checkpoint_dir, os.path.join(current_root_path, 'src/config'), args.size, args.old_version, args.preprocess)
|
34 |
+
|
35 |
+
#init model
|
36 |
+
preprocess_model = CropAndExtract(sadtalker_paths, device)
|
37 |
+
|
38 |
+
audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
|
39 |
+
|
40 |
+
animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device)
|
41 |
+
|
42 |
+
#crop image and extract 3dmm from image
|
43 |
+
first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
|
44 |
+
os.makedirs(first_frame_dir, exist_ok=True)
|
45 |
+
print('3DMM Extraction for source image')
|
46 |
+
first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(pic_path, first_frame_dir, args.preprocess,\
|
47 |
+
source_image_flag=True, pic_size=args.size)
|
48 |
+
if first_coeff_path is None:
|
49 |
+
print("Can't get the coeffs of the input")
|
50 |
+
return
|
51 |
+
|
52 |
+
if ref_eyeblink is not None:
|
53 |
+
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[0]
|
54 |
+
ref_eyeblink_frame_dir = os.path.join(save_dir, ref_eyeblink_videoname)
|
55 |
+
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
|
56 |
+
print('3DMM Extraction for the reference video providing eye blinking')
|
57 |
+
ref_eyeblink_coeff_path, _, _ = preprocess_model.generate(ref_eyeblink, ref_eyeblink_frame_dir, args.preprocess, source_image_flag=False)
|
58 |
+
else:
|
59 |
+
ref_eyeblink_coeff_path=None
|
60 |
+
|
61 |
+
if ref_pose is not None:
|
62 |
+
if ref_pose == ref_eyeblink:
|
63 |
+
ref_pose_coeff_path = ref_eyeblink_coeff_path
|
64 |
+
else:
|
65 |
+
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
|
66 |
+
ref_pose_frame_dir = os.path.join(save_dir, ref_pose_videoname)
|
67 |
+
os.makedirs(ref_pose_frame_dir, exist_ok=True)
|
68 |
+
print('3DMM Extraction for the reference video providing pose')
|
69 |
+
ref_pose_coeff_path, _, _ = preprocess_model.generate(ref_pose, ref_pose_frame_dir, args.preprocess, source_image_flag=False)
|
70 |
+
else:
|
71 |
+
ref_pose_coeff_path=None
|
72 |
+
|
73 |
+
#audio2ceoff
|
74 |
+
batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path, still=args.still)
|
75 |
+
coeff_path = audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
|
76 |
+
|
77 |
+
# 3dface render
|
78 |
+
if args.face3dvis:
|
79 |
+
from src.face3d.visualize import gen_composed_video
|
80 |
+
gen_composed_video(args, device, first_coeff_path, coeff_path, audio_path, os.path.join(save_dir, '3dface.mp4'))
|
81 |
+
|
82 |
+
#coeff2video
|
83 |
+
data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path,
|
84 |
+
batch_size, input_yaw_list, input_pitch_list, input_roll_list,
|
85 |
+
expression_scale=args.expression_scale, still_mode=args.still, preprocess=args.preprocess, size=args.size)
|
86 |
+
|
87 |
+
result = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, \
|
88 |
+
enhancer=args.enhancer, background_enhancer=args.background_enhancer, preprocess=args.preprocess, img_size=args.size)
|
89 |
+
|
90 |
+
shutil.move(result, save_dir+'.mp4')
|
91 |
+
print('The generated video is named:', save_dir+'.mp4')
|
92 |
+
|
93 |
+
if not args.verbose:
|
94 |
+
shutil.rmtree(save_dir)
|
95 |
+
|
96 |
+
|
97 |
+
if __name__ == '__main__':
|
98 |
+
|
99 |
+
parser = ArgumentParser()
|
100 |
+
parser.add_argument("--driven_audio", default='./examples/driven_audio/voice.wav', help="path to driven audio")
|
101 |
+
parser.add_argument("--source_image", default='./examples/source_image/istockphoto-487804668-612x612.png', help="path to source image")
|
102 |
+
parser.add_argument("--ref_eyeblink", default=None, help="path to reference video providing eye blinking")
|
103 |
+
parser.add_argument("--ref_pose", default=None, help="path to reference video providing pose")
|
104 |
+
parser.add_argument("--checkpoint_dir", default='./checkpoints', help="path to output")
|
105 |
+
parser.add_argument("--result_dir", default='./results', help="path to output")
|
106 |
+
parser.add_argument("--pose_style", type=int, default=0, help="input pose style from [0, 46)")
|
107 |
+
parser.add_argument("--batch_size", type=int, default=2, help="the batch size of facerender")
|
108 |
+
parser.add_argument("--size", type=int, default=256, help="the image size of the facerender")
|
109 |
+
parser.add_argument("--expression_scale", type=float, default=1., help="the batch size of facerender")
|
110 |
+
parser.add_argument('--input_yaw', nargs='+', type=int, default=None, help="the input yaw degree of the user ")
|
111 |
+
parser.add_argument('--input_pitch', nargs='+', type=int, default=None, help="the input pitch degree of the user")
|
112 |
+
parser.add_argument('--input_roll', nargs='+', type=int, default=None, help="the input roll degree of the user")
|
113 |
+
parser.add_argument('--enhancer', type=str, default=None, help="Face enhancer, [gfpgan, RestoreFormer]")
|
114 |
+
parser.add_argument('--background_enhancer', type=str, default=None, help="background enhancer, [realesrgan]")
|
115 |
+
parser.add_argument("--cpu", dest="cpu", action="store_true")
|
116 |
+
parser.add_argument("--face3dvis", action="store_true", help="generate 3d face and 3d landmarks")
|
117 |
+
parser.add_argument("--still", action="store_true", help="can crop back to the original videos for the full body aniamtion")
|
118 |
+
parser.add_argument("--preprocess", default='crop', choices=['crop', 'extcrop', 'resize', 'full', 'extfull'], help="how to preprocess the images" )
|
119 |
+
parser.add_argument("--verbose",action="store_true", help="saving the intermedia output or not" )
|
120 |
+
parser.add_argument("--old_version",action="store_true", help="use the pth other than safetensor version" )
|
121 |
+
|
122 |
+
|
123 |
+
# net structure and parameters
|
124 |
+
parser.add_argument('--net_recon', type=str, default='resnet50', choices=['resnet18', 'resnet34', 'resnet50'], help='useless')
|
125 |
+
parser.add_argument('--init_path', type=str, default=None, help='Useless')
|
126 |
+
parser.add_argument('--use_last_fc',default=False, help='zero initialize the last fc')
|
127 |
+
parser.add_argument('--bfm_folder', type=str, default='./checkpoints/BFM_Fitting/')
|
128 |
+
parser.add_argument('--bfm_model', type=str, default='BFM_model_front.mat', help='bfm model')
|
129 |
+
|
130 |
+
# default renderer parameters
|
131 |
+
parser.add_argument('--focal', type=float, default=1015.)
|
132 |
+
parser.add_argument('--center', type=float, default=112.)
|
133 |
+
parser.add_argument('--camera_d', type=float, default=10.)
|
134 |
+
parser.add_argument('--z_near', type=float, default=5.)
|
135 |
+
parser.add_argument('--z_far', type=float, default=15.)
|
136 |
+
|
137 |
+
args = parser.parse_args()
|
138 |
+
|
139 |
+
if torch.cuda.is_available() and not args.cpu:
|
140 |
+
args.device = "cuda"
|
141 |
+
else:
|
142 |
+
args.device = "cpu"
|
143 |
+
|
144 |
+
main(args)
|
145 |
+
|
launcher.py
ADDED
@@ -0,0 +1,204 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# this scripts installs necessary requirements and launches main program in webui.py
|
2 |
+
# borrow from : https://github.com/AUTOMATIC1111/stable-diffusion-webui/blob/master/launch.py
|
3 |
+
import subprocess
|
4 |
+
import os
|
5 |
+
import sys
|
6 |
+
import importlib.util
|
7 |
+
import shlex
|
8 |
+
import platform
|
9 |
+
import json
|
10 |
+
|
11 |
+
python = sys.executable
|
12 |
+
git = os.environ.get('GIT', "git")
|
13 |
+
index_url = os.environ.get('INDEX_URL', "")
|
14 |
+
stored_commit_hash = None
|
15 |
+
skip_install = False
|
16 |
+
dir_repos = "repositories"
|
17 |
+
script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
|
18 |
+
|
19 |
+
if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
|
20 |
+
os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'
|
21 |
+
|
22 |
+
|
23 |
+
def check_python_version():
|
24 |
+
is_windows = platform.system() == "Windows"
|
25 |
+
major = sys.version_info.major
|
26 |
+
minor = sys.version_info.minor
|
27 |
+
micro = sys.version_info.micro
|
28 |
+
|
29 |
+
if is_windows:
|
30 |
+
supported_minors = [10]
|
31 |
+
else:
|
32 |
+
supported_minors = [7, 8, 9, 10, 11]
|
33 |
+
|
34 |
+
if not (major == 3 and minor in supported_minors):
|
35 |
+
|
36 |
+
raise (f"""
|
37 |
+
INCOMPATIBLE PYTHON VERSION
|
38 |
+
This program is tested with 3.10.6 Python, but you have {major}.{minor}.{micro}.
|
39 |
+
If you encounter an error with "RuntimeError: Couldn't install torch." message,
|
40 |
+
or any other error regarding unsuccessful package (library) installation,
|
41 |
+
please downgrade (or upgrade) to the latest version of 3.10 Python
|
42 |
+
and delete current Python and "venv" folder in WebUI's directory.
|
43 |
+
You can download 3.10 Python from here: https://www.python.org/downloads/release/python-3109/
|
44 |
+
{"Alternatively, use a binary release of WebUI: https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases" if is_windows else ""}
|
45 |
+
Use --skip-python-version-check to suppress this warning.
|
46 |
+
""")
|
47 |
+
|
48 |
+
|
49 |
+
def commit_hash():
|
50 |
+
global stored_commit_hash
|
51 |
+
|
52 |
+
if stored_commit_hash is not None:
|
53 |
+
return stored_commit_hash
|
54 |
+
|
55 |
+
try:
|
56 |
+
stored_commit_hash = run(f"{git} rev-parse HEAD").strip()
|
57 |
+
except Exception:
|
58 |
+
stored_commit_hash = "<none>"
|
59 |
+
|
60 |
+
return stored_commit_hash
|
61 |
+
|
62 |
+
|
63 |
+
def run(command, desc=None, errdesc=None, custom_env=None, live=False):
|
64 |
+
if desc is not None:
|
65 |
+
print(desc)
|
66 |
+
|
67 |
+
if live:
|
68 |
+
result = subprocess.run(command, shell=True, env=os.environ if custom_env is None else custom_env)
|
69 |
+
if result.returncode != 0:
|
70 |
+
raise RuntimeError(f"""{errdesc or 'Error running command'}.
|
71 |
+
Command: {command}
|
72 |
+
Error code: {result.returncode}""")
|
73 |
+
|
74 |
+
return ""
|
75 |
+
|
76 |
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True, env=os.environ if custom_env is None else custom_env)
|
77 |
+
|
78 |
+
if result.returncode != 0:
|
79 |
+
|
80 |
+
message = f"""{errdesc or 'Error running command'}.
|
81 |
+
Command: {command}
|
82 |
+
Error code: {result.returncode}
|
83 |
+
stdout: {result.stdout.decode(encoding="utf8", errors="ignore") if len(result.stdout)>0 else '<empty>'}
|
84 |
+
stderr: {result.stderr.decode(encoding="utf8", errors="ignore") if len(result.stderr)>0 else '<empty>'}
|
85 |
+
"""
|
86 |
+
raise RuntimeError(message)
|
87 |
+
|
88 |
+
return result.stdout.decode(encoding="utf8", errors="ignore")
|
89 |
+
|
90 |
+
|
91 |
+
def check_run(command):
|
92 |
+
result = subprocess.run(command, stdout=subprocess.PIPE, stderr=subprocess.PIPE, shell=True)
|
93 |
+
return result.returncode == 0
|
94 |
+
|
95 |
+
|
96 |
+
def is_installed(package):
|
97 |
+
try:
|
98 |
+
spec = importlib.util.find_spec(package)
|
99 |
+
except ModuleNotFoundError:
|
100 |
+
return False
|
101 |
+
|
102 |
+
return spec is not None
|
103 |
+
|
104 |
+
|
105 |
+
def repo_dir(name):
|
106 |
+
return os.path.join(script_path, dir_repos, name)
|
107 |
+
|
108 |
+
|
109 |
+
def run_python(code, desc=None, errdesc=None):
|
110 |
+
return run(f'"{python}" -c "{code}"', desc, errdesc)
|
111 |
+
|
112 |
+
|
113 |
+
def run_pip(args, desc=None):
|
114 |
+
if skip_install:
|
115 |
+
return
|
116 |
+
|
117 |
+
index_url_line = f' --index-url {index_url}' if index_url != '' else ''
|
118 |
+
return run(f'"{python}" -m pip {args} --prefer-binary{index_url_line}', desc=f"Installing {desc}", errdesc=f"Couldn't install {desc}")
|
119 |
+
|
120 |
+
|
121 |
+
def check_run_python(code):
|
122 |
+
return check_run(f'"{python}" -c "{code}"')
|
123 |
+
|
124 |
+
|
125 |
+
def git_clone(url, dir, name, commithash=None):
|
126 |
+
# TODO clone into temporary dir and move if successful
|
127 |
+
|
128 |
+
if os.path.exists(dir):
|
129 |
+
if commithash is None:
|
130 |
+
return
|
131 |
+
|
132 |
+
current_hash = run(f'"{git}" -C "{dir}" rev-parse HEAD', None, f"Couldn't determine {name}'s hash: {commithash}").strip()
|
133 |
+
if current_hash == commithash:
|
134 |
+
return
|
135 |
+
|
136 |
+
run(f'"{git}" -C "{dir}" fetch', f"Fetching updates for {name}...", f"Couldn't fetch {name}")
|
137 |
+
run(f'"{git}" -C "{dir}" checkout {commithash}', f"Checking out commit for {name} with hash: {commithash}...", f"Couldn't checkout commit {commithash} for {name}")
|
138 |
+
return
|
139 |
+
|
140 |
+
run(f'"{git}" clone "{url}" "{dir}"', f"Cloning {name} into {dir}...", f"Couldn't clone {name}")
|
141 |
+
|
142 |
+
if commithash is not None:
|
143 |
+
run(f'"{git}" -C "{dir}" checkout {commithash}', None, "Couldn't checkout {name}'s hash: {commithash}")
|
144 |
+
|
145 |
+
|
146 |
+
def git_pull_recursive(dir):
|
147 |
+
for subdir, _, _ in os.walk(dir):
|
148 |
+
if os.path.exists(os.path.join(subdir, '.git')):
|
149 |
+
try:
|
150 |
+
output = subprocess.check_output([git, '-C', subdir, 'pull', '--autostash'])
|
151 |
+
print(f"Pulled changes for repository in '{subdir}':\n{output.decode('utf-8').strip()}\n")
|
152 |
+
except subprocess.CalledProcessError as e:
|
153 |
+
print(f"Couldn't perform 'git pull' on repository in '{subdir}':\n{e.output.decode('utf-8').strip()}\n")
|
154 |
+
|
155 |
+
|
156 |
+
def run_extension_installer(extension_dir):
|
157 |
+
path_installer = os.path.join(extension_dir, "install.py")
|
158 |
+
if not os.path.isfile(path_installer):
|
159 |
+
return
|
160 |
+
|
161 |
+
try:
|
162 |
+
env = os.environ.copy()
|
163 |
+
env['PYTHONPATH'] = os.path.abspath(".")
|
164 |
+
|
165 |
+
print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
|
166 |
+
except Exception as e:
|
167 |
+
print(e, file=sys.stderr)
|
168 |
+
|
169 |
+
|
170 |
+
def prepare_environment():
|
171 |
+
global skip_install
|
172 |
+
|
173 |
+
torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113")
|
174 |
+
|
175 |
+
## check windows
|
176 |
+
if sys.platform != 'win32':
|
177 |
+
requirements_file = os.environ.get('REQS_FILE', "req.txt")
|
178 |
+
else:
|
179 |
+
requirements_file = os.environ.get('REQS_FILE', "requirements.txt")
|
180 |
+
|
181 |
+
commit = commit_hash()
|
182 |
+
|
183 |
+
print(f"Python {sys.version}")
|
184 |
+
print(f"Commit hash: {commit}")
|
185 |
+
|
186 |
+
if not is_installed("torch") or not is_installed("torchvision"):
|
187 |
+
run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)
|
188 |
+
|
189 |
+
run_pip(f"install -r \"{requirements_file}\"", "requirements for SadTalker WebUI (may take longer time in first time)")
|
190 |
+
|
191 |
+
if sys.platform != 'win32' and not is_installed('tts'):
|
192 |
+
run_pip(f"install TTS", "install TTS individually in SadTalker, which might not work on windows.")
|
193 |
+
|
194 |
+
|
195 |
+
def start():
|
196 |
+
print(f"Launching SadTalker Web UI")
|
197 |
+
from app_sadtalker import sadtalker_demo
|
198 |
+
demo = sadtalker_demo()
|
199 |
+
demo.queue()
|
200 |
+
demo.launch()
|
201 |
+
|
202 |
+
if __name__ == "__main__":
|
203 |
+
prepare_environment()
|
204 |
+
start()
|
predict.py
ADDED
@@ -0,0 +1,192 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""run bash scripts/download_models.sh first to prepare the weights file"""
|
2 |
+
import os
|
3 |
+
import shutil
|
4 |
+
from argparse import Namespace
|
5 |
+
from src.utils.preprocess import CropAndExtract
|
6 |
+
from src.test_audio2coeff import Audio2Coeff
|
7 |
+
from src.facerender.animate import AnimateFromCoeff
|
8 |
+
from src.generate_batch import get_data
|
9 |
+
from src.generate_facerender_batch import get_facerender_data
|
10 |
+
from src.utils.init_path import init_path
|
11 |
+
from cog import BasePredictor, Input, Path
|
12 |
+
|
13 |
+
checkpoints = "checkpoints"
|
14 |
+
|
15 |
+
|
16 |
+
class Predictor(BasePredictor):
|
17 |
+
def setup(self):
|
18 |
+
"""Load the model into memory to make running multiple predictions efficient"""
|
19 |
+
device = "cuda"
|
20 |
+
|
21 |
+
|
22 |
+
sadtalker_paths = init_path(checkpoints,os.path.join("src","config"))
|
23 |
+
|
24 |
+
# init model
|
25 |
+
self.preprocess_model = CropAndExtract(sadtalker_paths, device
|
26 |
+
)
|
27 |
+
|
28 |
+
self.audio_to_coeff = Audio2Coeff(
|
29 |
+
sadtalker_paths,
|
30 |
+
device,
|
31 |
+
)
|
32 |
+
|
33 |
+
self.animate_from_coeff = {
|
34 |
+
"full": AnimateFromCoeff(
|
35 |
+
sadtalker_paths,
|
36 |
+
device,
|
37 |
+
),
|
38 |
+
"others": AnimateFromCoeff(
|
39 |
+
sadtalker_paths,
|
40 |
+
device,
|
41 |
+
),
|
42 |
+
}
|
43 |
+
|
44 |
+
def predict(
|
45 |
+
self,
|
46 |
+
source_image: Path = Input(
|
47 |
+
description="Upload the source image, it can be video.mp4 or picture.png",
|
48 |
+
),
|
49 |
+
driven_audio: Path = Input(
|
50 |
+
description="Upload the driven audio, accepts .wav and .mp4 file",
|
51 |
+
),
|
52 |
+
enhancer: str = Input(
|
53 |
+
description="Choose a face enhancer",
|
54 |
+
choices=["gfpgan", "RestoreFormer"],
|
55 |
+
default="gfpgan",
|
56 |
+
),
|
57 |
+
preprocess: str = Input(
|
58 |
+
description="how to preprocess the images",
|
59 |
+
choices=["crop", "resize", "full"],
|
60 |
+
default="full",
|
61 |
+
),
|
62 |
+
ref_eyeblink: Path = Input(
|
63 |
+
description="path to reference video providing eye blinking",
|
64 |
+
default=None,
|
65 |
+
),
|
66 |
+
ref_pose: Path = Input(
|
67 |
+
description="path to reference video providing pose",
|
68 |
+
default=None,
|
69 |
+
),
|
70 |
+
still: bool = Input(
|
71 |
+
description="can crop back to the original videos for the full body aniamtion when preprocess is full",
|
72 |
+
default=True,
|
73 |
+
),
|
74 |
+
) -> Path:
|
75 |
+
"""Run a single prediction on the model"""
|
76 |
+
|
77 |
+
animate_from_coeff = (
|
78 |
+
self.animate_from_coeff["full"]
|
79 |
+
if preprocess == "full"
|
80 |
+
else self.animate_from_coeff["others"]
|
81 |
+
)
|
82 |
+
|
83 |
+
args = load_default()
|
84 |
+
args.pic_path = str(source_image)
|
85 |
+
args.audio_path = str(driven_audio)
|
86 |
+
device = "cuda"
|
87 |
+
args.still = still
|
88 |
+
args.ref_eyeblink = None if ref_eyeblink is None else str(ref_eyeblink)
|
89 |
+
args.ref_pose = None if ref_pose is None else str(ref_pose)
|
90 |
+
|
91 |
+
# crop image and extract 3dmm from image
|
92 |
+
results_dir = "results"
|
93 |
+
if os.path.exists(results_dir):
|
94 |
+
shutil.rmtree(results_dir)
|
95 |
+
os.makedirs(results_dir)
|
96 |
+
first_frame_dir = os.path.join(results_dir, "first_frame_dir")
|
97 |
+
os.makedirs(first_frame_dir)
|
98 |
+
|
99 |
+
print("3DMM Extraction for source image")
|
100 |
+
first_coeff_path, crop_pic_path, crop_info = self.preprocess_model.generate(
|
101 |
+
args.pic_path, first_frame_dir, preprocess, source_image_flag=True
|
102 |
+
)
|
103 |
+
if first_coeff_path is None:
|
104 |
+
print("Can't get the coeffs of the input")
|
105 |
+
return
|
106 |
+
|
107 |
+
if ref_eyeblink is not None:
|
108 |
+
ref_eyeblink_videoname = os.path.splitext(os.path.split(ref_eyeblink)[-1])[
|
109 |
+
0
|
110 |
+
]
|
111 |
+
ref_eyeblink_frame_dir = os.path.join(results_dir, ref_eyeblink_videoname)
|
112 |
+
os.makedirs(ref_eyeblink_frame_dir, exist_ok=True)
|
113 |
+
print("3DMM Extraction for the reference video providing eye blinking")
|
114 |
+
ref_eyeblink_coeff_path, _, _ = self.preprocess_model.generate(
|
115 |
+
ref_eyeblink, ref_eyeblink_frame_dir
|
116 |
+
)
|
117 |
+
else:
|
118 |
+
ref_eyeblink_coeff_path = None
|
119 |
+
|
120 |
+
if ref_pose is not None:
|
121 |
+
if ref_pose == ref_eyeblink:
|
122 |
+
ref_pose_coeff_path = ref_eyeblink_coeff_path
|
123 |
+
else:
|
124 |
+
ref_pose_videoname = os.path.splitext(os.path.split(ref_pose)[-1])[0]
|
125 |
+
ref_pose_frame_dir = os.path.join(results_dir, ref_pose_videoname)
|
126 |
+
os.makedirs(ref_pose_frame_dir, exist_ok=True)
|
127 |
+
print("3DMM Extraction for the reference video providing pose")
|
128 |
+
ref_pose_coeff_path, _, _ = self.preprocess_model.generate(
|
129 |
+
ref_pose, ref_pose_frame_dir
|
130 |
+
)
|
131 |
+
else:
|
132 |
+
ref_pose_coeff_path = None
|
133 |
+
|
134 |
+
# audio2ceoff
|
135 |
+
batch = get_data(
|
136 |
+
first_coeff_path,
|
137 |
+
args.audio_path,
|
138 |
+
device,
|
139 |
+
ref_eyeblink_coeff_path,
|
140 |
+
still=still,
|
141 |
+
)
|
142 |
+
coeff_path = self.audio_to_coeff.generate(
|
143 |
+
batch, results_dir, args.pose_style, ref_pose_coeff_path
|
144 |
+
)
|
145 |
+
# coeff2video
|
146 |
+
print("coeff2video")
|
147 |
+
data = get_facerender_data(
|
148 |
+
coeff_path,
|
149 |
+
crop_pic_path,
|
150 |
+
first_coeff_path,
|
151 |
+
args.audio_path,
|
152 |
+
args.batch_size,
|
153 |
+
args.input_yaw,
|
154 |
+
args.input_pitch,
|
155 |
+
args.input_roll,
|
156 |
+
expression_scale=args.expression_scale,
|
157 |
+
still_mode=still,
|
158 |
+
preprocess=preprocess,
|
159 |
+
)
|
160 |
+
animate_from_coeff.generate(
|
161 |
+
data, results_dir, args.pic_path, crop_info,
|
162 |
+
enhancer=enhancer, background_enhancer=args.background_enhancer,
|
163 |
+
preprocess=preprocess)
|
164 |
+
|
165 |
+
output = "/tmp/out.mp4"
|
166 |
+
mp4_path = os.path.join(results_dir, [f for f in os.listdir(results_dir) if "enhanced.mp4" in f][0])
|
167 |
+
shutil.copy(mp4_path, output)
|
168 |
+
|
169 |
+
return Path(output)
|
170 |
+
|
171 |
+
|
172 |
+
def load_default():
|
173 |
+
return Namespace(
|
174 |
+
pose_style=0,
|
175 |
+
batch_size=2,
|
176 |
+
expression_scale=1.0,
|
177 |
+
input_yaw=None,
|
178 |
+
input_pitch=None,
|
179 |
+
input_roll=None,
|
180 |
+
background_enhancer=None,
|
181 |
+
face3dvis=False,
|
182 |
+
net_recon="resnet50",
|
183 |
+
init_path=None,
|
184 |
+
use_last_fc=False,
|
185 |
+
bfm_folder="./src/config/",
|
186 |
+
bfm_model="BFM_model_front.mat",
|
187 |
+
focal=1015.0,
|
188 |
+
center=112.0,
|
189 |
+
camera_d=10.0,
|
190 |
+
z_near=5.0,
|
191 |
+
z_far=15.0,
|
192 |
+
)
|
quick_demo.ipynb
ADDED
@@ -0,0 +1,221 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"attachments": {},
|
5 |
+
"cell_type": "markdown",
|
6 |
+
"metadata": {
|
7 |
+
"id": "M74Gs_TjYl_B"
|
8 |
+
},
|
9 |
+
"source": [
|
10 |
+
"[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/Winfredy/SadTalker/blob/main/quick_demo.ipynb)"
|
11 |
+
]
|
12 |
+
},
|
13 |
+
{
|
14 |
+
"attachments": {},
|
15 |
+
"cell_type": "markdown",
|
16 |
+
"metadata": {
|
17 |
+
"id": "view-in-github"
|
18 |
+
},
|
19 |
+
"source": [
|
20 |
+
"### SadTalker:Learning Realistic 3D Motion Coefficients for Stylized Audio-Driven Single Image Talking Face Animation \n",
|
21 |
+
"\n",
|
22 |
+
"[arxiv](https://arxiv.org/abs/2211.12194) | [project](https://sadtalker.github.io) | [Github](https://github.com/Winfredy/SadTalker)\n",
|
23 |
+
"\n",
|
24 |
+
"Wenxuan Zhang, Xiaodong Cun, Xuan Wang, Yong Zhang, Xi Shen, Yu Guo, Ying Shan, Fei Wang.\n",
|
25 |
+
"\n",
|
26 |
+
"Xi'an Jiaotong University, Tencent AI Lab, Ant Group\n",
|
27 |
+
"\n",
|
28 |
+
"CVPR 2023\n",
|
29 |
+
"\n",
|
30 |
+
"TL;DR: A realistic and stylized talking head video generation method from a single image and audio\n"
|
31 |
+
]
|
32 |
+
},
|
33 |
+
{
|
34 |
+
"attachments": {},
|
35 |
+
"cell_type": "markdown",
|
36 |
+
"metadata": {
|
37 |
+
"id": "kA89DV-sKS4i"
|
38 |
+
},
|
39 |
+
"source": [
|
40 |
+
"Installation (around 5 mins)"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": null,
|
46 |
+
"metadata": {
|
47 |
+
"id": "qJ4CplXsYl_E"
|
48 |
+
},
|
49 |
+
"outputs": [],
|
50 |
+
"source": [
|
51 |
+
"### make sure that CUDA is available in Edit -> Nootbook settings -> GPU\n",
|
52 |
+
"!nvidia-smi --query-gpu=name,memory.total,memory.free --format=csv,noheader"
|
53 |
+
]
|
54 |
+
},
|
55 |
+
{
|
56 |
+
"cell_type": "code",
|
57 |
+
"execution_count": null,
|
58 |
+
"metadata": {
|
59 |
+
"id": "Mdq6j4E5KQAR"
|
60 |
+
},
|
61 |
+
"outputs": [],
|
62 |
+
"source": [
|
63 |
+
"!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.8 2\n",
|
64 |
+
"!update-alternatives --install /usr/local/bin/python3 python3 /usr/bin/python3.9 1\n",
|
65 |
+
"!sudo apt install python3.8\n",
|
66 |
+
"\n",
|
67 |
+
"!sudo apt-get install python3.8-distutils\n",
|
68 |
+
"\n",
|
69 |
+
"!python --version\n",
|
70 |
+
"\n",
|
71 |
+
"!apt-get update\n",
|
72 |
+
"\n",
|
73 |
+
"!apt install software-properties-common\n",
|
74 |
+
"\n",
|
75 |
+
"!sudo dpkg --remove --force-remove-reinstreq python3-pip python3-setuptools python3-wheel\n",
|
76 |
+
"\n",
|
77 |
+
"!apt-get install python3-pip\n",
|
78 |
+
"\n",
|
79 |
+
"print('Git clone project and install requirements...')\n",
|
80 |
+
"!git clone https://github.com/Winfredy/SadTalker &> /dev/null\n",
|
81 |
+
"%cd SadTalker\n",
|
82 |
+
"!export PYTHONPATH=/content/SadTalker:$PYTHONPATH\n",
|
83 |
+
"!python3.8 -m pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 torchaudio==0.12.1 --extra-index-url https://download.pytorch.org/whl/cu113\n",
|
84 |
+
"!apt update\n",
|
85 |
+
"!apt install ffmpeg &> /dev/null\n",
|
86 |
+
"!python3.8 -m pip install -r requirements.txt"
|
87 |
+
]
|
88 |
+
},
|
89 |
+
{
|
90 |
+
"attachments": {},
|
91 |
+
"cell_type": "markdown",
|
92 |
+
"metadata": {
|
93 |
+
"id": "DddcKB_nKsnk"
|
94 |
+
},
|
95 |
+
"source": [
|
96 |
+
"Download models (1 mins)"
|
97 |
+
]
|
98 |
+
},
|
99 |
+
{
|
100 |
+
"cell_type": "code",
|
101 |
+
"execution_count": null,
|
102 |
+
"metadata": {
|
103 |
+
"id": "eDw3_UN8K2xa"
|
104 |
+
},
|
105 |
+
"outputs": [],
|
106 |
+
"source": [
|
107 |
+
"print('Download pre-trained models...')\n",
|
108 |
+
"!rm -rf checkpoints\n",
|
109 |
+
"!bash scripts/download_models.sh"
|
110 |
+
]
|
111 |
+
},
|
112 |
+
{
|
113 |
+
"cell_type": "code",
|
114 |
+
"execution_count": null,
|
115 |
+
"metadata": {
|
116 |
+
"id": "kK7DYeo7Yl_H"
|
117 |
+
},
|
118 |
+
"outputs": [],
|
119 |
+
"source": [
|
120 |
+
"# borrow from makeittalk\n",
|
121 |
+
"import ipywidgets as widgets\n",
|
122 |
+
"import glob\n",
|
123 |
+
"import matplotlib.pyplot as plt\n",
|
124 |
+
"print(\"Choose the image name to animate: (saved in folder 'examples/')\")\n",
|
125 |
+
"img_list = glob.glob1('examples/source_image', '*.png')\n",
|
126 |
+
"img_list.sort()\n",
|
127 |
+
"img_list = [item.split('.')[0] for item in img_list]\n",
|
128 |
+
"default_head_name = widgets.Dropdown(options=img_list, value='full3')\n",
|
129 |
+
"def on_change(change):\n",
|
130 |
+
" if change['type'] == 'change' and change['name'] == 'value':\n",
|
131 |
+
" plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
|
132 |
+
" plt.axis('off')\n",
|
133 |
+
" plt.show()\n",
|
134 |
+
"default_head_name.observe(on_change)\n",
|
135 |
+
"display(default_head_name)\n",
|
136 |
+
"plt.imshow(plt.imread('examples/source_image/{}.png'.format(default_head_name.value)))\n",
|
137 |
+
"plt.axis('off')\n",
|
138 |
+
"plt.show()"
|
139 |
+
]
|
140 |
+
},
|
141 |
+
{
|
142 |
+
"attachments": {},
|
143 |
+
"cell_type": "markdown",
|
144 |
+
"metadata": {
|
145 |
+
"id": "-khNZcnGK4UK"
|
146 |
+
},
|
147 |
+
"source": [
|
148 |
+
"Animation"
|
149 |
+
]
|
150 |
+
},
|
151 |
+
{
|
152 |
+
"cell_type": "code",
|
153 |
+
"execution_count": null,
|
154 |
+
"metadata": {
|
155 |
+
"id": "ToBlDusjK5sS"
|
156 |
+
},
|
157 |
+
"outputs": [],
|
158 |
+
"source": [
|
159 |
+
"# selected audio from exmaple/driven_audio\n",
|
160 |
+
"img = 'examples/source_image/{}.png'.format(default_head_name.value)\n",
|
161 |
+
"print(img)\n",
|
162 |
+
"!python3.8 inference.py --driven_audio ./examples/driven_audio/RD_Radio31_000.wav \\\n",
|
163 |
+
" --source_image {img} \\\n",
|
164 |
+
" --result_dir ./results --still --preprocess full --enhancer gfpgan"
|
165 |
+
]
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"execution_count": null,
|
170 |
+
"metadata": {
|
171 |
+
"id": "fAjwGmKKYl_I"
|
172 |
+
},
|
173 |
+
"outputs": [],
|
174 |
+
"source": [
|
175 |
+
"# visualize code from makeittalk\n",
|
176 |
+
"from IPython.display import HTML\n",
|
177 |
+
"from base64 import b64encode\n",
|
178 |
+
"import os, sys\n",
|
179 |
+
"\n",
|
180 |
+
"# get the last from results\n",
|
181 |
+
"\n",
|
182 |
+
"results = sorted(os.listdir('./results/'))\n",
|
183 |
+
"\n",
|
184 |
+
"mp4_name = glob.glob('./results/*.mp4')[0]\n",
|
185 |
+
"\n",
|
186 |
+
"mp4 = open('{}'.format(mp4_name),'rb').read()\n",
|
187 |
+
"data_url = \"data:video/mp4;base64,\" + b64encode(mp4).decode()\n",
|
188 |
+
"\n",
|
189 |
+
"print('Display animation: {}'.format(mp4_name), file=sys.stderr)\n",
|
190 |
+
"display(HTML(\"\"\"\n",
|
191 |
+
" <video width=256 controls>\n",
|
192 |
+
" <source src=\"%s\" type=\"video/mp4\">\n",
|
193 |
+
" </video>\n",
|
194 |
+
" \"\"\" % data_url))\n"
|
195 |
+
]
|
196 |
+
}
|
197 |
+
],
|
198 |
+
"metadata": {
|
199 |
+
"accelerator": "GPU",
|
200 |
+
"colab": {
|
201 |
+
"provenance": []
|
202 |
+
},
|
203 |
+
"gpuClass": "standard",
|
204 |
+
"kernelspec": {
|
205 |
+
"display_name": "base",
|
206 |
+
"language": "python",
|
207 |
+
"name": "python3"
|
208 |
+
},
|
209 |
+
"language_info": {
|
210 |
+
"name": "python",
|
211 |
+
"version": "3.9.7"
|
212 |
+
},
|
213 |
+
"vscode": {
|
214 |
+
"interpreter": {
|
215 |
+
"hash": "db5031b3636a3f037ea48eb287fd3d023feb9033aefc2a9652a92e470fb0851b"
|
216 |
+
}
|
217 |
+
}
|
218 |
+
},
|
219 |
+
"nbformat": 4,
|
220 |
+
"nbformat_minor": 0
|
221 |
+
}
|
req.txt
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
llvmlite==0.38.1
|
2 |
+
numpy==1.21.6
|
3 |
+
face_alignment==1.3.5
|
4 |
+
imageio==2.19.3
|
5 |
+
imageio-ffmpeg==0.4.7
|
6 |
+
librosa==0.10.0.post2
|
7 |
+
numba==0.55.1
|
8 |
+
resampy==0.3.1
|
9 |
+
pydub==0.25.1
|
10 |
+
scipy==1.10.1
|
11 |
+
kornia==0.6.8
|
12 |
+
tqdm
|
13 |
+
yacs==0.1.8
|
14 |
+
pyyaml
|
15 |
+
joblib==1.1.0
|
16 |
+
scikit-image==0.19.3
|
17 |
+
basicsr==1.4.2
|
18 |
+
facexlib==0.3.0
|
19 |
+
gradio
|
20 |
+
gfpgan
|
21 |
+
av
|
22 |
+
safetensors
|
requirements.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.23.4
|
2 |
+
face_alignment==1.3.5
|
3 |
+
imageio==2.19.3
|
4 |
+
imageio-ffmpeg==0.4.7
|
5 |
+
librosa==0.9.2 #
|
6 |
+
numba
|
7 |
+
resampy==0.3.1
|
8 |
+
pydub==0.25.1
|
9 |
+
scipy==1.10.1
|
10 |
+
kornia==0.6.8
|
11 |
+
tqdm
|
12 |
+
yacs==0.1.8
|
13 |
+
pyyaml
|
14 |
+
joblib==1.1.0
|
15 |
+
scikit-image==0.19.3
|
16 |
+
basicsr==1.4.2
|
17 |
+
facexlib==0.3.0
|
18 |
+
torchvision==0.12.0
|
19 |
+
elevenlabs
|
20 |
+
gradio
|
21 |
+
gfpgan
|
22 |
+
av
|
23 |
+
safetensors
|
24 |
+
openai
|
25 |
+
torch
|
26 |
+
moviepy
|
27 |
+
flask
|
28 |
+
gunicorn
|
29 |
+
flask_cors
|
30 |
+
flask_swagger_ui
|
requirements3d.txt
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.23.4
|
2 |
+
face_alignment==1.3.5
|
3 |
+
imageio==2.19.3
|
4 |
+
imageio-ffmpeg==0.4.7
|
5 |
+
librosa==0.9.2 #
|
6 |
+
numba
|
7 |
+
resampy==0.3.1
|
8 |
+
pydub==0.25.1
|
9 |
+
scipy==1.5.3
|
10 |
+
kornia==0.6.8
|
11 |
+
tqdm
|
12 |
+
yacs==0.1.8
|
13 |
+
pyyaml
|
14 |
+
joblib==1.1.0
|
15 |
+
scikit-image==0.19.3
|
16 |
+
basicsr==1.4.2
|
17 |
+
facexlib==0.3.0
|
18 |
+
trimesh==3.9.20
|
19 |
+
gradio
|
20 |
+
gfpgan
|
21 |
+
safetensors
|
scripts/download_models.sh
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
# Create the checkpoints directory
|
3 |
+
mkdir -p ./checkpoints
|
4 |
+
# Download model files into the checkpoints directory
|
5 |
+
wget -nc "https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00109-model.pth.tar" -O "./checkpoints/mapping_00109-model.pth.tar"
|
6 |
+
wget -nc "https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/mapping_00229-model.pth.tar" -O "./checkpoints/mapping_00229-model.pth.tar"
|
7 |
+
wget -nc "https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_256.safetensors" -O "./checkpoints/SadTalker_V0.0.2_256.safetensors"
|
8 |
+
wget -nc "https://github.com/OpenTalker/SadTalker/releases/download/v0.0.2-rc/SadTalker_V0.0.2_512.safetensors" -O "./checkpoints/SadTalker_V0.0.2_512.safetensors"
|
9 |
+
# Create the weights directory
|
10 |
+
mkdir -p -v "./gfpgan/weights"
|
11 |
+
# Download enhancer model files into the weights directory
|
12 |
+
wget -nc "https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth" -O "./gfpgan/weights/alignment_WFLW_4HG.pth"
|
13 |
+
wget -nc "https://github.com/xinntao/facexlib/releases/download/v0.1.0/detection_Resnet50_Final.pth" -O "./gfpgan/weights/detection_Resnet50_Final.pth"
|
14 |
+
wget -nc "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth" -O "./gfpgan/weights/GFPGANv1.4.pth"
|
15 |
+
wget -nc "https://github.com/xinntao/facexlib/releases/download/v0.2.2/parsing_parsenet.pth" -O "./gfpgan/weights/parsing_parsenet.pth"
|
src/audio2exp_models/audio2exp.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import torch
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
|
6 |
+
class Audio2Exp(nn.Module):
|
7 |
+
def __init__(self, netG, cfg, device, prepare_training_loss=False):
|
8 |
+
super(Audio2Exp, self).__init__()
|
9 |
+
self.cfg = cfg
|
10 |
+
self.device = device
|
11 |
+
self.netG = netG.to(device)
|
12 |
+
|
13 |
+
def test(self, batch):
|
14 |
+
|
15 |
+
mel_input = batch['indiv_mels'] # bs T 1 80 16
|
16 |
+
bs = mel_input.shape[0]
|
17 |
+
T = mel_input.shape[1]
|
18 |
+
|
19 |
+
exp_coeff_pred = []
|
20 |
+
|
21 |
+
for i in tqdm(range(0, T, 10),'audio2exp:'): # every 10 frames
|
22 |
+
|
23 |
+
current_mel_input = mel_input[:,i:i+10]
|
24 |
+
|
25 |
+
#ref = batch['ref'][:, :, :64].repeat((1,current_mel_input.shape[1],1)) #bs T 64
|
26 |
+
ref = batch['ref'][:, :, :64][:, i:i+10]
|
27 |
+
ratio = batch['ratio_gt'][:, i:i+10] #bs T
|
28 |
+
|
29 |
+
audiox = current_mel_input.view(-1, 1, 80, 16) # bs*T 1 80 16
|
30 |
+
|
31 |
+
curr_exp_coeff_pred = self.netG(audiox, ref, ratio) # bs T 64
|
32 |
+
|
33 |
+
exp_coeff_pred += [curr_exp_coeff_pred]
|
34 |
+
|
35 |
+
# BS x T x 64
|
36 |
+
results_dict = {
|
37 |
+
'exp_coeff_pred': torch.cat(exp_coeff_pred, axis=1)
|
38 |
+
}
|
39 |
+
return results_dict
|
40 |
+
|
41 |
+
|
src/audio2exp_models/networks.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
class Conv2d(nn.Module):
|
6 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, use_act = True, *args, **kwargs):
|
7 |
+
super().__init__(*args, **kwargs)
|
8 |
+
self.conv_block = nn.Sequential(
|
9 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
10 |
+
nn.BatchNorm2d(cout)
|
11 |
+
)
|
12 |
+
self.act = nn.ReLU()
|
13 |
+
self.residual = residual
|
14 |
+
self.use_act = use_act
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
out = self.conv_block(x)
|
18 |
+
if self.residual:
|
19 |
+
out += x
|
20 |
+
|
21 |
+
if self.use_act:
|
22 |
+
return self.act(out)
|
23 |
+
else:
|
24 |
+
return out
|
25 |
+
|
26 |
+
class SimpleWrapperV2(nn.Module):
|
27 |
+
def __init__(self) -> None:
|
28 |
+
super().__init__()
|
29 |
+
self.audio_encoder = nn.Sequential(
|
30 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
31 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
32 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
33 |
+
|
34 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
35 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
36 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
37 |
+
|
38 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
39 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
40 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
41 |
+
|
42 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
43 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
44 |
+
|
45 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
46 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),
|
47 |
+
)
|
48 |
+
|
49 |
+
#### load the pre-trained audio_encoder
|
50 |
+
#self.audio_encoder = self.audio_encoder.to(device)
|
51 |
+
'''
|
52 |
+
wav2lip_state_dict = torch.load('/apdcephfs_cq2/share_1290939/wenxuazhang/checkpoints/wav2lip.pth')['state_dict']
|
53 |
+
state_dict = self.audio_encoder.state_dict()
|
54 |
+
|
55 |
+
for k,v in wav2lip_state_dict.items():
|
56 |
+
if 'audio_encoder' in k:
|
57 |
+
print('init:', k)
|
58 |
+
state_dict[k.replace('module.audio_encoder.', '')] = v
|
59 |
+
self.audio_encoder.load_state_dict(state_dict)
|
60 |
+
'''
|
61 |
+
|
62 |
+
self.mapping1 = nn.Linear(512+64+1, 64)
|
63 |
+
#self.mapping2 = nn.Linear(30, 64)
|
64 |
+
#nn.init.constant_(self.mapping1.weight, 0.)
|
65 |
+
nn.init.constant_(self.mapping1.bias, 0.)
|
66 |
+
|
67 |
+
def forward(self, x, ref, ratio):
|
68 |
+
x = self.audio_encoder(x).view(x.size(0), -1)
|
69 |
+
ref_reshape = ref.reshape(x.size(0), -1)
|
70 |
+
ratio = ratio.reshape(x.size(0), -1)
|
71 |
+
|
72 |
+
y = self.mapping1(torch.cat([x, ref_reshape, ratio], dim=1))
|
73 |
+
out = y.reshape(ref.shape[0], ref.shape[1], -1) #+ ref # resudial
|
74 |
+
return out
|
src/audio2pose_models/audio2pose.py
ADDED
@@ -0,0 +1,94 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from src.audio2pose_models.cvae import CVAE
|
4 |
+
from src.audio2pose_models.discriminator import PoseSequenceDiscriminator
|
5 |
+
from src.audio2pose_models.audio_encoder import AudioEncoder
|
6 |
+
|
7 |
+
class Audio2Pose(nn.Module):
|
8 |
+
def __init__(self, cfg, wav2lip_checkpoint, device='cuda'):
|
9 |
+
super().__init__()
|
10 |
+
self.cfg = cfg
|
11 |
+
self.seq_len = cfg.MODEL.CVAE.SEQ_LEN
|
12 |
+
self.latent_dim = cfg.MODEL.CVAE.LATENT_SIZE
|
13 |
+
self.device = device
|
14 |
+
|
15 |
+
self.audio_encoder = AudioEncoder(wav2lip_checkpoint, device)
|
16 |
+
self.audio_encoder.eval()
|
17 |
+
for param in self.audio_encoder.parameters():
|
18 |
+
param.requires_grad = False
|
19 |
+
|
20 |
+
self.netG = CVAE(cfg)
|
21 |
+
self.netD_motion = PoseSequenceDiscriminator(cfg)
|
22 |
+
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
|
26 |
+
batch = {}
|
27 |
+
coeff_gt = x['gt'].cuda().squeeze(0) #bs frame_len+1 73
|
28 |
+
batch['pose_motion_gt'] = coeff_gt[:, 1:, 64:70] - coeff_gt[:, :1, 64:70] #bs frame_len 6
|
29 |
+
batch['ref'] = coeff_gt[:, 0, 64:70] #bs 6
|
30 |
+
batch['class'] = x['class'].squeeze(0).cuda() # bs
|
31 |
+
indiv_mels= x['indiv_mels'].cuda().squeeze(0) # bs seq_len+1 80 16
|
32 |
+
|
33 |
+
# forward
|
34 |
+
audio_emb_list = []
|
35 |
+
audio_emb = self.audio_encoder(indiv_mels[:, 1:, :, :].unsqueeze(2)) #bs seq_len 512
|
36 |
+
batch['audio_emb'] = audio_emb
|
37 |
+
batch = self.netG(batch)
|
38 |
+
|
39 |
+
pose_motion_pred = batch['pose_motion_pred'] # bs frame_len 6
|
40 |
+
pose_gt = coeff_gt[:, 1:, 64:70].clone() # bs frame_len 6
|
41 |
+
pose_pred = coeff_gt[:, :1, 64:70] + pose_motion_pred # bs frame_len 6
|
42 |
+
|
43 |
+
batch['pose_pred'] = pose_pred
|
44 |
+
batch['pose_gt'] = pose_gt
|
45 |
+
|
46 |
+
return batch
|
47 |
+
|
48 |
+
def test(self, x):
|
49 |
+
|
50 |
+
batch = {}
|
51 |
+
ref = x['ref'] #bs 1 70
|
52 |
+
batch['ref'] = x['ref'][:,0,-6:]
|
53 |
+
batch['class'] = x['class']
|
54 |
+
bs = ref.shape[0]
|
55 |
+
|
56 |
+
indiv_mels= x['indiv_mels'] # bs T 1 80 16
|
57 |
+
indiv_mels_use = indiv_mels[:, 1:] # we regard the ref as the first frame
|
58 |
+
num_frames = x['num_frames']
|
59 |
+
num_frames = int(num_frames) - 1
|
60 |
+
|
61 |
+
#
|
62 |
+
div = num_frames//self.seq_len
|
63 |
+
re = num_frames%self.seq_len
|
64 |
+
audio_emb_list = []
|
65 |
+
pose_motion_pred_list = [torch.zeros(batch['ref'].unsqueeze(1).shape, dtype=batch['ref'].dtype,
|
66 |
+
device=batch['ref'].device)]
|
67 |
+
|
68 |
+
for i in range(div):
|
69 |
+
z = torch.randn(bs, self.latent_dim).to(ref.device)
|
70 |
+
batch['z'] = z
|
71 |
+
audio_emb = self.audio_encoder(indiv_mels_use[:, i*self.seq_len:(i+1)*self.seq_len,:,:,:]) #bs seq_len 512
|
72 |
+
batch['audio_emb'] = audio_emb
|
73 |
+
batch = self.netG.test(batch)
|
74 |
+
pose_motion_pred_list.append(batch['pose_motion_pred']) #list of bs seq_len 6
|
75 |
+
|
76 |
+
if re != 0:
|
77 |
+
z = torch.randn(bs, self.latent_dim).to(ref.device)
|
78 |
+
batch['z'] = z
|
79 |
+
audio_emb = self.audio_encoder(indiv_mels_use[:, -1*self.seq_len:,:,:,:]) #bs seq_len 512
|
80 |
+
if audio_emb.shape[1] != self.seq_len:
|
81 |
+
pad_dim = self.seq_len-audio_emb.shape[1]
|
82 |
+
pad_audio_emb = audio_emb[:, :1].repeat(1, pad_dim, 1)
|
83 |
+
audio_emb = torch.cat([pad_audio_emb, audio_emb], 1)
|
84 |
+
batch['audio_emb'] = audio_emb
|
85 |
+
batch = self.netG.test(batch)
|
86 |
+
pose_motion_pred_list.append(batch['pose_motion_pred'][:,-1*re:,:])
|
87 |
+
|
88 |
+
pose_motion_pred = torch.cat(pose_motion_pred_list, dim = 1)
|
89 |
+
batch['pose_motion_pred'] = pose_motion_pred
|
90 |
+
|
91 |
+
pose_pred = ref[:, :1, -6:] + pose_motion_pred # bs T 6
|
92 |
+
|
93 |
+
batch['pose_pred'] = pose_pred
|
94 |
+
return batch
|
src/audio2pose_models/audio_encoder.py
ADDED
@@ -0,0 +1,64 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
from torch.nn import functional as F
|
4 |
+
|
5 |
+
class Conv2d(nn.Module):
|
6 |
+
def __init__(self, cin, cout, kernel_size, stride, padding, residual=False, *args, **kwargs):
|
7 |
+
super().__init__(*args, **kwargs)
|
8 |
+
self.conv_block = nn.Sequential(
|
9 |
+
nn.Conv2d(cin, cout, kernel_size, stride, padding),
|
10 |
+
nn.BatchNorm2d(cout)
|
11 |
+
)
|
12 |
+
self.act = nn.ReLU()
|
13 |
+
self.residual = residual
|
14 |
+
|
15 |
+
def forward(self, x):
|
16 |
+
out = self.conv_block(x)
|
17 |
+
if self.residual:
|
18 |
+
out += x
|
19 |
+
return self.act(out)
|
20 |
+
|
21 |
+
class AudioEncoder(nn.Module):
|
22 |
+
def __init__(self, wav2lip_checkpoint, device):
|
23 |
+
super(AudioEncoder, self).__init__()
|
24 |
+
|
25 |
+
self.audio_encoder = nn.Sequential(
|
26 |
+
Conv2d(1, 32, kernel_size=3, stride=1, padding=1),
|
27 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
28 |
+
Conv2d(32, 32, kernel_size=3, stride=1, padding=1, residual=True),
|
29 |
+
|
30 |
+
Conv2d(32, 64, kernel_size=3, stride=(3, 1), padding=1),
|
31 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
32 |
+
Conv2d(64, 64, kernel_size=3, stride=1, padding=1, residual=True),
|
33 |
+
|
34 |
+
Conv2d(64, 128, kernel_size=3, stride=3, padding=1),
|
35 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
36 |
+
Conv2d(128, 128, kernel_size=3, stride=1, padding=1, residual=True),
|
37 |
+
|
38 |
+
Conv2d(128, 256, kernel_size=3, stride=(3, 2), padding=1),
|
39 |
+
Conv2d(256, 256, kernel_size=3, stride=1, padding=1, residual=True),
|
40 |
+
|
41 |
+
Conv2d(256, 512, kernel_size=3, stride=1, padding=0),
|
42 |
+
Conv2d(512, 512, kernel_size=1, stride=1, padding=0),)
|
43 |
+
|
44 |
+
#### load the pre-trained audio_encoder, we do not need to load wav2lip model here.
|
45 |
+
# wav2lip_state_dict = torch.load(wav2lip_checkpoint, map_location=torch.device(device))['state_dict']
|
46 |
+
# state_dict = self.audio_encoder.state_dict()
|
47 |
+
|
48 |
+
# for k,v in wav2lip_state_dict.items():
|
49 |
+
# if 'audio_encoder' in k:
|
50 |
+
# state_dict[k.replace('module.audio_encoder.', '')] = v
|
51 |
+
# self.audio_encoder.load_state_dict(state_dict)
|
52 |
+
|
53 |
+
|
54 |
+
def forward(self, audio_sequences):
|
55 |
+
# audio_sequences = (B, T, 1, 80, 16)
|
56 |
+
B = audio_sequences.size(0)
|
57 |
+
|
58 |
+
audio_sequences = torch.cat([audio_sequences[:, i] for i in range(audio_sequences.size(1))], dim=0)
|
59 |
+
|
60 |
+
audio_embedding = self.audio_encoder(audio_sequences) # B, 512, 1, 1
|
61 |
+
dim = audio_embedding.shape[1]
|
62 |
+
audio_embedding = audio_embedding.reshape((B, -1, dim, 1, 1))
|
63 |
+
|
64 |
+
return audio_embedding.squeeze(-1).squeeze(-1) #B seq_len+1 512
|
src/audio2pose_models/cvae.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
from src.audio2pose_models.res_unet import ResUnet
|
5 |
+
|
6 |
+
def class2onehot(idx, class_num):
|
7 |
+
|
8 |
+
assert torch.max(idx).item() < class_num
|
9 |
+
onehot = torch.zeros(idx.size(0), class_num).to(idx.device)
|
10 |
+
onehot.scatter_(1, idx, 1)
|
11 |
+
return onehot
|
12 |
+
|
13 |
+
class CVAE(nn.Module):
|
14 |
+
def __init__(self, cfg):
|
15 |
+
super().__init__()
|
16 |
+
encoder_layer_sizes = cfg.MODEL.CVAE.ENCODER_LAYER_SIZES
|
17 |
+
decoder_layer_sizes = cfg.MODEL.CVAE.DECODER_LAYER_SIZES
|
18 |
+
latent_size = cfg.MODEL.CVAE.LATENT_SIZE
|
19 |
+
num_classes = cfg.DATASET.NUM_CLASSES
|
20 |
+
audio_emb_in_size = cfg.MODEL.CVAE.AUDIO_EMB_IN_SIZE
|
21 |
+
audio_emb_out_size = cfg.MODEL.CVAE.AUDIO_EMB_OUT_SIZE
|
22 |
+
seq_len = cfg.MODEL.CVAE.SEQ_LEN
|
23 |
+
|
24 |
+
self.latent_size = latent_size
|
25 |
+
|
26 |
+
self.encoder = ENCODER(encoder_layer_sizes, latent_size, num_classes,
|
27 |
+
audio_emb_in_size, audio_emb_out_size, seq_len)
|
28 |
+
self.decoder = DECODER(decoder_layer_sizes, latent_size, num_classes,
|
29 |
+
audio_emb_in_size, audio_emb_out_size, seq_len)
|
30 |
+
def reparameterize(self, mu, logvar):
|
31 |
+
std = torch.exp(0.5 * logvar)
|
32 |
+
eps = torch.randn_like(std)
|
33 |
+
return mu + eps * std
|
34 |
+
|
35 |
+
def forward(self, batch):
|
36 |
+
batch = self.encoder(batch)
|
37 |
+
mu = batch['mu']
|
38 |
+
logvar = batch['logvar']
|
39 |
+
z = self.reparameterize(mu, logvar)
|
40 |
+
batch['z'] = z
|
41 |
+
return self.decoder(batch)
|
42 |
+
|
43 |
+
def test(self, batch):
|
44 |
+
'''
|
45 |
+
class_id = batch['class']
|
46 |
+
z = torch.randn([class_id.size(0), self.latent_size]).to(class_id.device)
|
47 |
+
batch['z'] = z
|
48 |
+
'''
|
49 |
+
return self.decoder(batch)
|
50 |
+
|
51 |
+
class ENCODER(nn.Module):
|
52 |
+
def __init__(self, layer_sizes, latent_size, num_classes,
|
53 |
+
audio_emb_in_size, audio_emb_out_size, seq_len):
|
54 |
+
super().__init__()
|
55 |
+
|
56 |
+
self.resunet = ResUnet()
|
57 |
+
self.num_classes = num_classes
|
58 |
+
self.seq_len = seq_len
|
59 |
+
|
60 |
+
self.MLP = nn.Sequential()
|
61 |
+
layer_sizes[0] += latent_size + seq_len*audio_emb_out_size + 6
|
62 |
+
for i, (in_size, out_size) in enumerate(zip(layer_sizes[:-1], layer_sizes[1:])):
|
63 |
+
self.MLP.add_module(
|
64 |
+
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
|
65 |
+
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
|
66 |
+
|
67 |
+
self.linear_means = nn.Linear(layer_sizes[-1], latent_size)
|
68 |
+
self.linear_logvar = nn.Linear(layer_sizes[-1], latent_size)
|
69 |
+
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
|
70 |
+
|
71 |
+
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
|
72 |
+
|
73 |
+
def forward(self, batch):
|
74 |
+
class_id = batch['class']
|
75 |
+
pose_motion_gt = batch['pose_motion_gt'] #bs seq_len 6
|
76 |
+
ref = batch['ref'] #bs 6
|
77 |
+
bs = pose_motion_gt.shape[0]
|
78 |
+
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
|
79 |
+
|
80 |
+
#pose encode
|
81 |
+
pose_emb = self.resunet(pose_motion_gt.unsqueeze(1)) #bs 1 seq_len 6
|
82 |
+
pose_emb = pose_emb.reshape(bs, -1) #bs seq_len*6
|
83 |
+
|
84 |
+
#audio mapping
|
85 |
+
print(audio_in.shape)
|
86 |
+
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
|
87 |
+
audio_out = audio_out.reshape(bs, -1)
|
88 |
+
|
89 |
+
class_bias = self.classbias[class_id] #bs latent_size
|
90 |
+
x_in = torch.cat([ref, pose_emb, audio_out, class_bias], dim=-1) #bs seq_len*(audio_emb_out_size+6)+latent_size
|
91 |
+
x_out = self.MLP(x_in)
|
92 |
+
|
93 |
+
mu = self.linear_means(x_out)
|
94 |
+
logvar = self.linear_means(x_out) #bs latent_size
|
95 |
+
|
96 |
+
batch.update({'mu':mu, 'logvar':logvar})
|
97 |
+
return batch
|
98 |
+
|
99 |
+
class DECODER(nn.Module):
|
100 |
+
def __init__(self, layer_sizes, latent_size, num_classes,
|
101 |
+
audio_emb_in_size, audio_emb_out_size, seq_len):
|
102 |
+
super().__init__()
|
103 |
+
|
104 |
+
self.resunet = ResUnet()
|
105 |
+
self.num_classes = num_classes
|
106 |
+
self.seq_len = seq_len
|
107 |
+
|
108 |
+
self.MLP = nn.Sequential()
|
109 |
+
input_size = latent_size + seq_len*audio_emb_out_size + 6
|
110 |
+
for i, (in_size, out_size) in enumerate(zip([input_size]+layer_sizes[:-1], layer_sizes)):
|
111 |
+
self.MLP.add_module(
|
112 |
+
name="L{:d}".format(i), module=nn.Linear(in_size, out_size))
|
113 |
+
if i+1 < len(layer_sizes):
|
114 |
+
self.MLP.add_module(name="A{:d}".format(i), module=nn.ReLU())
|
115 |
+
else:
|
116 |
+
self.MLP.add_module(name="sigmoid", module=nn.Sigmoid())
|
117 |
+
|
118 |
+
self.pose_linear = nn.Linear(6, 6)
|
119 |
+
self.linear_audio = nn.Linear(audio_emb_in_size, audio_emb_out_size)
|
120 |
+
|
121 |
+
self.classbias = nn.Parameter(torch.randn(self.num_classes, latent_size))
|
122 |
+
|
123 |
+
def forward(self, batch):
|
124 |
+
|
125 |
+
z = batch['z'] #bs latent_size
|
126 |
+
bs = z.shape[0]
|
127 |
+
class_id = batch['class']
|
128 |
+
ref = batch['ref'] #bs 6
|
129 |
+
audio_in = batch['audio_emb'] # bs seq_len audio_emb_in_size
|
130 |
+
#print('audio_in: ', audio_in[:, :, :10])
|
131 |
+
|
132 |
+
audio_out = self.linear_audio(audio_in) # bs seq_len audio_emb_out_size
|
133 |
+
#print('audio_out: ', audio_out[:, :, :10])
|
134 |
+
audio_out = audio_out.reshape([bs, -1]) # bs seq_len*audio_emb_out_size
|
135 |
+
class_bias = self.classbias[class_id] #bs latent_size
|
136 |
+
|
137 |
+
z = z + class_bias
|
138 |
+
x_in = torch.cat([ref, z, audio_out], dim=-1)
|
139 |
+
x_out = self.MLP(x_in) # bs layer_sizes[-1]
|
140 |
+
x_out = x_out.reshape((bs, self.seq_len, -1))
|
141 |
+
|
142 |
+
#print('x_out: ', x_out)
|
143 |
+
|
144 |
+
pose_emb = self.resunet(x_out.unsqueeze(1)) #bs 1 seq_len 6
|
145 |
+
|
146 |
+
pose_motion_pred = self.pose_linear(pose_emb.squeeze(1)) #bs seq_len 6
|
147 |
+
|
148 |
+
batch.update({'pose_motion_pred':pose_motion_pred})
|
149 |
+
return batch
|
src/audio2pose_models/discriminator.py
ADDED
@@ -0,0 +1,76 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
from torch import nn
|
4 |
+
|
5 |
+
class ConvNormRelu(nn.Module):
|
6 |
+
def __init__(self, conv_type='1d', in_channels=3, out_channels=64, downsample=False,
|
7 |
+
kernel_size=None, stride=None, padding=None, norm='BN', leaky=False):
|
8 |
+
super().__init__()
|
9 |
+
if kernel_size is None:
|
10 |
+
if downsample:
|
11 |
+
kernel_size, stride, padding = 4, 2, 1
|
12 |
+
else:
|
13 |
+
kernel_size, stride, padding = 3, 1, 1
|
14 |
+
|
15 |
+
if conv_type == '2d':
|
16 |
+
self.conv = nn.Conv2d(
|
17 |
+
in_channels,
|
18 |
+
out_channels,
|
19 |
+
kernel_size,
|
20 |
+
stride,
|
21 |
+
padding,
|
22 |
+
bias=False,
|
23 |
+
)
|
24 |
+
if norm == 'BN':
|
25 |
+
self.norm = nn.BatchNorm2d(out_channels)
|
26 |
+
elif norm == 'IN':
|
27 |
+
self.norm = nn.InstanceNorm2d(out_channels)
|
28 |
+
else:
|
29 |
+
raise NotImplementedError
|
30 |
+
elif conv_type == '1d':
|
31 |
+
self.conv = nn.Conv1d(
|
32 |
+
in_channels,
|
33 |
+
out_channels,
|
34 |
+
kernel_size,
|
35 |
+
stride,
|
36 |
+
padding,
|
37 |
+
bias=False,
|
38 |
+
)
|
39 |
+
if norm == 'BN':
|
40 |
+
self.norm = nn.BatchNorm1d(out_channels)
|
41 |
+
elif norm == 'IN':
|
42 |
+
self.norm = nn.InstanceNorm1d(out_channels)
|
43 |
+
else:
|
44 |
+
raise NotImplementedError
|
45 |
+
nn.init.kaiming_normal_(self.conv.weight)
|
46 |
+
|
47 |
+
self.act = nn.LeakyReLU(negative_slope=0.2, inplace=False) if leaky else nn.ReLU(inplace=True)
|
48 |
+
|
49 |
+
def forward(self, x):
|
50 |
+
x = self.conv(x)
|
51 |
+
if isinstance(self.norm, nn.InstanceNorm1d):
|
52 |
+
x = self.norm(x.permute((0, 2, 1))).permute((0, 2, 1)) # normalize on [C]
|
53 |
+
else:
|
54 |
+
x = self.norm(x)
|
55 |
+
x = self.act(x)
|
56 |
+
return x
|
57 |
+
|
58 |
+
|
59 |
+
class PoseSequenceDiscriminator(nn.Module):
|
60 |
+
def __init__(self, cfg):
|
61 |
+
super().__init__()
|
62 |
+
self.cfg = cfg
|
63 |
+
leaky = self.cfg.MODEL.DISCRIMINATOR.LEAKY_RELU
|
64 |
+
|
65 |
+
self.seq = nn.Sequential(
|
66 |
+
ConvNormRelu('1d', cfg.MODEL.DISCRIMINATOR.INPUT_CHANNELS, 256, downsample=True, leaky=leaky), # B, 256, 64
|
67 |
+
ConvNormRelu('1d', 256, 512, downsample=True, leaky=leaky), # B, 512, 32
|
68 |
+
ConvNormRelu('1d', 512, 1024, kernel_size=3, stride=1, padding=1, leaky=leaky), # B, 1024, 16
|
69 |
+
nn.Conv1d(1024, 1, kernel_size=3, stride=1, padding=1, bias=True) # B, 1, 16
|
70 |
+
)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
x = x.reshape(x.size(0), x.size(1), -1).transpose(1, 2)
|
74 |
+
x = self.seq(x)
|
75 |
+
x = x.squeeze(1)
|
76 |
+
return x
|
src/audio2pose_models/networks.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch.nn as nn
|
2 |
+
import torch
|
3 |
+
|
4 |
+
|
5 |
+
class ResidualConv(nn.Module):
|
6 |
+
def __init__(self, input_dim, output_dim, stride, padding):
|
7 |
+
super(ResidualConv, self).__init__()
|
8 |
+
|
9 |
+
self.conv_block = nn.Sequential(
|
10 |
+
nn.BatchNorm2d(input_dim),
|
11 |
+
nn.ReLU(),
|
12 |
+
nn.Conv2d(
|
13 |
+
input_dim, output_dim, kernel_size=3, stride=stride, padding=padding
|
14 |
+
),
|
15 |
+
nn.BatchNorm2d(output_dim),
|
16 |
+
nn.ReLU(),
|
17 |
+
nn.Conv2d(output_dim, output_dim, kernel_size=3, padding=1),
|
18 |
+
)
|
19 |
+
self.conv_skip = nn.Sequential(
|
20 |
+
nn.Conv2d(input_dim, output_dim, kernel_size=3, stride=stride, padding=1),
|
21 |
+
nn.BatchNorm2d(output_dim),
|
22 |
+
)
|
23 |
+
|
24 |
+
def forward(self, x):
|
25 |
+
|
26 |
+
return self.conv_block(x) + self.conv_skip(x)
|
27 |
+
|
28 |
+
|
29 |
+
class Upsample(nn.Module):
|
30 |
+
def __init__(self, input_dim, output_dim, kernel, stride):
|
31 |
+
super(Upsample, self).__init__()
|
32 |
+
|
33 |
+
self.upsample = nn.ConvTranspose2d(
|
34 |
+
input_dim, output_dim, kernel_size=kernel, stride=stride
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.upsample(x)
|
39 |
+
|
40 |
+
|
41 |
+
class Squeeze_Excite_Block(nn.Module):
|
42 |
+
def __init__(self, channel, reduction=16):
|
43 |
+
super(Squeeze_Excite_Block, self).__init__()
|
44 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
45 |
+
self.fc = nn.Sequential(
|
46 |
+
nn.Linear(channel, channel // reduction, bias=False),
|
47 |
+
nn.ReLU(inplace=True),
|
48 |
+
nn.Linear(channel // reduction, channel, bias=False),
|
49 |
+
nn.Sigmoid(),
|
50 |
+
)
|
51 |
+
|
52 |
+
def forward(self, x):
|
53 |
+
b, c, _, _ = x.size()
|
54 |
+
y = self.avg_pool(x).view(b, c)
|
55 |
+
y = self.fc(y).view(b, c, 1, 1)
|
56 |
+
return x * y.expand_as(x)
|
57 |
+
|
58 |
+
|
59 |
+
class ASPP(nn.Module):
|
60 |
+
def __init__(self, in_dims, out_dims, rate=[6, 12, 18]):
|
61 |
+
super(ASPP, self).__init__()
|
62 |
+
|
63 |
+
self.aspp_block1 = nn.Sequential(
|
64 |
+
nn.Conv2d(
|
65 |
+
in_dims, out_dims, 3, stride=1, padding=rate[0], dilation=rate[0]
|
66 |
+
),
|
67 |
+
nn.ReLU(inplace=True),
|
68 |
+
nn.BatchNorm2d(out_dims),
|
69 |
+
)
|
70 |
+
self.aspp_block2 = nn.Sequential(
|
71 |
+
nn.Conv2d(
|
72 |
+
in_dims, out_dims, 3, stride=1, padding=rate[1], dilation=rate[1]
|
73 |
+
),
|
74 |
+
nn.ReLU(inplace=True),
|
75 |
+
nn.BatchNorm2d(out_dims),
|
76 |
+
)
|
77 |
+
self.aspp_block3 = nn.Sequential(
|
78 |
+
nn.Conv2d(
|
79 |
+
in_dims, out_dims, 3, stride=1, padding=rate[2], dilation=rate[2]
|
80 |
+
),
|
81 |
+
nn.ReLU(inplace=True),
|
82 |
+
nn.BatchNorm2d(out_dims),
|
83 |
+
)
|
84 |
+
|
85 |
+
self.output = nn.Conv2d(len(rate) * out_dims, out_dims, 1)
|
86 |
+
self._init_weights()
|
87 |
+
|
88 |
+
def forward(self, x):
|
89 |
+
x1 = self.aspp_block1(x)
|
90 |
+
x2 = self.aspp_block2(x)
|
91 |
+
x3 = self.aspp_block3(x)
|
92 |
+
out = torch.cat([x1, x2, x3], dim=1)
|
93 |
+
return self.output(out)
|
94 |
+
|
95 |
+
def _init_weights(self):
|
96 |
+
for m in self.modules():
|
97 |
+
if isinstance(m, nn.Conv2d):
|
98 |
+
nn.init.kaiming_normal_(m.weight)
|
99 |
+
elif isinstance(m, nn.BatchNorm2d):
|
100 |
+
m.weight.data.fill_(1)
|
101 |
+
m.bias.data.zero_()
|
102 |
+
|
103 |
+
|
104 |
+
class Upsample_(nn.Module):
|
105 |
+
def __init__(self, scale=2):
|
106 |
+
super(Upsample_, self).__init__()
|
107 |
+
|
108 |
+
self.upsample = nn.Upsample(mode="bilinear", scale_factor=scale)
|
109 |
+
|
110 |
+
def forward(self, x):
|
111 |
+
return self.upsample(x)
|
112 |
+
|
113 |
+
|
114 |
+
class AttentionBlock(nn.Module):
|
115 |
+
def __init__(self, input_encoder, input_decoder, output_dim):
|
116 |
+
super(AttentionBlock, self).__init__()
|
117 |
+
|
118 |
+
self.conv_encoder = nn.Sequential(
|
119 |
+
nn.BatchNorm2d(input_encoder),
|
120 |
+
nn.ReLU(),
|
121 |
+
nn.Conv2d(input_encoder, output_dim, 3, padding=1),
|
122 |
+
nn.MaxPool2d(2, 2),
|
123 |
+
)
|
124 |
+
|
125 |
+
self.conv_decoder = nn.Sequential(
|
126 |
+
nn.BatchNorm2d(input_decoder),
|
127 |
+
nn.ReLU(),
|
128 |
+
nn.Conv2d(input_decoder, output_dim, 3, padding=1),
|
129 |
+
)
|
130 |
+
|
131 |
+
self.conv_attn = nn.Sequential(
|
132 |
+
nn.BatchNorm2d(output_dim),
|
133 |
+
nn.ReLU(),
|
134 |
+
nn.Conv2d(output_dim, 1, 1),
|
135 |
+
)
|
136 |
+
|
137 |
+
def forward(self, x1, x2):
|
138 |
+
out = self.conv_encoder(x1) + self.conv_decoder(x2)
|
139 |
+
out = self.conv_attn(out)
|
140 |
+
return out * x2
|
src/audio2pose_models/res_unet.py
ADDED
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from src.audio2pose_models.networks import ResidualConv, Upsample
|
4 |
+
|
5 |
+
|
6 |
+
class ResUnet(nn.Module):
|
7 |
+
def __init__(self, channel=1, filters=[32, 64, 128, 256]):
|
8 |
+
super(ResUnet, self).__init__()
|
9 |
+
|
10 |
+
self.input_layer = nn.Sequential(
|
11 |
+
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1),
|
12 |
+
nn.BatchNorm2d(filters[0]),
|
13 |
+
nn.ReLU(),
|
14 |
+
nn.Conv2d(filters[0], filters[0], kernel_size=3, padding=1),
|
15 |
+
)
|
16 |
+
self.input_skip = nn.Sequential(
|
17 |
+
nn.Conv2d(channel, filters[0], kernel_size=3, padding=1)
|
18 |
+
)
|
19 |
+
|
20 |
+
self.residual_conv_1 = ResidualConv(filters[0], filters[1], stride=(2,1), padding=1)
|
21 |
+
self.residual_conv_2 = ResidualConv(filters[1], filters[2], stride=(2,1), padding=1)
|
22 |
+
|
23 |
+
self.bridge = ResidualConv(filters[2], filters[3], stride=(2,1), padding=1)
|
24 |
+
|
25 |
+
self.upsample_1 = Upsample(filters[3], filters[3], kernel=(2,1), stride=(2,1))
|
26 |
+
self.up_residual_conv1 = ResidualConv(filters[3] + filters[2], filters[2], stride=1, padding=1)
|
27 |
+
|
28 |
+
self.upsample_2 = Upsample(filters[2], filters[2], kernel=(2,1), stride=(2,1))
|
29 |
+
self.up_residual_conv2 = ResidualConv(filters[2] + filters[1], filters[1], stride=1, padding=1)
|
30 |
+
|
31 |
+
self.upsample_3 = Upsample(filters[1], filters[1], kernel=(2,1), stride=(2,1))
|
32 |
+
self.up_residual_conv3 = ResidualConv(filters[1] + filters[0], filters[0], stride=1, padding=1)
|
33 |
+
|
34 |
+
self.output_layer = nn.Sequential(
|
35 |
+
nn.Conv2d(filters[0], 1, 1, 1),
|
36 |
+
nn.Sigmoid(),
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x):
|
40 |
+
# Encode
|
41 |
+
x1 = self.input_layer(x) + self.input_skip(x)
|
42 |
+
x2 = self.residual_conv_1(x1)
|
43 |
+
x3 = self.residual_conv_2(x2)
|
44 |
+
# Bridge
|
45 |
+
x4 = self.bridge(x3)
|
46 |
+
|
47 |
+
# Decode
|
48 |
+
x4 = self.upsample_1(x4)
|
49 |
+
x5 = torch.cat([x4, x3], dim=1)
|
50 |
+
|
51 |
+
x6 = self.up_residual_conv1(x5)
|
52 |
+
|
53 |
+
x6 = self.upsample_2(x6)
|
54 |
+
x7 = torch.cat([x6, x2], dim=1)
|
55 |
+
|
56 |
+
x8 = self.up_residual_conv2(x7)
|
57 |
+
|
58 |
+
x8 = self.upsample_3(x8)
|
59 |
+
x9 = torch.cat([x8, x1], dim=1)
|
60 |
+
|
61 |
+
x10 = self.up_residual_conv3(x9)
|
62 |
+
|
63 |
+
output = self.output_layer(x10)
|
64 |
+
|
65 |
+
return output
|
src/config/auido2exp.yaml
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATASET:
|
2 |
+
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/train.txt
|
3 |
+
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/file_list/val.txt
|
4 |
+
TRAIN_BATCH_SIZE: 32
|
5 |
+
EVAL_BATCH_SIZE: 32
|
6 |
+
EXP: True
|
7 |
+
EXP_DIM: 64
|
8 |
+
FRAME_LEN: 32
|
9 |
+
COEFF_LEN: 73
|
10 |
+
NUM_CLASSES: 46
|
11 |
+
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
|
12 |
+
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav2lip_3dmm
|
13 |
+
LMDB_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
|
14 |
+
DEBUG: True
|
15 |
+
NUM_REPEATS: 2
|
16 |
+
T: 40
|
17 |
+
|
18 |
+
|
19 |
+
MODEL:
|
20 |
+
FRAMEWORK: V2
|
21 |
+
AUDIOENCODER:
|
22 |
+
LEAKY_RELU: True
|
23 |
+
NORM: 'IN'
|
24 |
+
DISCRIMINATOR:
|
25 |
+
LEAKY_RELU: False
|
26 |
+
INPUT_CHANNELS: 6
|
27 |
+
CVAE:
|
28 |
+
AUDIO_EMB_IN_SIZE: 512
|
29 |
+
AUDIO_EMB_OUT_SIZE: 128
|
30 |
+
SEQ_LEN: 32
|
31 |
+
LATENT_SIZE: 256
|
32 |
+
ENCODER_LAYER_SIZES: [192, 1024]
|
33 |
+
DECODER_LAYER_SIZES: [1024, 192]
|
34 |
+
|
35 |
+
|
36 |
+
TRAIN:
|
37 |
+
MAX_EPOCH: 300
|
38 |
+
GENERATOR:
|
39 |
+
LR: 2.0e-5
|
40 |
+
DISCRIMINATOR:
|
41 |
+
LR: 1.0e-5
|
42 |
+
LOSS:
|
43 |
+
W_FEAT: 0
|
44 |
+
W_COEFF_EXP: 2
|
45 |
+
W_LM: 1.0e-2
|
46 |
+
W_LM_MOUTH: 0
|
47 |
+
W_REG: 0
|
48 |
+
W_SYNC: 0
|
49 |
+
W_COLOR: 0
|
50 |
+
W_EXPRESSION: 0
|
51 |
+
W_LIPREADING: 0.01
|
52 |
+
W_LIPREADING_VV: 0
|
53 |
+
W_EYE_BLINK: 4
|
54 |
+
|
55 |
+
TAG:
|
56 |
+
NAME: small_dataset
|
57 |
+
|
58 |
+
|
src/config/auido2pose.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
DATASET:
|
2 |
+
TRAIN_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/train_33.txt
|
3 |
+
EVAL_FILE_LIST: /apdcephfs_cq2/share_1290939/wenxuazhang/code/audio2pose_unet_noAudio/dataset/val.txt
|
4 |
+
TRAIN_BATCH_SIZE: 64
|
5 |
+
EVAL_BATCH_SIZE: 1
|
6 |
+
EXP: True
|
7 |
+
EXP_DIM: 64
|
8 |
+
FRAME_LEN: 32
|
9 |
+
COEFF_LEN: 73
|
10 |
+
NUM_CLASSES: 46
|
11 |
+
AUDIO_ROOT_PATH: /apdcephfs_cq2/share_1290939/wenxuazhang/voxceleb1/wav
|
12 |
+
COEFF_ROOT_PATH: /apdcephfs_cq2/share_1290939/shadowcun/datasets/VoxCeleb/v1/imdb
|
13 |
+
DEBUG: True
|
14 |
+
|
15 |
+
|
16 |
+
MODEL:
|
17 |
+
AUDIOENCODER:
|
18 |
+
LEAKY_RELU: True
|
19 |
+
NORM: 'IN'
|
20 |
+
DISCRIMINATOR:
|
21 |
+
LEAKY_RELU: False
|
22 |
+
INPUT_CHANNELS: 6
|
23 |
+
CVAE:
|
24 |
+
AUDIO_EMB_IN_SIZE: 512
|
25 |
+
AUDIO_EMB_OUT_SIZE: 6
|
26 |
+
SEQ_LEN: 32
|
27 |
+
LATENT_SIZE: 64
|
28 |
+
ENCODER_LAYER_SIZES: [192, 128]
|
29 |
+
DECODER_LAYER_SIZES: [128, 192]
|
30 |
+
|
31 |
+
|
32 |
+
TRAIN:
|
33 |
+
MAX_EPOCH: 150
|
34 |
+
GENERATOR:
|
35 |
+
LR: 1.0e-4
|
36 |
+
DISCRIMINATOR:
|
37 |
+
LR: 1.0e-4
|
38 |
+
LOSS:
|
39 |
+
LAMBDA_REG: 1
|
40 |
+
LAMBDA_LANDMARKS: 0
|
41 |
+
LAMBDA_VERTICES: 0
|
42 |
+
LAMBDA_GAN_MOTION: 0.7
|
43 |
+
LAMBDA_GAN_COEFF: 0
|
44 |
+
LAMBDA_KL: 1
|
45 |
+
|
46 |
+
TAG:
|
47 |
+
NAME: cvae_UNET_useAudio_usewav2lipAudioEncoder
|
48 |
+
|
49 |
+
|
src/config/facerender.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_params:
|
2 |
+
common_params:
|
3 |
+
num_kp: 15
|
4 |
+
image_channel: 3
|
5 |
+
feature_channel: 32
|
6 |
+
estimate_jacobian: False # True
|
7 |
+
kp_detector_params:
|
8 |
+
temperature: 0.1
|
9 |
+
block_expansion: 32
|
10 |
+
max_features: 1024
|
11 |
+
scale_factor: 0.25 # 0.25
|
12 |
+
num_blocks: 5
|
13 |
+
reshape_channel: 16384 # 16384 = 1024 * 16
|
14 |
+
reshape_depth: 16
|
15 |
+
he_estimator_params:
|
16 |
+
block_expansion: 64
|
17 |
+
max_features: 2048
|
18 |
+
num_bins: 66
|
19 |
+
generator_params:
|
20 |
+
block_expansion: 64
|
21 |
+
max_features: 512
|
22 |
+
num_down_blocks: 2
|
23 |
+
reshape_channel: 32
|
24 |
+
reshape_depth: 16 # 512 = 32 * 16
|
25 |
+
num_resblocks: 6
|
26 |
+
estimate_occlusion_map: True
|
27 |
+
dense_motion_params:
|
28 |
+
block_expansion: 32
|
29 |
+
max_features: 1024
|
30 |
+
num_blocks: 5
|
31 |
+
reshape_depth: 16
|
32 |
+
compress: 4
|
33 |
+
discriminator_params:
|
34 |
+
scales: [1]
|
35 |
+
block_expansion: 32
|
36 |
+
max_features: 512
|
37 |
+
num_blocks: 4
|
38 |
+
sn: True
|
39 |
+
mapping_params:
|
40 |
+
coeff_nc: 70
|
41 |
+
descriptor_nc: 1024
|
42 |
+
layer: 3
|
43 |
+
num_kp: 15
|
44 |
+
num_bins: 66
|
45 |
+
|
src/config/facerender_still.yaml
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
model_params:
|
2 |
+
common_params:
|
3 |
+
num_kp: 15
|
4 |
+
image_channel: 3
|
5 |
+
feature_channel: 32
|
6 |
+
estimate_jacobian: False # True
|
7 |
+
kp_detector_params:
|
8 |
+
temperature: 0.1
|
9 |
+
block_expansion: 32
|
10 |
+
max_features: 1024
|
11 |
+
scale_factor: 0.25 # 0.25
|
12 |
+
num_blocks: 5
|
13 |
+
reshape_channel: 16384 # 16384 = 1024 * 16
|
14 |
+
reshape_depth: 16
|
15 |
+
he_estimator_params:
|
16 |
+
block_expansion: 64
|
17 |
+
max_features: 2048
|
18 |
+
num_bins: 66
|
19 |
+
generator_params:
|
20 |
+
block_expansion: 64
|
21 |
+
max_features: 512
|
22 |
+
num_down_blocks: 2
|
23 |
+
reshape_channel: 32
|
24 |
+
reshape_depth: 16 # 512 = 32 * 16
|
25 |
+
num_resblocks: 6
|
26 |
+
estimate_occlusion_map: True
|
27 |
+
dense_motion_params:
|
28 |
+
block_expansion: 32
|
29 |
+
max_features: 1024
|
30 |
+
num_blocks: 5
|
31 |
+
reshape_depth: 16
|
32 |
+
compress: 4
|
33 |
+
discriminator_params:
|
34 |
+
scales: [1]
|
35 |
+
block_expansion: 32
|
36 |
+
max_features: 512
|
37 |
+
num_blocks: 4
|
38 |
+
sn: True
|
39 |
+
mapping_params:
|
40 |
+
coeff_nc: 73
|
41 |
+
descriptor_nc: 1024
|
42 |
+
layer: 3
|
43 |
+
num_kp: 15
|
44 |
+
num_bins: 66
|
45 |
+
|
src/config/similarity_Lm3D_all.mat
ADDED
Binary file (994 Bytes). View file
|
|
src/face3d/data/__init__.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package includes all the modules related to data loading and preprocessing
|
2 |
+
|
3 |
+
To add a custom dataset class called 'dummy', you need to add a file called 'dummy_dataset.py' and define a subclass 'DummyDataset' inherited from BaseDataset.
|
4 |
+
You need to implement four functions:
|
5 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
6 |
+
-- <__len__>: return the size of dataset.
|
7 |
+
-- <__getitem__>: get a data point from data loader.
|
8 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
9 |
+
|
10 |
+
Now you can use the dataset class by specifying flag '--dataset_mode dummy'.
|
11 |
+
See our template dataset class 'template_dataset.py' for more details.
|
12 |
+
"""
|
13 |
+
import numpy as np
|
14 |
+
import importlib
|
15 |
+
import torch.utils.data
|
16 |
+
from face3d.data.base_dataset import BaseDataset
|
17 |
+
|
18 |
+
|
19 |
+
def find_dataset_using_name(dataset_name):
|
20 |
+
"""Import the module "data/[dataset_name]_dataset.py".
|
21 |
+
|
22 |
+
In the file, the class called DatasetNameDataset() will
|
23 |
+
be instantiated. It has to be a subclass of BaseDataset,
|
24 |
+
and it is case-insensitive.
|
25 |
+
"""
|
26 |
+
dataset_filename = "data." + dataset_name + "_dataset"
|
27 |
+
datasetlib = importlib.import_module(dataset_filename)
|
28 |
+
|
29 |
+
dataset = None
|
30 |
+
target_dataset_name = dataset_name.replace('_', '') + 'dataset'
|
31 |
+
for name, cls in datasetlib.__dict__.items():
|
32 |
+
if name.lower() == target_dataset_name.lower() \
|
33 |
+
and issubclass(cls, BaseDataset):
|
34 |
+
dataset = cls
|
35 |
+
|
36 |
+
if dataset is None:
|
37 |
+
raise NotImplementedError("In %s.py, there should be a subclass of BaseDataset with class name that matches %s in lowercase." % (dataset_filename, target_dataset_name))
|
38 |
+
|
39 |
+
return dataset
|
40 |
+
|
41 |
+
|
42 |
+
def get_option_setter(dataset_name):
|
43 |
+
"""Return the static method <modify_commandline_options> of the dataset class."""
|
44 |
+
dataset_class = find_dataset_using_name(dataset_name)
|
45 |
+
return dataset_class.modify_commandline_options
|
46 |
+
|
47 |
+
|
48 |
+
def create_dataset(opt, rank=0):
|
49 |
+
"""Create a dataset given the option.
|
50 |
+
|
51 |
+
This function wraps the class CustomDatasetDataLoader.
|
52 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
53 |
+
|
54 |
+
Example:
|
55 |
+
>>> from data import create_dataset
|
56 |
+
>>> dataset = create_dataset(opt)
|
57 |
+
"""
|
58 |
+
data_loader = CustomDatasetDataLoader(opt, rank=rank)
|
59 |
+
dataset = data_loader.load_data()
|
60 |
+
return dataset
|
61 |
+
|
62 |
+
class CustomDatasetDataLoader():
|
63 |
+
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
64 |
+
|
65 |
+
def __init__(self, opt, rank=0):
|
66 |
+
"""Initialize this class
|
67 |
+
|
68 |
+
Step 1: create a dataset instance given the name [dataset_mode]
|
69 |
+
Step 2: create a multi-threaded data loader.
|
70 |
+
"""
|
71 |
+
self.opt = opt
|
72 |
+
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
73 |
+
self.dataset = dataset_class(opt)
|
74 |
+
self.sampler = None
|
75 |
+
print("rank %d %s dataset [%s] was created" % (rank, self.dataset.name, type(self.dataset).__name__))
|
76 |
+
if opt.use_ddp and opt.isTrain:
|
77 |
+
world_size = opt.world_size
|
78 |
+
self.sampler = torch.utils.data.distributed.DistributedSampler(
|
79 |
+
self.dataset,
|
80 |
+
num_replicas=world_size,
|
81 |
+
rank=rank,
|
82 |
+
shuffle=not opt.serial_batches
|
83 |
+
)
|
84 |
+
self.dataloader = torch.utils.data.DataLoader(
|
85 |
+
self.dataset,
|
86 |
+
sampler=self.sampler,
|
87 |
+
num_workers=int(opt.num_threads / world_size),
|
88 |
+
batch_size=int(opt.batch_size / world_size),
|
89 |
+
drop_last=True)
|
90 |
+
else:
|
91 |
+
self.dataloader = torch.utils.data.DataLoader(
|
92 |
+
self.dataset,
|
93 |
+
batch_size=opt.batch_size,
|
94 |
+
shuffle=(not opt.serial_batches) and opt.isTrain,
|
95 |
+
num_workers=int(opt.num_threads),
|
96 |
+
drop_last=True
|
97 |
+
)
|
98 |
+
|
99 |
+
def set_epoch(self, epoch):
|
100 |
+
self.dataset.current_epoch = epoch
|
101 |
+
if self.sampler is not None:
|
102 |
+
self.sampler.set_epoch(epoch)
|
103 |
+
|
104 |
+
def load_data(self):
|
105 |
+
return self
|
106 |
+
|
107 |
+
def __len__(self):
|
108 |
+
"""Return the number of data in the dataset"""
|
109 |
+
return min(len(self.dataset), self.opt.max_dataset_size)
|
110 |
+
|
111 |
+
def __iter__(self):
|
112 |
+
"""Return a batch of data"""
|
113 |
+
for i, data in enumerate(self.dataloader):
|
114 |
+
if i * self.opt.batch_size >= self.opt.max_dataset_size:
|
115 |
+
break
|
116 |
+
yield data
|
src/face3d/data/base_dataset.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This module implements an abstract base class (ABC) 'BaseDataset' for datasets.
|
2 |
+
|
3 |
+
It also includes common transformation functions (e.g., get_transform, __scale_width), which can be later used in subclasses.
|
4 |
+
"""
|
5 |
+
import random
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
from PIL import Image
|
9 |
+
import torchvision.transforms as transforms
|
10 |
+
from abc import ABC, abstractmethod
|
11 |
+
|
12 |
+
|
13 |
+
class BaseDataset(data.Dataset, ABC):
|
14 |
+
"""This class is an abstract base class (ABC) for datasets.
|
15 |
+
|
16 |
+
To create a subclass, you need to implement the following four functions:
|
17 |
+
-- <__init__>: initialize the class, first call BaseDataset.__init__(self, opt).
|
18 |
+
-- <__len__>: return the size of dataset.
|
19 |
+
-- <__getitem__>: get a data point.
|
20 |
+
-- <modify_commandline_options>: (optionally) add dataset-specific options and set default options.
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self, opt):
|
24 |
+
"""Initialize the class; save the options in the class
|
25 |
+
|
26 |
+
Parameters:
|
27 |
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
28 |
+
"""
|
29 |
+
self.opt = opt
|
30 |
+
# self.root = opt.dataroot
|
31 |
+
self.current_epoch = 0
|
32 |
+
|
33 |
+
@staticmethod
|
34 |
+
def modify_commandline_options(parser, is_train):
|
35 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
36 |
+
|
37 |
+
Parameters:
|
38 |
+
parser -- original option parser
|
39 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
40 |
+
|
41 |
+
Returns:
|
42 |
+
the modified parser.
|
43 |
+
"""
|
44 |
+
return parser
|
45 |
+
|
46 |
+
@abstractmethod
|
47 |
+
def __len__(self):
|
48 |
+
"""Return the total number of images in the dataset."""
|
49 |
+
return 0
|
50 |
+
|
51 |
+
@abstractmethod
|
52 |
+
def __getitem__(self, index):
|
53 |
+
"""Return a data point and its metadata information.
|
54 |
+
|
55 |
+
Parameters:
|
56 |
+
index - - a random integer for data indexing
|
57 |
+
|
58 |
+
Returns:
|
59 |
+
a dictionary of data with their names. It ususally contains the data itself and its metadata information.
|
60 |
+
"""
|
61 |
+
pass
|
62 |
+
|
63 |
+
|
64 |
+
def get_transform(grayscale=False):
|
65 |
+
transform_list = []
|
66 |
+
if grayscale:
|
67 |
+
transform_list.append(transforms.Grayscale(1))
|
68 |
+
transform_list += [transforms.ToTensor()]
|
69 |
+
return transforms.Compose(transform_list)
|
70 |
+
|
71 |
+
def get_affine_mat(opt, size):
|
72 |
+
shift_x, shift_y, scale, rot_angle, flip = 0., 0., 1., 0., False
|
73 |
+
w, h = size
|
74 |
+
|
75 |
+
if 'shift' in opt.preprocess:
|
76 |
+
shift_pixs = int(opt.shift_pixs)
|
77 |
+
shift_x = random.randint(-shift_pixs, shift_pixs)
|
78 |
+
shift_y = random.randint(-shift_pixs, shift_pixs)
|
79 |
+
if 'scale' in opt.preprocess:
|
80 |
+
scale = 1 + opt.scale_delta * (2 * random.random() - 1)
|
81 |
+
if 'rot' in opt.preprocess:
|
82 |
+
rot_angle = opt.rot_angle * (2 * random.random() - 1)
|
83 |
+
rot_rad = -rot_angle * np.pi/180
|
84 |
+
if 'flip' in opt.preprocess:
|
85 |
+
flip = random.random() > 0.5
|
86 |
+
|
87 |
+
shift_to_origin = np.array([1, 0, -w//2, 0, 1, -h//2, 0, 0, 1]).reshape([3, 3])
|
88 |
+
flip_mat = np.array([-1 if flip else 1, 0, 0, 0, 1, 0, 0, 0, 1]).reshape([3, 3])
|
89 |
+
shift_mat = np.array([1, 0, shift_x, 0, 1, shift_y, 0, 0, 1]).reshape([3, 3])
|
90 |
+
rot_mat = np.array([np.cos(rot_rad), np.sin(rot_rad), 0, -np.sin(rot_rad), np.cos(rot_rad), 0, 0, 0, 1]).reshape([3, 3])
|
91 |
+
scale_mat = np.array([scale, 0, 0, 0, scale, 0, 0, 0, 1]).reshape([3, 3])
|
92 |
+
shift_to_center = np.array([1, 0, w//2, 0, 1, h//2, 0, 0, 1]).reshape([3, 3])
|
93 |
+
|
94 |
+
affine = shift_to_center @ scale_mat @ rot_mat @ shift_mat @ flip_mat @ shift_to_origin
|
95 |
+
affine_inv = np.linalg.inv(affine)
|
96 |
+
return affine, affine_inv, flip
|
97 |
+
|
98 |
+
def apply_img_affine(img, affine_inv, method=Image.BICUBIC):
|
99 |
+
return img.transform(img.size, Image.AFFINE, data=affine_inv.flatten()[:6], resample=Image.BICUBIC)
|
100 |
+
|
101 |
+
def apply_lm_affine(landmark, affine, flip, size):
|
102 |
+
_, h = size
|
103 |
+
lm = landmark.copy()
|
104 |
+
lm[:, 1] = h - 1 - lm[:, 1]
|
105 |
+
lm = np.concatenate((lm, np.ones([lm.shape[0], 1])), -1)
|
106 |
+
lm = lm @ np.transpose(affine)
|
107 |
+
lm[:, :2] = lm[:, :2] / lm[:, 2:]
|
108 |
+
lm = lm[:, :2]
|
109 |
+
lm[:, 1] = h - 1 - lm[:, 1]
|
110 |
+
if flip:
|
111 |
+
lm_ = lm.copy()
|
112 |
+
lm_[:17] = lm[16::-1]
|
113 |
+
lm_[17:22] = lm[26:21:-1]
|
114 |
+
lm_[22:27] = lm[21:16:-1]
|
115 |
+
lm_[31:36] = lm[35:30:-1]
|
116 |
+
lm_[36:40] = lm[45:41:-1]
|
117 |
+
lm_[40:42] = lm[47:45:-1]
|
118 |
+
lm_[42:46] = lm[39:35:-1]
|
119 |
+
lm_[46:48] = lm[41:39:-1]
|
120 |
+
lm_[48:55] = lm[54:47:-1]
|
121 |
+
lm_[55:60] = lm[59:54:-1]
|
122 |
+
lm_[60:65] = lm[64:59:-1]
|
123 |
+
lm_[65:68] = lm[67:64:-1]
|
124 |
+
lm = lm_
|
125 |
+
return lm
|
src/face3d/data/flist_dataset.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This script defines the custom dataset for Deep3DFaceRecon_pytorch
|
2 |
+
"""
|
3 |
+
|
4 |
+
import os.path
|
5 |
+
from data.base_dataset import BaseDataset, get_transform, get_affine_mat, apply_img_affine, apply_lm_affine
|
6 |
+
from data.image_folder import make_dataset
|
7 |
+
from PIL import Image
|
8 |
+
import random
|
9 |
+
import util.util as util
|
10 |
+
import numpy as np
|
11 |
+
import json
|
12 |
+
import torch
|
13 |
+
from scipy.io import loadmat, savemat
|
14 |
+
import pickle
|
15 |
+
from util.preprocess import align_img, estimate_norm
|
16 |
+
from util.load_mats import load_lm3d
|
17 |
+
|
18 |
+
|
19 |
+
def default_flist_reader(flist):
|
20 |
+
"""
|
21 |
+
flist format: impath label\nimpath label\n ...(same to caffe's filelist)
|
22 |
+
"""
|
23 |
+
imlist = []
|
24 |
+
with open(flist, 'r') as rf:
|
25 |
+
for line in rf.readlines():
|
26 |
+
impath = line.strip()
|
27 |
+
imlist.append(impath)
|
28 |
+
|
29 |
+
return imlist
|
30 |
+
|
31 |
+
def jason_flist_reader(flist):
|
32 |
+
with open(flist, 'r') as fp:
|
33 |
+
info = json.load(fp)
|
34 |
+
return info
|
35 |
+
|
36 |
+
def parse_label(label):
|
37 |
+
return torch.tensor(np.array(label).astype(np.float32))
|
38 |
+
|
39 |
+
|
40 |
+
class FlistDataset(BaseDataset):
|
41 |
+
"""
|
42 |
+
It requires one directories to host training images '/path/to/data/train'
|
43 |
+
You can train the model with the dataset flag '--dataroot /path/to/data'.
|
44 |
+
"""
|
45 |
+
|
46 |
+
def __init__(self, opt):
|
47 |
+
"""Initialize this dataset class.
|
48 |
+
|
49 |
+
Parameters:
|
50 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
51 |
+
"""
|
52 |
+
BaseDataset.__init__(self, opt)
|
53 |
+
|
54 |
+
self.lm3d_std = load_lm3d(opt.bfm_folder)
|
55 |
+
|
56 |
+
msk_names = default_flist_reader(opt.flist)
|
57 |
+
self.msk_paths = [os.path.join(opt.data_root, i) for i in msk_names]
|
58 |
+
|
59 |
+
self.size = len(self.msk_paths)
|
60 |
+
self.opt = opt
|
61 |
+
|
62 |
+
self.name = 'train' if opt.isTrain else 'val'
|
63 |
+
if '_' in opt.flist:
|
64 |
+
self.name += '_' + opt.flist.split(os.sep)[-1].split('_')[0]
|
65 |
+
|
66 |
+
|
67 |
+
def __getitem__(self, index):
|
68 |
+
"""Return a data point and its metadata information.
|
69 |
+
|
70 |
+
Parameters:
|
71 |
+
index (int) -- a random integer for data indexing
|
72 |
+
|
73 |
+
Returns a dictionary that contains A, B, A_paths and B_paths
|
74 |
+
img (tensor) -- an image in the input domain
|
75 |
+
msk (tensor) -- its corresponding attention mask
|
76 |
+
lm (tensor) -- its corresponding 3d landmarks
|
77 |
+
im_paths (str) -- image paths
|
78 |
+
aug_flag (bool) -- a flag used to tell whether its raw or augmented
|
79 |
+
"""
|
80 |
+
msk_path = self.msk_paths[index % self.size] # make sure index is within then range
|
81 |
+
img_path = msk_path.replace('mask/', '')
|
82 |
+
lm_path = '.'.join(msk_path.replace('mask', 'landmarks').split('.')[:-1]) + '.txt'
|
83 |
+
|
84 |
+
raw_img = Image.open(img_path).convert('RGB')
|
85 |
+
raw_msk = Image.open(msk_path).convert('RGB')
|
86 |
+
raw_lm = np.loadtxt(lm_path).astype(np.float32)
|
87 |
+
|
88 |
+
_, img, lm, msk = align_img(raw_img, raw_lm, self.lm3d_std, raw_msk)
|
89 |
+
|
90 |
+
aug_flag = self.opt.use_aug and self.opt.isTrain
|
91 |
+
if aug_flag:
|
92 |
+
img, lm, msk = self._augmentation(img, lm, self.opt, msk)
|
93 |
+
|
94 |
+
_, H = img.size
|
95 |
+
M = estimate_norm(lm, H)
|
96 |
+
transform = get_transform()
|
97 |
+
img_tensor = transform(img)
|
98 |
+
msk_tensor = transform(msk)[:1, ...]
|
99 |
+
lm_tensor = parse_label(lm)
|
100 |
+
M_tensor = parse_label(M)
|
101 |
+
|
102 |
+
|
103 |
+
return {'imgs': img_tensor,
|
104 |
+
'lms': lm_tensor,
|
105 |
+
'msks': msk_tensor,
|
106 |
+
'M': M_tensor,
|
107 |
+
'im_paths': img_path,
|
108 |
+
'aug_flag': aug_flag,
|
109 |
+
'dataset': self.name}
|
110 |
+
|
111 |
+
def _augmentation(self, img, lm, opt, msk=None):
|
112 |
+
affine, affine_inv, flip = get_affine_mat(opt, img.size)
|
113 |
+
img = apply_img_affine(img, affine_inv)
|
114 |
+
lm = apply_lm_affine(lm, affine, flip, img.size)
|
115 |
+
if msk is not None:
|
116 |
+
msk = apply_img_affine(msk, affine_inv, method=Image.BILINEAR)
|
117 |
+
return img, lm, msk
|
118 |
+
|
119 |
+
|
120 |
+
|
121 |
+
|
122 |
+
def __len__(self):
|
123 |
+
"""Return the total number of images in the dataset.
|
124 |
+
"""
|
125 |
+
return self.size
|
src/face3d/data/image_folder.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""A modified image folder class
|
2 |
+
|
3 |
+
We modify the official PyTorch image folder (https://github.com/pytorch/vision/blob/master/torchvision/datasets/folder.py)
|
4 |
+
so that this class can load images from both current directory and its subdirectories.
|
5 |
+
"""
|
6 |
+
import numpy as np
|
7 |
+
import torch.utils.data as data
|
8 |
+
|
9 |
+
from PIL import Image
|
10 |
+
import os
|
11 |
+
import os.path
|
12 |
+
|
13 |
+
IMG_EXTENSIONS = [
|
14 |
+
'.jpg', '.JPG', '.jpeg', '.JPEG',
|
15 |
+
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
|
16 |
+
'.tif', '.TIF', '.tiff', '.TIFF',
|
17 |
+
]
|
18 |
+
|
19 |
+
|
20 |
+
def is_image_file(filename):
|
21 |
+
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
|
22 |
+
|
23 |
+
|
24 |
+
def make_dataset(dir, max_dataset_size=float("inf")):
|
25 |
+
images = []
|
26 |
+
assert os.path.isdir(dir) or os.path.islink(dir), '%s is not a valid directory' % dir
|
27 |
+
|
28 |
+
for root, _, fnames in sorted(os.walk(dir, followlinks=True)):
|
29 |
+
for fname in fnames:
|
30 |
+
if is_image_file(fname):
|
31 |
+
path = os.path.join(root, fname)
|
32 |
+
images.append(path)
|
33 |
+
return images[:min(max_dataset_size, len(images))]
|
34 |
+
|
35 |
+
|
36 |
+
def default_loader(path):
|
37 |
+
return Image.open(path).convert('RGB')
|
38 |
+
|
39 |
+
|
40 |
+
class ImageFolder(data.Dataset):
|
41 |
+
|
42 |
+
def __init__(self, root, transform=None, return_paths=False,
|
43 |
+
loader=default_loader):
|
44 |
+
imgs = make_dataset(root)
|
45 |
+
if len(imgs) == 0:
|
46 |
+
raise(RuntimeError("Found 0 images in: " + root + "\n"
|
47 |
+
"Supported image extensions are: " + ",".join(IMG_EXTENSIONS)))
|
48 |
+
|
49 |
+
self.root = root
|
50 |
+
self.imgs = imgs
|
51 |
+
self.transform = transform
|
52 |
+
self.return_paths = return_paths
|
53 |
+
self.loader = loader
|
54 |
+
|
55 |
+
def __getitem__(self, index):
|
56 |
+
path = self.imgs[index]
|
57 |
+
img = self.loader(path)
|
58 |
+
if self.transform is not None:
|
59 |
+
img = self.transform(img)
|
60 |
+
if self.return_paths:
|
61 |
+
return img, path
|
62 |
+
else:
|
63 |
+
return img
|
64 |
+
|
65 |
+
def __len__(self):
|
66 |
+
return len(self.imgs)
|
src/face3d/data/template_dataset.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Dataset class template
|
2 |
+
|
3 |
+
This module provides a template for users to implement custom datasets.
|
4 |
+
You can specify '--dataset_mode template' to use this dataset.
|
5 |
+
The class name should be consistent with both the filename and its dataset_mode option.
|
6 |
+
The filename should be <dataset_mode>_dataset.py
|
7 |
+
The class name should be <Dataset_mode>Dataset.py
|
8 |
+
You need to implement the following functions:
|
9 |
+
-- <modify_commandline_options>: Add dataset-specific options and rewrite default values for existing options.
|
10 |
+
-- <__init__>: Initialize this dataset class.
|
11 |
+
-- <__getitem__>: Return a data point and its metadata information.
|
12 |
+
-- <__len__>: Return the number of images.
|
13 |
+
"""
|
14 |
+
from data.base_dataset import BaseDataset, get_transform
|
15 |
+
# from data.image_folder import make_dataset
|
16 |
+
# from PIL import Image
|
17 |
+
|
18 |
+
|
19 |
+
class TemplateDataset(BaseDataset):
|
20 |
+
"""A template dataset class for you to implement custom datasets."""
|
21 |
+
@staticmethod
|
22 |
+
def modify_commandline_options(parser, is_train):
|
23 |
+
"""Add new dataset-specific options, and rewrite default values for existing options.
|
24 |
+
|
25 |
+
Parameters:
|
26 |
+
parser -- original option parser
|
27 |
+
is_train (bool) -- whether training phase or test phase. You can use this flag to add training-specific or test-specific options.
|
28 |
+
|
29 |
+
Returns:
|
30 |
+
the modified parser.
|
31 |
+
"""
|
32 |
+
parser.add_argument('--new_dataset_option', type=float, default=1.0, help='new dataset option')
|
33 |
+
parser.set_defaults(max_dataset_size=10, new_dataset_option=2.0) # specify dataset-specific default values
|
34 |
+
return parser
|
35 |
+
|
36 |
+
def __init__(self, opt):
|
37 |
+
"""Initialize this dataset class.
|
38 |
+
|
39 |
+
Parameters:
|
40 |
+
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
41 |
+
|
42 |
+
A few things can be done here.
|
43 |
+
- save the options (have been done in BaseDataset)
|
44 |
+
- get image paths and meta information of the dataset.
|
45 |
+
- define the image transformation.
|
46 |
+
"""
|
47 |
+
# save the option and dataset root
|
48 |
+
BaseDataset.__init__(self, opt)
|
49 |
+
# get the image paths of your dataset;
|
50 |
+
self.image_paths = [] # You can call sorted(make_dataset(self.root, opt.max_dataset_size)) to get all the image paths under the directory self.root
|
51 |
+
# define the default transform function. You can use <base_dataset.get_transform>; You can also define your custom transform function
|
52 |
+
self.transform = get_transform(opt)
|
53 |
+
|
54 |
+
def __getitem__(self, index):
|
55 |
+
"""Return a data point and its metadata information.
|
56 |
+
|
57 |
+
Parameters:
|
58 |
+
index -- a random integer for data indexing
|
59 |
+
|
60 |
+
Returns:
|
61 |
+
a dictionary of data with their names. It usually contains the data itself and its metadata information.
|
62 |
+
|
63 |
+
Step 1: get a random image path: e.g., path = self.image_paths[index]
|
64 |
+
Step 2: load your data from the disk: e.g., image = Image.open(path).convert('RGB').
|
65 |
+
Step 3: convert your data to a PyTorch tensor. You can use helpder functions such as self.transform. e.g., data = self.transform(image)
|
66 |
+
Step 4: return a data point as a dictionary.
|
67 |
+
"""
|
68 |
+
path = 'temp' # needs to be a string
|
69 |
+
data_A = None # needs to be a tensor
|
70 |
+
data_B = None # needs to be a tensor
|
71 |
+
return {'data_A': data_A, 'data_B': data_B, 'path': path}
|
72 |
+
|
73 |
+
def __len__(self):
|
74 |
+
"""Return the total number of images."""
|
75 |
+
return len(self.image_paths)
|
src/face3d/extract_kp_videos.py
ADDED
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import glob
|
5 |
+
import argparse
|
6 |
+
import face_alignment
|
7 |
+
import numpy as np
|
8 |
+
from PIL import Image
|
9 |
+
from tqdm import tqdm
|
10 |
+
from itertools import cycle
|
11 |
+
|
12 |
+
from torch.multiprocessing import Pool, Process, set_start_method
|
13 |
+
|
14 |
+
class KeypointExtractor():
|
15 |
+
def __init__(self, device):
|
16 |
+
self.detector = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D,
|
17 |
+
device=device)
|
18 |
+
|
19 |
+
def extract_keypoint(self, images, name=None, info=True):
|
20 |
+
if isinstance(images, list):
|
21 |
+
keypoints = []
|
22 |
+
if info:
|
23 |
+
i_range = tqdm(images,desc='landmark Det:')
|
24 |
+
else:
|
25 |
+
i_range = images
|
26 |
+
|
27 |
+
for image in i_range:
|
28 |
+
current_kp = self.extract_keypoint(image)
|
29 |
+
if np.mean(current_kp) == -1 and keypoints:
|
30 |
+
keypoints.append(keypoints[-1])
|
31 |
+
else:
|
32 |
+
keypoints.append(current_kp[None])
|
33 |
+
|
34 |
+
keypoints = np.concatenate(keypoints, 0)
|
35 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
36 |
+
return keypoints
|
37 |
+
else:
|
38 |
+
while True:
|
39 |
+
try:
|
40 |
+
keypoints = self.detector.get_landmarks_from_image(np.array(images))[0]
|
41 |
+
break
|
42 |
+
except RuntimeError as e:
|
43 |
+
if str(e).startswith('CUDA'):
|
44 |
+
print("Warning: out of memory, sleep for 1s")
|
45 |
+
time.sleep(1)
|
46 |
+
else:
|
47 |
+
print(e)
|
48 |
+
break
|
49 |
+
except TypeError:
|
50 |
+
print('No face detected in this image')
|
51 |
+
shape = [68, 2]
|
52 |
+
keypoints = -1. * np.ones(shape)
|
53 |
+
break
|
54 |
+
if name is not None:
|
55 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
56 |
+
return keypoints
|
57 |
+
|
58 |
+
def read_video(filename):
|
59 |
+
frames = []
|
60 |
+
cap = cv2.VideoCapture(filename)
|
61 |
+
while cap.isOpened():
|
62 |
+
ret, frame = cap.read()
|
63 |
+
if ret:
|
64 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
65 |
+
frame = Image.fromarray(frame)
|
66 |
+
frames.append(frame)
|
67 |
+
else:
|
68 |
+
break
|
69 |
+
cap.release()
|
70 |
+
return frames
|
71 |
+
|
72 |
+
def run(data):
|
73 |
+
filename, opt, device = data
|
74 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
75 |
+
kp_extractor = KeypointExtractor()
|
76 |
+
images = read_video(filename)
|
77 |
+
name = filename.split('/')[-2:]
|
78 |
+
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
|
79 |
+
kp_extractor.extract_keypoint(
|
80 |
+
images,
|
81 |
+
name=os.path.join(opt.output_dir, name[-2], name[-1])
|
82 |
+
)
|
83 |
+
|
84 |
+
if __name__ == '__main__':
|
85 |
+
set_start_method('spawn')
|
86 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
87 |
+
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
|
88 |
+
parser.add_argument('--output_dir', type=str, help='the folder of the output files')
|
89 |
+
parser.add_argument('--device_ids', type=str, default='0,1')
|
90 |
+
parser.add_argument('--workers', type=int, default=4)
|
91 |
+
|
92 |
+
opt = parser.parse_args()
|
93 |
+
filenames = list()
|
94 |
+
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
|
95 |
+
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
|
96 |
+
extensions = VIDEO_EXTENSIONS
|
97 |
+
|
98 |
+
for ext in extensions:
|
99 |
+
os.listdir(f'{opt.input_dir}')
|
100 |
+
print(f'{opt.input_dir}/*.{ext}')
|
101 |
+
filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
|
102 |
+
print('Total number of videos:', len(filenames))
|
103 |
+
pool = Pool(opt.workers)
|
104 |
+
args_list = cycle([opt])
|
105 |
+
device_ids = opt.device_ids.split(",")
|
106 |
+
device_ids = cycle(device_ids)
|
107 |
+
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
|
108 |
+
None
|
src/face3d/extract_kp_videos_safe.py
ADDED
@@ -0,0 +1,151 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import cv2
|
3 |
+
import time
|
4 |
+
import glob
|
5 |
+
import argparse
|
6 |
+
import numpy as np
|
7 |
+
from PIL import Image
|
8 |
+
import torch
|
9 |
+
from tqdm import tqdm
|
10 |
+
from itertools import cycle
|
11 |
+
from torch.multiprocessing import Pool, Process, set_start_method
|
12 |
+
|
13 |
+
from facexlib.alignment import landmark_98_to_68
|
14 |
+
from facexlib.detection import init_detection_model
|
15 |
+
|
16 |
+
from facexlib.utils import load_file_from_url
|
17 |
+
from src.face3d.util.my_awing_arch import FAN
|
18 |
+
|
19 |
+
def init_alignment_model(model_name, half=False, device='cuda', model_rootpath=None):
|
20 |
+
if model_name == 'awing_fan':
|
21 |
+
model = FAN(num_modules=4, num_landmarks=98, device=device)
|
22 |
+
model_url = 'https://github.com/xinntao/facexlib/releases/download/v0.1.0/alignment_WFLW_4HG.pth'
|
23 |
+
else:
|
24 |
+
raise NotImplementedError(f'{model_name} is not implemented.')
|
25 |
+
|
26 |
+
model_path = load_file_from_url(
|
27 |
+
url=model_url, model_dir='facexlib/weights', progress=True, file_name=None, save_dir=model_rootpath)
|
28 |
+
model.load_state_dict(torch.load(model_path, map_location=device)['state_dict'], strict=True)
|
29 |
+
model.eval()
|
30 |
+
model = model.to(device)
|
31 |
+
return model
|
32 |
+
|
33 |
+
|
34 |
+
class KeypointExtractor():
|
35 |
+
def __init__(self, device='cuda'):
|
36 |
+
|
37 |
+
### gfpgan/weights
|
38 |
+
try:
|
39 |
+
import webui # in webui
|
40 |
+
root_path = 'extensions/SadTalker/gfpgan/weights'
|
41 |
+
|
42 |
+
except:
|
43 |
+
root_path = 'gfpgan/weights'
|
44 |
+
|
45 |
+
self.detector = init_alignment_model('awing_fan',device=device, model_rootpath=root_path)
|
46 |
+
self.det_net = init_detection_model('retinaface_resnet50', half=False,device=device, model_rootpath=root_path)
|
47 |
+
|
48 |
+
def extract_keypoint(self, images, name=None, info=True):
|
49 |
+
if isinstance(images, list):
|
50 |
+
keypoints = []
|
51 |
+
if info:
|
52 |
+
i_range = tqdm(images,desc='landmark Det:')
|
53 |
+
else:
|
54 |
+
i_range = images
|
55 |
+
|
56 |
+
for image in i_range:
|
57 |
+
current_kp = self.extract_keypoint(image)
|
58 |
+
# current_kp = self.detector.get_landmarks(np.array(image))
|
59 |
+
if np.mean(current_kp) == -1 and keypoints:
|
60 |
+
keypoints.append(keypoints[-1])
|
61 |
+
else:
|
62 |
+
keypoints.append(current_kp[None])
|
63 |
+
|
64 |
+
keypoints = np.concatenate(keypoints, 0)
|
65 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
66 |
+
return keypoints
|
67 |
+
else:
|
68 |
+
while True:
|
69 |
+
try:
|
70 |
+
with torch.no_grad():
|
71 |
+
# face detection -> face alignment.
|
72 |
+
img = np.array(images)
|
73 |
+
bboxes = self.det_net.detect_faces(images, 0.97)
|
74 |
+
|
75 |
+
bboxes = bboxes[0]
|
76 |
+
img = img[int(bboxes[1]):int(bboxes[3]), int(bboxes[0]):int(bboxes[2]), :]
|
77 |
+
|
78 |
+
keypoints = landmark_98_to_68(self.detector.get_landmarks(img)) # [0]
|
79 |
+
|
80 |
+
#### keypoints to the original location
|
81 |
+
keypoints[:,0] += int(bboxes[0])
|
82 |
+
keypoints[:,1] += int(bboxes[1])
|
83 |
+
|
84 |
+
break
|
85 |
+
except RuntimeError as e:
|
86 |
+
if str(e).startswith('CUDA'):
|
87 |
+
print("Warning: out of memory, sleep for 1s")
|
88 |
+
time.sleep(1)
|
89 |
+
else:
|
90 |
+
print(e)
|
91 |
+
break
|
92 |
+
except TypeError:
|
93 |
+
print('No face detected in this image')
|
94 |
+
shape = [68, 2]
|
95 |
+
keypoints = -1. * np.ones(shape)
|
96 |
+
break
|
97 |
+
if name is not None:
|
98 |
+
np.savetxt(os.path.splitext(name)[0]+'.txt', keypoints.reshape(-1))
|
99 |
+
return keypoints
|
100 |
+
|
101 |
+
def read_video(filename):
|
102 |
+
frames = []
|
103 |
+
cap = cv2.VideoCapture(filename)
|
104 |
+
while cap.isOpened():
|
105 |
+
ret, frame = cap.read()
|
106 |
+
if ret:
|
107 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
108 |
+
frame = Image.fromarray(frame)
|
109 |
+
frames.append(frame)
|
110 |
+
else:
|
111 |
+
break
|
112 |
+
cap.release()
|
113 |
+
return frames
|
114 |
+
|
115 |
+
def run(data):
|
116 |
+
filename, opt, device = data
|
117 |
+
os.environ['CUDA_VISIBLE_DEVICES'] = device
|
118 |
+
kp_extractor = KeypointExtractor()
|
119 |
+
images = read_video(filename)
|
120 |
+
name = filename.split('/')[-2:]
|
121 |
+
os.makedirs(os.path.join(opt.output_dir, name[-2]), exist_ok=True)
|
122 |
+
kp_extractor.extract_keypoint(
|
123 |
+
images,
|
124 |
+
name=os.path.join(opt.output_dir, name[-2], name[-1])
|
125 |
+
)
|
126 |
+
|
127 |
+
if __name__ == '__main__':
|
128 |
+
set_start_method('spawn')
|
129 |
+
parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
130 |
+
parser.add_argument('--input_dir', type=str, help='the folder of the input files')
|
131 |
+
parser.add_argument('--output_dir', type=str, help='the folder of the output files')
|
132 |
+
parser.add_argument('--device_ids', type=str, default='0,1')
|
133 |
+
parser.add_argument('--workers', type=int, default=4)
|
134 |
+
|
135 |
+
opt = parser.parse_args()
|
136 |
+
filenames = list()
|
137 |
+
VIDEO_EXTENSIONS_LOWERCASE = {'mp4'}
|
138 |
+
VIDEO_EXTENSIONS = VIDEO_EXTENSIONS_LOWERCASE.union({f.upper() for f in VIDEO_EXTENSIONS_LOWERCASE})
|
139 |
+
extensions = VIDEO_EXTENSIONS
|
140 |
+
|
141 |
+
for ext in extensions:
|
142 |
+
os.listdir(f'{opt.input_dir}')
|
143 |
+
print(f'{opt.input_dir}/*.{ext}')
|
144 |
+
filenames = sorted(glob.glob(f'{opt.input_dir}/*.{ext}'))
|
145 |
+
print('Total number of videos:', len(filenames))
|
146 |
+
pool = Pool(opt.workers)
|
147 |
+
args_list = cycle([opt])
|
148 |
+
device_ids = opt.device_ids.split(",")
|
149 |
+
device_ids = cycle(device_ids)
|
150 |
+
for data in tqdm(pool.imap_unordered(run, zip(filenames, args_list, device_ids))):
|
151 |
+
None
|
src/face3d/models/__init__.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""This package contains modules related to objective functions, optimizations, and network architectures.
|
2 |
+
|
3 |
+
To add a custom model class called 'dummy', you need to add a file called 'dummy_model.py' and define a subclass DummyModel inherited from BaseModel.
|
4 |
+
You need to implement the following five functions:
|
5 |
+
-- <__init__>: initialize the class; first call BaseModel.__init__(self, opt).
|
6 |
+
-- <set_input>: unpack data from dataset and apply preprocessing.
|
7 |
+
-- <forward>: produce intermediate results.
|
8 |
+
-- <optimize_parameters>: calculate loss, gradients, and update network weights.
|
9 |
+
-- <modify_commandline_options>: (optionally) add model-specific options and set default options.
|
10 |
+
|
11 |
+
In the function <__init__>, you need to define four lists:
|
12 |
+
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
13 |
+
-- self.model_names (str list): define networks used in our training.
|
14 |
+
-- self.visual_names (str list): specify the images that you want to display and save.
|
15 |
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an usage.
|
16 |
+
|
17 |
+
Now you can use the model class by specifying flag '--model dummy'.
|
18 |
+
See our template model class 'template_model.py' for more details.
|
19 |
+
"""
|
20 |
+
|
21 |
+
import importlib
|
22 |
+
from src.face3d.models.base_model import BaseModel
|
23 |
+
|
24 |
+
|
25 |
+
def find_model_using_name(model_name):
|
26 |
+
"""Import the module "models/[model_name]_model.py".
|
27 |
+
|
28 |
+
In the file, the class called DatasetNameModel() will
|
29 |
+
be instantiated. It has to be a subclass of BaseModel,
|
30 |
+
and it is case-insensitive.
|
31 |
+
"""
|
32 |
+
model_filename = "face3d.models." + model_name + "_model"
|
33 |
+
modellib = importlib.import_module(model_filename)
|
34 |
+
model = None
|
35 |
+
target_model_name = model_name.replace('_', '') + 'model'
|
36 |
+
for name, cls in modellib.__dict__.items():
|
37 |
+
if name.lower() == target_model_name.lower() \
|
38 |
+
and issubclass(cls, BaseModel):
|
39 |
+
model = cls
|
40 |
+
|
41 |
+
if model is None:
|
42 |
+
print("In %s.py, there should be a subclass of BaseModel with class name that matches %s in lowercase." % (model_filename, target_model_name))
|
43 |
+
exit(0)
|
44 |
+
|
45 |
+
return model
|
46 |
+
|
47 |
+
|
48 |
+
def get_option_setter(model_name):
|
49 |
+
"""Return the static method <modify_commandline_options> of the model class."""
|
50 |
+
model_class = find_model_using_name(model_name)
|
51 |
+
return model_class.modify_commandline_options
|
52 |
+
|
53 |
+
|
54 |
+
def create_model(opt):
|
55 |
+
"""Create a model given the option.
|
56 |
+
|
57 |
+
This function warps the class CustomDatasetDataLoader.
|
58 |
+
This is the main interface between this package and 'train.py'/'test.py'
|
59 |
+
|
60 |
+
Example:
|
61 |
+
>>> from models import create_model
|
62 |
+
>>> model = create_model(opt)
|
63 |
+
"""
|
64 |
+
model = find_model_using_name(opt.model)
|
65 |
+
instance = model(opt)
|
66 |
+
print("model [%s] was created" % type(instance).__name__)
|
67 |
+
return instance
|
src/face3d/models/arcface_torch/README.md
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Distributed Arcface Training in Pytorch
|
2 |
+
|
3 |
+
This is a deep learning library that makes face recognition efficient, and effective, which can train tens of millions
|
4 |
+
identity on a single server.
|
5 |
+
|
6 |
+
## Requirements
|
7 |
+
|
8 |
+
- Install [pytorch](http://pytorch.org) (torch>=1.6.0), our doc for [install.md](docs/install.md).
|
9 |
+
- `pip install -r requirements.txt`.
|
10 |
+
- Download the dataset
|
11 |
+
from [https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_](https://github.com/deepinsight/insightface/tree/master/recognition/_datasets_)
|
12 |
+
.
|
13 |
+
|
14 |
+
## How to Training
|
15 |
+
|
16 |
+
To train a model, run `train.py` with the path to the configs:
|
17 |
+
|
18 |
+
### 1. Single node, 8 GPUs:
|
19 |
+
|
20 |
+
```shell
|
21 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r50
|
22 |
+
```
|
23 |
+
|
24 |
+
### 2. Multiple nodes, each node 8 GPUs:
|
25 |
+
|
26 |
+
Node 0:
|
27 |
+
|
28 |
+
```shell
|
29 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=0 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
|
30 |
+
```
|
31 |
+
|
32 |
+
Node 1:
|
33 |
+
|
34 |
+
```shell
|
35 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=2 --node_rank=1 --master_addr="ip1" --master_port=1234 train.py train.py configs/ms1mv3_r50
|
36 |
+
```
|
37 |
+
|
38 |
+
### 3.Training resnet2060 with 8 GPUs:
|
39 |
+
|
40 |
+
```shell
|
41 |
+
python -m torch.distributed.launch --nproc_per_node=8 --nnodes=1 --node_rank=0 --master_addr="127.0.0.1" --master_port=1234 train.py configs/ms1mv3_r2060.py
|
42 |
+
```
|
43 |
+
|
44 |
+
## Model Zoo
|
45 |
+
|
46 |
+
- The models are available for non-commercial research purposes only.
|
47 |
+
- All models can be found in here.
|
48 |
+
- [Baidu Yun Pan](https://pan.baidu.com/s/1CL-l4zWqsI1oDuEEYVhj-g): e8pw
|
49 |
+
- [onedrive](https://1drv.ms/u/s!AswpsDO2toNKq0lWY69vN58GR6mw?e=p9Ov5d)
|
50 |
+
|
51 |
+
### Performance on [**ICCV2021-MFR**](http://iccv21-mfr.com/)
|
52 |
+
|
53 |
+
ICCV2021-MFR testset consists of non-celebrities so we can ensure that it has very few overlap with public available face
|
54 |
+
recognition training set, such as MS1M and CASIA as they mostly collected from online celebrities.
|
55 |
+
As the result, we can evaluate the FAIR performance for different algorithms.
|
56 |
+
|
57 |
+
For **ICCV2021-MFR-ALL** set, TAR is measured on all-to-all 1:1 protocal, with FAR less than 0.000001(e-6). The
|
58 |
+
globalised multi-racial testset contains 242,143 identities and 1,624,305 images.
|
59 |
+
|
60 |
+
For **ICCV2021-MFR-MASK** set, TAR is measured on mask-to-nonmask 1:1 protocal, with FAR less than 0.0001(e-4).
|
61 |
+
Mask testset contains 6,964 identities, 6,964 masked images and 13,928 non-masked images.
|
62 |
+
There are totally 13,928 positive pairs and 96,983,824 negative pairs.
|
63 |
+
|
64 |
+
| Datasets | backbone | Training throughout | Size / MB | **ICCV2021-MFR-MASK** | **ICCV2021-MFR-ALL** |
|
65 |
+
| :---: | :--- | :--- | :--- |:--- |:--- |
|
66 |
+
| MS1MV3 | r18 | - | 91 | **47.85** | **68.33** |
|
67 |
+
| Glint360k | r18 | 8536 | 91 | **53.32** | **72.07** |
|
68 |
+
| MS1MV3 | r34 | - | 130 | **58.72** | **77.36** |
|
69 |
+
| Glint360k | r34 | 6344 | 130 | **65.10** | **83.02** |
|
70 |
+
| MS1MV3 | r50 | 5500 | 166 | **63.85** | **80.53** |
|
71 |
+
| Glint360k | r50 | 5136 | 166 | **70.23** | **87.08** |
|
72 |
+
| MS1MV3 | r100 | - | 248 | **69.09** | **84.31** |
|
73 |
+
| Glint360k | r100 | 3332 | 248 | **75.57** | **90.66** |
|
74 |
+
| MS1MV3 | mobilefacenet | 12185 | 7.8 | **41.52** | **65.26** |
|
75 |
+
| Glint360k | mobilefacenet | 11197 | 7.8 | **44.52** | **66.48** |
|
76 |
+
|
77 |
+
### Performance on IJB-C and Verification Datasets
|
78 |
+
|
79 |
+
| Datasets | backbone | IJBC(1e-05) | IJBC(1e-04) | agedb30 | cfp_fp | lfw | log |
|
80 |
+
| :---: | :--- | :--- | :--- | :--- |:--- |:--- |:--- |
|
81 |
+
| MS1MV3 | r18 | 92.07 | 94.66 | 97.77 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r18_fp16/training.log)|
|
82 |
+
| MS1MV3 | r34 | 94.10 | 95.90 | 98.10 | 98.67 | 99.80 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r34_fp16/training.log)|
|
83 |
+
| MS1MV3 | r50 | 94.79 | 96.46 | 98.35 | 98.96 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r50_fp16/training.log)|
|
84 |
+
| MS1MV3 | r100 | 95.31 | 96.81 | 98.48 | 99.06 | 99.85 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r100_fp16/training.log)|
|
85 |
+
| MS1MV3 | **r2060**| 95.34 | 97.11 | 98.67 | 99.24 | 99.87 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/ms1mv3_arcface_r2060_fp16/training.log)|
|
86 |
+
| Glint360k |r18-0.1 | 93.16 | 95.33 | 97.72 | 97.73 | 99.77 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r18_fp16_0.1/training.log)|
|
87 |
+
| Glint360k |r34-0.1 | 95.16 | 96.56 | 98.33 | 98.78 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r34_fp16_0.1/training.log)|
|
88 |
+
| Glint360k |r50-0.1 | 95.61 | 96.97 | 98.38 | 99.20 | 99.83 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r50_fp16_0.1/training.log)|
|
89 |
+
| Glint360k |r100-0.1 | 95.88 | 97.32 | 98.48 | 99.29 | 99.82 |[log](https://raw.githubusercontent.com/anxiangsir/insightface_arcface_log/master/glint360k_cosface_r100_fp16_0.1/training.log)|
|
90 |
+
|
91 |
+
[comment]: <> (More details see [model.md](docs/modelzoo.md) in docs.)
|
92 |
+
|
93 |
+
|
94 |
+
## [Speed Benchmark](docs/speed_benchmark.md)
|
95 |
+
|
96 |
+
**Arcface Torch** can train large-scale face recognition training set efficiently and quickly. When the number of
|
97 |
+
classes in training sets is greater than 300K and the training is sufficient, partial fc sampling strategy will get same
|
98 |
+
accuracy with several times faster training performance and smaller GPU memory.
|
99 |
+
Partial FC is a sparse variant of the model parallel architecture for large sacle face recognition. Partial FC use a
|
100 |
+
sparse softmax, where each batch dynamicly sample a subset of class centers for training. In each iteration, only a
|
101 |
+
sparse part of the parameters will be updated, which can reduce a lot of GPU memory and calculations. With Partial FC,
|
102 |
+
we can scale trainset of 29 millions identities, the largest to date. Partial FC also supports multi-machine distributed
|
103 |
+
training and mixed precision training.
|
104 |
+
|
105 |
+
![Image text](https://github.com/anxiangsir/insightface_arcface_log/blob/master/partial_fc_v2.png)
|
106 |
+
|
107 |
+
More details see
|
108 |
+
[speed_benchmark.md](docs/speed_benchmark.md) in docs.
|
109 |
+
|
110 |
+
### 1. Training speed of different parallel methods (samples / second), Tesla V100 32GB * 8. (Larger is better)
|
111 |
+
|
112 |
+
`-` means training failed because of gpu memory limitations.
|
113 |
+
|
114 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
115 |
+
| :--- | :--- | :--- | :--- |
|
116 |
+
|125000 | 4681 | 4824 | 5004 |
|
117 |
+
|1400000 | **1672** | 3043 | 4738 |
|
118 |
+
|5500000 | **-** | **1389** | 3975 |
|
119 |
+
|8000000 | **-** | **-** | 3565 |
|
120 |
+
|16000000 | **-** | **-** | 2679 |
|
121 |
+
|29000000 | **-** | **-** | **1855** |
|
122 |
+
|
123 |
+
### 2. GPU memory cost of different parallel methods (MB per GPU), Tesla V100 32GB * 8. (Smaller is better)
|
124 |
+
|
125 |
+
| Number of Identities in Dataset | Data Parallel | Model Parallel | Partial FC 0.1 |
|
126 |
+
| :--- | :--- | :--- | :--- |
|
127 |
+
|125000 | 7358 | 5306 | 4868 |
|
128 |
+
|1400000 | 32252 | 11178 | 6056 |
|
129 |
+
|5500000 | **-** | 32188 | 9854 |
|
130 |
+
|8000000 | **-** | **-** | 12310 |
|
131 |
+
|16000000 | **-** | **-** | 19950 |
|
132 |
+
|29000000 | **-** | **-** | 32324 |
|
133 |
+
|
134 |
+
## Evaluation ICCV2021-MFR and IJB-C
|
135 |
+
|
136 |
+
More details see [eval.md](docs/eval.md) in docs.
|
137 |
+
|
138 |
+
## Test
|
139 |
+
|
140 |
+
We tested many versions of PyTorch. Please create an issue if you are having trouble.
|
141 |
+
|
142 |
+
- [x] torch 1.6.0
|
143 |
+
- [x] torch 1.7.1
|
144 |
+
- [x] torch 1.8.0
|
145 |
+
- [x] torch 1.9.0
|
146 |
+
|
147 |
+
## Citation
|
148 |
+
|
149 |
+
```
|
150 |
+
@inproceedings{deng2019arcface,
|
151 |
+
title={Arcface: Additive angular margin loss for deep face recognition},
|
152 |
+
author={Deng, Jiankang and Guo, Jia and Xue, Niannan and Zafeiriou, Stefanos},
|
153 |
+
booktitle={Proceedings of the IEEE Conference on Computer Vision and Pattern Recognition},
|
154 |
+
pages={4690--4699},
|
155 |
+
year={2019}
|
156 |
+
}
|
157 |
+
@inproceedings{an2020partical_fc,
|
158 |
+
title={Partial FC: Training 10 Million Identities on a Single Machine},
|
159 |
+
author={An, Xiang and Zhu, Xuhan and Xiao, Yang and Wu, Lan and Zhang, Ming and Gao, Yuan and Qin, Bin and
|
160 |
+
Zhang, Debing and Fu Ying},
|
161 |
+
booktitle={Arxiv 2010.05222},
|
162 |
+
year={2020}
|
163 |
+
}
|
164 |
+
```
|
src/face3d/models/arcface_torch/backbones/__init__.py
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .iresnet import iresnet18, iresnet34, iresnet50, iresnet100, iresnet200
|
2 |
+
from .mobilefacenet import get_mbf
|
3 |
+
|
4 |
+
|
5 |
+
def get_model(name, **kwargs):
|
6 |
+
# resnet
|
7 |
+
if name == "r18":
|
8 |
+
return iresnet18(False, **kwargs)
|
9 |
+
elif name == "r34":
|
10 |
+
return iresnet34(False, **kwargs)
|
11 |
+
elif name == "r50":
|
12 |
+
return iresnet50(False, **kwargs)
|
13 |
+
elif name == "r100":
|
14 |
+
return iresnet100(False, **kwargs)
|
15 |
+
elif name == "r200":
|
16 |
+
return iresnet200(False, **kwargs)
|
17 |
+
elif name == "r2060":
|
18 |
+
from .iresnet2060 import iresnet2060
|
19 |
+
return iresnet2060(False, **kwargs)
|
20 |
+
elif name == "mbf":
|
21 |
+
fp16 = kwargs.get("fp16", False)
|
22 |
+
num_features = kwargs.get("num_features", 512)
|
23 |
+
return get_mbf(fp16=fp16, num_features=num_features)
|
24 |
+
else:
|
25 |
+
raise ValueError()
|
src/face3d/models/arcface_torch/backbones/iresnet.py
ADDED
@@ -0,0 +1,187 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
__all__ = ['iresnet18', 'iresnet34', 'iresnet50', 'iresnet100', 'iresnet200']
|
5 |
+
|
6 |
+
|
7 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
8 |
+
"""3x3 convolution with padding"""
|
9 |
+
return nn.Conv2d(in_planes,
|
10 |
+
out_planes,
|
11 |
+
kernel_size=3,
|
12 |
+
stride=stride,
|
13 |
+
padding=dilation,
|
14 |
+
groups=groups,
|
15 |
+
bias=False,
|
16 |
+
dilation=dilation)
|
17 |
+
|
18 |
+
|
19 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
20 |
+
"""1x1 convolution"""
|
21 |
+
return nn.Conv2d(in_planes,
|
22 |
+
out_planes,
|
23 |
+
kernel_size=1,
|
24 |
+
stride=stride,
|
25 |
+
bias=False)
|
26 |
+
|
27 |
+
|
28 |
+
class IBasicBlock(nn.Module):
|
29 |
+
expansion = 1
|
30 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
31 |
+
groups=1, base_width=64, dilation=1):
|
32 |
+
super(IBasicBlock, self).__init__()
|
33 |
+
if groups != 1 or base_width != 64:
|
34 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
35 |
+
if dilation > 1:
|
36 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
37 |
+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05,)
|
38 |
+
self.conv1 = conv3x3(inplanes, planes)
|
39 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05,)
|
40 |
+
self.prelu = nn.PReLU(planes)
|
41 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
42 |
+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05,)
|
43 |
+
self.downsample = downsample
|
44 |
+
self.stride = stride
|
45 |
+
|
46 |
+
def forward(self, x):
|
47 |
+
identity = x
|
48 |
+
out = self.bn1(x)
|
49 |
+
out = self.conv1(out)
|
50 |
+
out = self.bn2(out)
|
51 |
+
out = self.prelu(out)
|
52 |
+
out = self.conv2(out)
|
53 |
+
out = self.bn3(out)
|
54 |
+
if self.downsample is not None:
|
55 |
+
identity = self.downsample(x)
|
56 |
+
out += identity
|
57 |
+
return out
|
58 |
+
|
59 |
+
|
60 |
+
class IResNet(nn.Module):
|
61 |
+
fc_scale = 7 * 7
|
62 |
+
def __init__(self,
|
63 |
+
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
64 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
65 |
+
super(IResNet, self).__init__()
|
66 |
+
self.fp16 = fp16
|
67 |
+
self.inplanes = 64
|
68 |
+
self.dilation = 1
|
69 |
+
if replace_stride_with_dilation is None:
|
70 |
+
replace_stride_with_dilation = [False, False, False]
|
71 |
+
if len(replace_stride_with_dilation) != 3:
|
72 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
73 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
74 |
+
self.groups = groups
|
75 |
+
self.base_width = width_per_group
|
76 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
77 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
78 |
+
self.prelu = nn.PReLU(self.inplanes)
|
79 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
80 |
+
self.layer2 = self._make_layer(block,
|
81 |
+
128,
|
82 |
+
layers[1],
|
83 |
+
stride=2,
|
84 |
+
dilate=replace_stride_with_dilation[0])
|
85 |
+
self.layer3 = self._make_layer(block,
|
86 |
+
256,
|
87 |
+
layers[2],
|
88 |
+
stride=2,
|
89 |
+
dilate=replace_stride_with_dilation[1])
|
90 |
+
self.layer4 = self._make_layer(block,
|
91 |
+
512,
|
92 |
+
layers[3],
|
93 |
+
stride=2,
|
94 |
+
dilate=replace_stride_with_dilation[2])
|
95 |
+
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05,)
|
96 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
97 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
98 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
99 |
+
nn.init.constant_(self.features.weight, 1.0)
|
100 |
+
self.features.weight.requires_grad = False
|
101 |
+
|
102 |
+
for m in self.modules():
|
103 |
+
if isinstance(m, nn.Conv2d):
|
104 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
105 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
106 |
+
nn.init.constant_(m.weight, 1)
|
107 |
+
nn.init.constant_(m.bias, 0)
|
108 |
+
|
109 |
+
if zero_init_residual:
|
110 |
+
for m in self.modules():
|
111 |
+
if isinstance(m, IBasicBlock):
|
112 |
+
nn.init.constant_(m.bn2.weight, 0)
|
113 |
+
|
114 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
115 |
+
downsample = None
|
116 |
+
previous_dilation = self.dilation
|
117 |
+
if dilate:
|
118 |
+
self.dilation *= stride
|
119 |
+
stride = 1
|
120 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
121 |
+
downsample = nn.Sequential(
|
122 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
123 |
+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
|
124 |
+
)
|
125 |
+
layers = []
|
126 |
+
layers.append(
|
127 |
+
block(self.inplanes, planes, stride, downsample, self.groups,
|
128 |
+
self.base_width, previous_dilation))
|
129 |
+
self.inplanes = planes * block.expansion
|
130 |
+
for _ in range(1, blocks):
|
131 |
+
layers.append(
|
132 |
+
block(self.inplanes,
|
133 |
+
planes,
|
134 |
+
groups=self.groups,
|
135 |
+
base_width=self.base_width,
|
136 |
+
dilation=self.dilation))
|
137 |
+
|
138 |
+
return nn.Sequential(*layers)
|
139 |
+
|
140 |
+
def forward(self, x):
|
141 |
+
with torch.cuda.amp.autocast(self.fp16):
|
142 |
+
x = self.conv1(x)
|
143 |
+
x = self.bn1(x)
|
144 |
+
x = self.prelu(x)
|
145 |
+
x = self.layer1(x)
|
146 |
+
x = self.layer2(x)
|
147 |
+
x = self.layer3(x)
|
148 |
+
x = self.layer4(x)
|
149 |
+
x = self.bn2(x)
|
150 |
+
x = torch.flatten(x, 1)
|
151 |
+
x = self.dropout(x)
|
152 |
+
x = self.fc(x.float() if self.fp16 else x)
|
153 |
+
x = self.features(x)
|
154 |
+
return x
|
155 |
+
|
156 |
+
|
157 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
158 |
+
model = IResNet(block, layers, **kwargs)
|
159 |
+
if pretrained:
|
160 |
+
raise ValueError()
|
161 |
+
return model
|
162 |
+
|
163 |
+
|
164 |
+
def iresnet18(pretrained=False, progress=True, **kwargs):
|
165 |
+
return _iresnet('iresnet18', IBasicBlock, [2, 2, 2, 2], pretrained,
|
166 |
+
progress, **kwargs)
|
167 |
+
|
168 |
+
|
169 |
+
def iresnet34(pretrained=False, progress=True, **kwargs):
|
170 |
+
return _iresnet('iresnet34', IBasicBlock, [3, 4, 6, 3], pretrained,
|
171 |
+
progress, **kwargs)
|
172 |
+
|
173 |
+
|
174 |
+
def iresnet50(pretrained=False, progress=True, **kwargs):
|
175 |
+
return _iresnet('iresnet50', IBasicBlock, [3, 4, 14, 3], pretrained,
|
176 |
+
progress, **kwargs)
|
177 |
+
|
178 |
+
|
179 |
+
def iresnet100(pretrained=False, progress=True, **kwargs):
|
180 |
+
return _iresnet('iresnet100', IBasicBlock, [3, 13, 30, 3], pretrained,
|
181 |
+
progress, **kwargs)
|
182 |
+
|
183 |
+
|
184 |
+
def iresnet200(pretrained=False, progress=True, **kwargs):
|
185 |
+
return _iresnet('iresnet200', IBasicBlock, [6, 26, 60, 6], pretrained,
|
186 |
+
progress, **kwargs)
|
187 |
+
|
src/face3d/models/arcface_torch/backbones/iresnet2060.py
ADDED
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch import nn
|
3 |
+
|
4 |
+
assert torch.__version__ >= "1.8.1"
|
5 |
+
from torch.utils.checkpoint import checkpoint_sequential
|
6 |
+
|
7 |
+
__all__ = ['iresnet2060']
|
8 |
+
|
9 |
+
|
10 |
+
def conv3x3(in_planes, out_planes, stride=1, groups=1, dilation=1):
|
11 |
+
"""3x3 convolution with padding"""
|
12 |
+
return nn.Conv2d(in_planes,
|
13 |
+
out_planes,
|
14 |
+
kernel_size=3,
|
15 |
+
stride=stride,
|
16 |
+
padding=dilation,
|
17 |
+
groups=groups,
|
18 |
+
bias=False,
|
19 |
+
dilation=dilation)
|
20 |
+
|
21 |
+
|
22 |
+
def conv1x1(in_planes, out_planes, stride=1):
|
23 |
+
"""1x1 convolution"""
|
24 |
+
return nn.Conv2d(in_planes,
|
25 |
+
out_planes,
|
26 |
+
kernel_size=1,
|
27 |
+
stride=stride,
|
28 |
+
bias=False)
|
29 |
+
|
30 |
+
|
31 |
+
class IBasicBlock(nn.Module):
|
32 |
+
expansion = 1
|
33 |
+
|
34 |
+
def __init__(self, inplanes, planes, stride=1, downsample=None,
|
35 |
+
groups=1, base_width=64, dilation=1):
|
36 |
+
super(IBasicBlock, self).__init__()
|
37 |
+
if groups != 1 or base_width != 64:
|
38 |
+
raise ValueError('BasicBlock only supports groups=1 and base_width=64')
|
39 |
+
if dilation > 1:
|
40 |
+
raise NotImplementedError("Dilation > 1 not supported in BasicBlock")
|
41 |
+
self.bn1 = nn.BatchNorm2d(inplanes, eps=1e-05, )
|
42 |
+
self.conv1 = conv3x3(inplanes, planes)
|
43 |
+
self.bn2 = nn.BatchNorm2d(planes, eps=1e-05, )
|
44 |
+
self.prelu = nn.PReLU(planes)
|
45 |
+
self.conv2 = conv3x3(planes, planes, stride)
|
46 |
+
self.bn3 = nn.BatchNorm2d(planes, eps=1e-05, )
|
47 |
+
self.downsample = downsample
|
48 |
+
self.stride = stride
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
identity = x
|
52 |
+
out = self.bn1(x)
|
53 |
+
out = self.conv1(out)
|
54 |
+
out = self.bn2(out)
|
55 |
+
out = self.prelu(out)
|
56 |
+
out = self.conv2(out)
|
57 |
+
out = self.bn3(out)
|
58 |
+
if self.downsample is not None:
|
59 |
+
identity = self.downsample(x)
|
60 |
+
out += identity
|
61 |
+
return out
|
62 |
+
|
63 |
+
|
64 |
+
class IResNet(nn.Module):
|
65 |
+
fc_scale = 7 * 7
|
66 |
+
|
67 |
+
def __init__(self,
|
68 |
+
block, layers, dropout=0, num_features=512, zero_init_residual=False,
|
69 |
+
groups=1, width_per_group=64, replace_stride_with_dilation=None, fp16=False):
|
70 |
+
super(IResNet, self).__init__()
|
71 |
+
self.fp16 = fp16
|
72 |
+
self.inplanes = 64
|
73 |
+
self.dilation = 1
|
74 |
+
if replace_stride_with_dilation is None:
|
75 |
+
replace_stride_with_dilation = [False, False, False]
|
76 |
+
if len(replace_stride_with_dilation) != 3:
|
77 |
+
raise ValueError("replace_stride_with_dilation should be None "
|
78 |
+
"or a 3-element tuple, got {}".format(replace_stride_with_dilation))
|
79 |
+
self.groups = groups
|
80 |
+
self.base_width = width_per_group
|
81 |
+
self.conv1 = nn.Conv2d(3, self.inplanes, kernel_size=3, stride=1, padding=1, bias=False)
|
82 |
+
self.bn1 = nn.BatchNorm2d(self.inplanes, eps=1e-05)
|
83 |
+
self.prelu = nn.PReLU(self.inplanes)
|
84 |
+
self.layer1 = self._make_layer(block, 64, layers[0], stride=2)
|
85 |
+
self.layer2 = self._make_layer(block,
|
86 |
+
128,
|
87 |
+
layers[1],
|
88 |
+
stride=2,
|
89 |
+
dilate=replace_stride_with_dilation[0])
|
90 |
+
self.layer3 = self._make_layer(block,
|
91 |
+
256,
|
92 |
+
layers[2],
|
93 |
+
stride=2,
|
94 |
+
dilate=replace_stride_with_dilation[1])
|
95 |
+
self.layer4 = self._make_layer(block,
|
96 |
+
512,
|
97 |
+
layers[3],
|
98 |
+
stride=2,
|
99 |
+
dilate=replace_stride_with_dilation[2])
|
100 |
+
self.bn2 = nn.BatchNorm2d(512 * block.expansion, eps=1e-05, )
|
101 |
+
self.dropout = nn.Dropout(p=dropout, inplace=True)
|
102 |
+
self.fc = nn.Linear(512 * block.expansion * self.fc_scale, num_features)
|
103 |
+
self.features = nn.BatchNorm1d(num_features, eps=1e-05)
|
104 |
+
nn.init.constant_(self.features.weight, 1.0)
|
105 |
+
self.features.weight.requires_grad = False
|
106 |
+
|
107 |
+
for m in self.modules():
|
108 |
+
if isinstance(m, nn.Conv2d):
|
109 |
+
nn.init.normal_(m.weight, 0, 0.1)
|
110 |
+
elif isinstance(m, (nn.BatchNorm2d, nn.GroupNorm)):
|
111 |
+
nn.init.constant_(m.weight, 1)
|
112 |
+
nn.init.constant_(m.bias, 0)
|
113 |
+
|
114 |
+
if zero_init_residual:
|
115 |
+
for m in self.modules():
|
116 |
+
if isinstance(m, IBasicBlock):
|
117 |
+
nn.init.constant_(m.bn2.weight, 0)
|
118 |
+
|
119 |
+
def _make_layer(self, block, planes, blocks, stride=1, dilate=False):
|
120 |
+
downsample = None
|
121 |
+
previous_dilation = self.dilation
|
122 |
+
if dilate:
|
123 |
+
self.dilation *= stride
|
124 |
+
stride = 1
|
125 |
+
if stride != 1 or self.inplanes != planes * block.expansion:
|
126 |
+
downsample = nn.Sequential(
|
127 |
+
conv1x1(self.inplanes, planes * block.expansion, stride),
|
128 |
+
nn.BatchNorm2d(planes * block.expansion, eps=1e-05, ),
|
129 |
+
)
|
130 |
+
layers = []
|
131 |
+
layers.append(
|
132 |
+
block(self.inplanes, planes, stride, downsample, self.groups,
|
133 |
+
self.base_width, previous_dilation))
|
134 |
+
self.inplanes = planes * block.expansion
|
135 |
+
for _ in range(1, blocks):
|
136 |
+
layers.append(
|
137 |
+
block(self.inplanes,
|
138 |
+
planes,
|
139 |
+
groups=self.groups,
|
140 |
+
base_width=self.base_width,
|
141 |
+
dilation=self.dilation))
|
142 |
+
|
143 |
+
return nn.Sequential(*layers)
|
144 |
+
|
145 |
+
def checkpoint(self, func, num_seg, x):
|
146 |
+
if self.training:
|
147 |
+
return checkpoint_sequential(func, num_seg, x)
|
148 |
+
else:
|
149 |
+
return func(x)
|
150 |
+
|
151 |
+
def forward(self, x):
|
152 |
+
with torch.cuda.amp.autocast(self.fp16):
|
153 |
+
x = self.conv1(x)
|
154 |
+
x = self.bn1(x)
|
155 |
+
x = self.prelu(x)
|
156 |
+
x = self.layer1(x)
|
157 |
+
x = self.checkpoint(self.layer2, 20, x)
|
158 |
+
x = self.checkpoint(self.layer3, 100, x)
|
159 |
+
x = self.layer4(x)
|
160 |
+
x = self.bn2(x)
|
161 |
+
x = torch.flatten(x, 1)
|
162 |
+
x = self.dropout(x)
|
163 |
+
x = self.fc(x.float() if self.fp16 else x)
|
164 |
+
x = self.features(x)
|
165 |
+
return x
|
166 |
+
|
167 |
+
|
168 |
+
def _iresnet(arch, block, layers, pretrained, progress, **kwargs):
|
169 |
+
model = IResNet(block, layers, **kwargs)
|
170 |
+
if pretrained:
|
171 |
+
raise ValueError()
|
172 |
+
return model
|
173 |
+
|
174 |
+
|
175 |
+
def iresnet2060(pretrained=False, progress=True, **kwargs):
|
176 |
+
return _iresnet('iresnet2060', IBasicBlock, [3, 128, 1024 - 128, 3], pretrained, progress, **kwargs)
|
src/face3d/models/arcface_torch/backbones/mobilefacenet.py
ADDED
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
'''
|
2 |
+
Adapted from https://github.com/cavalleria/cavaface.pytorch/blob/master/backbone/mobilefacenet.py
|
3 |
+
Original author cavalleria
|
4 |
+
'''
|
5 |
+
|
6 |
+
import torch.nn as nn
|
7 |
+
from torch.nn import Linear, Conv2d, BatchNorm1d, BatchNorm2d, PReLU, Sequential, Module
|
8 |
+
import torch
|
9 |
+
|
10 |
+
|
11 |
+
class Flatten(Module):
|
12 |
+
def forward(self, x):
|
13 |
+
return x.view(x.size(0), -1)
|
14 |
+
|
15 |
+
|
16 |
+
class ConvBlock(Module):
|
17 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
18 |
+
super(ConvBlock, self).__init__()
|
19 |
+
self.layers = nn.Sequential(
|
20 |
+
Conv2d(in_c, out_c, kernel, groups=groups, stride=stride, padding=padding, bias=False),
|
21 |
+
BatchNorm2d(num_features=out_c),
|
22 |
+
PReLU(num_parameters=out_c)
|
23 |
+
)
|
24 |
+
|
25 |
+
def forward(self, x):
|
26 |
+
return self.layers(x)
|
27 |
+
|
28 |
+
|
29 |
+
class LinearBlock(Module):
|
30 |
+
def __init__(self, in_c, out_c, kernel=(1, 1), stride=(1, 1), padding=(0, 0), groups=1):
|
31 |
+
super(LinearBlock, self).__init__()
|
32 |
+
self.layers = nn.Sequential(
|
33 |
+
Conv2d(in_c, out_c, kernel, stride, padding, groups=groups, bias=False),
|
34 |
+
BatchNorm2d(num_features=out_c)
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.layers(x)
|
39 |
+
|
40 |
+
|
41 |
+
class DepthWise(Module):
|
42 |
+
def __init__(self, in_c, out_c, residual=False, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=1):
|
43 |
+
super(DepthWise, self).__init__()
|
44 |
+
self.residual = residual
|
45 |
+
self.layers = nn.Sequential(
|
46 |
+
ConvBlock(in_c, out_c=groups, kernel=(1, 1), padding=(0, 0), stride=(1, 1)),
|
47 |
+
ConvBlock(groups, groups, groups=groups, kernel=kernel, padding=padding, stride=stride),
|
48 |
+
LinearBlock(groups, out_c, kernel=(1, 1), padding=(0, 0), stride=(1, 1))
|
49 |
+
)
|
50 |
+
|
51 |
+
def forward(self, x):
|
52 |
+
short_cut = None
|
53 |
+
if self.residual:
|
54 |
+
short_cut = x
|
55 |
+
x = self.layers(x)
|
56 |
+
if self.residual:
|
57 |
+
output = short_cut + x
|
58 |
+
else:
|
59 |
+
output = x
|
60 |
+
return output
|
61 |
+
|
62 |
+
|
63 |
+
class Residual(Module):
|
64 |
+
def __init__(self, c, num_block, groups, kernel=(3, 3), stride=(1, 1), padding=(1, 1)):
|
65 |
+
super(Residual, self).__init__()
|
66 |
+
modules = []
|
67 |
+
for _ in range(num_block):
|
68 |
+
modules.append(DepthWise(c, c, True, kernel, stride, padding, groups))
|
69 |
+
self.layers = Sequential(*modules)
|
70 |
+
|
71 |
+
def forward(self, x):
|
72 |
+
return self.layers(x)
|
73 |
+
|
74 |
+
|
75 |
+
class GDC(Module):
|
76 |
+
def __init__(self, embedding_size):
|
77 |
+
super(GDC, self).__init__()
|
78 |
+
self.layers = nn.Sequential(
|
79 |
+
LinearBlock(512, 512, groups=512, kernel=(7, 7), stride=(1, 1), padding=(0, 0)),
|
80 |
+
Flatten(),
|
81 |
+
Linear(512, embedding_size, bias=False),
|
82 |
+
BatchNorm1d(embedding_size))
|
83 |
+
|
84 |
+
def forward(self, x):
|
85 |
+
return self.layers(x)
|
86 |
+
|
87 |
+
|
88 |
+
class MobileFaceNet(Module):
|
89 |
+
def __init__(self, fp16=False, num_features=512):
|
90 |
+
super(MobileFaceNet, self).__init__()
|
91 |
+
scale = 2
|
92 |
+
self.fp16 = fp16
|
93 |
+
self.layers = nn.Sequential(
|
94 |
+
ConvBlock(3, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1)),
|
95 |
+
ConvBlock(64 * scale, 64 * scale, kernel=(3, 3), stride=(1, 1), padding=(1, 1), groups=64),
|
96 |
+
DepthWise(64 * scale, 64 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=128),
|
97 |
+
Residual(64 * scale, num_block=4, groups=128, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
98 |
+
DepthWise(64 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=256),
|
99 |
+
Residual(128 * scale, num_block=6, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
100 |
+
DepthWise(128 * scale, 128 * scale, kernel=(3, 3), stride=(2, 2), padding=(1, 1), groups=512),
|
101 |
+
Residual(128 * scale, num_block=2, groups=256, kernel=(3, 3), stride=(1, 1), padding=(1, 1)),
|
102 |
+
)
|
103 |
+
self.conv_sep = ConvBlock(128 * scale, 512, kernel=(1, 1), stride=(1, 1), padding=(0, 0))
|
104 |
+
self.features = GDC(num_features)
|
105 |
+
self._initialize_weights()
|
106 |
+
|
107 |
+
def _initialize_weights(self):
|
108 |
+
for m in self.modules():
|
109 |
+
if isinstance(m, nn.Conv2d):
|
110 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
111 |
+
if m.bias is not None:
|
112 |
+
m.bias.data.zero_()
|
113 |
+
elif isinstance(m, nn.BatchNorm2d):
|
114 |
+
m.weight.data.fill_(1)
|
115 |
+
m.bias.data.zero_()
|
116 |
+
elif isinstance(m, nn.Linear):
|
117 |
+
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
|
118 |
+
if m.bias is not None:
|
119 |
+
m.bias.data.zero_()
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
with torch.cuda.amp.autocast(self.fp16):
|
123 |
+
x = self.layers(x)
|
124 |
+
x = self.conv_sep(x.float() if self.fp16 else x)
|
125 |
+
x = self.features(x)
|
126 |
+
return x
|
127 |
+
|
128 |
+
|
129 |
+
def get_mbf(fp16, num_features):
|
130 |
+
return MobileFaceNet(fp16, num_features)
|
src/face3d/models/arcface_torch/configs/3millions.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# configs for test speed
|
4 |
+
|
5 |
+
config = edict()
|
6 |
+
config.loss = "arcface"
|
7 |
+
config.network = "r50"
|
8 |
+
config.resume = False
|
9 |
+
config.output = None
|
10 |
+
config.embedding_size = 512
|
11 |
+
config.sample_rate = 1.0
|
12 |
+
config.fp16 = True
|
13 |
+
config.momentum = 0.9
|
14 |
+
config.weight_decay = 5e-4
|
15 |
+
config.batch_size = 128
|
16 |
+
config.lr = 0.1 # batch size is 512
|
17 |
+
|
18 |
+
config.rec = "synthetic"
|
19 |
+
config.num_classes = 300 * 10000
|
20 |
+
config.num_epoch = 30
|
21 |
+
config.warmup_epoch = -1
|
22 |
+
config.decay_epoch = [10, 16, 22]
|
23 |
+
config.val_targets = []
|
src/face3d/models/arcface_torch/configs/3millions_pfc.py
ADDED
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# configs for test speed
|
4 |
+
|
5 |
+
config = edict()
|
6 |
+
config.loss = "arcface"
|
7 |
+
config.network = "r50"
|
8 |
+
config.resume = False
|
9 |
+
config.output = None
|
10 |
+
config.embedding_size = 512
|
11 |
+
config.sample_rate = 0.1
|
12 |
+
config.fp16 = True
|
13 |
+
config.momentum = 0.9
|
14 |
+
config.weight_decay = 5e-4
|
15 |
+
config.batch_size = 128
|
16 |
+
config.lr = 0.1 # batch size is 512
|
17 |
+
|
18 |
+
config.rec = "synthetic"
|
19 |
+
config.num_classes = 300 * 10000
|
20 |
+
config.num_epoch = 30
|
21 |
+
config.warmup_epoch = -1
|
22 |
+
config.decay_epoch = [10, 16, 22]
|
23 |
+
config.val_targets = []
|
src/face3d/models/arcface_torch/configs/__init__.py
ADDED
File without changes
|
src/face3d/models/arcface_torch/configs/base.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "arcface"
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = "ms1mv3_arcface_r50"
|
12 |
+
|
13 |
+
config.dataset = "ms1m-retinaface-t1"
|
14 |
+
config.embedding_size = 512
|
15 |
+
config.sample_rate = 1
|
16 |
+
config.fp16 = False
|
17 |
+
config.momentum = 0.9
|
18 |
+
config.weight_decay = 5e-4
|
19 |
+
config.batch_size = 128
|
20 |
+
config.lr = 0.1 # batch size is 512
|
21 |
+
|
22 |
+
if config.dataset == "emore":
|
23 |
+
config.rec = "/train_tmp/faces_emore"
|
24 |
+
config.num_classes = 85742
|
25 |
+
config.num_image = 5822653
|
26 |
+
config.num_epoch = 16
|
27 |
+
config.warmup_epoch = -1
|
28 |
+
config.decay_epoch = [8, 14, ]
|
29 |
+
config.val_targets = ["lfw", ]
|
30 |
+
|
31 |
+
elif config.dataset == "ms1m-retinaface-t1":
|
32 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
33 |
+
config.num_classes = 93431
|
34 |
+
config.num_image = 5179510
|
35 |
+
config.num_epoch = 25
|
36 |
+
config.warmup_epoch = -1
|
37 |
+
config.decay_epoch = [11, 17, 22]
|
38 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
39 |
+
|
40 |
+
elif config.dataset == "glint360k":
|
41 |
+
config.rec = "/train_tmp/glint360k"
|
42 |
+
config.num_classes = 360232
|
43 |
+
config.num_image = 17091657
|
44 |
+
config.num_epoch = 20
|
45 |
+
config.warmup_epoch = -1
|
46 |
+
config.decay_epoch = [8, 12, 15, 18]
|
47 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
48 |
+
|
49 |
+
elif config.dataset == "webface":
|
50 |
+
config.rec = "/train_tmp/faces_webface_112x112"
|
51 |
+
config.num_classes = 10572
|
52 |
+
config.num_image = "forget"
|
53 |
+
config.num_epoch = 34
|
54 |
+
config.warmup_epoch = -1
|
55 |
+
config.decay_epoch = [20, 28, 32]
|
56 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
src/face3d/models/arcface_torch/configs/glint360k_mbf.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "mbf"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 0.1
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 2e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
src/face3d/models/arcface_torch/configs/glint360k_r100.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "r100"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
src/face3d/models/arcface_torch/configs/glint360k_r18.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "r18"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
src/face3d/models/arcface_torch/configs/glint360k_r34.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "r34"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
src/face3d/models/arcface_torch/configs/glint360k_r50.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "cosface"
|
9 |
+
config.network = "r50"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 5e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/glint360k"
|
21 |
+
config.num_classes = 360232
|
22 |
+
config.num_image = 17091657
|
23 |
+
config.num_epoch = 20
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [8, 12, 15, 18]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|
src/face3d/models/arcface_torch/configs/ms1mv3_mbf.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from easydict import EasyDict as edict
|
2 |
+
|
3 |
+
# make training faster
|
4 |
+
# our RAM is 256G
|
5 |
+
# mount -t tmpfs -o size=140G tmpfs /train_tmp
|
6 |
+
|
7 |
+
config = edict()
|
8 |
+
config.loss = "arcface"
|
9 |
+
config.network = "mbf"
|
10 |
+
config.resume = False
|
11 |
+
config.output = None
|
12 |
+
config.embedding_size = 512
|
13 |
+
config.sample_rate = 1.0
|
14 |
+
config.fp16 = True
|
15 |
+
config.momentum = 0.9
|
16 |
+
config.weight_decay = 2e-4
|
17 |
+
config.batch_size = 128
|
18 |
+
config.lr = 0.1 # batch size is 512
|
19 |
+
|
20 |
+
config.rec = "/train_tmp/ms1m-retinaface-t1"
|
21 |
+
config.num_classes = 93431
|
22 |
+
config.num_image = 5179510
|
23 |
+
config.num_epoch = 30
|
24 |
+
config.warmup_epoch = -1
|
25 |
+
config.decay_epoch = [10, 20, 25]
|
26 |
+
config.val_targets = ["lfw", "cfp_fp", "agedb_30"]
|