Spaces:
Runtime error
Runtime error
init
Browse files- CITATION.cff +17 -0
- LICENSE +2 -0
- LICENSE.apache-2.0 +201 -0
- LICENSE.cc-by-nc-sa-4.0 +437 -0
- README.md +5 -5
- app.py +141 -0
- configs/dalle-1.3B.yaml +33 -0
- configs/transfer-imagenet-clscond-gen.yaml +49 -0
- configs/transfer-imagenet-uncond-gen.yaml +48 -0
- dalle/__init__.py +0 -0
- dalle/models/__init__.py +202 -0
- dalle/models/stage1/layers.py +373 -0
- dalle/models/stage1/vqgan.py +93 -0
- dalle/models/stage2/layers.py +140 -0
- dalle/models/stage2/transformer.py +255 -0
- dalle/models/tokenizer.py +26 -0
- dalle/utils/__init__.py +3 -0
- dalle/utils/config.py +123 -0
- dalle/utils/sampling.py +152 -0
- dalle/utils/utils.py +84 -0
- examples/sampling_ex.py +63 -0
- examples/sampling_interactive_demo.ipynb +298 -0
- examples/transfer_learning_ex.py +172 -0
- requirements.txt +10 -0
- setup.cfg +3 -0
CITATION.cff
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cff-version: 1.2.0
|
2 |
+
message: "If you find this repository useful in your research, please cite"
|
3 |
+
authors:
|
4 |
+
- family-names: Kim
|
5 |
+
given-names: Saehoon
|
6 |
+
- family-names: Cho
|
7 |
+
given-names: Sanghun
|
8 |
+
- family-names: Kim
|
9 |
+
given-names: Chiheon
|
10 |
+
- family-names: Lee
|
11 |
+
given-names: Doyup
|
12 |
+
- family-names: Baek
|
13 |
+
given-names: Woonhyuk
|
14 |
+
title: "minDALL-E on Conceptual Captions"
|
15 |
+
version: 0.1
|
16 |
+
date-released: 2021-12-14
|
17 |
+
repository-code: https://github.com/kakaobrain/minDALL-E
|
LICENSE
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
The `source codes` are licensed under [Apache 2.0](LICENSE.apache-2.0) License.
|
2 |
+
The `stage2 pretrained weights` are licensed under [CC-BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) License.
|
LICENSE.apache-2.0
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [2021.11.13] [Kakao Brain]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
LICENSE.cc-by-nc-sa-4.0
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More_considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
58 |
+
Public License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
63 |
+
("Public License"). To the extent this Public License may be
|
64 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
65 |
+
consideration of Your acceptance of these terms and conditions, and the
|
66 |
+
Licensor grants You such rights in consideration of benefits the
|
67 |
+
Licensor receives from making the Licensed Material available under
|
68 |
+
these terms and conditions.
|
69 |
+
|
70 |
+
|
71 |
+
Section 1 -- Definitions.
|
72 |
+
|
73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
74 |
+
Rights that is derived from or based upon the Licensed Material
|
75 |
+
and in which the Licensed Material is translated, altered,
|
76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
77 |
+
permission under the Copyright and Similar Rights held by the
|
78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
79 |
+
Material is a musical work, performance, or sound recording,
|
80 |
+
Adapted Material is always produced where the Licensed Material is
|
81 |
+
synched in timed relation with a moving image.
|
82 |
+
|
83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
85 |
+
accordance with the terms and conditions of this Public License.
|
86 |
+
|
87 |
+
c. BY-NC-SA Compatible License means a license listed at
|
88 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
89 |
+
Commons as essentially the equivalent of this Public License.
|
90 |
+
|
91 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
92 |
+
closely related to copyright including, without limitation,
|
93 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
94 |
+
Rights, without regard to how the rights are labeled or
|
95 |
+
categorized. For purposes of this Public License, the rights
|
96 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
97 |
+
Rights.
|
98 |
+
|
99 |
+
e. Effective Technological Measures means those measures that, in the
|
100 |
+
absence of proper authority, may not be circumvented under laws
|
101 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
102 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
103 |
+
agreements.
|
104 |
+
|
105 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
106 |
+
any other exception or limitation to Copyright and Similar Rights
|
107 |
+
that applies to Your use of the Licensed Material.
|
108 |
+
|
109 |
+
g. License Elements means the license attributes listed in the name
|
110 |
+
of a Creative Commons Public License. The License Elements of this
|
111 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
112 |
+
|
113 |
+
h. Licensed Material means the artistic or literary work, database,
|
114 |
+
or other material to which the Licensor applied this Public
|
115 |
+
License.
|
116 |
+
|
117 |
+
i. Licensed Rights means the rights granted to You subject to the
|
118 |
+
terms and conditions of this Public License, which are limited to
|
119 |
+
all Copyright and Similar Rights that apply to Your use of the
|
120 |
+
Licensed Material and that the Licensor has authority to license.
|
121 |
+
|
122 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
123 |
+
under this Public License.
|
124 |
+
|
125 |
+
k. NonCommercial means not primarily intended for or directed towards
|
126 |
+
commercial advantage or monetary compensation. For purposes of
|
127 |
+
this Public License, the exchange of the Licensed Material for
|
128 |
+
other material subject to Copyright and Similar Rights by digital
|
129 |
+
file-sharing or similar means is NonCommercial provided there is
|
130 |
+
no payment of monetary compensation in connection with the
|
131 |
+
exchange.
|
132 |
+
|
133 |
+
l. Share means to provide material to the public by any means or
|
134 |
+
process that requires permission under the Licensed Rights, such
|
135 |
+
as reproduction, public display, public performance, distribution,
|
136 |
+
dissemination, communication, or importation, and to make material
|
137 |
+
available to the public including in ways that members of the
|
138 |
+
public may access the material from a place and at a time
|
139 |
+
individually chosen by them.
|
140 |
+
|
141 |
+
m. Sui Generis Database Rights means rights other than copyright
|
142 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
143 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
144 |
+
as amended and/or succeeded, as well as other essentially
|
145 |
+
equivalent rights anywhere in the world.
|
146 |
+
|
147 |
+
n. You means the individual or entity exercising the Licensed Rights
|
148 |
+
under this Public License. Your has a corresponding meaning.
|
149 |
+
|
150 |
+
|
151 |
+
Section 2 -- Scope.
|
152 |
+
|
153 |
+
a. License grant.
|
154 |
+
|
155 |
+
1. Subject to the terms and conditions of this Public License,
|
156 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
157 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
158 |
+
exercise the Licensed Rights in the Licensed Material to:
|
159 |
+
|
160 |
+
a. reproduce and Share the Licensed Material, in whole or
|
161 |
+
in part, for NonCommercial purposes only; and
|
162 |
+
|
163 |
+
b. produce, reproduce, and Share Adapted Material for
|
164 |
+
NonCommercial purposes only.
|
165 |
+
|
166 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
167 |
+
Exceptions and Limitations apply to Your use, this Public
|
168 |
+
License does not apply, and You do not need to comply with
|
169 |
+
its terms and conditions.
|
170 |
+
|
171 |
+
3. Term. The term of this Public License is specified in Section
|
172 |
+
6(a).
|
173 |
+
|
174 |
+
4. Media and formats; technical modifications allowed. The
|
175 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
176 |
+
all media and formats whether now known or hereafter created,
|
177 |
+
and to make technical modifications necessary to do so. The
|
178 |
+
Licensor waives and/or agrees not to assert any right or
|
179 |
+
authority to forbid You from making technical modifications
|
180 |
+
necessary to exercise the Licensed Rights, including
|
181 |
+
technical modifications necessary to circumvent Effective
|
182 |
+
Technological Measures. For purposes of this Public License,
|
183 |
+
simply making modifications authorized by this Section 2(a)
|
184 |
+
(4) never produces Adapted Material.
|
185 |
+
|
186 |
+
5. Downstream recipients.
|
187 |
+
|
188 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
189 |
+
recipient of the Licensed Material automatically
|
190 |
+
receives an offer from the Licensor to exercise the
|
191 |
+
Licensed Rights under the terms and conditions of this
|
192 |
+
Public License.
|
193 |
+
|
194 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
195 |
+
Every recipient of Adapted Material from You
|
196 |
+
automatically receives an offer from the Licensor to
|
197 |
+
exercise the Licensed Rights in the Adapted Material
|
198 |
+
under the conditions of the Adapter's License You apply.
|
199 |
+
|
200 |
+
c. No downstream restrictions. You may not offer or impose
|
201 |
+
any additional or different terms or conditions on, or
|
202 |
+
apply any Effective Technological Measures to, the
|
203 |
+
Licensed Material if doing so restricts exercise of the
|
204 |
+
Licensed Rights by any recipient of the Licensed
|
205 |
+
Material.
|
206 |
+
|
207 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
208 |
+
may be construed as permission to assert or imply that You
|
209 |
+
are, or that Your use of the Licensed Material is, connected
|
210 |
+
with, or sponsored, endorsed, or granted official status by,
|
211 |
+
the Licensor or others designated to receive attribution as
|
212 |
+
provided in Section 3(a)(1)(A)(i).
|
213 |
+
|
214 |
+
b. Other rights.
|
215 |
+
|
216 |
+
1. Moral rights, such as the right of integrity, are not
|
217 |
+
licensed under this Public License, nor are publicity,
|
218 |
+
privacy, and/or other similar personality rights; however, to
|
219 |
+
the extent possible, the Licensor waives and/or agrees not to
|
220 |
+
assert any such rights held by the Licensor to the limited
|
221 |
+
extent necessary to allow You to exercise the Licensed
|
222 |
+
Rights, but not otherwise.
|
223 |
+
|
224 |
+
2. Patent and trademark rights are not licensed under this
|
225 |
+
Public License.
|
226 |
+
|
227 |
+
3. To the extent possible, the Licensor waives any right to
|
228 |
+
collect royalties from You for the exercise of the Licensed
|
229 |
+
Rights, whether directly or through a collecting society
|
230 |
+
under any voluntary or waivable statutory or compulsory
|
231 |
+
licensing scheme. In all other cases the Licensor expressly
|
232 |
+
reserves any right to collect such royalties, including when
|
233 |
+
the Licensed Material is used other than for NonCommercial
|
234 |
+
purposes.
|
235 |
+
|
236 |
+
|
237 |
+
Section 3 -- License Conditions.
|
238 |
+
|
239 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
240 |
+
following conditions.
|
241 |
+
|
242 |
+
a. Attribution.
|
243 |
+
|
244 |
+
1. If You Share the Licensed Material (including in modified
|
245 |
+
form), You must:
|
246 |
+
|
247 |
+
a. retain the following if it is supplied by the Licensor
|
248 |
+
with the Licensed Material:
|
249 |
+
|
250 |
+
i. identification of the creator(s) of the Licensed
|
251 |
+
Material and any others designated to receive
|
252 |
+
attribution, in any reasonable manner requested by
|
253 |
+
the Licensor (including by pseudonym if
|
254 |
+
designated);
|
255 |
+
|
256 |
+
ii. a copyright notice;
|
257 |
+
|
258 |
+
iii. a notice that refers to this Public License;
|
259 |
+
|
260 |
+
iv. a notice that refers to the disclaimer of
|
261 |
+
warranties;
|
262 |
+
|
263 |
+
v. a URI or hyperlink to the Licensed Material to the
|
264 |
+
extent reasonably practicable;
|
265 |
+
|
266 |
+
b. indicate if You modified the Licensed Material and
|
267 |
+
retain an indication of any previous modifications; and
|
268 |
+
|
269 |
+
c. indicate the Licensed Material is licensed under this
|
270 |
+
Public License, and include the text of, or the URI or
|
271 |
+
hyperlink to, this Public License.
|
272 |
+
|
273 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
274 |
+
reasonable manner based on the medium, means, and context in
|
275 |
+
which You Share the Licensed Material. For example, it may be
|
276 |
+
reasonable to satisfy the conditions by providing a URI or
|
277 |
+
hyperlink to a resource that includes the required
|
278 |
+
information.
|
279 |
+
3. If requested by the Licensor, You must remove any of the
|
280 |
+
information required by Section 3(a)(1)(A) to the extent
|
281 |
+
reasonably practicable.
|
282 |
+
|
283 |
+
b. ShareAlike.
|
284 |
+
|
285 |
+
In addition to the conditions in Section 3(a), if You Share
|
286 |
+
Adapted Material You produce, the following conditions also apply.
|
287 |
+
|
288 |
+
1. The Adapter's License You apply must be a Creative Commons
|
289 |
+
license with the same License Elements, this version or
|
290 |
+
later, or a BY-NC-SA Compatible License.
|
291 |
+
|
292 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
293 |
+
Adapter's License You apply. You may satisfy this condition
|
294 |
+
in any reasonable manner based on the medium, means, and
|
295 |
+
context in which You Share Adapted Material.
|
296 |
+
|
297 |
+
3. You may not offer or impose any additional or different terms
|
298 |
+
or conditions on, or apply any Effective Technological
|
299 |
+
Measures to, Adapted Material that restrict exercise of the
|
300 |
+
rights granted under the Adapter's License You apply.
|
301 |
+
|
302 |
+
|
303 |
+
Section 4 -- Sui Generis Database Rights.
|
304 |
+
|
305 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
306 |
+
apply to Your use of the Licensed Material:
|
307 |
+
|
308 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
309 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
310 |
+
portion of the contents of the database for NonCommercial purposes
|
311 |
+
only;
|
312 |
+
|
313 |
+
b. if You include all or a substantial portion of the database
|
314 |
+
contents in a database in which You have Sui Generis Database
|
315 |
+
Rights, then the database in which You have Sui Generis Database
|
316 |
+
Rights (but not its individual contents) is Adapted Material,
|
317 |
+
including for purposes of Section 3(b); and
|
318 |
+
|
319 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
320 |
+
all or a substantial portion of the contents of the database.
|
321 |
+
|
322 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
323 |
+
replace Your obligations under this Public License where the Licensed
|
324 |
+
Rights include other Copyright and Similar Rights.
|
325 |
+
|
326 |
+
|
327 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
328 |
+
|
329 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
330 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
331 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
332 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
333 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
334 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
335 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
336 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
337 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
338 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
339 |
+
|
340 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
341 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
342 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
343 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
344 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
345 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
346 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
347 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
348 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
349 |
+
|
350 |
+
c. The disclaimer of warranties and limitation of liability provided
|
351 |
+
above shall be interpreted in a manner that, to the extent
|
352 |
+
possible, most closely approximates an absolute disclaimer and
|
353 |
+
waiver of all liability.
|
354 |
+
|
355 |
+
|
356 |
+
Section 6 -- Term and Termination.
|
357 |
+
|
358 |
+
a. This Public License applies for the term of the Copyright and
|
359 |
+
Similar Rights licensed here. However, if You fail to comply with
|
360 |
+
this Public License, then Your rights under this Public License
|
361 |
+
terminate automatically.
|
362 |
+
|
363 |
+
b. Where Your right to use the Licensed Material has terminated under
|
364 |
+
Section 6(a), it reinstates:
|
365 |
+
|
366 |
+
1. automatically as of the date the violation is cured, provided
|
367 |
+
it is cured within 30 days of Your discovery of the
|
368 |
+
violation; or
|
369 |
+
|
370 |
+
2. upon express reinstatement by the Licensor.
|
371 |
+
|
372 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
373 |
+
right the Licensor may have to seek remedies for Your violations
|
374 |
+
of this Public License.
|
375 |
+
|
376 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
377 |
+
Licensed Material under separate terms or conditions or stop
|
378 |
+
distributing the Licensed Material at any time; however, doing so
|
379 |
+
will not terminate this Public License.
|
380 |
+
|
381 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
382 |
+
License.
|
383 |
+
|
384 |
+
|
385 |
+
Section 7 -- Other Terms and Conditions.
|
386 |
+
|
387 |
+
a. The Licensor shall not be bound by any additional or different
|
388 |
+
terms or conditions communicated by You unless expressly agreed.
|
389 |
+
|
390 |
+
b. Any arrangements, understandings, or agreements regarding the
|
391 |
+
Licensed Material not stated herein are separate from and
|
392 |
+
independent of the terms and conditions of this Public License.
|
393 |
+
|
394 |
+
|
395 |
+
Section 8 -- Interpretation.
|
396 |
+
|
397 |
+
a. For the avoidance of doubt, this Public License does not, and
|
398 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
399 |
+
conditions on any use of the Licensed Material that could lawfully
|
400 |
+
be made without permission under this Public License.
|
401 |
+
|
402 |
+
b. To the extent possible, if any provision of this Public License is
|
403 |
+
deemed unenforceable, it shall be automatically reformed to the
|
404 |
+
minimum extent necessary to make it enforceable. If the provision
|
405 |
+
cannot be reformed, it shall be severed from this Public License
|
406 |
+
without affecting the enforceability of the remaining terms and
|
407 |
+
conditions.
|
408 |
+
|
409 |
+
c. No term or condition of this Public License will be waived and no
|
410 |
+
failure to comply consented to unless expressly agreed to by the
|
411 |
+
Licensor.
|
412 |
+
|
413 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
414 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
415 |
+
that apply to the Licensor or You, including from the legal
|
416 |
+
processes of any jurisdiction or authority.
|
417 |
+
|
418 |
+
=======================================================================
|
419 |
+
|
420 |
+
Creative Commons is not a party to its public
|
421 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
422 |
+
its public licenses to material it publishes and in those instances
|
423 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
424 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
425 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
426 |
+
material is shared under a Creative Commons public license or as
|
427 |
+
otherwise permitted by the Creative Commons policies published at
|
428 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
429 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
430 |
+
of Creative Commons without its prior written consent including,
|
431 |
+
without limitation, in connection with any unauthorized modifications
|
432 |
+
to any of its public licenses or any other arrangements,
|
433 |
+
understandings, or agreements concerning use of licensed material. For
|
434 |
+
the avoidance of doubt, this paragraph does not form part of the
|
435 |
+
public licenses.
|
436 |
+
|
437 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
CHANGED
@@ -1,9 +1,9 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
emoji:
|
4 |
-
colorFrom:
|
5 |
-
colorTo:
|
6 |
-
sdk:
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
---
|
|
|
1 |
---
|
2 |
+
title: MinDALL E
|
3 |
+
emoji: 🔥
|
4 |
+
colorFrom: red
|
5 |
+
colorTo: yellow
|
6 |
+
sdk: gradio
|
7 |
app_file: app.py
|
8 |
pinned: false
|
9 |
---
|
app.py
ADDED
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import streamlit as st
|
6 |
+
from PIL import Image
|
7 |
+
|
8 |
+
# import clip
|
9 |
+
|
10 |
+
|
11 |
+
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
12 |
+
|
13 |
+
# import gradio as gr
|
14 |
+
# from dalle.models import Dalle
|
15 |
+
# from dalle.utils.utils import clip_score, set_seed
|
16 |
+
|
17 |
+
|
18 |
+
device = "cpu"
|
19 |
+
# model = Dalle.from_pretrained("minDALL-E/1.3B") # This will automatically download the pretrained model.
|
20 |
+
# model.to(device=device)
|
21 |
+
|
22 |
+
# model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
|
23 |
+
# model_clip.to(device=device)
|
24 |
+
|
25 |
+
|
26 |
+
# def sample(prompt):
|
27 |
+
# # Sampling
|
28 |
+
# images = (
|
29 |
+
# model.sampling(prompt=prompt, top_k=256, top_p=None, softmax_temperature=1.0, num_candidates=3, device=device)
|
30 |
+
# .cpu()
|
31 |
+
# .numpy()
|
32 |
+
# )
|
33 |
+
# images = np.transpose(images, (0, 2, 3, 1))
|
34 |
+
|
35 |
+
# # CLIP Re-ranking
|
36 |
+
# rank = clip_score(
|
37 |
+
# prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device
|
38 |
+
# )
|
39 |
+
|
40 |
+
# # Save images
|
41 |
+
# images = images[rank]
|
42 |
+
# # print(rank, images.shape)
|
43 |
+
# pil_images = []
|
44 |
+
# for i in range(len(images)):
|
45 |
+
# im = Image.fromarray((images[i] * 255).astype(np.uint8))
|
46 |
+
# pil_images.append(im)
|
47 |
+
|
48 |
+
# # im = Image.fromarray((images[0] * 255).astype(np.uint8))
|
49 |
+
# return pil_images
|
50 |
+
|
51 |
+
|
52 |
+
# title = "Interactive demo: ImageGPT"
|
53 |
+
# description = "Demo for OpenAI's ImageGPT: Generative Pretraining from Pixels. To use it, simply upload an image or use the example image below and click 'submit'. Results will show up in a few seconds."
|
54 |
+
# article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>ImageGPT: Generative Pretraining from Pixels</a> | <a href='https://openai.com/blog/image-gpt/'>Official blog</a></p>"
|
55 |
+
|
56 |
+
# iface = gr.Interface(
|
57 |
+
# fn=sample,
|
58 |
+
# inputs=[gr.inputs.Textbox(label="What would you like to see?")],
|
59 |
+
# outputs=gr.outputs.Image(type="pil", label="Model input + completions"),
|
60 |
+
# title=title,
|
61 |
+
# description=description,
|
62 |
+
# article=article,
|
63 |
+
# #examples=examples,
|
64 |
+
# enable_queue=True,
|
65 |
+
# )
|
66 |
+
# iface.launch(debug=True)
|
67 |
+
|
68 |
+
#!/usr/bin/env python
|
69 |
+
# coding: utf-8
|
70 |
+
|
71 |
+
|
72 |
+
st.sidebar.markdown(
|
73 |
+
"""
|
74 |
+
<style>
|
75 |
+
.aligncenter {
|
76 |
+
text-align: center;
|
77 |
+
}
|
78 |
+
</style>
|
79 |
+
<p class="aligncenter">
|
80 |
+
<img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
|
81 |
+
</p>
|
82 |
+
""",
|
83 |
+
unsafe_allow_html=True,
|
84 |
+
)
|
85 |
+
st.sidebar.markdown(
|
86 |
+
"""
|
87 |
+
___
|
88 |
+
<p style='text-align: center'>
|
89 |
+
DALL·E mini is an AI model that generates images from any prompt you give!
|
90 |
+
</p>
|
91 |
+
|
92 |
+
<p style='text-align: center'>
|
93 |
+
Created by Boris Dayma et al. 2021
|
94 |
+
<br/>
|
95 |
+
<a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
|
96 |
+
</p>
|
97 |
+
""",
|
98 |
+
unsafe_allow_html=True,
|
99 |
+
)
|
100 |
+
|
101 |
+
st.header("DALL·E mini")
|
102 |
+
st.subheader("Generate images from text")
|
103 |
+
|
104 |
+
prompt = st.text_input("What do you want to see?")
|
105 |
+
|
106 |
+
DEBUG = False
|
107 |
+
# if prompt != "":
|
108 |
+
# container = st.empty()
|
109 |
+
# container.markdown(
|
110 |
+
# f"""
|
111 |
+
# <style> p {{ margin:0 }} div {{ margin:0 }} </style>
|
112 |
+
# <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
|
113 |
+
# <div class="stAlert">
|
114 |
+
# <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
|
115 |
+
# <div class="st-b7">
|
116 |
+
# <div class="css-whx05o e13vu3m50">
|
117 |
+
# <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
|
118 |
+
# <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
|
119 |
+
# Generating predictions for: <b>{prompt}</b>
|
120 |
+
# </div>
|
121 |
+
# </div>
|
122 |
+
# </div>
|
123 |
+
# </div>
|
124 |
+
# </div>
|
125 |
+
# </div>
|
126 |
+
# <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
|
127 |
+
# """,
|
128 |
+
# unsafe_allow_html=True,
|
129 |
+
# )
|
130 |
+
|
131 |
+
# print(f"Getting selections: {prompt}")
|
132 |
+
# selected = sample(prompt)
|
133 |
+
|
134 |
+
# margin = 0.1 # for better position of zoom in arrow
|
135 |
+
# n_columns = 3
|
136 |
+
# cols = st.columns([1] + [margin, 1] * (n_columns - 1))
|
137 |
+
# for i, img in enumerate(selected):
|
138 |
+
# cols[(i % n_columns) * 2].image(img)
|
139 |
+
# container.markdown(f"**{prompt}**")
|
140 |
+
|
141 |
+
# st.button("Again!", key="again_button")
|
configs/dalle-1.3B.yaml
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
stage1:
|
2 |
+
type: vqgan
|
3 |
+
embed_dim: 256
|
4 |
+
n_embed: 16384
|
5 |
+
hparams:
|
6 |
+
double_z: False
|
7 |
+
z_channels: 256
|
8 |
+
resolution: 256
|
9 |
+
in_channels: 3
|
10 |
+
out_ch: 3
|
11 |
+
ch: 128
|
12 |
+
ch_mult: [1, 1, 2, 2, 4]
|
13 |
+
num_res_blocks: 2
|
14 |
+
attn_resolutions: [16]
|
15 |
+
pdrop: 0.0
|
16 |
+
|
17 |
+
stage2:
|
18 |
+
type: transformer1d
|
19 |
+
vocab_size_txt: 16384
|
20 |
+
vocab_size_img: 16384
|
21 |
+
hparams:
|
22 |
+
embed_dim: 1536
|
23 |
+
n_layers: 42
|
24 |
+
n_heads: 24
|
25 |
+
n_dense_layers: 42
|
26 |
+
ctx_len_img: 256
|
27 |
+
ctx_len_txt: 64
|
28 |
+
embd_pdrop: 0.0
|
29 |
+
resid_pdrop: 0.0
|
30 |
+
attn_pdrop: 0.0
|
31 |
+
mlp_bias: True
|
32 |
+
attn_bias: True
|
33 |
+
gelu_use_approx: False
|
configs/transfer-imagenet-clscond-gen.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
dataset: imagenet
|
3 |
+
image_resolution: 256
|
4 |
+
|
5 |
+
stage1:
|
6 |
+
type: vqgan
|
7 |
+
embed_dim: 256
|
8 |
+
n_embed: 16384
|
9 |
+
hparams:
|
10 |
+
double_z: False
|
11 |
+
z_channels: 256
|
12 |
+
resolution: 256
|
13 |
+
in_channels: 3
|
14 |
+
out_ch: 3
|
15 |
+
ch: 128
|
16 |
+
ch_mult: [1, 1, 2, 2, 4]
|
17 |
+
num_res_blocks: 2
|
18 |
+
attn_resolutions: [16]
|
19 |
+
pdrop: 0.0
|
20 |
+
|
21 |
+
stage2:
|
22 |
+
type: igpt
|
23 |
+
use_cls_cond: True
|
24 |
+
vocab_size_img: 16384
|
25 |
+
hparams:
|
26 |
+
embed_dim: 1536
|
27 |
+
n_layers: 42
|
28 |
+
n_heads: 24
|
29 |
+
n_dense_layers: 42
|
30 |
+
ctx_len_img: 256
|
31 |
+
embd_pdrop: 0.0
|
32 |
+
resid_pdrop: 0.0
|
33 |
+
attn_pdrop: 0.0
|
34 |
+
mlp_bias: True
|
35 |
+
attn_bias: True
|
36 |
+
gelu_use_approx: False
|
37 |
+
n_classes: 1000
|
38 |
+
|
39 |
+
optimizer:
|
40 |
+
opt_type: adamW
|
41 |
+
base_lr: 1e-4
|
42 |
+
weight_decay: 0.0
|
43 |
+
betas: [0.9, 0.95]
|
44 |
+
grad_clip_norm: 4.0
|
45 |
+
|
46 |
+
experiment:
|
47 |
+
local_batch_size: 2
|
48 |
+
total_batch_size: 512
|
49 |
+
epochs: 8
|
configs/transfer-imagenet-uncond-gen.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
dataset:
|
2 |
+
dataset: imagenet
|
3 |
+
image_resolution: 256
|
4 |
+
|
5 |
+
stage1:
|
6 |
+
type: vqgan
|
7 |
+
embed_dim: 256
|
8 |
+
n_embed: 16384
|
9 |
+
hparams:
|
10 |
+
double_z: False
|
11 |
+
z_channels: 256
|
12 |
+
resolution: 256
|
13 |
+
in_channels: 3
|
14 |
+
out_ch: 3
|
15 |
+
ch: 128
|
16 |
+
ch_mult: [1, 1, 2, 2, 4]
|
17 |
+
num_res_blocks: 2
|
18 |
+
attn_resolutions: [16]
|
19 |
+
pdrop: 0.0
|
20 |
+
|
21 |
+
stage2:
|
22 |
+
type: igpt
|
23 |
+
use_cls_cond: False
|
24 |
+
vocab_size_img: 16384
|
25 |
+
hparams:
|
26 |
+
embed_dim: 1536
|
27 |
+
n_layers: 42
|
28 |
+
n_heads: 24
|
29 |
+
n_dense_layers: 42
|
30 |
+
ctx_len_img: 256
|
31 |
+
embd_pdrop: 0.0
|
32 |
+
resid_pdrop: 0.0
|
33 |
+
attn_pdrop: 0.0
|
34 |
+
mlp_bias: True
|
35 |
+
attn_bias: True
|
36 |
+
gelu_use_approx: False
|
37 |
+
|
38 |
+
optimizer:
|
39 |
+
opt_type: adamW
|
40 |
+
base_lr: 1e-4
|
41 |
+
weight_decay: 0.0
|
42 |
+
betas: [0.9, 0.95]
|
43 |
+
grad_clip_norm: 4.0
|
44 |
+
|
45 |
+
experiment:
|
46 |
+
local_batch_size: 2
|
47 |
+
total_batch_size: 512
|
48 |
+
epochs: 8
|
dalle/__init__.py
ADDED
File without changes
|
dalle/models/__init__.py
ADDED
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
import pytorch_lightning as pl
|
11 |
+
from typing import Optional, Tuple
|
12 |
+
from omegaconf import OmegaConf
|
13 |
+
from torch.cuda.amp import autocast
|
14 |
+
from torch.optim.lr_scheduler import CosineAnnealingLR
|
15 |
+
from torch.nn import functional as F
|
16 |
+
from .stage1.vqgan import VQGAN
|
17 |
+
from .stage2.transformer import Transformer1d, iGPT
|
18 |
+
from .. import utils
|
19 |
+
from ..utils.config import get_base_config
|
20 |
+
from ..utils.sampling import sampling, sampling_igpt
|
21 |
+
from .tokenizer import build_tokenizer
|
22 |
+
|
23 |
+
_MODELS = {
|
24 |
+
'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
|
25 |
+
}
|
26 |
+
|
27 |
+
|
28 |
+
class Dalle(nn.Module):
|
29 |
+
def __init__(self,
|
30 |
+
config: OmegaConf) -> None:
|
31 |
+
super().__init__()
|
32 |
+
self.tokenizer = None
|
33 |
+
self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
|
34 |
+
embed_dim=config.stage1.embed_dim,
|
35 |
+
hparams=config.stage1.hparams)
|
36 |
+
self.stage2 = Transformer1d(vocab_size_txt=config.stage2.vocab_size_txt,
|
37 |
+
vocab_size_img=config.stage2.vocab_size_img,
|
38 |
+
hparams=config.stage2.hparams)
|
39 |
+
self.config_stage1 = config.stage1
|
40 |
+
self.config_stage2 = config.stage2
|
41 |
+
self.config_dataset = config.dataset
|
42 |
+
|
43 |
+
@classmethod
|
44 |
+
def from_pretrained(cls,
|
45 |
+
path: str) -> nn.Module:
|
46 |
+
path = _MODELS[path] if path in _MODELS else path
|
47 |
+
path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))
|
48 |
+
|
49 |
+
config_base = get_base_config()
|
50 |
+
config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
|
51 |
+
config_update = OmegaConf.merge(config_base, config_new)
|
52 |
+
|
53 |
+
model = cls(config_update)
|
54 |
+
model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
|
55 |
+
context_length=model.config_dataset.context_length,
|
56 |
+
lowercase=True,
|
57 |
+
dropout=None)
|
58 |
+
model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
|
59 |
+
model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
|
60 |
+
return model
|
61 |
+
|
62 |
+
@torch.no_grad()
|
63 |
+
def sampling(self,
|
64 |
+
prompt: str,
|
65 |
+
top_k: int = 256,
|
66 |
+
top_p: Optional[float] = None,
|
67 |
+
softmax_temperature: float = 1.0,
|
68 |
+
num_candidates: int = 96,
|
69 |
+
device: str = 'cuda:0',
|
70 |
+
use_fp16: bool = True) -> torch.FloatTensor:
|
71 |
+
self.stage1.eval()
|
72 |
+
self.stage2.eval()
|
73 |
+
|
74 |
+
tokens = self.tokenizer.encode(prompt)
|
75 |
+
tokens = torch.LongTensor(tokens.ids)
|
76 |
+
tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
|
77 |
+
|
78 |
+
# Check if the encoding works as intended
|
79 |
+
# print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
|
80 |
+
|
81 |
+
tokens = tokens.to(device)
|
82 |
+
codes = sampling(self.stage2,
|
83 |
+
tokens,
|
84 |
+
top_k=top_k,
|
85 |
+
top_p=top_p,
|
86 |
+
softmax_temperature=softmax_temperature,
|
87 |
+
use_fp16=use_fp16)
|
88 |
+
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
|
89 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
90 |
+
return pixels
|
91 |
+
|
92 |
+
|
93 |
+
class ImageGPT(pl.LightningModule):
|
94 |
+
def __init__(self,
|
95 |
+
config: OmegaConf) -> None:
|
96 |
+
super().__init__()
|
97 |
+
self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
|
98 |
+
embed_dim=config.stage1.embed_dim,
|
99 |
+
hparams=config.stage1.hparams)
|
100 |
+
self.stage2 = iGPT(vocab_size_img=config.stage2.vocab_size_img,
|
101 |
+
use_cls_cond=config.stage2.use_cls_cond,
|
102 |
+
hparams=config.stage2.hparams)
|
103 |
+
self.config = config
|
104 |
+
self.use_cls_cond = config.stage2.use_cls_cond
|
105 |
+
|
106 |
+
# make the parameters in stage 1 not trainable
|
107 |
+
self.stage1.eval()
|
108 |
+
for p in self.stage1.parameters():
|
109 |
+
p.requires_grad = False
|
110 |
+
|
111 |
+
@classmethod
|
112 |
+
def from_pretrained(cls,
|
113 |
+
path_upstream: str,
|
114 |
+
path_downstream: str) -> Tuple[nn.Module, OmegaConf]:
|
115 |
+
config_base = get_base_config(use_default=False)
|
116 |
+
config_down = OmegaConf.load(path_downstream)
|
117 |
+
config_down = OmegaConf.merge(config_base, config_down)
|
118 |
+
|
119 |
+
model = cls(config_down)
|
120 |
+
model.stage1.from_ckpt(os.path.join(path_upstream, 'stage1_last.ckpt'), strict=True)
|
121 |
+
model.stage2.from_ckpt(os.path.join(path_upstream, 'stage2_last.ckpt'), strict=False)
|
122 |
+
return model, config_down
|
123 |
+
|
124 |
+
def sample(self,
|
125 |
+
cls_idx: Optional[int] = None,
|
126 |
+
top_k: int = 256,
|
127 |
+
top_p: Optional[float] = None,
|
128 |
+
softmax_temperature: float = 1.0,
|
129 |
+
num_candidates: int = 16,
|
130 |
+
device: str = 'cuda:0',
|
131 |
+
use_fp16: bool = True,
|
132 |
+
is_tqdm: bool = True) -> torch.FloatTensor:
|
133 |
+
self.stage1.eval()
|
134 |
+
self.stage2.eval()
|
135 |
+
|
136 |
+
if cls_idx is None:
|
137 |
+
sos = self.stage2.sos.repeat(num_candidates, 1, 1)
|
138 |
+
else:
|
139 |
+
sos = torch.LongTensor([cls_idx]).to(device=device)
|
140 |
+
sos = sos.repeat(num_candidates)
|
141 |
+
sos = self.stage2.sos(sos).unsqueeze(1)
|
142 |
+
|
143 |
+
codes = sampling_igpt(self.stage2,
|
144 |
+
sos=sos,
|
145 |
+
top_k=top_k,
|
146 |
+
top_p=top_p,
|
147 |
+
softmax_temperature=softmax_temperature,
|
148 |
+
use_fp16=use_fp16,
|
149 |
+
is_tqdm=is_tqdm)
|
150 |
+
codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
|
151 |
+
pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
|
152 |
+
return pixels
|
153 |
+
|
154 |
+
def forward(self,
|
155 |
+
images: torch.FloatTensor,
|
156 |
+
labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
|
157 |
+
B, C, H, W = images.shape
|
158 |
+
with torch.no_grad():
|
159 |
+
with autocast(enabled=False):
|
160 |
+
codes = self.stage1.get_codes(images).detach()
|
161 |
+
logits = self.stage2(codes, labels)
|
162 |
+
return logits, codes
|
163 |
+
|
164 |
+
def training_step(self, batch, batch_idx):
|
165 |
+
images, labels = batch
|
166 |
+
logits, codes = self(images, labels=labels if self.use_cls_cond else None)
|
167 |
+
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
|
168 |
+
self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
|
169 |
+
return loss
|
170 |
+
|
171 |
+
def validation_step(self, batch, batch_idx):
|
172 |
+
images, labels = batch
|
173 |
+
logits, codes = self(images, labels=labels if self.use_cls_cond else None)
|
174 |
+
loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
|
175 |
+
self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
|
176 |
+
return loss
|
177 |
+
|
178 |
+
def configure_optimizers(self):
|
179 |
+
assert self.config.optimizer.opt_type == 'adamW'
|
180 |
+
assert self.config.optimizer.sched_type == 'cosine'
|
181 |
+
|
182 |
+
opt = torch.optim.AdamW(self.parameters(),
|
183 |
+
lr=self.config.optimizer.base_lr,
|
184 |
+
betas=self.config.optimizer.betas,
|
185 |
+
weight_decay=self.config.optimizer.weight_decay)
|
186 |
+
sched = CosineAnnealingLR(opt,
|
187 |
+
T_max=self.config.optimizer.max_steps,
|
188 |
+
eta_min=self.config.optimizer.min_lr)
|
189 |
+
sched = {
|
190 |
+
'scheduler': sched,
|
191 |
+
'name': 'cosine'
|
192 |
+
}
|
193 |
+
return [opt], [sched]
|
194 |
+
|
195 |
+
def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
|
196 |
+
on_tpu=False, using_native_amp=False, using_lbfgs=False):
|
197 |
+
optimizer.step(closure=optimizer_closure)
|
198 |
+
self.lr_schedulers().step()
|
199 |
+
self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
|
200 |
+
|
201 |
+
def on_epoch_start(self):
|
202 |
+
self.stage1.eval()
|
dalle/models/stage1/layers.py
ADDED
@@ -0,0 +1,373 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Modified from VQGAN (https://github.com/CompVis/taming-transformers)
|
3 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
4 |
+
# ------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from typing import Tuple, Optional
|
9 |
+
|
10 |
+
|
11 |
+
def nonlinearity(x):
|
12 |
+
# swish
|
13 |
+
return x*torch.sigmoid(x)
|
14 |
+
|
15 |
+
|
16 |
+
def Normalize(in_channels):
|
17 |
+
return torch.nn.GroupNorm(num_groups=32,
|
18 |
+
num_channels=in_channels,
|
19 |
+
eps=1e-6,
|
20 |
+
affine=True)
|
21 |
+
|
22 |
+
|
23 |
+
class Upsample(nn.Module):
|
24 |
+
def __init__(self, in_channels, with_conv):
|
25 |
+
super().__init__()
|
26 |
+
self.with_conv = with_conv
|
27 |
+
if self.with_conv:
|
28 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
29 |
+
in_channels,
|
30 |
+
kernel_size=3,
|
31 |
+
stride=1,
|
32 |
+
padding=1)
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
|
36 |
+
if self.with_conv:
|
37 |
+
x = self.conv(x)
|
38 |
+
return x
|
39 |
+
|
40 |
+
|
41 |
+
class Downsample(nn.Module):
|
42 |
+
def __init__(self, in_channels, with_conv):
|
43 |
+
super().__init__()
|
44 |
+
self.with_conv = with_conv
|
45 |
+
if self.with_conv:
|
46 |
+
# no asymmetric padding in torch conv, must do it ourselves
|
47 |
+
self.conv = torch.nn.Conv2d(in_channels,
|
48 |
+
in_channels,
|
49 |
+
kernel_size=3,
|
50 |
+
stride=2,
|
51 |
+
padding=0)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
if self.with_conv:
|
55 |
+
pad = (0, 1, 0, 1)
|
56 |
+
x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
|
57 |
+
x = self.conv(x)
|
58 |
+
else:
|
59 |
+
x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
|
60 |
+
return x
|
61 |
+
|
62 |
+
|
63 |
+
class ResnetBlock(nn.Module):
|
64 |
+
def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
|
65 |
+
dropout, temb_channels=512):
|
66 |
+
assert temb_channels == 0
|
67 |
+
super().__init__()
|
68 |
+
self.in_channels = in_channels
|
69 |
+
out_channels = in_channels if out_channels is None else out_channels
|
70 |
+
self.out_channels = out_channels
|
71 |
+
self.use_conv_shortcut = conv_shortcut
|
72 |
+
|
73 |
+
self.norm1 = Normalize(in_channels)
|
74 |
+
self.conv1 = torch.nn.Conv2d(in_channels,
|
75 |
+
out_channels,
|
76 |
+
kernel_size=3,
|
77 |
+
stride=1,
|
78 |
+
padding=1)
|
79 |
+
self.norm2 = Normalize(out_channels)
|
80 |
+
self.dropout = torch.nn.Dropout(dropout)
|
81 |
+
self.conv2 = torch.nn.Conv2d(out_channels,
|
82 |
+
out_channels,
|
83 |
+
kernel_size=3,
|
84 |
+
stride=1,
|
85 |
+
padding=1)
|
86 |
+
if self.in_channels != self.out_channels:
|
87 |
+
if self.use_conv_shortcut:
|
88 |
+
self.conv_shortcut = torch.nn.Conv2d(in_channels,
|
89 |
+
out_channels,
|
90 |
+
kernel_size=3,
|
91 |
+
stride=1,
|
92 |
+
padding=1)
|
93 |
+
else:
|
94 |
+
self.nin_shortcut = torch.nn.Conv2d(in_channels,
|
95 |
+
out_channels,
|
96 |
+
kernel_size=1,
|
97 |
+
stride=1,
|
98 |
+
padding=0)
|
99 |
+
|
100 |
+
def forward(self, x, temb=None):
|
101 |
+
assert temb is None
|
102 |
+
|
103 |
+
h = x
|
104 |
+
h = self.norm1(h)
|
105 |
+
h = nonlinearity(h)
|
106 |
+
h = self.conv1(h)
|
107 |
+
|
108 |
+
h = self.norm2(h)
|
109 |
+
h = nonlinearity(h)
|
110 |
+
h = self.dropout(h)
|
111 |
+
h = self.conv2(h)
|
112 |
+
|
113 |
+
if self.in_channels != self.out_channels:
|
114 |
+
if self.use_conv_shortcut:
|
115 |
+
x = self.conv_shortcut(x)
|
116 |
+
else:
|
117 |
+
x = self.nin_shortcut(x)
|
118 |
+
return x+h
|
119 |
+
|
120 |
+
|
121 |
+
class AttnBlock(nn.Module):
|
122 |
+
def __init__(self, in_channels):
|
123 |
+
super().__init__()
|
124 |
+
self.in_channels = in_channels
|
125 |
+
|
126 |
+
self.norm = Normalize(in_channels)
|
127 |
+
self.q = torch.nn.Conv2d(in_channels,
|
128 |
+
in_channels,
|
129 |
+
kernel_size=1,
|
130 |
+
stride=1,
|
131 |
+
padding=0)
|
132 |
+
self.k = torch.nn.Conv2d(in_channels,
|
133 |
+
in_channels,
|
134 |
+
kernel_size=1,
|
135 |
+
stride=1,
|
136 |
+
padding=0)
|
137 |
+
self.v = torch.nn.Conv2d(in_channels,
|
138 |
+
in_channels,
|
139 |
+
kernel_size=1,
|
140 |
+
stride=1,
|
141 |
+
padding=0)
|
142 |
+
self.proj_out = torch.nn.Conv2d(in_channels,
|
143 |
+
in_channels,
|
144 |
+
kernel_size=1,
|
145 |
+
stride=1,
|
146 |
+
padding=0)
|
147 |
+
|
148 |
+
def forward(self, x):
|
149 |
+
h_ = x
|
150 |
+
h_ = self.norm(h_)
|
151 |
+
q = self.q(h_)
|
152 |
+
k = self.k(h_)
|
153 |
+
v = self.v(h_)
|
154 |
+
|
155 |
+
# compute attention
|
156 |
+
b, c, h, w = q.shape
|
157 |
+
q = q.reshape(b, c, h*w)
|
158 |
+
q = q.permute(0, 2, 1) # b,hw,c
|
159 |
+
k = k.reshape(b, c, h*w) # b,c,hw
|
160 |
+
w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
|
161 |
+
w_ = w_ * (int(c)**(-0.5))
|
162 |
+
w_ = torch.nn.functional.softmax(w_, dim=2)
|
163 |
+
|
164 |
+
# attend to values
|
165 |
+
v = v.reshape(b, c, h*w)
|
166 |
+
w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
|
167 |
+
h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
|
168 |
+
h_ = h_.reshape(b, c, h, w)
|
169 |
+
|
170 |
+
h_ = self.proj_out(h_)
|
171 |
+
return x+h_
|
172 |
+
|
173 |
+
|
174 |
+
class Encoder(nn.Module):
|
175 |
+
def __init__(self,
|
176 |
+
*, # forced to use named arguments
|
177 |
+
ch: int,
|
178 |
+
out_ch: int,
|
179 |
+
ch_mult: Tuple[int] = (1, 2, 4, 8),
|
180 |
+
num_res_blocks: int,
|
181 |
+
attn_resolutions: Tuple[int],
|
182 |
+
pdrop: float = 0.0,
|
183 |
+
resamp_with_conv: bool = True,
|
184 |
+
in_channels: int,
|
185 |
+
resolution: int,
|
186 |
+
z_channels: int,
|
187 |
+
double_z: Optional[bool] = None) -> None:
|
188 |
+
super().__init__()
|
189 |
+
self.ch = ch
|
190 |
+
self.temb_ch = 0
|
191 |
+
self.num_resolutions = len(ch_mult)
|
192 |
+
self.num_res_blocks = num_res_blocks
|
193 |
+
self.resolution = resolution
|
194 |
+
self.in_channels = in_channels
|
195 |
+
|
196 |
+
# downsampling
|
197 |
+
self.conv_in = torch.nn.Conv2d(in_channels,
|
198 |
+
self.ch,
|
199 |
+
kernel_size=3,
|
200 |
+
stride=1,
|
201 |
+
padding=1)
|
202 |
+
|
203 |
+
curr_res = resolution
|
204 |
+
in_ch_mult = (1,)+tuple(ch_mult)
|
205 |
+
self.down = nn.ModuleList()
|
206 |
+
for i_level in range(self.num_resolutions):
|
207 |
+
block = nn.ModuleList()
|
208 |
+
attn = nn.ModuleList()
|
209 |
+
block_in = ch*in_ch_mult[i_level]
|
210 |
+
block_out = ch*ch_mult[i_level]
|
211 |
+
for i_block in range(self.num_res_blocks):
|
212 |
+
block.append(ResnetBlock(in_channels=block_in,
|
213 |
+
out_channels=block_out,
|
214 |
+
temb_channels=self.temb_ch,
|
215 |
+
dropout=pdrop))
|
216 |
+
block_in = block_out
|
217 |
+
if curr_res in attn_resolutions:
|
218 |
+
attn.append(AttnBlock(block_in))
|
219 |
+
down = nn.Module()
|
220 |
+
down.block = block
|
221 |
+
down.attn = attn
|
222 |
+
if i_level != self.num_resolutions-1:
|
223 |
+
down.downsample = Downsample(block_in, resamp_with_conv)
|
224 |
+
curr_res = curr_res // 2
|
225 |
+
self.down.append(down)
|
226 |
+
|
227 |
+
# middle
|
228 |
+
self.mid = nn.Module()
|
229 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
230 |
+
out_channels=block_in,
|
231 |
+
temb_channels=self.temb_ch,
|
232 |
+
dropout=pdrop)
|
233 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
234 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
235 |
+
out_channels=block_in,
|
236 |
+
temb_channels=self.temb_ch,
|
237 |
+
dropout=pdrop)
|
238 |
+
|
239 |
+
# end
|
240 |
+
self.norm_out = Normalize(block_in)
|
241 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
242 |
+
2*z_channels if double_z else z_channels,
|
243 |
+
kernel_size=3,
|
244 |
+
stride=1,
|
245 |
+
padding=1)
|
246 |
+
|
247 |
+
def forward(self, x):
|
248 |
+
assert x.shape[2] == x.shape[3] == self.resolution, \
|
249 |
+
"{}, {}".format(x.shape, self.resolution)
|
250 |
+
|
251 |
+
# downsampling
|
252 |
+
h = self.conv_in(x)
|
253 |
+
for i_level in range(self.num_resolutions):
|
254 |
+
for i_block in range(self.num_res_blocks):
|
255 |
+
h = self.down[i_level].block[i_block](h)
|
256 |
+
if len(self.down[i_level].attn) > 0:
|
257 |
+
h = self.down[i_level].attn[i_block](h)
|
258 |
+
if i_level != self.num_resolutions-1:
|
259 |
+
h = self.down[i_level].downsample(h)
|
260 |
+
|
261 |
+
# middle
|
262 |
+
h = self.mid.block_1(h)
|
263 |
+
h = self.mid.attn_1(h)
|
264 |
+
h = self.mid.block_2(h)
|
265 |
+
|
266 |
+
# end
|
267 |
+
h = self.norm_out(h)
|
268 |
+
h = nonlinearity(h)
|
269 |
+
h = self.conv_out(h)
|
270 |
+
return h
|
271 |
+
|
272 |
+
|
273 |
+
class Decoder(nn.Module):
|
274 |
+
def __init__(self,
|
275 |
+
*, # forced to use named arguments
|
276 |
+
ch: int,
|
277 |
+
out_ch: int,
|
278 |
+
ch_mult: Tuple[int] = (1, 2, 4, 8),
|
279 |
+
num_res_blocks: int,
|
280 |
+
attn_resolutions: Tuple[int],
|
281 |
+
pdrop: float = 0.0,
|
282 |
+
resamp_with_conv: bool = True,
|
283 |
+
in_channels: int,
|
284 |
+
resolution: int,
|
285 |
+
z_channels: int,
|
286 |
+
double_z: bool) -> None:
|
287 |
+
super().__init__()
|
288 |
+
self.ch = ch
|
289 |
+
self.temb_ch = 0
|
290 |
+
self.num_resolutions = len(ch_mult)
|
291 |
+
self.num_res_blocks = num_res_blocks
|
292 |
+
self.resolution = resolution
|
293 |
+
self.in_channels = in_channels
|
294 |
+
|
295 |
+
# compute in_ch_mult, block_in and curr_res at lowest res
|
296 |
+
block_in = ch*ch_mult[self.num_resolutions-1]
|
297 |
+
curr_res = resolution // 2**(self.num_resolutions-1)
|
298 |
+
self.z_shape = (1, z_channels, curr_res, curr_res)
|
299 |
+
|
300 |
+
# z to block_in
|
301 |
+
self.conv_in = torch.nn.Conv2d(z_channels,
|
302 |
+
block_in,
|
303 |
+
kernel_size=3,
|
304 |
+
stride=1,
|
305 |
+
padding=1)
|
306 |
+
|
307 |
+
# middle
|
308 |
+
self.mid = nn.Module()
|
309 |
+
self.mid.block_1 = ResnetBlock(in_channels=block_in,
|
310 |
+
out_channels=block_in,
|
311 |
+
temb_channels=self.temb_ch,
|
312 |
+
dropout=pdrop)
|
313 |
+
self.mid.attn_1 = AttnBlock(block_in)
|
314 |
+
self.mid.block_2 = ResnetBlock(in_channels=block_in,
|
315 |
+
out_channels=block_in,
|
316 |
+
temb_channels=self.temb_ch,
|
317 |
+
dropout=pdrop)
|
318 |
+
|
319 |
+
# upsampling
|
320 |
+
self.up = nn.ModuleList()
|
321 |
+
for i_level in reversed(range(self.num_resolutions)):
|
322 |
+
block = nn.ModuleList()
|
323 |
+
attn = nn.ModuleList()
|
324 |
+
block_out = ch*ch_mult[i_level]
|
325 |
+
for i_block in range(self.num_res_blocks+1):
|
326 |
+
block.append(ResnetBlock(in_channels=block_in,
|
327 |
+
out_channels=block_out,
|
328 |
+
temb_channels=self.temb_ch,
|
329 |
+
dropout=pdrop))
|
330 |
+
block_in = block_out
|
331 |
+
if curr_res in attn_resolutions:
|
332 |
+
attn.append(AttnBlock(block_in))
|
333 |
+
up = nn.Module()
|
334 |
+
up.block = block
|
335 |
+
up.attn = attn
|
336 |
+
if i_level != 0:
|
337 |
+
up.upsample = Upsample(block_in, resamp_with_conv)
|
338 |
+
curr_res = curr_res * 2
|
339 |
+
self.up.insert(0, up) # prepend to get consistent order
|
340 |
+
|
341 |
+
# end
|
342 |
+
self.norm_out = Normalize(block_in)
|
343 |
+
self.conv_out = torch.nn.Conv2d(block_in,
|
344 |
+
out_ch,
|
345 |
+
kernel_size=3,
|
346 |
+
stride=1,
|
347 |
+
padding=1)
|
348 |
+
|
349 |
+
def forward(self, z):
|
350 |
+
assert z.shape[1:] == self.z_shape[1:]
|
351 |
+
self.last_z_shape = z.shape
|
352 |
+
|
353 |
+
# z to block_in
|
354 |
+
h = self.conv_in(z)
|
355 |
+
|
356 |
+
# middle
|
357 |
+
h = self.mid.block_1(h)
|
358 |
+
h = self.mid.attn_1(h)
|
359 |
+
h = self.mid.block_2(h)
|
360 |
+
|
361 |
+
# upsampling
|
362 |
+
for i_level in reversed(range(self.num_resolutions)):
|
363 |
+
for i_block in range(self.num_res_blocks+1):
|
364 |
+
h = self.up[i_level].block[i_block](h)
|
365 |
+
if len(self.up[i_level].attn) > 0:
|
366 |
+
h = self.up[i_level].attn[i_block](h)
|
367 |
+
if i_level != 0:
|
368 |
+
h = self.up[i_level].upsample(h)
|
369 |
+
|
370 |
+
h = self.norm_out(h)
|
371 |
+
h = nonlinearity(h)
|
372 |
+
h = self.conv_out(h)
|
373 |
+
return h
|
dalle/models/stage1/vqgan.py
ADDED
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Modified from VQGAN (https://github.com/CompVis/taming-transformers)
|
3 |
+
# Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
|
4 |
+
# ------------------------------------------------------------------------------------
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from typing import List, Tuple, Optional
|
9 |
+
from einops import rearrange
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
from .layers import Encoder, Decoder
|
12 |
+
|
13 |
+
|
14 |
+
class VectorQuantizer(nn.Module):
|
15 |
+
"""
|
16 |
+
Simplified VectorQuantizer in the original VQGAN repository
|
17 |
+
by removing unncessary modules for sampling
|
18 |
+
"""
|
19 |
+
def __init__(self, dim: int, n_embed: int, beta: float) -> None:
|
20 |
+
super().__init__()
|
21 |
+
self.n_embed = n_embed
|
22 |
+
self.dim = dim
|
23 |
+
self.beta = beta
|
24 |
+
|
25 |
+
self.embedding = nn.Embedding(self.n_embed, self.dim)
|
26 |
+
self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed)
|
27 |
+
|
28 |
+
def forward(self,
|
29 |
+
z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
|
30 |
+
z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C]
|
31 |
+
z_flattened = z.view(-1, self.dim)
|
32 |
+
|
33 |
+
d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
|
34 |
+
torch.sum(self.embedding.weight**2, dim=1) - 2 * \
|
35 |
+
torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
|
36 |
+
|
37 |
+
min_encoding_indices = torch.argmin(d, dim=1)
|
38 |
+
z_q = self.embedding(min_encoding_indices).view(z.shape)
|
39 |
+
return z_q, min_encoding_indices
|
40 |
+
|
41 |
+
def get_codebook_entry(self,
|
42 |
+
indices: torch.LongTensor,
|
43 |
+
shape: Optional[List[int]] = None) -> torch.FloatTensor:
|
44 |
+
z_q = self.embedding(indices)
|
45 |
+
if shape is not None:
|
46 |
+
z_q = z_q.view(shape)
|
47 |
+
z_q = z_q.permute(0, 3, 1, 2).contiguous()
|
48 |
+
return z_q
|
49 |
+
|
50 |
+
|
51 |
+
class VQGAN(nn.Module):
|
52 |
+
def __init__(self, n_embed: int, embed_dim: int, hparams: OmegaConf) -> None:
|
53 |
+
super().__init__()
|
54 |
+
self.encoder = Encoder(**hparams)
|
55 |
+
self.decoder = Decoder(**hparams)
|
56 |
+
self.quantize = VectorQuantizer(dim=embed_dim, n_embed=n_embed, beta=0.25)
|
57 |
+
self.quant_conv = torch.nn.Conv2d(hparams.z_channels, embed_dim, 1)
|
58 |
+
self.post_quant_conv = torch.nn.Conv2d(embed_dim, hparams.z_channels, 1)
|
59 |
+
self.latent_dim = hparams.attn_resolutions[0]
|
60 |
+
|
61 |
+
def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
62 |
+
quant = self.encode(x)
|
63 |
+
dec = self.decode(quant)
|
64 |
+
return dec
|
65 |
+
|
66 |
+
def encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
|
67 |
+
h = self.encoder(x)
|
68 |
+
h = self.quant_conv(h)
|
69 |
+
quant = self.quantize(h)[0]
|
70 |
+
quant = rearrange(quant, 'b h w c -> b c h w').contiguous()
|
71 |
+
return quant
|
72 |
+
|
73 |
+
def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor:
|
74 |
+
quant = self.post_quant_conv(quant)
|
75 |
+
dec = self.decoder(quant)
|
76 |
+
return dec
|
77 |
+
|
78 |
+
def decode_code(self, code: torch.LongTensor) -> torch.FloatTensor:
|
79 |
+
quant = self.quantize.get_codebook_entry(code)
|
80 |
+
quant = quant.permute(0, 3, 1, 2)
|
81 |
+
dec = self.decode(quant)
|
82 |
+
return dec
|
83 |
+
|
84 |
+
def get_codes(self, x: torch.FloatTensor) -> torch.LongTensor:
|
85 |
+
h = self.encoder(x)
|
86 |
+
h = self.quant_conv(h)
|
87 |
+
codes = self.quantize(h)[1].view(x.shape[0], self.latent_dim ** 2)
|
88 |
+
return codes
|
89 |
+
|
90 |
+
def from_ckpt(self, path: str, strict: bool = True) -> None:
|
91 |
+
ckpt = torch.load(path, map_location='cpu')['state_dict']
|
92 |
+
self.load_state_dict(ckpt, strict=strict)
|
93 |
+
print(f'{path} successfully restored..')
|
dalle/models/stage2/layers.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
# Modified from minGPT (https://github.com/karpathy/minGPT)
|
7 |
+
# Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
|
8 |
+
# ------------------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import math
|
11 |
+
import torch
|
12 |
+
import torch.nn as nn
|
13 |
+
from torch.nn import functional as F
|
14 |
+
|
15 |
+
|
16 |
+
class GELU(nn.Module):
|
17 |
+
def __init__(self, use_approx=False):
|
18 |
+
super().__init__()
|
19 |
+
self.use_approx = use_approx
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
if self.use_approx:
|
23 |
+
return x * torch.sigmoid(1.702 * x)
|
24 |
+
else:
|
25 |
+
return F.gelu(x)
|
26 |
+
|
27 |
+
|
28 |
+
class MultiHeadSelfAttention(nn.Module):
|
29 |
+
|
30 |
+
def __init__(self,
|
31 |
+
ctx_len: int,
|
32 |
+
embed_dim: int,
|
33 |
+
n_heads: int,
|
34 |
+
resid_pdrop: float,
|
35 |
+
attn_pdrop: float,
|
36 |
+
attn_bias: bool,
|
37 |
+
use_mask: bool = True):
|
38 |
+
super().__init__()
|
39 |
+
assert embed_dim % n_heads == 0
|
40 |
+
|
41 |
+
# key, query, value projections for all heads
|
42 |
+
self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
|
43 |
+
self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
|
44 |
+
self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
|
45 |
+
|
46 |
+
# regularization
|
47 |
+
self.attn_drop = nn.Dropout(attn_pdrop)
|
48 |
+
self.resid_drop = nn.Dropout(resid_pdrop)
|
49 |
+
|
50 |
+
# output projection
|
51 |
+
self.proj = nn.Linear(embed_dim, embed_dim, attn_bias)
|
52 |
+
|
53 |
+
self.n_heads = n_heads
|
54 |
+
self.ctx_len = ctx_len
|
55 |
+
self.use_mask = use_mask
|
56 |
+
if self.use_mask:
|
57 |
+
self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False)
|
58 |
+
self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len)
|
59 |
+
|
60 |
+
def forward(self, x, use_cache=False, layer_past=None):
|
61 |
+
B, T, C = x.shape
|
62 |
+
x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
|
63 |
+
|
64 |
+
# calculate query, key, values for all heads in batch and move head forward to be the batch dim
|
65 |
+
k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
66 |
+
q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
67 |
+
v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
|
68 |
+
|
69 |
+
if use_cache:
|
70 |
+
present = torch.stack([k, v])
|
71 |
+
|
72 |
+
if layer_past is not None:
|
73 |
+
past_key, past_value = layer_past
|
74 |
+
k = torch.cat([past_key, k], dim=-2)
|
75 |
+
v = torch.cat([past_value, v], dim=-2)
|
76 |
+
|
77 |
+
if use_cache and layer_past is not None:
|
78 |
+
# Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K)
|
79 |
+
att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
|
80 |
+
att = F.softmax(att, dim=-1)
|
81 |
+
att = self.attn_drop(att)
|
82 |
+
y = torch.bmm(att, v) # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs)
|
83 |
+
else:
|
84 |
+
# Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T)
|
85 |
+
att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
|
86 |
+
if self.use_mask:
|
87 |
+
mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T]
|
88 |
+
att = att.masked_fill(mask == 0, float('-inf'))
|
89 |
+
att = F.softmax(att, dim=-1)
|
90 |
+
att = self.attn_drop(att)
|
91 |
+
y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
|
92 |
+
y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
|
93 |
+
|
94 |
+
# output projection
|
95 |
+
y = self.resid_drop(self.proj(y))
|
96 |
+
if use_cache:
|
97 |
+
return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C)
|
98 |
+
else:
|
99 |
+
return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C)
|
100 |
+
|
101 |
+
|
102 |
+
class Block(nn.Module):
|
103 |
+
|
104 |
+
def __init__(self,
|
105 |
+
ctx_len: int,
|
106 |
+
embed_dim: int,
|
107 |
+
n_heads: int,
|
108 |
+
mlp_bias: bool,
|
109 |
+
attn_bias: bool,
|
110 |
+
resid_pdrop: bool,
|
111 |
+
attn_pdrop: bool,
|
112 |
+
gelu_use_approx: bool):
|
113 |
+
super().__init__()
|
114 |
+
self.ln1 = nn.LayerNorm(embed_dim)
|
115 |
+
self.ln2 = nn.LayerNorm(embed_dim)
|
116 |
+
|
117 |
+
self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
|
118 |
+
embed_dim=embed_dim,
|
119 |
+
n_heads=n_heads,
|
120 |
+
attn_pdrop=attn_pdrop,
|
121 |
+
resid_pdrop=resid_pdrop,
|
122 |
+
attn_bias=attn_bias,
|
123 |
+
use_mask=True)
|
124 |
+
self.mlp = nn.Sequential(
|
125 |
+
nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias),
|
126 |
+
GELU(gelu_use_approx),
|
127 |
+
nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias),
|
128 |
+
nn.Dropout(resid_pdrop),
|
129 |
+
)
|
130 |
+
|
131 |
+
def forward(self, x):
|
132 |
+
x = x + self.attn(self.ln1(x))
|
133 |
+
x = x + self.mlp(self.ln2(x))
|
134 |
+
return x
|
135 |
+
|
136 |
+
def sample(self, x, layer_past=None):
|
137 |
+
attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
|
138 |
+
x = x + attn
|
139 |
+
x = x + self.mlp(self.ln2(x))
|
140 |
+
return x, present
|
dalle/models/stage2/transformer.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
# Modified from minGPT (https://github.com/karpathy/minGPT)
|
7 |
+
# Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
|
8 |
+
# ------------------------------------------------------------------------------------
|
9 |
+
|
10 |
+
import torch
|
11 |
+
import torch.nn as nn
|
12 |
+
from typing import Optional, Tuple, List
|
13 |
+
from torch.cuda.amp import autocast
|
14 |
+
from omegaconf import OmegaConf
|
15 |
+
from .layers import Block
|
16 |
+
|
17 |
+
|
18 |
+
class Transformer1d(nn.Module):
|
19 |
+
|
20 |
+
def __init__(self,
|
21 |
+
vocab_size_txt: int,
|
22 |
+
vocab_size_img: int,
|
23 |
+
hparams: OmegaConf) -> None:
|
24 |
+
super().__init__()
|
25 |
+
assert hparams.n_layers == hparams.n_dense_layers
|
26 |
+
|
27 |
+
# input embedding for image and text
|
28 |
+
self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
|
29 |
+
self.tok_emb_txt = nn.Embedding(vocab_size_txt, hparams.embed_dim)
|
30 |
+
|
31 |
+
self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
|
32 |
+
self.pos_emb_txt = nn.Embedding(hparams.ctx_len_txt, hparams.embed_dim)
|
33 |
+
|
34 |
+
self.drop = nn.Dropout(hparams.embd_pdrop)
|
35 |
+
|
36 |
+
# transformer blocks
|
37 |
+
self.blocks = [Block(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
|
38 |
+
embed_dim=hparams.embed_dim,
|
39 |
+
n_heads=hparams.n_heads,
|
40 |
+
mlp_bias=hparams.mlp_bias,
|
41 |
+
attn_bias=hparams.attn_bias,
|
42 |
+
resid_pdrop=hparams.resid_pdrop,
|
43 |
+
attn_pdrop=hparams.attn_pdrop,
|
44 |
+
gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
|
45 |
+
self.blocks = nn.Sequential(*self.blocks)
|
46 |
+
|
47 |
+
# heads for image and text
|
48 |
+
self.ln_f = nn.LayerNorm(hparams.embed_dim)
|
49 |
+
self.head_img = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
|
50 |
+
self.head_txt = nn.Linear(hparams.embed_dim, vocab_size_txt, bias=False)
|
51 |
+
|
52 |
+
self.ctx_len_img = hparams.ctx_len_img
|
53 |
+
self.ctx_len_txt = hparams.ctx_len_txt
|
54 |
+
self.n_layers = hparams.n_layers
|
55 |
+
|
56 |
+
self.apply(self._init_weights)
|
57 |
+
|
58 |
+
def _init_weights(self, module: nn.Module) -> None:
|
59 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
60 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
61 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
62 |
+
module.bias.data.zero_()
|
63 |
+
elif isinstance(module, nn.LayerNorm):
|
64 |
+
module.bias.data.zero_()
|
65 |
+
module.weight.data.fill_(1.0)
|
66 |
+
|
67 |
+
def forward(self,
|
68 |
+
images: torch.LongTensor,
|
69 |
+
texts: torch.LongTensor,
|
70 |
+
pos_images: torch.LongTensor,
|
71 |
+
pos_texts: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
72 |
+
B, T = images.shape
|
73 |
+
_, N = texts.shape
|
74 |
+
|
75 |
+
assert T <= self.ctx_len_img, "Already reached the maximum context length (image)."
|
76 |
+
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
|
77 |
+
|
78 |
+
texts = self.tok_emb_txt(texts)
|
79 |
+
images = self.tok_emb_img(images)
|
80 |
+
|
81 |
+
texts = texts + self.pos_emb_txt(pos_texts)
|
82 |
+
images = images + self.pos_emb_img(pos_images)
|
83 |
+
|
84 |
+
x = torch.cat([texts, images], axis=1).contiguous()
|
85 |
+
x = self.drop(x)
|
86 |
+
x = self.blocks(x)
|
87 |
+
x = self.ln_f(x)
|
88 |
+
|
89 |
+
texts = x[:, :N-1].contiguous()
|
90 |
+
images = x[:, N-1:-1].contiguous()
|
91 |
+
|
92 |
+
logits_txt = self.head_txt(texts)
|
93 |
+
logits_img = self.head_img(images)
|
94 |
+
return logits_img, logits_txt
|
95 |
+
|
96 |
+
@torch.no_grad()
|
97 |
+
def sampling(self,
|
98 |
+
images: torch.LongTensor,
|
99 |
+
texts: torch.LongTensor,
|
100 |
+
pos_images: torch.LongTensor,
|
101 |
+
pos_texts: torch.LongTensor,
|
102 |
+
use_fp16: bool = True,
|
103 |
+
past: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
|
104 |
+
_, N = texts.shape
|
105 |
+
assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
|
106 |
+
|
107 |
+
with autocast(enabled=use_fp16):
|
108 |
+
if images is None:
|
109 |
+
assert past is None
|
110 |
+
|
111 |
+
texts = self.tok_emb_txt(texts)
|
112 |
+
x = texts + self.pos_emb_txt(pos_texts)
|
113 |
+
x = self.drop(x)
|
114 |
+
|
115 |
+
presents = []
|
116 |
+
for i, block in enumerate(self.blocks):
|
117 |
+
x, present = block.sample(x, layer_past=None)
|
118 |
+
presents.append(present)
|
119 |
+
x = self.ln_f(x)
|
120 |
+
x = x[:, N-1].contiguous()
|
121 |
+
logits = self.head_img(x)
|
122 |
+
else:
|
123 |
+
if past is None:
|
124 |
+
texts = self.tok_emb_txt(texts)
|
125 |
+
images = self.tok_emb_img(images)
|
126 |
+
texts = texts + self.pos_emb_txt(pos_texts)
|
127 |
+
images = images + self.pos_emb_img(pos_images)
|
128 |
+
x = torch.cat([texts, images], axis=1).contiguous()
|
129 |
+
else:
|
130 |
+
images = self.tok_emb_img(images)
|
131 |
+
x = images + self.pos_emb_img(pos_images)
|
132 |
+
x = self.drop(x)
|
133 |
+
|
134 |
+
if past is not None:
|
135 |
+
past = torch.cat(past, dim=-2)
|
136 |
+
presents = []
|
137 |
+
for i, block in enumerate(self.blocks):
|
138 |
+
x, present = block.sample(x, layer_past=None if past is None else past[i])
|
139 |
+
presents.append(present)
|
140 |
+
x = self.ln_f(x)
|
141 |
+
x = x[:, -1].contiguous()
|
142 |
+
logits = self.head_img(x)
|
143 |
+
return logits, presents
|
144 |
+
|
145 |
+
def from_ckpt(self, path: str) -> None:
|
146 |
+
ckpt = torch.load(path, map_location='cpu')['state_dict']
|
147 |
+
self.load_state_dict(ckpt, strict=True)
|
148 |
+
print(f'{path} succesfully restored..')
|
149 |
+
|
150 |
+
|
151 |
+
class iGPT(nn.Module):
|
152 |
+
def __init__(self,
|
153 |
+
vocab_size_img: int,
|
154 |
+
use_cls_cond: bool,
|
155 |
+
hparams: OmegaConf) -> None:
|
156 |
+
super().__init__()
|
157 |
+
self.use_cls_cond = use_cls_cond
|
158 |
+
|
159 |
+
# sos token embedding
|
160 |
+
if self.use_cls_cond:
|
161 |
+
self.sos = nn.Embedding(hparams.n_classes, hparams.embed_dim)
|
162 |
+
else:
|
163 |
+
self.sos = nn.Parameter(torch.randn(1, 1, hparams.embed_dim))
|
164 |
+
|
165 |
+
# input embedding
|
166 |
+
self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
|
167 |
+
self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
|
168 |
+
|
169 |
+
self.drop = nn.Dropout(hparams.embd_pdrop)
|
170 |
+
|
171 |
+
# transformer blocks
|
172 |
+
self.blocks = [Block(ctx_len=hparams.ctx_len_img + 1,
|
173 |
+
embed_dim=hparams.embed_dim,
|
174 |
+
n_heads=hparams.n_heads,
|
175 |
+
mlp_bias=hparams.mlp_bias,
|
176 |
+
attn_bias=hparams.attn_bias,
|
177 |
+
resid_pdrop=hparams.resid_pdrop,
|
178 |
+
attn_pdrop=hparams.attn_pdrop,
|
179 |
+
gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
|
180 |
+
self.blocks = nn.Sequential(*self.blocks)
|
181 |
+
|
182 |
+
# head
|
183 |
+
self.ln_f = nn.LayerNorm(hparams.embed_dim)
|
184 |
+
self.head = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
|
185 |
+
|
186 |
+
self.ctx_len_img = hparams.ctx_len_img
|
187 |
+
self.n_layers = hparams.n_layers
|
188 |
+
|
189 |
+
self.apply(self._init_weights)
|
190 |
+
|
191 |
+
def _init_weights(self, module: nn.Module) -> None:
|
192 |
+
if isinstance(module, (nn.Linear, nn.Embedding)):
|
193 |
+
module.weight.data.normal_(mean=0.0, std=0.02)
|
194 |
+
if isinstance(module, nn.Linear) and module.bias is not None:
|
195 |
+
module.bias.data.zero_()
|
196 |
+
elif isinstance(module, nn.LayerNorm):
|
197 |
+
module.bias.data.zero_()
|
198 |
+
module.weight.data.fill_(1.0)
|
199 |
+
|
200 |
+
@torch.no_grad()
|
201 |
+
def sampling(self,
|
202 |
+
sos: torch.FloatTensor,
|
203 |
+
codes: torch.LongTensor,
|
204 |
+
pos_codes: torch.LongTensor,
|
205 |
+
n_samples: int = 16,
|
206 |
+
use_fp16: bool = True,
|
207 |
+
past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
|
208 |
+
with autocast(enabled=use_fp16):
|
209 |
+
if codes is None:
|
210 |
+
assert past is None
|
211 |
+
xs = self.drop(sos)
|
212 |
+
presents = []
|
213 |
+
for i, block in enumerate(self.blocks):
|
214 |
+
xs, present = block.sample(xs, layer_past=None)
|
215 |
+
presents.append(present)
|
216 |
+
xs = self.ln_f(xs)
|
217 |
+
logits = self.head(xs)[:, -1]
|
218 |
+
else:
|
219 |
+
if past is None:
|
220 |
+
xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
|
221 |
+
xs = torch.cat([sos, xs], dim=1)
|
222 |
+
else:
|
223 |
+
xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
|
224 |
+
xs = self.drop(xs)
|
225 |
+
|
226 |
+
past = torch.cat(past, dim=-2) if past is not None else past
|
227 |
+
presents = []
|
228 |
+
for i, block in enumerate(self.blocks):
|
229 |
+
xs, present = block.sample(xs, layer_past=None if past is None else past[i])
|
230 |
+
presents.append(present)
|
231 |
+
|
232 |
+
xs = self.ln_f(xs)
|
233 |
+
logits = self.head(xs)[:, -1]
|
234 |
+
return logits, presents
|
235 |
+
|
236 |
+
def forward(self,
|
237 |
+
codes: torch.LongTensor,
|
238 |
+
labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
|
239 |
+
B, T = codes.shape
|
240 |
+
xps = torch.arange(T, device=codes.device).repeat((B, 1))
|
241 |
+
sos = self.sos.repeat((B, 1, 1)) if labels is None else self.sos(labels).unsqueeze(1)
|
242 |
+
|
243 |
+
h = self.tok_emb_img(codes) + self.pos_emb_img(xps)
|
244 |
+
h = torch.cat([sos, h[:, :-1]], dim=1).contiguous()
|
245 |
+
|
246 |
+
h = self.drop(h)
|
247 |
+
h = self.blocks(h)
|
248 |
+
h = self.ln_f(h)
|
249 |
+
logits = self.head(h)
|
250 |
+
return logits
|
251 |
+
|
252 |
+
def from_ckpt(self, path: str, strict: bool = True) -> None:
|
253 |
+
ckpt = torch.load(path, map_location='cpu')['state_dict']
|
254 |
+
self.load_state_dict(ckpt, strict=strict)
|
255 |
+
print(f'{path} successfully restored..')
|
dalle/models/tokenizer.py
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
from functools import partial
|
9 |
+
from tokenizers import CharBPETokenizer
|
10 |
+
|
11 |
+
|
12 |
+
def build_tokenizer(path: str,
|
13 |
+
context_length: int = 64,
|
14 |
+
*args,
|
15 |
+
**kwargs):
|
16 |
+
from_file = partial(CharBPETokenizer.from_file,
|
17 |
+
vocab_filename=os.path.join(path, 'bpe-16k-vocab.json'),
|
18 |
+
merges_filename=os.path.join(path, 'bpe-16k-merges.txt'),
|
19 |
+
unk_token='[UNK]')
|
20 |
+
tokenizer = from_file(*args, **kwargs)
|
21 |
+
tokenizer.add_special_tokens(['[PAD]'])
|
22 |
+
tokenizer.enable_padding(length=context_length,
|
23 |
+
pad_id=tokenizer.token_to_id('[PAD]'))
|
24 |
+
tokenizer.enable_truncation(max_length=context_length)
|
25 |
+
print(f'{path} successfully restored..')
|
26 |
+
return tokenizer
|
dalle/utils/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import *
|
2 |
+
from .config import *
|
3 |
+
from .sampling import *
|
dalle/utils/config.py
ADDED
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
from typing import Optional, List
|
8 |
+
from dataclasses import dataclass, field
|
9 |
+
from omegaconf import OmegaConf
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class DataConfig:
|
14 |
+
dataset: Optional[str] = None
|
15 |
+
tokenizer_type: str = 'CharBPE'
|
16 |
+
context_length: int = 64
|
17 |
+
image_resolution: int = 256
|
18 |
+
transforms: str = 'dalle-vqvae'
|
19 |
+
bpe_pdrop: Optional[float] = None
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class Stage1Hparams:
|
24 |
+
double_z: bool = False
|
25 |
+
z_channels: int = 256
|
26 |
+
resolution: int = 256
|
27 |
+
in_channels: int = 3
|
28 |
+
out_ch: int = 3
|
29 |
+
ch: int = 128
|
30 |
+
ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
|
31 |
+
num_res_blocks: int = 2
|
32 |
+
attn_resolutions: List[int] = field(default_factory=lambda: [16])
|
33 |
+
pdrop: float = 0.0
|
34 |
+
|
35 |
+
|
36 |
+
@dataclass
|
37 |
+
class Stage2Hparams:
|
38 |
+
embed_dim: int = 1536
|
39 |
+
n_layers: int = 42
|
40 |
+
n_heads: int = 24
|
41 |
+
n_dense_layers: int = 42
|
42 |
+
ctx_len_img: int = 256
|
43 |
+
ctx_len_txt: int = 64
|
44 |
+
embd_pdrop: float = 0.0
|
45 |
+
resid_pdrop: float = 0.0
|
46 |
+
attn_pdrop: float = 0.0
|
47 |
+
mlp_bias: bool = True
|
48 |
+
attn_bias: bool = True
|
49 |
+
gelu_use_approx: bool = False
|
50 |
+
use_head_txt: bool = True
|
51 |
+
n_classes: Optional[int] = None
|
52 |
+
|
53 |
+
|
54 |
+
@dataclass
|
55 |
+
class Stage1Config:
|
56 |
+
type: str = 'vqgan'
|
57 |
+
embed_dim: int = 256
|
58 |
+
n_embed: int = 16384
|
59 |
+
hparams: Stage1Hparams = Stage1Hparams()
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class Stage2Config:
|
64 |
+
type: str = 'transformer1d'
|
65 |
+
vocab_size_txt: int = 16384
|
66 |
+
vocab_size_img: int = 16384
|
67 |
+
use_cls_cond: Optional[bool] = None
|
68 |
+
hparams: Stage2Hparams = Stage2Hparams()
|
69 |
+
|
70 |
+
|
71 |
+
@dataclass
|
72 |
+
class WarmupConfig:
|
73 |
+
epoch: int = 1
|
74 |
+
multiplier: int = 1
|
75 |
+
buffer_epoch: int = 0
|
76 |
+
min_lr: float = 0.0
|
77 |
+
mode: str = 'fix'
|
78 |
+
peak_lr: float = 1e-4
|
79 |
+
start_from_zero: bool = True
|
80 |
+
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class OptConfig:
|
84 |
+
opt_type: str = 'adamW'
|
85 |
+
base_lr: float = 1e-4
|
86 |
+
weight_decay: float = 1e-4
|
87 |
+
betas: List[float] = field(default_factory=lambda: [0.9, 0.99])
|
88 |
+
grad_clip_norm: float = 1.0
|
89 |
+
|
90 |
+
sched_type: str = 'cosine'
|
91 |
+
max_steps: int = 0
|
92 |
+
min_lr: float = 0.0
|
93 |
+
|
94 |
+
|
95 |
+
@dataclass
|
96 |
+
class ExpConfig:
|
97 |
+
local_batch_size: int = 4
|
98 |
+
total_batch_size: int = 512
|
99 |
+
valid_batch_size: int = 32
|
100 |
+
epochs: int = 10
|
101 |
+
save_ckpt_freq: int = 2
|
102 |
+
test_freq: int = 1
|
103 |
+
use_amp: bool = True
|
104 |
+
|
105 |
+
|
106 |
+
@dataclass
|
107 |
+
class DefaultConfig:
|
108 |
+
dataset: DataConfig = DataConfig()
|
109 |
+
stage1: Stage1Config = Stage1Config()
|
110 |
+
stage2: Stage2Config = Stage2Config()
|
111 |
+
|
112 |
+
|
113 |
+
@dataclass
|
114 |
+
class FineTuningConfig:
|
115 |
+
dataset: DataConfig = DataConfig()
|
116 |
+
stage1: Stage1Config = Stage1Config()
|
117 |
+
stage2: Stage2Config = Stage2Config()
|
118 |
+
optimizer: OptConfig = OptConfig()
|
119 |
+
experiment: ExpConfig = ExpConfig()
|
120 |
+
|
121 |
+
|
122 |
+
def get_base_config(use_default=True):
|
123 |
+
return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig)
|
dalle/utils/sampling.py
ADDED
@@ -0,0 +1,152 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import torch
|
8 |
+
from typing import Optional
|
9 |
+
from tqdm import tqdm
|
10 |
+
from torch.nn import functional as F
|
11 |
+
|
12 |
+
|
13 |
+
def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor:
|
14 |
+
if k is None:
|
15 |
+
return logits
|
16 |
+
else:
|
17 |
+
v, ix = torch.topk(logits, k)
|
18 |
+
out = logits.clone()
|
19 |
+
out[out < v[:, [-1]]] = -float('Inf')
|
20 |
+
return out
|
21 |
+
|
22 |
+
|
23 |
+
def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor:
|
24 |
+
if p is None:
|
25 |
+
return probs
|
26 |
+
else:
|
27 |
+
sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
|
28 |
+
cum_probs = torch.cumsum(sorted_probs, dim=-1)
|
29 |
+
|
30 |
+
sorted_idx_remove_cond = cum_probs >= p
|
31 |
+
|
32 |
+
sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
|
33 |
+
sorted_idx_remove_cond[..., 0] = 0
|
34 |
+
|
35 |
+
indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
|
36 |
+
probs = probs.masked_fill(indices_to_remove, 0.0)
|
37 |
+
norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
|
38 |
+
return norm_probs
|
39 |
+
|
40 |
+
|
41 |
+
def get_positional_encoding(inputs: torch.LongTensor, mode: str = '1d') -> torch.LongTensor:
|
42 |
+
device = inputs.device
|
43 |
+
if mode == '1d':
|
44 |
+
B, N = inputs.shape
|
45 |
+
xs_pos = torch.arange(N, device=device).repeat((B, 1))
|
46 |
+
elif mode == '2d':
|
47 |
+
B, H, W = inputs.shape
|
48 |
+
xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2)
|
49 |
+
xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1)
|
50 |
+
xs_pos = (xs_pos_h, xs_pos_w)
|
51 |
+
else:
|
52 |
+
raise ValueError('%s positional encoding invalid' % mode)
|
53 |
+
return xs_pos
|
54 |
+
|
55 |
+
|
56 |
+
@torch.no_grad()
|
57 |
+
def sampling(model: torch.nn.Module,
|
58 |
+
tokens: torch.LongTensor,
|
59 |
+
top_k: Optional[float] = None,
|
60 |
+
top_p: Optional[float] = None,
|
61 |
+
softmax_temperature: float = 1.0,
|
62 |
+
is_tqdm: bool = True,
|
63 |
+
use_fp16: bool = True,
|
64 |
+
max_seq_len: int = 256) -> torch.LongTensor:
|
65 |
+
code = None
|
66 |
+
past = None
|
67 |
+
|
68 |
+
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
|
69 |
+
pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
|
70 |
+
|
71 |
+
for cnt, h in enumerate(pbar):
|
72 |
+
if code is None:
|
73 |
+
code_ = None
|
74 |
+
pos_enc_code_ = None
|
75 |
+
else:
|
76 |
+
code_ = code.clone().detach()
|
77 |
+
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
|
78 |
+
code_ = code_[:, cnt-1].unsqueeze(-1)
|
79 |
+
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
|
80 |
+
|
81 |
+
logits, present = model.sampling(images=code_,
|
82 |
+
texts=tokens,
|
83 |
+
pos_images=pos_enc_code_,
|
84 |
+
pos_texts=pos_enc_tokens,
|
85 |
+
use_fp16=use_fp16,
|
86 |
+
past=past)
|
87 |
+
logits = logits.to(dtype=torch.float32)
|
88 |
+
logits = logits / softmax_temperature
|
89 |
+
|
90 |
+
present = torch.stack(present).clone().detach()
|
91 |
+
if past is None:
|
92 |
+
past = [present]
|
93 |
+
else:
|
94 |
+
past.append(present)
|
95 |
+
|
96 |
+
logits = cutoff_topk_logits(logits, top_k)
|
97 |
+
probs = F.softmax(logits, dim=-1)
|
98 |
+
probs = cutoff_topp_probs(probs, top_p)
|
99 |
+
|
100 |
+
idx = torch.multinomial(probs, num_samples=1).clone().detach()
|
101 |
+
code = idx if code is None else torch.cat([code, idx], axis=1)
|
102 |
+
|
103 |
+
del past
|
104 |
+
return code
|
105 |
+
|
106 |
+
|
107 |
+
@torch.no_grad()
|
108 |
+
def sampling_igpt(model: torch.nn.Module,
|
109 |
+
sos: torch.FloatTensor,
|
110 |
+
top_k: Optional[float] = None,
|
111 |
+
top_p: Optional[float] = None,
|
112 |
+
softmax_temperature: float = 1.0,
|
113 |
+
is_tqdm: bool = True,
|
114 |
+
use_fp16: bool = True,
|
115 |
+
max_seq_len: int = 256) -> torch.LongTensor:
|
116 |
+
code = None
|
117 |
+
past = None
|
118 |
+
pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
|
119 |
+
|
120 |
+
for cnt, h in enumerate(pbar):
|
121 |
+
if code is None:
|
122 |
+
code_ = None
|
123 |
+
pos_enc_code_ = None
|
124 |
+
else:
|
125 |
+
code_ = code.clone().detach()
|
126 |
+
pos_enc_code_ = get_positional_encoding(code_, mode='1d')
|
127 |
+
code_ = code_[:, cnt-1].unsqueeze(-1)
|
128 |
+
pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
|
129 |
+
|
130 |
+
logits, present = model.sampling(sos=sos,
|
131 |
+
codes=code_,
|
132 |
+
pos_codes=pos_enc_code_,
|
133 |
+
use_fp16=use_fp16,
|
134 |
+
past=past)
|
135 |
+
logits = logits.to(dtype=torch.float32)
|
136 |
+
logits = logits / softmax_temperature
|
137 |
+
|
138 |
+
present = torch.stack(present).clone().detach()
|
139 |
+
if past is None:
|
140 |
+
past = [present]
|
141 |
+
else:
|
142 |
+
past.append(present)
|
143 |
+
|
144 |
+
logits = cutoff_topk_logits(logits, top_k)
|
145 |
+
probs = F.softmax(logits, dim=-1)
|
146 |
+
probs = cutoff_topp_probs(probs, top_p)
|
147 |
+
|
148 |
+
idx = torch.multinomial(probs, num_samples=1).clone().detach()
|
149 |
+
code = idx if code is None else torch.cat([code, idx], axis=1)
|
150 |
+
|
151 |
+
del past
|
152 |
+
return code
|
dalle/utils/utils.py
ADDED
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import urllib
|
10 |
+
import hashlib
|
11 |
+
import tarfile
|
12 |
+
import torch
|
13 |
+
import clip
|
14 |
+
import numpy as np
|
15 |
+
from PIL import Image
|
16 |
+
from torch.nn import functional as F
|
17 |
+
from tqdm import tqdm
|
18 |
+
|
19 |
+
|
20 |
+
def set_seed(seed: int):
|
21 |
+
random.seed(seed)
|
22 |
+
np.random.seed(seed)
|
23 |
+
torch.manual_seed(seed)
|
24 |
+
torch.cuda.manual_seed_all(seed)
|
25 |
+
|
26 |
+
|
27 |
+
@torch.no_grad()
|
28 |
+
def clip_score(prompt: str,
|
29 |
+
images: np.ndarray,
|
30 |
+
model_clip: torch.nn.Module,
|
31 |
+
preprocess_clip,
|
32 |
+
device: str) -> np.ndarray:
|
33 |
+
images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
|
34 |
+
images = torch.stack(images, dim=0).to(device=device)
|
35 |
+
texts = clip.tokenize(prompt).to(device=device)
|
36 |
+
texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
|
37 |
+
|
38 |
+
image_features = model_clip.encode_image(images)
|
39 |
+
text_features = model_clip.encode_text(texts)
|
40 |
+
|
41 |
+
scores = F.cosine_similarity(image_features, text_features).squeeze()
|
42 |
+
rank = torch.argsort(scores, descending=True).cpu().numpy()
|
43 |
+
return rank
|
44 |
+
|
45 |
+
|
46 |
+
def download(url: str, root: str) -> str:
|
47 |
+
os.makedirs(root, exist_ok=True)
|
48 |
+
filename = os.path.basename(url)
|
49 |
+
pathname = filename[:-len('.tar.gz')]
|
50 |
+
|
51 |
+
expected_md5 = url.split("/")[-2]
|
52 |
+
download_target = os.path.join(root, filename)
|
53 |
+
result_path = os.path.join(root, pathname)
|
54 |
+
|
55 |
+
if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
|
56 |
+
return result_path
|
57 |
+
|
58 |
+
with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
|
59 |
+
with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True,
|
60 |
+
unit_divisor=1024) as loop:
|
61 |
+
while True:
|
62 |
+
buffer = source.read(8192)
|
63 |
+
if not buffer:
|
64 |
+
break
|
65 |
+
|
66 |
+
output.write(buffer)
|
67 |
+
loop.update(len(buffer))
|
68 |
+
|
69 |
+
if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5:
|
70 |
+
raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match')
|
71 |
+
|
72 |
+
with tarfile.open(download_target, 'r:gz') as f:
|
73 |
+
pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
|
74 |
+
for member in pbar:
|
75 |
+
pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)')
|
76 |
+
f.extract(member=member, path=root)
|
77 |
+
|
78 |
+
return result_path
|
79 |
+
|
80 |
+
|
81 |
+
def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
|
82 |
+
if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'):
|
83 |
+
return download(url_or_path, root)
|
84 |
+
return url_or_path
|
examples/sampling_ex.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import argparse
|
10 |
+
import clip
|
11 |
+
import numpy as np
|
12 |
+
from PIL import Image
|
13 |
+
|
14 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
15 |
+
|
16 |
+
from dalle.models import Dalle
|
17 |
+
from dalle.utils.utils import set_seed, clip_score
|
18 |
+
|
19 |
+
|
20 |
+
parser = argparse.ArgumentParser()
|
21 |
+
parser.add_argument('-n', '--num_candidates', type=int, default=96)
|
22 |
+
parser.add_argument('--prompt', type=str, default='A painting of a tree on the ocean')
|
23 |
+
parser.add_argument('--softmax-temperature', type=float, default=1.0)
|
24 |
+
parser.add_argument('--top-k', type=int, default=256)
|
25 |
+
parser.add_argument('--top-p', type=float, default=None, help='0.0 <= top-p <= 1.0')
|
26 |
+
parser.add_argument('--seed', type=int, default=0)
|
27 |
+
|
28 |
+
args = parser.parse_args()
|
29 |
+
|
30 |
+
# Setup
|
31 |
+
assert args.top_k <= 256, "It is recommended that top_k is set lower than 256."
|
32 |
+
|
33 |
+
set_seed(args.seed)
|
34 |
+
device = 'cuda:0'
|
35 |
+
model = Dalle.from_pretrained('minDALL-E/1.3B') # This will automatically download the pretrained model.
|
36 |
+
model.to(device=device)
|
37 |
+
|
38 |
+
# Sampling
|
39 |
+
images = model.sampling(prompt=args.prompt,
|
40 |
+
top_k=args.top_k,
|
41 |
+
top_p=args.top_p,
|
42 |
+
softmax_temperature=args.softmax_temperature,
|
43 |
+
num_candidates=args.num_candidates,
|
44 |
+
device=device).cpu().numpy()
|
45 |
+
images = np.transpose(images, (0, 2, 3, 1))
|
46 |
+
|
47 |
+
# CLIP Re-ranking
|
48 |
+
model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
|
49 |
+
model_clip.to(device=device)
|
50 |
+
rank = clip_score(prompt=args.prompt,
|
51 |
+
images=images,
|
52 |
+
model_clip=model_clip,
|
53 |
+
preprocess_clip=preprocess_clip,
|
54 |
+
device=device)
|
55 |
+
|
56 |
+
# Save images
|
57 |
+
images = images[rank]
|
58 |
+
print(rank, images.shape)
|
59 |
+
if not os.path.exists('./figures'):
|
60 |
+
os.makedirs('./figures')
|
61 |
+
for i in range(min(16, args.num_candidates)):
|
62 |
+
im = Image.fromarray((images[i]*255).astype(np.uint8))
|
63 |
+
im.save(f'./figures/{args.prompt}_{i}.png')
|
examples/sampling_interactive_demo.ipynb
ADDED
@@ -0,0 +1,298 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": 1,
|
6 |
+
"id": "cdf36725-ec00-4027-95d6-374340c2264e",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [
|
9 |
+
{
|
10 |
+
"name": "stderr",
|
11 |
+
"output_type": "stream",
|
12 |
+
"text": [
|
13 |
+
"100%|█████████████████████████████████████| 4.72G/4.72G [02:04<00:00, 40.7MiB/s]\n",
|
14 |
+
"extracting: ./1.3B/tokenizer/bpe-16k-vocab.json (size:0MB): 100%|██████████| 7/7 [00:59<00:00, 8.51s/it]\n"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"name": "stdout",
|
19 |
+
"output_type": "stream",
|
20 |
+
"text": [
|
21 |
+
"/root/.cache/minDALL-E/1.3B/tokenizer successfully restored..\n",
|
22 |
+
"/root/.cache/minDALL-E/1.3B/stage1_last.ckpt successfully restored..\n"
|
23 |
+
]
|
24 |
+
},
|
25 |
+
{
|
26 |
+
"name": "stderr",
|
27 |
+
"output_type": "stream",
|
28 |
+
"text": [
|
29 |
+
" 0%| | 0.00/338M [00:00<?, ?iB/s]"
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"name": "stdout",
|
34 |
+
"output_type": "stream",
|
35 |
+
"text": [
|
36 |
+
"/root/.cache/minDALL-E/1.3B/stage2_last.ckpt succesfully restored..\n"
|
37 |
+
]
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"name": "stderr",
|
41 |
+
"output_type": "stream",
|
42 |
+
"text": [
|
43 |
+
"100%|███████████████████████████████████████| 338M/338M [00:09<00:00, 38.5MiB/s]\n"
|
44 |
+
]
|
45 |
+
}
|
46 |
+
],
|
47 |
+
"source": [
|
48 |
+
"import os\n",
|
49 |
+
"import sys\n",
|
50 |
+
"import math\n",
|
51 |
+
"import argparse\n",
|
52 |
+
"import clip\n",
|
53 |
+
"import numpy as np\n",
|
54 |
+
"%matplotlib inline\n",
|
55 |
+
"import matplotlib.pyplot as plt\n",
|
56 |
+
"from PIL import Image\n",
|
57 |
+
"\n",
|
58 |
+
"sys.path.append(os.path.dirname(os.getcwd()))\n",
|
59 |
+
"\n",
|
60 |
+
"from dalle.models import Dalle\n",
|
61 |
+
"from dalle.utils.utils import set_seed, clip_score\n",
|
62 |
+
"\n",
|
63 |
+
"device = 'cuda:0'\n",
|
64 |
+
"model = Dalle.from_pretrained(\"minDALL-E/1.3B\")\n",
|
65 |
+
"model_clip, preprocess_clip = clip.load(\"ViT-B/32\", device=device)\n",
|
66 |
+
"\n",
|
67 |
+
"model_clip.to(device=device)\n",
|
68 |
+
"model.to(device=device)\n",
|
69 |
+
"\n",
|
70 |
+
"def sampling(prompt, top_k, softmax_temperature, seed, num_candidates=96, num_samples_for_display=36):\n",
|
71 |
+
" # Setup\n",
|
72 |
+
" n_row = int(math.sqrt(num_samples_for_display))\n",
|
73 |
+
" n_col = int(math.sqrt(num_samples_for_display))\n",
|
74 |
+
" set_seed(seed)\n",
|
75 |
+
" \n",
|
76 |
+
" # Sampling\n",
|
77 |
+
" images = model.sampling(prompt=prompt,\n",
|
78 |
+
" top_k=top_k,\n",
|
79 |
+
" top_p=None,\n",
|
80 |
+
" softmax_temperature=softmax_temperature,\n",
|
81 |
+
" num_candidates=num_candidates,\n",
|
82 |
+
" device=device).cpu().numpy()\n",
|
83 |
+
" images = np.transpose(images, (0, 2, 3, 1))\n",
|
84 |
+
"\n",
|
85 |
+
" # CLIP Re-ranking\n",
|
86 |
+
" rank = clip_score(prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device)\n",
|
87 |
+
" images = images[rank]\n",
|
88 |
+
" \n",
|
89 |
+
" images = images[:num_samples_for_display]\n",
|
90 |
+
" fig = plt.figure(figsize=(8*n_row, 8*n_col))\n",
|
91 |
+
"\n",
|
92 |
+
" for i in range(num_samples_for_display):\n",
|
93 |
+
" ax = fig.add_subplot(n_row, n_col, i+1)\n",
|
94 |
+
" ax.imshow(images[i])\n",
|
95 |
+
" ax.set_axis_off()\n",
|
96 |
+
"\n",
|
97 |
+
" plt.tight_layout()\n",
|
98 |
+
" plt.show()"
|
99 |
+
]
|
100 |
+
},
|
101 |
+
{
|
102 |
+
"cell_type": "code",
|
103 |
+
"execution_count": 2,
|
104 |
+
"id": "619add15-073e-40f4-9a97-06b89d647c81",
|
105 |
+
"metadata": {},
|
106 |
+
"outputs": [
|
107 |
+
{
|
108 |
+
"data": {
|
109 |
+
"application/vnd.jupyter.widget-view+json": {
|
110 |
+
"model_id": "ee477531ea0e4b86b20d997f8cb83767",
|
111 |
+
"version_major": 2,
|
112 |
+
"version_minor": 0
|
113 |
+
},
|
114 |
+
"text/plain": [
|
115 |
+
"IntSlider(value=0, description='RND SEED: ', max=1024)"
|
116 |
+
]
|
117 |
+
},
|
118 |
+
"metadata": {},
|
119 |
+
"output_type": "display_data"
|
120 |
+
},
|
121 |
+
{
|
122 |
+
"data": {
|
123 |
+
"application/vnd.jupyter.widget-view+json": {
|
124 |
+
"model_id": "d63edc4725ef4f4e8a6f03f7693a481d",
|
125 |
+
"version_major": 2,
|
126 |
+
"version_minor": 0
|
127 |
+
},
|
128 |
+
"text/plain": [
|
129 |
+
"FloatSlider(value=1.0, description='SOFTMAX TEMPERATURE:', max=5.0, step=0.2)"
|
130 |
+
]
|
131 |
+
},
|
132 |
+
"metadata": {},
|
133 |
+
"output_type": "display_data"
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"data": {
|
137 |
+
"application/vnd.jupyter.widget-view+json": {
|
138 |
+
"model_id": "5bb9170e9e8b4686a661799d8aff3901",
|
139 |
+
"version_major": 2,
|
140 |
+
"version_minor": 0
|
141 |
+
},
|
142 |
+
"text/plain": [
|
143 |
+
"IntSlider(value=256, description='TOP-K:', max=512, step=16)"
|
144 |
+
]
|
145 |
+
},
|
146 |
+
"metadata": {},
|
147 |
+
"output_type": "display_data"
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"data": {
|
151 |
+
"application/vnd.jupyter.widget-view+json": {
|
152 |
+
"model_id": "6b97b49debfc4f7ab002748e9fd89864",
|
153 |
+
"version_major": 2,
|
154 |
+
"version_minor": 0
|
155 |
+
},
|
156 |
+
"text/plain": [
|
157 |
+
"Text(value='A painting of a monkey with sunglasses in the frame', description='String:', placeholder='Text pro…"
|
158 |
+
]
|
159 |
+
},
|
160 |
+
"metadata": {},
|
161 |
+
"output_type": "display_data"
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"data": {
|
165 |
+
"application/vnd.jupyter.widget-view+json": {
|
166 |
+
"model_id": "a520b10d8c0b4dd0bb6db56dc37b4422",
|
167 |
+
"version_major": 2,
|
168 |
+
"version_minor": 0
|
169 |
+
},
|
170 |
+
"text/plain": [
|
171 |
+
"Button(description='Generate!', style=ButtonStyle())"
|
172 |
+
]
|
173 |
+
},
|
174 |
+
"metadata": {},
|
175 |
+
"output_type": "display_data"
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"data": {
|
179 |
+
"application/vnd.jupyter.widget-view+json": {
|
180 |
+
"model_id": "5a98437abf964636a467677dc4f816bb",
|
181 |
+
"version_major": 2,
|
182 |
+
"version_minor": 0
|
183 |
+
},
|
184 |
+
"text/plain": [
|
185 |
+
"Output()"
|
186 |
+
]
|
187 |
+
},
|
188 |
+
"metadata": {},
|
189 |
+
"output_type": "display_data"
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"data": {
|
193 |
+
"application/vnd.jupyter.widget-view+json": {
|
194 |
+
"model_id": "90d05006d50e4d88b8fb7c36095b12e7",
|
195 |
+
"version_major": 2,
|
196 |
+
"version_minor": 0
|
197 |
+
},
|
198 |
+
"text/plain": [
|
199 |
+
"Output()"
|
200 |
+
]
|
201 |
+
},
|
202 |
+
"metadata": {},
|
203 |
+
"output_type": "display_data"
|
204 |
+
}
|
205 |
+
],
|
206 |
+
"source": [
|
207 |
+
"import ipywidgets as widgets\n",
|
208 |
+
"from IPython.display import display\n",
|
209 |
+
"from IPython.display import clear_output\n",
|
210 |
+
"\n",
|
211 |
+
"output = widgets.Output()\n",
|
212 |
+
"plot_output = widgets.Output()\n",
|
213 |
+
"\n",
|
214 |
+
"def btn_eventhandler(obj):\n",
|
215 |
+
" output.clear_output()\n",
|
216 |
+
" plot_output.clear_output()\n",
|
217 |
+
" \n",
|
218 |
+
" with output:\n",
|
219 |
+
" print(f'SEED: {slider_seed.value}')\n",
|
220 |
+
" print(f'Softmax Temperature: {slider_temp.value}')\n",
|
221 |
+
" print(f'Top-K: {slider_topk.value}')\n",
|
222 |
+
" print(f'Text prompt: {wd_text.value}')\n",
|
223 |
+
" \n",
|
224 |
+
" with plot_output:\n",
|
225 |
+
" sampling(prompt=wd_text.value, top_k=slider_topk.value, softmax_temperature=slider_temp.value, seed=slider_seed.value)\n",
|
226 |
+
" \n",
|
227 |
+
"slider_seed = widgets.IntSlider(\n",
|
228 |
+
" min=0,\n",
|
229 |
+
" max=1024,\n",
|
230 |
+
" step=1,\n",
|
231 |
+
" description='RND SEED: ',\n",
|
232 |
+
" value=0\n",
|
233 |
+
")\n",
|
234 |
+
"slider_topk = widgets.IntSlider(\n",
|
235 |
+
" min=0,\n",
|
236 |
+
" max=512,\n",
|
237 |
+
" step=16,\n",
|
238 |
+
" description='TOP-K:',\n",
|
239 |
+
" value=256\n",
|
240 |
+
")\n",
|
241 |
+
"slider_temp = widgets.FloatSlider(\n",
|
242 |
+
" min=0.0,\n",
|
243 |
+
" max=5.0,\n",
|
244 |
+
" step=0.2,\n",
|
245 |
+
" description='SOFTMAX TEMPERATURE:',\n",
|
246 |
+
" value=1.0\n",
|
247 |
+
")\n",
|
248 |
+
"wd_text = widgets.Text(\n",
|
249 |
+
" value='A painting of a monkey with sunglasses in the frame',\n",
|
250 |
+
" placeholder='Text prompt',\n",
|
251 |
+
" description='String:',\n",
|
252 |
+
" disabled=False\n",
|
253 |
+
")\n",
|
254 |
+
"\n",
|
255 |
+
"display(slider_seed)\n",
|
256 |
+
"display(slider_temp)\n",
|
257 |
+
"display(slider_topk)\n",
|
258 |
+
"display(wd_text)\n",
|
259 |
+
"\n",
|
260 |
+
"btn = widgets.Button(description='Generate!')\n",
|
261 |
+
"display(btn)\n",
|
262 |
+
"btn.on_click(btn_eventhandler)\n",
|
263 |
+
"\n",
|
264 |
+
"display(output)\n",
|
265 |
+
"display(plot_output)"
|
266 |
+
]
|
267 |
+
},
|
268 |
+
{
|
269 |
+
"cell_type": "code",
|
270 |
+
"execution_count": null,
|
271 |
+
"id": "20571236-3b9a-426e-ab29-96b643c8cbe1",
|
272 |
+
"metadata": {},
|
273 |
+
"outputs": [],
|
274 |
+
"source": []
|
275 |
+
}
|
276 |
+
],
|
277 |
+
"metadata": {
|
278 |
+
"kernelspec": {
|
279 |
+
"display_name": "Python 3 (ipykernel)",
|
280 |
+
"language": "python",
|
281 |
+
"name": "python3"
|
282 |
+
},
|
283 |
+
"language_info": {
|
284 |
+
"codemirror_mode": {
|
285 |
+
"name": "ipython",
|
286 |
+
"version": 3
|
287 |
+
},
|
288 |
+
"file_extension": ".py",
|
289 |
+
"mimetype": "text/x-python",
|
290 |
+
"name": "python",
|
291 |
+
"nbconvert_exporter": "python",
|
292 |
+
"pygments_lexer": "ipython3",
|
293 |
+
"version": "3.7.7"
|
294 |
+
}
|
295 |
+
},
|
296 |
+
"nbformat": 4,
|
297 |
+
"nbformat_minor": 5
|
298 |
+
}
|
examples/transfer_learning_ex.py
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ------------------------------------------------------------------------------------
|
2 |
+
# Minimal DALL-E
|
3 |
+
# Copyright (c) 2021 KakaoBrain. All Rights Reserved.
|
4 |
+
# Licensed under the Apache License, Version 2.0 [see LICENSE for details]
|
5 |
+
# ------------------------------------------------------------------------------------
|
6 |
+
|
7 |
+
import os
|
8 |
+
import sys
|
9 |
+
import argparse
|
10 |
+
from typing import Optional
|
11 |
+
from datetime import datetime
|
12 |
+
|
13 |
+
import torch
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
import torchvision
|
16 |
+
import torchvision.transforms as transforms
|
17 |
+
import pytorch_lightning as pl
|
18 |
+
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
|
19 |
+
from pytorch_lightning.loggers import TensorBoardLogger
|
20 |
+
from pytorch_lightning.utilities.distributed import rank_zero_only
|
21 |
+
|
22 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
|
23 |
+
|
24 |
+
from dalle.models import ImageGPT
|
25 |
+
|
26 |
+
|
27 |
+
parser = argparse.ArgumentParser()
|
28 |
+
|
29 |
+
parser.add_argument('-d', '--config-downstream', type=str, default=None, required=True)
|
30 |
+
parser.add_argument('-u', '--path-upstream', type=str, default=None, required=True)
|
31 |
+
parser.add_argument('-r', '--result-path', type=str, default=None, required=True)
|
32 |
+
parser.add_argument('--imagenet-path', type=str, default=None, required=True)
|
33 |
+
|
34 |
+
parser.add_argument('--n-gpus', type=int, default=1)
|
35 |
+
parser.add_argument('--seed', type=int, default=0)
|
36 |
+
|
37 |
+
|
38 |
+
args = parser.parse_args()
|
39 |
+
|
40 |
+
|
41 |
+
class ImageLogger(Callback):
|
42 |
+
def __init__(self):
|
43 |
+
super().__init__()
|
44 |
+
|
45 |
+
@rank_zero_only
|
46 |
+
def log_img(self, pl_module, batch, current_epoch, split="train"):
|
47 |
+
with torch.no_grad():
|
48 |
+
images, labels = batch
|
49 |
+
recons = pl_module.stage1(images)
|
50 |
+
images = images.cpu()
|
51 |
+
recons = recons.cpu()
|
52 |
+
|
53 |
+
grid_org = (torchvision.utils.make_grid(images, nrow=8) + 1.0) / 2.0
|
54 |
+
grid_rec = (torchvision.utils.make_grid(recons, nrow=8) + 1.0) / 2.0
|
55 |
+
grid_rec = torch.clip(grid_rec, min=0, max=1)
|
56 |
+
|
57 |
+
pl_module.logger.experiment.add_image(f"images_org/{split}", grid_org, global_step=current_epoch)
|
58 |
+
pl_module.logger.experiment.add_image(f"images_rec/{split}", grid_rec, global_step=current_epoch)
|
59 |
+
|
60 |
+
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
61 |
+
if batch_idx == 0 and trainer.current_epoch < 5:
|
62 |
+
self.log_img(pl_module, batch, current_epoch=trainer.current_epoch, split="train")
|
63 |
+
|
64 |
+
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
65 |
+
if batch_idx == 0 and trainer.current_epoch < 5:
|
66 |
+
self.log_img(pl_module, batch, current_epoch=trainer.current_epoch, split="test")
|
67 |
+
|
68 |
+
|
69 |
+
class ImageNetDataModule(pl.LightningDataModule):
|
70 |
+
def __init__(self,
|
71 |
+
data_dir: Optional[str] = None,
|
72 |
+
image_resolution: int = 256,
|
73 |
+
train_batch_size: int = 2,
|
74 |
+
valid_batch_size: int = 32,
|
75 |
+
num_workers: int = 8):
|
76 |
+
super().__init__()
|
77 |
+
|
78 |
+
self.data_dir = data_dir
|
79 |
+
self.image_resolution = image_resolution
|
80 |
+
self.train_batch_size = train_batch_size
|
81 |
+
self.valid_batch_size = valid_batch_size
|
82 |
+
self.num_workers = num_workers
|
83 |
+
|
84 |
+
self.train_transform = transforms.Compose(
|
85 |
+
[transforms.Resize(image_resolution),
|
86 |
+
transforms.RandomCrop(image_resolution),
|
87 |
+
transforms.ToTensor(),
|
88 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
|
89 |
+
)
|
90 |
+
self.valid_transform = transforms.Compose(
|
91 |
+
[transforms.Resize(image_resolution),
|
92 |
+
transforms.CenterCrop(image_resolution),
|
93 |
+
transforms.ToTensor(),
|
94 |
+
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
|
95 |
+
)
|
96 |
+
|
97 |
+
def setup(self, stage=None):
|
98 |
+
self.trainset = torchvision.datasets.ImageNet(root=self.data_dir, split='train', transform=self.train_transform)
|
99 |
+
self.validset = torchvision.datasets.ImageNet(root=self.data_dir, split='val', transform=self.valid_transform)
|
100 |
+
|
101 |
+
def train_dataloader(self):
|
102 |
+
return DataLoader(self.trainset,
|
103 |
+
batch_size=self.train_batch_size,
|
104 |
+
num_workers=self.num_workers,
|
105 |
+
pin_memory=True)
|
106 |
+
|
107 |
+
def valid_dataloader(self):
|
108 |
+
return DataLoader(self.validset,
|
109 |
+
batch_size=self.valid_batch_size,
|
110 |
+
num_workers=self.num_workers,
|
111 |
+
pin_memory=True)
|
112 |
+
|
113 |
+
|
114 |
+
def setup_callbacks(config):
|
115 |
+
# Setup callbacks
|
116 |
+
now = datetime.now().strftime('%d%m%Y_%H%M%S')
|
117 |
+
result_path = os.path.join(args.result_path,
|
118 |
+
os.path.basename(args.config_downstream).split('.')[0],
|
119 |
+
now)
|
120 |
+
ckpt_path = os.path.join(result_path, 'ckpt')
|
121 |
+
log_path = os.path.join(result_path, 'log')
|
122 |
+
|
123 |
+
checkpoint_callback = ModelCheckpoint(
|
124 |
+
dirpath=ckpt_path,
|
125 |
+
filename="imagenet-clscond-gen-{epoch:02d}" if config.stage2.use_cls_cond else
|
126 |
+
"imagenet-uncond-gen-{epoch:02d}",
|
127 |
+
every_n_epochs=config.experiment.save_ckpt_freq,
|
128 |
+
save_weights_only=True,
|
129 |
+
save_last=True
|
130 |
+
)
|
131 |
+
logger = TensorBoardLogger(log_path, name="iGPT")
|
132 |
+
logger_img = ImageLogger()
|
133 |
+
return checkpoint_callback, logger, logger_img
|
134 |
+
|
135 |
+
|
136 |
+
if __name__ == '__main__':
|
137 |
+
pl.seed_everything(args.seed)
|
138 |
+
|
139 |
+
# Build iGPT
|
140 |
+
model, config = ImageGPT.from_pretrained(args.path_upstream, args.config_downstream)
|
141 |
+
|
142 |
+
# Setup callbacks
|
143 |
+
ckpt_callback, logger, logger_img = setup_callbacks(config)
|
144 |
+
|
145 |
+
# Build data modules
|
146 |
+
dataset = ImageNetDataModule(data_dir=args.imagenet_path,
|
147 |
+
image_resolution=config.dataset.image_resolution,
|
148 |
+
train_batch_size=config.experiment.local_batch_size,
|
149 |
+
valid_batch_size=config.experiment.valid_batch_size,
|
150 |
+
num_workers=16)
|
151 |
+
dataset.setup()
|
152 |
+
train_dataloader = dataset.train_dataloader()
|
153 |
+
valid_dataloader = dataset.valid_dataloader()
|
154 |
+
print(f"len(train_dataset) = {len(dataset.trainset)}")
|
155 |
+
print(f"len(valid_dataset) = {len(dataset.validset)}")
|
156 |
+
|
157 |
+
# Calculate how many batches are accumulated
|
158 |
+
assert config.experiment.total_batch_size % (config.experiment.local_batch_size * args.n_gpus) == 0
|
159 |
+
grad_accm_steps = config.experiment.total_batch_size // (config.experiment.local_batch_size * args.n_gpus)
|
160 |
+
config.optimizer.max_steps = len(dataset.trainset) // config.experiment.total_batch_size * config.experiment.epochs
|
161 |
+
|
162 |
+
# Build trainer
|
163 |
+
trainer = pl.Trainer(max_epochs=config.experiment.epochs,
|
164 |
+
accumulate_grad_batches=grad_accm_steps,
|
165 |
+
gradient_clip_val=config.optimizer.grad_clip_norm,
|
166 |
+
precision=16 if config.experiment.use_amp else 32,
|
167 |
+
callbacks=[ckpt_callback, logger_img],
|
168 |
+
accelerator="gpu",
|
169 |
+
devices=args.n_gpus,
|
170 |
+
strategy="ddp",
|
171 |
+
logger=logger)
|
172 |
+
trainer.fit(model, train_dataloader, valid_dataloader)
|
requirements.txt
ADDED
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
torch==1.8.0
|
2 |
+
torchvision>=0.8.2
|
3 |
+
tokenizers>=0.10.2
|
4 |
+
pyflakes>=2.2.0
|
5 |
+
tqdm>=4.46.0
|
6 |
+
pytorch-lightning>=1.5
|
7 |
+
einops
|
8 |
+
omegaconf
|
9 |
+
git+https://github.com/openai/CLIP.git
|
10 |
+
matplotlib
|
setup.cfg
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
[flake8]
|
2 |
+
max-line-length = 120
|
3 |
+
ignore = E226, E402, W504
|