valhalla commited on
Commit
b442155
1 Parent(s): 4e3891e
CITATION.cff ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ message: "If you find this repository useful in your research, please cite"
3
+ authors:
4
+ - family-names: Kim
5
+ given-names: Saehoon
6
+ - family-names: Cho
7
+ given-names: Sanghun
8
+ - family-names: Kim
9
+ given-names: Chiheon
10
+ - family-names: Lee
11
+ given-names: Doyup
12
+ - family-names: Baek
13
+ given-names: Woonhyuk
14
+ title: "minDALL-E on Conceptual Captions"
15
+ version: 0.1
16
+ date-released: 2021-12-14
17
+ repository-code: https://github.com/kakaobrain/minDALL-E
LICENSE ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ The `source codes` are licensed under [Apache 2.0](LICENSE.apache-2.0) License.
2
+ The `stage2 pretrained weights` are licensed under [CC-BY-NC-SA 4.0](https://creativecommons.org/licenses/by-nc-sa/4.0/) License.
LICENSE.apache-2.0 ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [2021.11.13] [Kakao Brain]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
LICENSE.cc-by-nc-sa-4.0 ADDED
@@ -0,0 +1,437 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Attribution-NonCommercial-ShareAlike 4.0 International
2
+
3
+ =======================================================================
4
+
5
+ Creative Commons Corporation ("Creative Commons") is not a law firm and
6
+ does not provide legal services or legal advice. Distribution of
7
+ Creative Commons public licenses does not create a lawyer-client or
8
+ other relationship. Creative Commons makes its licenses and related
9
+ information available on an "as-is" basis. Creative Commons gives no
10
+ warranties regarding its licenses, any material licensed under their
11
+ terms and conditions, or any related information. Creative Commons
12
+ disclaims all liability for damages resulting from their use to the
13
+ fullest extent possible.
14
+
15
+ Using Creative Commons Public Licenses
16
+
17
+ Creative Commons public licenses provide a standard set of terms and
18
+ conditions that creators and other rights holders may use to share
19
+ original works of authorship and other material subject to copyright
20
+ and certain other rights specified in the public license below. The
21
+ following considerations are for informational purposes only, are not
22
+ exhaustive, and do not form part of our licenses.
23
+
24
+ Considerations for licensors: Our public licenses are
25
+ intended for use by those authorized to give the public
26
+ permission to use material in ways otherwise restricted by
27
+ copyright and certain other rights. Our licenses are
28
+ irrevocable. Licensors should read and understand the terms
29
+ and conditions of the license they choose before applying it.
30
+ Licensors should also secure all rights necessary before
31
+ applying our licenses so that the public can reuse the
32
+ material as expected. Licensors should clearly mark any
33
+ material not subject to the license. This includes other CC-
34
+ licensed material, or material used under an exception or
35
+ limitation to copyright. More considerations for licensors:
36
+ wiki.creativecommons.org/Considerations_for_licensors
37
+
38
+ Considerations for the public: By using one of our public
39
+ licenses, a licensor grants the public permission to use the
40
+ licensed material under specified terms and conditions. If
41
+ the licensor's permission is not necessary for any reason--for
42
+ example, because of any applicable exception or limitation to
43
+ copyright--then that use is not regulated by the license. Our
44
+ licenses grant only permissions under copyright and certain
45
+ other rights that a licensor has authority to grant. Use of
46
+ the licensed material may still be restricted for other
47
+ reasons, including because others have copyright or other
48
+ rights in the material. A licensor may make special requests,
49
+ such as asking that all changes be marked or described.
50
+ Although not required by our licenses, you are encouraged to
51
+ respect those requests where reasonable. More_considerations
52
+ for the public:
53
+ wiki.creativecommons.org/Considerations_for_licensees
54
+
55
+ =======================================================================
56
+
57
+ Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
58
+ Public License
59
+
60
+ By exercising the Licensed Rights (defined below), You accept and agree
61
+ to be bound by the terms and conditions of this Creative Commons
62
+ Attribution-NonCommercial-ShareAlike 4.0 International Public License
63
+ ("Public License"). To the extent this Public License may be
64
+ interpreted as a contract, You are granted the Licensed Rights in
65
+ consideration of Your acceptance of these terms and conditions, and the
66
+ Licensor grants You such rights in consideration of benefits the
67
+ Licensor receives from making the Licensed Material available under
68
+ these terms and conditions.
69
+
70
+
71
+ Section 1 -- Definitions.
72
+
73
+ a. Adapted Material means material subject to Copyright and Similar
74
+ Rights that is derived from or based upon the Licensed Material
75
+ and in which the Licensed Material is translated, altered,
76
+ arranged, transformed, or otherwise modified in a manner requiring
77
+ permission under the Copyright and Similar Rights held by the
78
+ Licensor. For purposes of this Public License, where the Licensed
79
+ Material is a musical work, performance, or sound recording,
80
+ Adapted Material is always produced where the Licensed Material is
81
+ synched in timed relation with a moving image.
82
+
83
+ b. Adapter's License means the license You apply to Your Copyright
84
+ and Similar Rights in Your contributions to Adapted Material in
85
+ accordance with the terms and conditions of this Public License.
86
+
87
+ c. BY-NC-SA Compatible License means a license listed at
88
+ creativecommons.org/compatiblelicenses, approved by Creative
89
+ Commons as essentially the equivalent of this Public License.
90
+
91
+ d. Copyright and Similar Rights means copyright and/or similar rights
92
+ closely related to copyright including, without limitation,
93
+ performance, broadcast, sound recording, and Sui Generis Database
94
+ Rights, without regard to how the rights are labeled or
95
+ categorized. For purposes of this Public License, the rights
96
+ specified in Section 2(b)(1)-(2) are not Copyright and Similar
97
+ Rights.
98
+
99
+ e. Effective Technological Measures means those measures that, in the
100
+ absence of proper authority, may not be circumvented under laws
101
+ fulfilling obligations under Article 11 of the WIPO Copyright
102
+ Treaty adopted on December 20, 1996, and/or similar international
103
+ agreements.
104
+
105
+ f. Exceptions and Limitations means fair use, fair dealing, and/or
106
+ any other exception or limitation to Copyright and Similar Rights
107
+ that applies to Your use of the Licensed Material.
108
+
109
+ g. License Elements means the license attributes listed in the name
110
+ of a Creative Commons Public License. The License Elements of this
111
+ Public License are Attribution, NonCommercial, and ShareAlike.
112
+
113
+ h. Licensed Material means the artistic or literary work, database,
114
+ or other material to which the Licensor applied this Public
115
+ License.
116
+
117
+ i. Licensed Rights means the rights granted to You subject to the
118
+ terms and conditions of this Public License, which are limited to
119
+ all Copyright and Similar Rights that apply to Your use of the
120
+ Licensed Material and that the Licensor has authority to license.
121
+
122
+ j. Licensor means the individual(s) or entity(ies) granting rights
123
+ under this Public License.
124
+
125
+ k. NonCommercial means not primarily intended for or directed towards
126
+ commercial advantage or monetary compensation. For purposes of
127
+ this Public License, the exchange of the Licensed Material for
128
+ other material subject to Copyright and Similar Rights by digital
129
+ file-sharing or similar means is NonCommercial provided there is
130
+ no payment of monetary compensation in connection with the
131
+ exchange.
132
+
133
+ l. Share means to provide material to the public by any means or
134
+ process that requires permission under the Licensed Rights, such
135
+ as reproduction, public display, public performance, distribution,
136
+ dissemination, communication, or importation, and to make material
137
+ available to the public including in ways that members of the
138
+ public may access the material from a place and at a time
139
+ individually chosen by them.
140
+
141
+ m. Sui Generis Database Rights means rights other than copyright
142
+ resulting from Directive 96/9/EC of the European Parliament and of
143
+ the Council of 11 March 1996 on the legal protection of databases,
144
+ as amended and/or succeeded, as well as other essentially
145
+ equivalent rights anywhere in the world.
146
+
147
+ n. You means the individual or entity exercising the Licensed Rights
148
+ under this Public License. Your has a corresponding meaning.
149
+
150
+
151
+ Section 2 -- Scope.
152
+
153
+ a. License grant.
154
+
155
+ 1. Subject to the terms and conditions of this Public License,
156
+ the Licensor hereby grants You a worldwide, royalty-free,
157
+ non-sublicensable, non-exclusive, irrevocable license to
158
+ exercise the Licensed Rights in the Licensed Material to:
159
+
160
+ a. reproduce and Share the Licensed Material, in whole or
161
+ in part, for NonCommercial purposes only; and
162
+
163
+ b. produce, reproduce, and Share Adapted Material for
164
+ NonCommercial purposes only.
165
+
166
+ 2. Exceptions and Limitations. For the avoidance of doubt, where
167
+ Exceptions and Limitations apply to Your use, this Public
168
+ License does not apply, and You do not need to comply with
169
+ its terms and conditions.
170
+
171
+ 3. Term. The term of this Public License is specified in Section
172
+ 6(a).
173
+
174
+ 4. Media and formats; technical modifications allowed. The
175
+ Licensor authorizes You to exercise the Licensed Rights in
176
+ all media and formats whether now known or hereafter created,
177
+ and to make technical modifications necessary to do so. The
178
+ Licensor waives and/or agrees not to assert any right or
179
+ authority to forbid You from making technical modifications
180
+ necessary to exercise the Licensed Rights, including
181
+ technical modifications necessary to circumvent Effective
182
+ Technological Measures. For purposes of this Public License,
183
+ simply making modifications authorized by this Section 2(a)
184
+ (4) never produces Adapted Material.
185
+
186
+ 5. Downstream recipients.
187
+
188
+ a. Offer from the Licensor -- Licensed Material. Every
189
+ recipient of the Licensed Material automatically
190
+ receives an offer from the Licensor to exercise the
191
+ Licensed Rights under the terms and conditions of this
192
+ Public License.
193
+
194
+ b. Additional offer from the Licensor -- Adapted Material.
195
+ Every recipient of Adapted Material from You
196
+ automatically receives an offer from the Licensor to
197
+ exercise the Licensed Rights in the Adapted Material
198
+ under the conditions of the Adapter's License You apply.
199
+
200
+ c. No downstream restrictions. You may not offer or impose
201
+ any additional or different terms or conditions on, or
202
+ apply any Effective Technological Measures to, the
203
+ Licensed Material if doing so restricts exercise of the
204
+ Licensed Rights by any recipient of the Licensed
205
+ Material.
206
+
207
+ 6. No endorsement. Nothing in this Public License constitutes or
208
+ may be construed as permission to assert or imply that You
209
+ are, or that Your use of the Licensed Material is, connected
210
+ with, or sponsored, endorsed, or granted official status by,
211
+ the Licensor or others designated to receive attribution as
212
+ provided in Section 3(a)(1)(A)(i).
213
+
214
+ b. Other rights.
215
+
216
+ 1. Moral rights, such as the right of integrity, are not
217
+ licensed under this Public License, nor are publicity,
218
+ privacy, and/or other similar personality rights; however, to
219
+ the extent possible, the Licensor waives and/or agrees not to
220
+ assert any such rights held by the Licensor to the limited
221
+ extent necessary to allow You to exercise the Licensed
222
+ Rights, but not otherwise.
223
+
224
+ 2. Patent and trademark rights are not licensed under this
225
+ Public License.
226
+
227
+ 3. To the extent possible, the Licensor waives any right to
228
+ collect royalties from You for the exercise of the Licensed
229
+ Rights, whether directly or through a collecting society
230
+ under any voluntary or waivable statutory or compulsory
231
+ licensing scheme. In all other cases the Licensor expressly
232
+ reserves any right to collect such royalties, including when
233
+ the Licensed Material is used other than for NonCommercial
234
+ purposes.
235
+
236
+
237
+ Section 3 -- License Conditions.
238
+
239
+ Your exercise of the Licensed Rights is expressly made subject to the
240
+ following conditions.
241
+
242
+ a. Attribution.
243
+
244
+ 1. If You Share the Licensed Material (including in modified
245
+ form), You must:
246
+
247
+ a. retain the following if it is supplied by the Licensor
248
+ with the Licensed Material:
249
+
250
+ i. identification of the creator(s) of the Licensed
251
+ Material and any others designated to receive
252
+ attribution, in any reasonable manner requested by
253
+ the Licensor (including by pseudonym if
254
+ designated);
255
+
256
+ ii. a copyright notice;
257
+
258
+ iii. a notice that refers to this Public License;
259
+
260
+ iv. a notice that refers to the disclaimer of
261
+ warranties;
262
+
263
+ v. a URI or hyperlink to the Licensed Material to the
264
+ extent reasonably practicable;
265
+
266
+ b. indicate if You modified the Licensed Material and
267
+ retain an indication of any previous modifications; and
268
+
269
+ c. indicate the Licensed Material is licensed under this
270
+ Public License, and include the text of, or the URI or
271
+ hyperlink to, this Public License.
272
+
273
+ 2. You may satisfy the conditions in Section 3(a)(1) in any
274
+ reasonable manner based on the medium, means, and context in
275
+ which You Share the Licensed Material. For example, it may be
276
+ reasonable to satisfy the conditions by providing a URI or
277
+ hyperlink to a resource that includes the required
278
+ information.
279
+ 3. If requested by the Licensor, You must remove any of the
280
+ information required by Section 3(a)(1)(A) to the extent
281
+ reasonably practicable.
282
+
283
+ b. ShareAlike.
284
+
285
+ In addition to the conditions in Section 3(a), if You Share
286
+ Adapted Material You produce, the following conditions also apply.
287
+
288
+ 1. The Adapter's License You apply must be a Creative Commons
289
+ license with the same License Elements, this version or
290
+ later, or a BY-NC-SA Compatible License.
291
+
292
+ 2. You must include the text of, or the URI or hyperlink to, the
293
+ Adapter's License You apply. You may satisfy this condition
294
+ in any reasonable manner based on the medium, means, and
295
+ context in which You Share Adapted Material.
296
+
297
+ 3. You may not offer or impose any additional or different terms
298
+ or conditions on, or apply any Effective Technological
299
+ Measures to, Adapted Material that restrict exercise of the
300
+ rights granted under the Adapter's License You apply.
301
+
302
+
303
+ Section 4 -- Sui Generis Database Rights.
304
+
305
+ Where the Licensed Rights include Sui Generis Database Rights that
306
+ apply to Your use of the Licensed Material:
307
+
308
+ a. for the avoidance of doubt, Section 2(a)(1) grants You the right
309
+ to extract, reuse, reproduce, and Share all or a substantial
310
+ portion of the contents of the database for NonCommercial purposes
311
+ only;
312
+
313
+ b. if You include all or a substantial portion of the database
314
+ contents in a database in which You have Sui Generis Database
315
+ Rights, then the database in which You have Sui Generis Database
316
+ Rights (but not its individual contents) is Adapted Material,
317
+ including for purposes of Section 3(b); and
318
+
319
+ c. You must comply with the conditions in Section 3(a) if You Share
320
+ all or a substantial portion of the contents of the database.
321
+
322
+ For the avoidance of doubt, this Section 4 supplements and does not
323
+ replace Your obligations under this Public License where the Licensed
324
+ Rights include other Copyright and Similar Rights.
325
+
326
+
327
+ Section 5 -- Disclaimer of Warranties and Limitation of Liability.
328
+
329
+ a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
330
+ EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
331
+ AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
332
+ ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
333
+ IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
334
+ WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
335
+ PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
336
+ ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
337
+ KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
338
+ ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
339
+
340
+ b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
341
+ TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
342
+ NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
343
+ INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
344
+ COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
345
+ USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
346
+ ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
347
+ DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
348
+ IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
349
+
350
+ c. The disclaimer of warranties and limitation of liability provided
351
+ above shall be interpreted in a manner that, to the extent
352
+ possible, most closely approximates an absolute disclaimer and
353
+ waiver of all liability.
354
+
355
+
356
+ Section 6 -- Term and Termination.
357
+
358
+ a. This Public License applies for the term of the Copyright and
359
+ Similar Rights licensed here. However, if You fail to comply with
360
+ this Public License, then Your rights under this Public License
361
+ terminate automatically.
362
+
363
+ b. Where Your right to use the Licensed Material has terminated under
364
+ Section 6(a), it reinstates:
365
+
366
+ 1. automatically as of the date the violation is cured, provided
367
+ it is cured within 30 days of Your discovery of the
368
+ violation; or
369
+
370
+ 2. upon express reinstatement by the Licensor.
371
+
372
+ For the avoidance of doubt, this Section 6(b) does not affect any
373
+ right the Licensor may have to seek remedies for Your violations
374
+ of this Public License.
375
+
376
+ c. For the avoidance of doubt, the Licensor may also offer the
377
+ Licensed Material under separate terms or conditions or stop
378
+ distributing the Licensed Material at any time; however, doing so
379
+ will not terminate this Public License.
380
+
381
+ d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
382
+ License.
383
+
384
+
385
+ Section 7 -- Other Terms and Conditions.
386
+
387
+ a. The Licensor shall not be bound by any additional or different
388
+ terms or conditions communicated by You unless expressly agreed.
389
+
390
+ b. Any arrangements, understandings, or agreements regarding the
391
+ Licensed Material not stated herein are separate from and
392
+ independent of the terms and conditions of this Public License.
393
+
394
+
395
+ Section 8 -- Interpretation.
396
+
397
+ a. For the avoidance of doubt, this Public License does not, and
398
+ shall not be interpreted to, reduce, limit, restrict, or impose
399
+ conditions on any use of the Licensed Material that could lawfully
400
+ be made without permission under this Public License.
401
+
402
+ b. To the extent possible, if any provision of this Public License is
403
+ deemed unenforceable, it shall be automatically reformed to the
404
+ minimum extent necessary to make it enforceable. If the provision
405
+ cannot be reformed, it shall be severed from this Public License
406
+ without affecting the enforceability of the remaining terms and
407
+ conditions.
408
+
409
+ c. No term or condition of this Public License will be waived and no
410
+ failure to comply consented to unless expressly agreed to by the
411
+ Licensor.
412
+
413
+ d. Nothing in this Public License constitutes or may be interpreted
414
+ as a limitation upon, or waiver of, any privileges and immunities
415
+ that apply to the Licensor or You, including from the legal
416
+ processes of any jurisdiction or authority.
417
+
418
+ =======================================================================
419
+
420
+ Creative Commons is not a party to its public
421
+ licenses. Notwithstanding, Creative Commons may elect to apply one of
422
+ its public licenses to material it publishes and in those instances
423
+ will be considered the “Licensor.” The text of the Creative Commons
424
+ public licenses is dedicated to the public domain under the CC0 Public
425
+ Domain Dedication. Except for the limited purpose of indicating that
426
+ material is shared under a Creative Commons public license or as
427
+ otherwise permitted by the Creative Commons policies published at
428
+ creativecommons.org/policies, Creative Commons does not authorize the
429
+ use of the trademark "Creative Commons" or any other trademark or logo
430
+ of Creative Commons without its prior written consent including,
431
+ without limitation, in connection with any unauthorized modifications
432
+ to any of its public licenses or any other arrangements,
433
+ understandings, or agreements concerning use of licensed material. For
434
+ the avoidance of doubt, this paragraph does not form part of the
435
+ public licenses.
436
+
437
+ Creative Commons may be contacted at creativecommons.org.
README.md CHANGED
@@ -1,9 +1,9 @@
1
  ---
2
- title: MinDALLE
3
- emoji: 🏃
4
- colorFrom: blue
5
- colorTo: blue
6
- sdk: streamlit
7
  app_file: app.py
8
  pinned: false
9
  ---
 
1
  ---
2
+ title: MinDALL E
3
+ emoji: 🔥
4
+ colorFrom: red
5
+ colorTo: yellow
6
+ sdk: gradio
7
  app_file: app.py
8
  pinned: false
9
  ---
app.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ import numpy as np
5
+ import streamlit as st
6
+ from PIL import Image
7
+
8
+ # import clip
9
+
10
+
11
+ # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
12
+
13
+ # import gradio as gr
14
+ # from dalle.models import Dalle
15
+ # from dalle.utils.utils import clip_score, set_seed
16
+
17
+
18
+ device = "cpu"
19
+ # model = Dalle.from_pretrained("minDALL-E/1.3B") # This will automatically download the pretrained model.
20
+ # model.to(device=device)
21
+
22
+ # model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
23
+ # model_clip.to(device=device)
24
+
25
+
26
+ # def sample(prompt):
27
+ # # Sampling
28
+ # images = (
29
+ # model.sampling(prompt=prompt, top_k=256, top_p=None, softmax_temperature=1.0, num_candidates=3, device=device)
30
+ # .cpu()
31
+ # .numpy()
32
+ # )
33
+ # images = np.transpose(images, (0, 2, 3, 1))
34
+
35
+ # # CLIP Re-ranking
36
+ # rank = clip_score(
37
+ # prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device
38
+ # )
39
+
40
+ # # Save images
41
+ # images = images[rank]
42
+ # # print(rank, images.shape)
43
+ # pil_images = []
44
+ # for i in range(len(images)):
45
+ # im = Image.fromarray((images[i] * 255).astype(np.uint8))
46
+ # pil_images.append(im)
47
+
48
+ # # im = Image.fromarray((images[0] * 255).astype(np.uint8))
49
+ # return pil_images
50
+
51
+
52
+ # title = "Interactive demo: ImageGPT"
53
+ # description = "Demo for OpenAI's ImageGPT: Generative Pretraining from Pixels. To use it, simply upload an image or use the example image below and click 'submit'. Results will show up in a few seconds."
54
+ # article = "<p style='text-align: center'><a href='https://arxiv.org/abs/2109.10282'>ImageGPT: Generative Pretraining from Pixels</a> | <a href='https://openai.com/blog/image-gpt/'>Official blog</a></p>"
55
+
56
+ # iface = gr.Interface(
57
+ # fn=sample,
58
+ # inputs=[gr.inputs.Textbox(label="What would you like to see?")],
59
+ # outputs=gr.outputs.Image(type="pil", label="Model input + completions"),
60
+ # title=title,
61
+ # description=description,
62
+ # article=article,
63
+ # #examples=examples,
64
+ # enable_queue=True,
65
+ # )
66
+ # iface.launch(debug=True)
67
+
68
+ #!/usr/bin/env python
69
+ # coding: utf-8
70
+
71
+
72
+ st.sidebar.markdown(
73
+ """
74
+ <style>
75
+ .aligncenter {
76
+ text-align: center;
77
+ }
78
+ </style>
79
+ <p class="aligncenter">
80
+ <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/img/logo.png"/>
81
+ </p>
82
+ """,
83
+ unsafe_allow_html=True,
84
+ )
85
+ st.sidebar.markdown(
86
+ """
87
+ ___
88
+ <p style='text-align: center'>
89
+ DALL·E mini is an AI model that generates images from any prompt you give!
90
+ </p>
91
+
92
+ <p style='text-align: center'>
93
+ Created by Boris Dayma et al. 2021
94
+ <br/>
95
+ <a href="https://github.com/borisdayma/dalle-mini" target="_blank">GitHub</a> | <a href="https://wandb.ai/dalle-mini/dalle-mini/reports/DALL-E-mini--Vmlldzo4NjIxODA" target="_blank">Project Report</a>
96
+ </p>
97
+ """,
98
+ unsafe_allow_html=True,
99
+ )
100
+
101
+ st.header("DALL·E mini")
102
+ st.subheader("Generate images from text")
103
+
104
+ prompt = st.text_input("What do you want to see?")
105
+
106
+ DEBUG = False
107
+ # if prompt != "":
108
+ # container = st.empty()
109
+ # container.markdown(
110
+ # f"""
111
+ # <style> p {{ margin:0 }} div {{ margin:0 }} </style>
112
+ # <div data-stale="false" class="element-container css-1e5imcs e1tzin5v1">
113
+ # <div class="stAlert">
114
+ # <div role="alert" data-baseweb="notification" class="st-ae st-af st-ag st-ah st-ai st-aj st-ak st-g3 st-am st-b8 st-ao st-ap st-aq st-ar st-as st-at st-au st-av st-aw st-ax st-ay st-az st-b9 st-b1 st-b2 st-b3 st-b4 st-b5 st-b6">
115
+ # <div class="st-b7">
116
+ # <div class="css-whx05o e13vu3m50">
117
+ # <div data-testid="stMarkdownContainer" class="css-1ekf893 e16nr0p30">
118
+ # <img src="https://raw.githubusercontent.com/borisdayma/dalle-mini/main/app/streamlit/img/loading.gif" width="30"/>
119
+ # Generating predictions for: <b>{prompt}</b>
120
+ # </div>
121
+ # </div>
122
+ # </div>
123
+ # </div>
124
+ # </div>
125
+ # </div>
126
+ # <small><i>Predictions may take up to 40s under high load. Please stand by.</i></small>
127
+ # """,
128
+ # unsafe_allow_html=True,
129
+ # )
130
+
131
+ # print(f"Getting selections: {prompt}")
132
+ # selected = sample(prompt)
133
+
134
+ # margin = 0.1 # for better position of zoom in arrow
135
+ # n_columns = 3
136
+ # cols = st.columns([1] + [margin, 1] * (n_columns - 1))
137
+ # for i, img in enumerate(selected):
138
+ # cols[(i % n_columns) * 2].image(img)
139
+ # container.markdown(f"**{prompt}**")
140
+
141
+ # st.button("Again!", key="again_button")
configs/dalle-1.3B.yaml ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ stage1:
2
+ type: vqgan
3
+ embed_dim: 256
4
+ n_embed: 16384
5
+ hparams:
6
+ double_z: False
7
+ z_channels: 256
8
+ resolution: 256
9
+ in_channels: 3
10
+ out_ch: 3
11
+ ch: 128
12
+ ch_mult: [1, 1, 2, 2, 4]
13
+ num_res_blocks: 2
14
+ attn_resolutions: [16]
15
+ pdrop: 0.0
16
+
17
+ stage2:
18
+ type: transformer1d
19
+ vocab_size_txt: 16384
20
+ vocab_size_img: 16384
21
+ hparams:
22
+ embed_dim: 1536
23
+ n_layers: 42
24
+ n_heads: 24
25
+ n_dense_layers: 42
26
+ ctx_len_img: 256
27
+ ctx_len_txt: 64
28
+ embd_pdrop: 0.0
29
+ resid_pdrop: 0.0
30
+ attn_pdrop: 0.0
31
+ mlp_bias: True
32
+ attn_bias: True
33
+ gelu_use_approx: False
configs/transfer-imagenet-clscond-gen.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ dataset: imagenet
3
+ image_resolution: 256
4
+
5
+ stage1:
6
+ type: vqgan
7
+ embed_dim: 256
8
+ n_embed: 16384
9
+ hparams:
10
+ double_z: False
11
+ z_channels: 256
12
+ resolution: 256
13
+ in_channels: 3
14
+ out_ch: 3
15
+ ch: 128
16
+ ch_mult: [1, 1, 2, 2, 4]
17
+ num_res_blocks: 2
18
+ attn_resolutions: [16]
19
+ pdrop: 0.0
20
+
21
+ stage2:
22
+ type: igpt
23
+ use_cls_cond: True
24
+ vocab_size_img: 16384
25
+ hparams:
26
+ embed_dim: 1536
27
+ n_layers: 42
28
+ n_heads: 24
29
+ n_dense_layers: 42
30
+ ctx_len_img: 256
31
+ embd_pdrop: 0.0
32
+ resid_pdrop: 0.0
33
+ attn_pdrop: 0.0
34
+ mlp_bias: True
35
+ attn_bias: True
36
+ gelu_use_approx: False
37
+ n_classes: 1000
38
+
39
+ optimizer:
40
+ opt_type: adamW
41
+ base_lr: 1e-4
42
+ weight_decay: 0.0
43
+ betas: [0.9, 0.95]
44
+ grad_clip_norm: 4.0
45
+
46
+ experiment:
47
+ local_batch_size: 2
48
+ total_batch_size: 512
49
+ epochs: 8
configs/transfer-imagenet-uncond-gen.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ dataset:
2
+ dataset: imagenet
3
+ image_resolution: 256
4
+
5
+ stage1:
6
+ type: vqgan
7
+ embed_dim: 256
8
+ n_embed: 16384
9
+ hparams:
10
+ double_z: False
11
+ z_channels: 256
12
+ resolution: 256
13
+ in_channels: 3
14
+ out_ch: 3
15
+ ch: 128
16
+ ch_mult: [1, 1, 2, 2, 4]
17
+ num_res_blocks: 2
18
+ attn_resolutions: [16]
19
+ pdrop: 0.0
20
+
21
+ stage2:
22
+ type: igpt
23
+ use_cls_cond: False
24
+ vocab_size_img: 16384
25
+ hparams:
26
+ embed_dim: 1536
27
+ n_layers: 42
28
+ n_heads: 24
29
+ n_dense_layers: 42
30
+ ctx_len_img: 256
31
+ embd_pdrop: 0.0
32
+ resid_pdrop: 0.0
33
+ attn_pdrop: 0.0
34
+ mlp_bias: True
35
+ attn_bias: True
36
+ gelu_use_approx: False
37
+
38
+ optimizer:
39
+ opt_type: adamW
40
+ base_lr: 1e-4
41
+ weight_decay: 0.0
42
+ betas: [0.9, 0.95]
43
+ grad_clip_norm: 4.0
44
+
45
+ experiment:
46
+ local_batch_size: 2
47
+ total_batch_size: 512
48
+ epochs: 8
dalle/__init__.py ADDED
File without changes
dalle/models/__init__.py ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ import torch
9
+ import torch.nn as nn
10
+ import pytorch_lightning as pl
11
+ from typing import Optional, Tuple
12
+ from omegaconf import OmegaConf
13
+ from torch.cuda.amp import autocast
14
+ from torch.optim.lr_scheduler import CosineAnnealingLR
15
+ from torch.nn import functional as F
16
+ from .stage1.vqgan import VQGAN
17
+ from .stage2.transformer import Transformer1d, iGPT
18
+ from .. import utils
19
+ from ..utils.config import get_base_config
20
+ from ..utils.sampling import sampling, sampling_igpt
21
+ from .tokenizer import build_tokenizer
22
+
23
+ _MODELS = {
24
+ 'minDALL-E/1.3B': 'https://arena.kakaocdn.net/brainrepo/models/minDALL-E/57b008f02ceaa02b779c8b7463143315/1.3B.tar.gz'
25
+ }
26
+
27
+
28
+ class Dalle(nn.Module):
29
+ def __init__(self,
30
+ config: OmegaConf) -> None:
31
+ super().__init__()
32
+ self.tokenizer = None
33
+ self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
34
+ embed_dim=config.stage1.embed_dim,
35
+ hparams=config.stage1.hparams)
36
+ self.stage2 = Transformer1d(vocab_size_txt=config.stage2.vocab_size_txt,
37
+ vocab_size_img=config.stage2.vocab_size_img,
38
+ hparams=config.stage2.hparams)
39
+ self.config_stage1 = config.stage1
40
+ self.config_stage2 = config.stage2
41
+ self.config_dataset = config.dataset
42
+
43
+ @classmethod
44
+ def from_pretrained(cls,
45
+ path: str) -> nn.Module:
46
+ path = _MODELS[path] if path in _MODELS else path
47
+ path = utils.realpath_url_or_path(path, root=os.path.expanduser("~/.cache/minDALL-E"))
48
+
49
+ config_base = get_base_config()
50
+ config_new = OmegaConf.load(os.path.join(path, 'config.yaml'))
51
+ config_update = OmegaConf.merge(config_base, config_new)
52
+
53
+ model = cls(config_update)
54
+ model.tokenizer = build_tokenizer(os.path.join(path, 'tokenizer'),
55
+ context_length=model.config_dataset.context_length,
56
+ lowercase=True,
57
+ dropout=None)
58
+ model.stage1.from_ckpt(os.path.join(path, 'stage1_last.ckpt'))
59
+ model.stage2.from_ckpt(os.path.join(path, 'stage2_last.ckpt'))
60
+ return model
61
+
62
+ @torch.no_grad()
63
+ def sampling(self,
64
+ prompt: str,
65
+ top_k: int = 256,
66
+ top_p: Optional[float] = None,
67
+ softmax_temperature: float = 1.0,
68
+ num_candidates: int = 96,
69
+ device: str = 'cuda:0',
70
+ use_fp16: bool = True) -> torch.FloatTensor:
71
+ self.stage1.eval()
72
+ self.stage2.eval()
73
+
74
+ tokens = self.tokenizer.encode(prompt)
75
+ tokens = torch.LongTensor(tokens.ids)
76
+ tokens = torch.repeat_interleave(tokens.unsqueeze(0), num_candidates, dim=0)
77
+
78
+ # Check if the encoding works as intended
79
+ # print(self.tokenizer.decode_batch(tokens.tolist(), skip_special_tokens=True)[0])
80
+
81
+ tokens = tokens.to(device)
82
+ codes = sampling(self.stage2,
83
+ tokens,
84
+ top_k=top_k,
85
+ top_p=top_p,
86
+ softmax_temperature=softmax_temperature,
87
+ use_fp16=use_fp16)
88
+ codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
89
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
90
+ return pixels
91
+
92
+
93
+ class ImageGPT(pl.LightningModule):
94
+ def __init__(self,
95
+ config: OmegaConf) -> None:
96
+ super().__init__()
97
+ self.stage1 = VQGAN(n_embed=config.stage1.n_embed,
98
+ embed_dim=config.stage1.embed_dim,
99
+ hparams=config.stage1.hparams)
100
+ self.stage2 = iGPT(vocab_size_img=config.stage2.vocab_size_img,
101
+ use_cls_cond=config.stage2.use_cls_cond,
102
+ hparams=config.stage2.hparams)
103
+ self.config = config
104
+ self.use_cls_cond = config.stage2.use_cls_cond
105
+
106
+ # make the parameters in stage 1 not trainable
107
+ self.stage1.eval()
108
+ for p in self.stage1.parameters():
109
+ p.requires_grad = False
110
+
111
+ @classmethod
112
+ def from_pretrained(cls,
113
+ path_upstream: str,
114
+ path_downstream: str) -> Tuple[nn.Module, OmegaConf]:
115
+ config_base = get_base_config(use_default=False)
116
+ config_down = OmegaConf.load(path_downstream)
117
+ config_down = OmegaConf.merge(config_base, config_down)
118
+
119
+ model = cls(config_down)
120
+ model.stage1.from_ckpt(os.path.join(path_upstream, 'stage1_last.ckpt'), strict=True)
121
+ model.stage2.from_ckpt(os.path.join(path_upstream, 'stage2_last.ckpt'), strict=False)
122
+ return model, config_down
123
+
124
+ def sample(self,
125
+ cls_idx: Optional[int] = None,
126
+ top_k: int = 256,
127
+ top_p: Optional[float] = None,
128
+ softmax_temperature: float = 1.0,
129
+ num_candidates: int = 16,
130
+ device: str = 'cuda:0',
131
+ use_fp16: bool = True,
132
+ is_tqdm: bool = True) -> torch.FloatTensor:
133
+ self.stage1.eval()
134
+ self.stage2.eval()
135
+
136
+ if cls_idx is None:
137
+ sos = self.stage2.sos.repeat(num_candidates, 1, 1)
138
+ else:
139
+ sos = torch.LongTensor([cls_idx]).to(device=device)
140
+ sos = sos.repeat(num_candidates)
141
+ sos = self.stage2.sos(sos).unsqueeze(1)
142
+
143
+ codes = sampling_igpt(self.stage2,
144
+ sos=sos,
145
+ top_k=top_k,
146
+ top_p=top_p,
147
+ softmax_temperature=softmax_temperature,
148
+ use_fp16=use_fp16,
149
+ is_tqdm=is_tqdm)
150
+ codes = codes.view(num_candidates, 16, 16) # [B, 16, 16]
151
+ pixels = torch.clamp(self.stage1.decode_code(codes) * 0.5 + 0.5, 0, 1) # [B, 256, 256]
152
+ return pixels
153
+
154
+ def forward(self,
155
+ images: torch.FloatTensor,
156
+ labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
157
+ B, C, H, W = images.shape
158
+ with torch.no_grad():
159
+ with autocast(enabled=False):
160
+ codes = self.stage1.get_codes(images).detach()
161
+ logits = self.stage2(codes, labels)
162
+ return logits, codes
163
+
164
+ def training_step(self, batch, batch_idx):
165
+ images, labels = batch
166
+ logits, codes = self(images, labels=labels if self.use_cls_cond else None)
167
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
168
+ self.log("train/loss", loss, on_step=True, on_epoch=True, prog_bar=False, logger=True)
169
+ return loss
170
+
171
+ def validation_step(self, batch, batch_idx):
172
+ images, labels = batch
173
+ logits, codes = self(images, labels=labels if self.use_cls_cond else None)
174
+ loss = F.cross_entropy(logits.view(-1, logits.shape[-1]), codes.view(-1))
175
+ self.log("val/loss", loss, on_step=False, on_epoch=True, prog_bar=False, logger=True)
176
+ return loss
177
+
178
+ def configure_optimizers(self):
179
+ assert self.config.optimizer.opt_type == 'adamW'
180
+ assert self.config.optimizer.sched_type == 'cosine'
181
+
182
+ opt = torch.optim.AdamW(self.parameters(),
183
+ lr=self.config.optimizer.base_lr,
184
+ betas=self.config.optimizer.betas,
185
+ weight_decay=self.config.optimizer.weight_decay)
186
+ sched = CosineAnnealingLR(opt,
187
+ T_max=self.config.optimizer.max_steps,
188
+ eta_min=self.config.optimizer.min_lr)
189
+ sched = {
190
+ 'scheduler': sched,
191
+ 'name': 'cosine'
192
+ }
193
+ return [opt], [sched]
194
+
195
+ def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure,
196
+ on_tpu=False, using_native_amp=False, using_lbfgs=False):
197
+ optimizer.step(closure=optimizer_closure)
198
+ self.lr_schedulers().step()
199
+ self.log("lr", self.lr_schedulers().get_last_lr()[0], on_step=True, on_epoch=False, prog_bar=True, logger=True)
200
+
201
+ def on_epoch_start(self):
202
+ self.stage1.eval()
dalle/models/stage1/layers.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Modified from VQGAN (https://github.com/CompVis/taming-transformers)
3
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------------
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import Tuple, Optional
9
+
10
+
11
+ def nonlinearity(x):
12
+ # swish
13
+ return x*torch.sigmoid(x)
14
+
15
+
16
+ def Normalize(in_channels):
17
+ return torch.nn.GroupNorm(num_groups=32,
18
+ num_channels=in_channels,
19
+ eps=1e-6,
20
+ affine=True)
21
+
22
+
23
+ class Upsample(nn.Module):
24
+ def __init__(self, in_channels, with_conv):
25
+ super().__init__()
26
+ self.with_conv = with_conv
27
+ if self.with_conv:
28
+ self.conv = torch.nn.Conv2d(in_channels,
29
+ in_channels,
30
+ kernel_size=3,
31
+ stride=1,
32
+ padding=1)
33
+
34
+ def forward(self, x):
35
+ x = torch.nn.functional.interpolate(x, scale_factor=2.0, mode="nearest")
36
+ if self.with_conv:
37
+ x = self.conv(x)
38
+ return x
39
+
40
+
41
+ class Downsample(nn.Module):
42
+ def __init__(self, in_channels, with_conv):
43
+ super().__init__()
44
+ self.with_conv = with_conv
45
+ if self.with_conv:
46
+ # no asymmetric padding in torch conv, must do it ourselves
47
+ self.conv = torch.nn.Conv2d(in_channels,
48
+ in_channels,
49
+ kernel_size=3,
50
+ stride=2,
51
+ padding=0)
52
+
53
+ def forward(self, x):
54
+ if self.with_conv:
55
+ pad = (0, 1, 0, 1)
56
+ x = torch.nn.functional.pad(x, pad, mode="constant", value=0)
57
+ x = self.conv(x)
58
+ else:
59
+ x = torch.nn.functional.avg_pool2d(x, kernel_size=2, stride=2)
60
+ return x
61
+
62
+
63
+ class ResnetBlock(nn.Module):
64
+ def __init__(self, *, in_channels, out_channels=None, conv_shortcut=False,
65
+ dropout, temb_channels=512):
66
+ assert temb_channels == 0
67
+ super().__init__()
68
+ self.in_channels = in_channels
69
+ out_channels = in_channels if out_channels is None else out_channels
70
+ self.out_channels = out_channels
71
+ self.use_conv_shortcut = conv_shortcut
72
+
73
+ self.norm1 = Normalize(in_channels)
74
+ self.conv1 = torch.nn.Conv2d(in_channels,
75
+ out_channels,
76
+ kernel_size=3,
77
+ stride=1,
78
+ padding=1)
79
+ self.norm2 = Normalize(out_channels)
80
+ self.dropout = torch.nn.Dropout(dropout)
81
+ self.conv2 = torch.nn.Conv2d(out_channels,
82
+ out_channels,
83
+ kernel_size=3,
84
+ stride=1,
85
+ padding=1)
86
+ if self.in_channels != self.out_channels:
87
+ if self.use_conv_shortcut:
88
+ self.conv_shortcut = torch.nn.Conv2d(in_channels,
89
+ out_channels,
90
+ kernel_size=3,
91
+ stride=1,
92
+ padding=1)
93
+ else:
94
+ self.nin_shortcut = torch.nn.Conv2d(in_channels,
95
+ out_channels,
96
+ kernel_size=1,
97
+ stride=1,
98
+ padding=0)
99
+
100
+ def forward(self, x, temb=None):
101
+ assert temb is None
102
+
103
+ h = x
104
+ h = self.norm1(h)
105
+ h = nonlinearity(h)
106
+ h = self.conv1(h)
107
+
108
+ h = self.norm2(h)
109
+ h = nonlinearity(h)
110
+ h = self.dropout(h)
111
+ h = self.conv2(h)
112
+
113
+ if self.in_channels != self.out_channels:
114
+ if self.use_conv_shortcut:
115
+ x = self.conv_shortcut(x)
116
+ else:
117
+ x = self.nin_shortcut(x)
118
+ return x+h
119
+
120
+
121
+ class AttnBlock(nn.Module):
122
+ def __init__(self, in_channels):
123
+ super().__init__()
124
+ self.in_channels = in_channels
125
+
126
+ self.norm = Normalize(in_channels)
127
+ self.q = torch.nn.Conv2d(in_channels,
128
+ in_channels,
129
+ kernel_size=1,
130
+ stride=1,
131
+ padding=0)
132
+ self.k = torch.nn.Conv2d(in_channels,
133
+ in_channels,
134
+ kernel_size=1,
135
+ stride=1,
136
+ padding=0)
137
+ self.v = torch.nn.Conv2d(in_channels,
138
+ in_channels,
139
+ kernel_size=1,
140
+ stride=1,
141
+ padding=0)
142
+ self.proj_out = torch.nn.Conv2d(in_channels,
143
+ in_channels,
144
+ kernel_size=1,
145
+ stride=1,
146
+ padding=0)
147
+
148
+ def forward(self, x):
149
+ h_ = x
150
+ h_ = self.norm(h_)
151
+ q = self.q(h_)
152
+ k = self.k(h_)
153
+ v = self.v(h_)
154
+
155
+ # compute attention
156
+ b, c, h, w = q.shape
157
+ q = q.reshape(b, c, h*w)
158
+ q = q.permute(0, 2, 1) # b,hw,c
159
+ k = k.reshape(b, c, h*w) # b,c,hw
160
+ w_ = torch.bmm(q, k) # b,hw,hw w[b,i,j]=sum_c q[b,i,c]k[b,c,j]
161
+ w_ = w_ * (int(c)**(-0.5))
162
+ w_ = torch.nn.functional.softmax(w_, dim=2)
163
+
164
+ # attend to values
165
+ v = v.reshape(b, c, h*w)
166
+ w_ = w_.permute(0, 2, 1) # b,hw,hw (first hw of k, second of q)
167
+ h_ = torch.bmm(v, w_) # b, c,hw (hw of q) h_[b,c,j] = sum_i v[b,c,i] w_[b,i,j]
168
+ h_ = h_.reshape(b, c, h, w)
169
+
170
+ h_ = self.proj_out(h_)
171
+ return x+h_
172
+
173
+
174
+ class Encoder(nn.Module):
175
+ def __init__(self,
176
+ *, # forced to use named arguments
177
+ ch: int,
178
+ out_ch: int,
179
+ ch_mult: Tuple[int] = (1, 2, 4, 8),
180
+ num_res_blocks: int,
181
+ attn_resolutions: Tuple[int],
182
+ pdrop: float = 0.0,
183
+ resamp_with_conv: bool = True,
184
+ in_channels: int,
185
+ resolution: int,
186
+ z_channels: int,
187
+ double_z: Optional[bool] = None) -> None:
188
+ super().__init__()
189
+ self.ch = ch
190
+ self.temb_ch = 0
191
+ self.num_resolutions = len(ch_mult)
192
+ self.num_res_blocks = num_res_blocks
193
+ self.resolution = resolution
194
+ self.in_channels = in_channels
195
+
196
+ # downsampling
197
+ self.conv_in = torch.nn.Conv2d(in_channels,
198
+ self.ch,
199
+ kernel_size=3,
200
+ stride=1,
201
+ padding=1)
202
+
203
+ curr_res = resolution
204
+ in_ch_mult = (1,)+tuple(ch_mult)
205
+ self.down = nn.ModuleList()
206
+ for i_level in range(self.num_resolutions):
207
+ block = nn.ModuleList()
208
+ attn = nn.ModuleList()
209
+ block_in = ch*in_ch_mult[i_level]
210
+ block_out = ch*ch_mult[i_level]
211
+ for i_block in range(self.num_res_blocks):
212
+ block.append(ResnetBlock(in_channels=block_in,
213
+ out_channels=block_out,
214
+ temb_channels=self.temb_ch,
215
+ dropout=pdrop))
216
+ block_in = block_out
217
+ if curr_res in attn_resolutions:
218
+ attn.append(AttnBlock(block_in))
219
+ down = nn.Module()
220
+ down.block = block
221
+ down.attn = attn
222
+ if i_level != self.num_resolutions-1:
223
+ down.downsample = Downsample(block_in, resamp_with_conv)
224
+ curr_res = curr_res // 2
225
+ self.down.append(down)
226
+
227
+ # middle
228
+ self.mid = nn.Module()
229
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
230
+ out_channels=block_in,
231
+ temb_channels=self.temb_ch,
232
+ dropout=pdrop)
233
+ self.mid.attn_1 = AttnBlock(block_in)
234
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
235
+ out_channels=block_in,
236
+ temb_channels=self.temb_ch,
237
+ dropout=pdrop)
238
+
239
+ # end
240
+ self.norm_out = Normalize(block_in)
241
+ self.conv_out = torch.nn.Conv2d(block_in,
242
+ 2*z_channels if double_z else z_channels,
243
+ kernel_size=3,
244
+ stride=1,
245
+ padding=1)
246
+
247
+ def forward(self, x):
248
+ assert x.shape[2] == x.shape[3] == self.resolution, \
249
+ "{}, {}".format(x.shape, self.resolution)
250
+
251
+ # downsampling
252
+ h = self.conv_in(x)
253
+ for i_level in range(self.num_resolutions):
254
+ for i_block in range(self.num_res_blocks):
255
+ h = self.down[i_level].block[i_block](h)
256
+ if len(self.down[i_level].attn) > 0:
257
+ h = self.down[i_level].attn[i_block](h)
258
+ if i_level != self.num_resolutions-1:
259
+ h = self.down[i_level].downsample(h)
260
+
261
+ # middle
262
+ h = self.mid.block_1(h)
263
+ h = self.mid.attn_1(h)
264
+ h = self.mid.block_2(h)
265
+
266
+ # end
267
+ h = self.norm_out(h)
268
+ h = nonlinearity(h)
269
+ h = self.conv_out(h)
270
+ return h
271
+
272
+
273
+ class Decoder(nn.Module):
274
+ def __init__(self,
275
+ *, # forced to use named arguments
276
+ ch: int,
277
+ out_ch: int,
278
+ ch_mult: Tuple[int] = (1, 2, 4, 8),
279
+ num_res_blocks: int,
280
+ attn_resolutions: Tuple[int],
281
+ pdrop: float = 0.0,
282
+ resamp_with_conv: bool = True,
283
+ in_channels: int,
284
+ resolution: int,
285
+ z_channels: int,
286
+ double_z: bool) -> None:
287
+ super().__init__()
288
+ self.ch = ch
289
+ self.temb_ch = 0
290
+ self.num_resolutions = len(ch_mult)
291
+ self.num_res_blocks = num_res_blocks
292
+ self.resolution = resolution
293
+ self.in_channels = in_channels
294
+
295
+ # compute in_ch_mult, block_in and curr_res at lowest res
296
+ block_in = ch*ch_mult[self.num_resolutions-1]
297
+ curr_res = resolution // 2**(self.num_resolutions-1)
298
+ self.z_shape = (1, z_channels, curr_res, curr_res)
299
+
300
+ # z to block_in
301
+ self.conv_in = torch.nn.Conv2d(z_channels,
302
+ block_in,
303
+ kernel_size=3,
304
+ stride=1,
305
+ padding=1)
306
+
307
+ # middle
308
+ self.mid = nn.Module()
309
+ self.mid.block_1 = ResnetBlock(in_channels=block_in,
310
+ out_channels=block_in,
311
+ temb_channels=self.temb_ch,
312
+ dropout=pdrop)
313
+ self.mid.attn_1 = AttnBlock(block_in)
314
+ self.mid.block_2 = ResnetBlock(in_channels=block_in,
315
+ out_channels=block_in,
316
+ temb_channels=self.temb_ch,
317
+ dropout=pdrop)
318
+
319
+ # upsampling
320
+ self.up = nn.ModuleList()
321
+ for i_level in reversed(range(self.num_resolutions)):
322
+ block = nn.ModuleList()
323
+ attn = nn.ModuleList()
324
+ block_out = ch*ch_mult[i_level]
325
+ for i_block in range(self.num_res_blocks+1):
326
+ block.append(ResnetBlock(in_channels=block_in,
327
+ out_channels=block_out,
328
+ temb_channels=self.temb_ch,
329
+ dropout=pdrop))
330
+ block_in = block_out
331
+ if curr_res in attn_resolutions:
332
+ attn.append(AttnBlock(block_in))
333
+ up = nn.Module()
334
+ up.block = block
335
+ up.attn = attn
336
+ if i_level != 0:
337
+ up.upsample = Upsample(block_in, resamp_with_conv)
338
+ curr_res = curr_res * 2
339
+ self.up.insert(0, up) # prepend to get consistent order
340
+
341
+ # end
342
+ self.norm_out = Normalize(block_in)
343
+ self.conv_out = torch.nn.Conv2d(block_in,
344
+ out_ch,
345
+ kernel_size=3,
346
+ stride=1,
347
+ padding=1)
348
+
349
+ def forward(self, z):
350
+ assert z.shape[1:] == self.z_shape[1:]
351
+ self.last_z_shape = z.shape
352
+
353
+ # z to block_in
354
+ h = self.conv_in(z)
355
+
356
+ # middle
357
+ h = self.mid.block_1(h)
358
+ h = self.mid.attn_1(h)
359
+ h = self.mid.block_2(h)
360
+
361
+ # upsampling
362
+ for i_level in reversed(range(self.num_resolutions)):
363
+ for i_block in range(self.num_res_blocks+1):
364
+ h = self.up[i_level].block[i_block](h)
365
+ if len(self.up[i_level].attn) > 0:
366
+ h = self.up[i_level].attn[i_block](h)
367
+ if i_level != 0:
368
+ h = self.up[i_level].upsample(h)
369
+
370
+ h = self.norm_out(h)
371
+ h = nonlinearity(h)
372
+ h = self.conv_out(h)
373
+ return h
dalle/models/stage1/vqgan.py ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Modified from VQGAN (https://github.com/CompVis/taming-transformers)
3
+ # Copyright (c) 2020 Patrick Esser and Robin Rombach and Björn Ommer. All Rights Reserved.
4
+ # ------------------------------------------------------------------------------------
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from typing import List, Tuple, Optional
9
+ from einops import rearrange
10
+ from omegaconf import OmegaConf
11
+ from .layers import Encoder, Decoder
12
+
13
+
14
+ class VectorQuantizer(nn.Module):
15
+ """
16
+ Simplified VectorQuantizer in the original VQGAN repository
17
+ by removing unncessary modules for sampling
18
+ """
19
+ def __init__(self, dim: int, n_embed: int, beta: float) -> None:
20
+ super().__init__()
21
+ self.n_embed = n_embed
22
+ self.dim = dim
23
+ self.beta = beta
24
+
25
+ self.embedding = nn.Embedding(self.n_embed, self.dim)
26
+ self.embedding.weight.data.uniform_(-1.0 / self.n_embed, 1.0 / self.n_embed)
27
+
28
+ def forward(self,
29
+ z: torch.FloatTensor) -> Tuple[torch.FloatTensor, torch.LongTensor]:
30
+ z = rearrange(z, 'b c h w -> b h w c').contiguous() # [B,C,H,W] -> [B,H,W,C]
31
+ z_flattened = z.view(-1, self.dim)
32
+
33
+ d = torch.sum(z_flattened ** 2, dim=1, keepdim=True) + \
34
+ torch.sum(self.embedding.weight**2, dim=1) - 2 * \
35
+ torch.einsum('bd,dn->bn', z_flattened, rearrange(self.embedding.weight, 'n d -> d n'))
36
+
37
+ min_encoding_indices = torch.argmin(d, dim=1)
38
+ z_q = self.embedding(min_encoding_indices).view(z.shape)
39
+ return z_q, min_encoding_indices
40
+
41
+ def get_codebook_entry(self,
42
+ indices: torch.LongTensor,
43
+ shape: Optional[List[int]] = None) -> torch.FloatTensor:
44
+ z_q = self.embedding(indices)
45
+ if shape is not None:
46
+ z_q = z_q.view(shape)
47
+ z_q = z_q.permute(0, 3, 1, 2).contiguous()
48
+ return z_q
49
+
50
+
51
+ class VQGAN(nn.Module):
52
+ def __init__(self, n_embed: int, embed_dim: int, hparams: OmegaConf) -> None:
53
+ super().__init__()
54
+ self.encoder = Encoder(**hparams)
55
+ self.decoder = Decoder(**hparams)
56
+ self.quantize = VectorQuantizer(dim=embed_dim, n_embed=n_embed, beta=0.25)
57
+ self.quant_conv = torch.nn.Conv2d(hparams.z_channels, embed_dim, 1)
58
+ self.post_quant_conv = torch.nn.Conv2d(embed_dim, hparams.z_channels, 1)
59
+ self.latent_dim = hparams.attn_resolutions[0]
60
+
61
+ def forward(self, x: torch.FloatTensor) -> torch.FloatTensor:
62
+ quant = self.encode(x)
63
+ dec = self.decode(quant)
64
+ return dec
65
+
66
+ def encode(self, x: torch.FloatTensor) -> torch.FloatTensor:
67
+ h = self.encoder(x)
68
+ h = self.quant_conv(h)
69
+ quant = self.quantize(h)[0]
70
+ quant = rearrange(quant, 'b h w c -> b c h w').contiguous()
71
+ return quant
72
+
73
+ def decode(self, quant: torch.FloatTensor) -> torch.FloatTensor:
74
+ quant = self.post_quant_conv(quant)
75
+ dec = self.decoder(quant)
76
+ return dec
77
+
78
+ def decode_code(self, code: torch.LongTensor) -> torch.FloatTensor:
79
+ quant = self.quantize.get_codebook_entry(code)
80
+ quant = quant.permute(0, 3, 1, 2)
81
+ dec = self.decode(quant)
82
+ return dec
83
+
84
+ def get_codes(self, x: torch.FloatTensor) -> torch.LongTensor:
85
+ h = self.encoder(x)
86
+ h = self.quant_conv(h)
87
+ codes = self.quantize(h)[1].view(x.shape[0], self.latent_dim ** 2)
88
+ return codes
89
+
90
+ def from_ckpt(self, path: str, strict: bool = True) -> None:
91
+ ckpt = torch.load(path, map_location='cpu')['state_dict']
92
+ self.load_state_dict(ckpt, strict=strict)
93
+ print(f'{path} successfully restored..')
dalle/models/stage2/layers.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+ # Modified from minGPT (https://github.com/karpathy/minGPT)
7
+ # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------------
9
+
10
+ import math
11
+ import torch
12
+ import torch.nn as nn
13
+ from torch.nn import functional as F
14
+
15
+
16
+ class GELU(nn.Module):
17
+ def __init__(self, use_approx=False):
18
+ super().__init__()
19
+ self.use_approx = use_approx
20
+
21
+ def forward(self, x):
22
+ if self.use_approx:
23
+ return x * torch.sigmoid(1.702 * x)
24
+ else:
25
+ return F.gelu(x)
26
+
27
+
28
+ class MultiHeadSelfAttention(nn.Module):
29
+
30
+ def __init__(self,
31
+ ctx_len: int,
32
+ embed_dim: int,
33
+ n_heads: int,
34
+ resid_pdrop: float,
35
+ attn_pdrop: float,
36
+ attn_bias: bool,
37
+ use_mask: bool = True):
38
+ super().__init__()
39
+ assert embed_dim % n_heads == 0
40
+
41
+ # key, query, value projections for all heads
42
+ self.key = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
43
+ self.query = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
44
+ self.value = nn.Linear(embed_dim, embed_dim, bias=attn_bias)
45
+
46
+ # regularization
47
+ self.attn_drop = nn.Dropout(attn_pdrop)
48
+ self.resid_drop = nn.Dropout(resid_pdrop)
49
+
50
+ # output projection
51
+ self.proj = nn.Linear(embed_dim, embed_dim, attn_bias)
52
+
53
+ self.n_heads = n_heads
54
+ self.ctx_len = ctx_len
55
+ self.use_mask = use_mask
56
+ if self.use_mask:
57
+ self.register_buffer("mask", torch.ones(ctx_len, ctx_len), persistent=False)
58
+ self.mask = torch.tril(self.mask).view(1, ctx_len, ctx_len)
59
+
60
+ def forward(self, x, use_cache=False, layer_past=None):
61
+ B, T, C = x.shape
62
+ x = x.transpose(0, 1).contiguous() # (B, T, C) -> (T, B, C)
63
+
64
+ # calculate query, key, values for all heads in batch and move head forward to be the batch dim
65
+ k = self.key(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
66
+ q = self.query(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
67
+ v = self.value(x).view(T, B*self.n_heads, C//self.n_heads).transpose(0, 1) # (B*nh, T, hs)
68
+
69
+ if use_cache:
70
+ present = torch.stack([k, v])
71
+
72
+ if layer_past is not None:
73
+ past_key, past_value = layer_past
74
+ k = torch.cat([past_key, k], dim=-2)
75
+ v = torch.cat([past_value, v], dim=-2)
76
+
77
+ if use_cache and layer_past is not None:
78
+ # Tensor shape below: (B * nh, 1, hs) X (B * nh, hs, K) -> (B * nh, 1, K)
79
+ att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
80
+ att = F.softmax(att, dim=-1)
81
+ att = self.attn_drop(att)
82
+ y = torch.bmm(att, v) # (B*nh, 1, K) X (B*nh, K, hs) -> (B*nh, 1, hs)
83
+ else:
84
+ # Tensor shape below: (B * nh, T, hs) X (B * nh, hs, T) -> (B * nh, T, T)
85
+ att = torch.bmm(q, (k.transpose(-2, -1)) * (1.0 / math.sqrt(k.size(-1))))
86
+ if self.use_mask:
87
+ mask = self.mask if T == self.ctx_len else self.mask[:, :T, :T]
88
+ att = att.masked_fill(mask == 0, float('-inf'))
89
+ att = F.softmax(att, dim=-1)
90
+ att = self.attn_drop(att)
91
+ y = torch.bmm(att, v) # (B*nh, T, T) X (B*nh, T, hs) -> (B*nh, T, hs)
92
+ y = y.transpose(0, 1).contiguous().view(T, B, C) # re-assemble all head outputs side by side
93
+
94
+ # output projection
95
+ y = self.resid_drop(self.proj(y))
96
+ if use_cache:
97
+ return y.transpose(0, 1).contiguous(), present # (T, B, C) -> (B, T, C)
98
+ else:
99
+ return y.transpose(0, 1).contiguous() # (T, B, C) -> (B, T, C)
100
+
101
+
102
+ class Block(nn.Module):
103
+
104
+ def __init__(self,
105
+ ctx_len: int,
106
+ embed_dim: int,
107
+ n_heads: int,
108
+ mlp_bias: bool,
109
+ attn_bias: bool,
110
+ resid_pdrop: bool,
111
+ attn_pdrop: bool,
112
+ gelu_use_approx: bool):
113
+ super().__init__()
114
+ self.ln1 = nn.LayerNorm(embed_dim)
115
+ self.ln2 = nn.LayerNorm(embed_dim)
116
+
117
+ self.attn = MultiHeadSelfAttention(ctx_len=ctx_len,
118
+ embed_dim=embed_dim,
119
+ n_heads=n_heads,
120
+ attn_pdrop=attn_pdrop,
121
+ resid_pdrop=resid_pdrop,
122
+ attn_bias=attn_bias,
123
+ use_mask=True)
124
+ self.mlp = nn.Sequential(
125
+ nn.Linear(embed_dim, 4 * embed_dim, bias=mlp_bias),
126
+ GELU(gelu_use_approx),
127
+ nn.Linear(4 * embed_dim, embed_dim, bias=mlp_bias),
128
+ nn.Dropout(resid_pdrop),
129
+ )
130
+
131
+ def forward(self, x):
132
+ x = x + self.attn(self.ln1(x))
133
+ x = x + self.mlp(self.ln2(x))
134
+ return x
135
+
136
+ def sample(self, x, layer_past=None):
137
+ attn, present = self.attn(self.ln1(x), use_cache=True, layer_past=layer_past)
138
+ x = x + attn
139
+ x = x + self.mlp(self.ln2(x))
140
+ return x, present
dalle/models/stage2/transformer.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+ # Modified from minGPT (https://github.com/karpathy/minGPT)
7
+ # Copyright (c) 2020 Andrej Karpathy. All Rights Reserved.
8
+ # ------------------------------------------------------------------------------------
9
+
10
+ import torch
11
+ import torch.nn as nn
12
+ from typing import Optional, Tuple, List
13
+ from torch.cuda.amp import autocast
14
+ from omegaconf import OmegaConf
15
+ from .layers import Block
16
+
17
+
18
+ class Transformer1d(nn.Module):
19
+
20
+ def __init__(self,
21
+ vocab_size_txt: int,
22
+ vocab_size_img: int,
23
+ hparams: OmegaConf) -> None:
24
+ super().__init__()
25
+ assert hparams.n_layers == hparams.n_dense_layers
26
+
27
+ # input embedding for image and text
28
+ self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
29
+ self.tok_emb_txt = nn.Embedding(vocab_size_txt, hparams.embed_dim)
30
+
31
+ self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
32
+ self.pos_emb_txt = nn.Embedding(hparams.ctx_len_txt, hparams.embed_dim)
33
+
34
+ self.drop = nn.Dropout(hparams.embd_pdrop)
35
+
36
+ # transformer blocks
37
+ self.blocks = [Block(ctx_len=hparams.ctx_len_img + hparams.ctx_len_txt,
38
+ embed_dim=hparams.embed_dim,
39
+ n_heads=hparams.n_heads,
40
+ mlp_bias=hparams.mlp_bias,
41
+ attn_bias=hparams.attn_bias,
42
+ resid_pdrop=hparams.resid_pdrop,
43
+ attn_pdrop=hparams.attn_pdrop,
44
+ gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
45
+ self.blocks = nn.Sequential(*self.blocks)
46
+
47
+ # heads for image and text
48
+ self.ln_f = nn.LayerNorm(hparams.embed_dim)
49
+ self.head_img = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
50
+ self.head_txt = nn.Linear(hparams.embed_dim, vocab_size_txt, bias=False)
51
+
52
+ self.ctx_len_img = hparams.ctx_len_img
53
+ self.ctx_len_txt = hparams.ctx_len_txt
54
+ self.n_layers = hparams.n_layers
55
+
56
+ self.apply(self._init_weights)
57
+
58
+ def _init_weights(self, module: nn.Module) -> None:
59
+ if isinstance(module, (nn.Linear, nn.Embedding)):
60
+ module.weight.data.normal_(mean=0.0, std=0.02)
61
+ if isinstance(module, nn.Linear) and module.bias is not None:
62
+ module.bias.data.zero_()
63
+ elif isinstance(module, nn.LayerNorm):
64
+ module.bias.data.zero_()
65
+ module.weight.data.fill_(1.0)
66
+
67
+ def forward(self,
68
+ images: torch.LongTensor,
69
+ texts: torch.LongTensor,
70
+ pos_images: torch.LongTensor,
71
+ pos_texts: torch.LongTensor) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
72
+ B, T = images.shape
73
+ _, N = texts.shape
74
+
75
+ assert T <= self.ctx_len_img, "Already reached the maximum context length (image)."
76
+ assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
77
+
78
+ texts = self.tok_emb_txt(texts)
79
+ images = self.tok_emb_img(images)
80
+
81
+ texts = texts + self.pos_emb_txt(pos_texts)
82
+ images = images + self.pos_emb_img(pos_images)
83
+
84
+ x = torch.cat([texts, images], axis=1).contiguous()
85
+ x = self.drop(x)
86
+ x = self.blocks(x)
87
+ x = self.ln_f(x)
88
+
89
+ texts = x[:, :N-1].contiguous()
90
+ images = x[:, N-1:-1].contiguous()
91
+
92
+ logits_txt = self.head_txt(texts)
93
+ logits_img = self.head_img(images)
94
+ return logits_img, logits_txt
95
+
96
+ @torch.no_grad()
97
+ def sampling(self,
98
+ images: torch.LongTensor,
99
+ texts: torch.LongTensor,
100
+ pos_images: torch.LongTensor,
101
+ pos_texts: torch.LongTensor,
102
+ use_fp16: bool = True,
103
+ past: Optional[List[torch.Tensor]] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
104
+ _, N = texts.shape
105
+ assert N == self.ctx_len_txt, "Already reached the maximum context length (text)."
106
+
107
+ with autocast(enabled=use_fp16):
108
+ if images is None:
109
+ assert past is None
110
+
111
+ texts = self.tok_emb_txt(texts)
112
+ x = texts + self.pos_emb_txt(pos_texts)
113
+ x = self.drop(x)
114
+
115
+ presents = []
116
+ for i, block in enumerate(self.blocks):
117
+ x, present = block.sample(x, layer_past=None)
118
+ presents.append(present)
119
+ x = self.ln_f(x)
120
+ x = x[:, N-1].contiguous()
121
+ logits = self.head_img(x)
122
+ else:
123
+ if past is None:
124
+ texts = self.tok_emb_txt(texts)
125
+ images = self.tok_emb_img(images)
126
+ texts = texts + self.pos_emb_txt(pos_texts)
127
+ images = images + self.pos_emb_img(pos_images)
128
+ x = torch.cat([texts, images], axis=1).contiguous()
129
+ else:
130
+ images = self.tok_emb_img(images)
131
+ x = images + self.pos_emb_img(pos_images)
132
+ x = self.drop(x)
133
+
134
+ if past is not None:
135
+ past = torch.cat(past, dim=-2)
136
+ presents = []
137
+ for i, block in enumerate(self.blocks):
138
+ x, present = block.sample(x, layer_past=None if past is None else past[i])
139
+ presents.append(present)
140
+ x = self.ln_f(x)
141
+ x = x[:, -1].contiguous()
142
+ logits = self.head_img(x)
143
+ return logits, presents
144
+
145
+ def from_ckpt(self, path: str) -> None:
146
+ ckpt = torch.load(path, map_location='cpu')['state_dict']
147
+ self.load_state_dict(ckpt, strict=True)
148
+ print(f'{path} succesfully restored..')
149
+
150
+
151
+ class iGPT(nn.Module):
152
+ def __init__(self,
153
+ vocab_size_img: int,
154
+ use_cls_cond: bool,
155
+ hparams: OmegaConf) -> None:
156
+ super().__init__()
157
+ self.use_cls_cond = use_cls_cond
158
+
159
+ # sos token embedding
160
+ if self.use_cls_cond:
161
+ self.sos = nn.Embedding(hparams.n_classes, hparams.embed_dim)
162
+ else:
163
+ self.sos = nn.Parameter(torch.randn(1, 1, hparams.embed_dim))
164
+
165
+ # input embedding
166
+ self.tok_emb_img = nn.Embedding(vocab_size_img, hparams.embed_dim)
167
+ self.pos_emb_img = nn.Embedding(hparams.ctx_len_img, hparams.embed_dim)
168
+
169
+ self.drop = nn.Dropout(hparams.embd_pdrop)
170
+
171
+ # transformer blocks
172
+ self.blocks = [Block(ctx_len=hparams.ctx_len_img + 1,
173
+ embed_dim=hparams.embed_dim,
174
+ n_heads=hparams.n_heads,
175
+ mlp_bias=hparams.mlp_bias,
176
+ attn_bias=hparams.attn_bias,
177
+ resid_pdrop=hparams.resid_pdrop,
178
+ attn_pdrop=hparams.attn_pdrop,
179
+ gelu_use_approx=hparams.gelu_use_approx) for i in range(1, hparams.n_layers+1)]
180
+ self.blocks = nn.Sequential(*self.blocks)
181
+
182
+ # head
183
+ self.ln_f = nn.LayerNorm(hparams.embed_dim)
184
+ self.head = nn.Linear(hparams.embed_dim, vocab_size_img, bias=False)
185
+
186
+ self.ctx_len_img = hparams.ctx_len_img
187
+ self.n_layers = hparams.n_layers
188
+
189
+ self.apply(self._init_weights)
190
+
191
+ def _init_weights(self, module: nn.Module) -> None:
192
+ if isinstance(module, (nn.Linear, nn.Embedding)):
193
+ module.weight.data.normal_(mean=0.0, std=0.02)
194
+ if isinstance(module, nn.Linear) and module.bias is not None:
195
+ module.bias.data.zero_()
196
+ elif isinstance(module, nn.LayerNorm):
197
+ module.bias.data.zero_()
198
+ module.weight.data.fill_(1.0)
199
+
200
+ @torch.no_grad()
201
+ def sampling(self,
202
+ sos: torch.FloatTensor,
203
+ codes: torch.LongTensor,
204
+ pos_codes: torch.LongTensor,
205
+ n_samples: int = 16,
206
+ use_fp16: bool = True,
207
+ past: Optional[torch.Tensor] = None) -> Tuple[torch.FloatTensor, List[torch.FloatTensor]]:
208
+ with autocast(enabled=use_fp16):
209
+ if codes is None:
210
+ assert past is None
211
+ xs = self.drop(sos)
212
+ presents = []
213
+ for i, block in enumerate(self.blocks):
214
+ xs, present = block.sample(xs, layer_past=None)
215
+ presents.append(present)
216
+ xs = self.ln_f(xs)
217
+ logits = self.head(xs)[:, -1]
218
+ else:
219
+ if past is None:
220
+ xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
221
+ xs = torch.cat([sos, xs], dim=1)
222
+ else:
223
+ xs = self.tok_emb_img(codes) + self.pos_emb_img(pos_codes)
224
+ xs = self.drop(xs)
225
+
226
+ past = torch.cat(past, dim=-2) if past is not None else past
227
+ presents = []
228
+ for i, block in enumerate(self.blocks):
229
+ xs, present = block.sample(xs, layer_past=None if past is None else past[i])
230
+ presents.append(present)
231
+
232
+ xs = self.ln_f(xs)
233
+ logits = self.head(xs)[:, -1]
234
+ return logits, presents
235
+
236
+ def forward(self,
237
+ codes: torch.LongTensor,
238
+ labels: Optional[torch.LongTensor] = None) -> torch.FloatTensor:
239
+ B, T = codes.shape
240
+ xps = torch.arange(T, device=codes.device).repeat((B, 1))
241
+ sos = self.sos.repeat((B, 1, 1)) if labels is None else self.sos(labels).unsqueeze(1)
242
+
243
+ h = self.tok_emb_img(codes) + self.pos_emb_img(xps)
244
+ h = torch.cat([sos, h[:, :-1]], dim=1).contiguous()
245
+
246
+ h = self.drop(h)
247
+ h = self.blocks(h)
248
+ h = self.ln_f(h)
249
+ logits = self.head(h)
250
+ return logits
251
+
252
+ def from_ckpt(self, path: str, strict: bool = True) -> None:
253
+ ckpt = torch.load(path, map_location='cpu')['state_dict']
254
+ self.load_state_dict(ckpt, strict=strict)
255
+ print(f'{path} successfully restored..')
dalle/models/tokenizer.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ from functools import partial
9
+ from tokenizers import CharBPETokenizer
10
+
11
+
12
+ def build_tokenizer(path: str,
13
+ context_length: int = 64,
14
+ *args,
15
+ **kwargs):
16
+ from_file = partial(CharBPETokenizer.from_file,
17
+ vocab_filename=os.path.join(path, 'bpe-16k-vocab.json'),
18
+ merges_filename=os.path.join(path, 'bpe-16k-merges.txt'),
19
+ unk_token='[UNK]')
20
+ tokenizer = from_file(*args, **kwargs)
21
+ tokenizer.add_special_tokens(['[PAD]'])
22
+ tokenizer.enable_padding(length=context_length,
23
+ pad_id=tokenizer.token_to_id('[PAD]'))
24
+ tokenizer.enable_truncation(max_length=context_length)
25
+ print(f'{path} successfully restored..')
26
+ return tokenizer
dalle/utils/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .utils import *
2
+ from .config import *
3
+ from .sampling import *
dalle/utils/config.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ from typing import Optional, List
8
+ from dataclasses import dataclass, field
9
+ from omegaconf import OmegaConf
10
+
11
+
12
+ @dataclass
13
+ class DataConfig:
14
+ dataset: Optional[str] = None
15
+ tokenizer_type: str = 'CharBPE'
16
+ context_length: int = 64
17
+ image_resolution: int = 256
18
+ transforms: str = 'dalle-vqvae'
19
+ bpe_pdrop: Optional[float] = None
20
+
21
+
22
+ @dataclass
23
+ class Stage1Hparams:
24
+ double_z: bool = False
25
+ z_channels: int = 256
26
+ resolution: int = 256
27
+ in_channels: int = 3
28
+ out_ch: int = 3
29
+ ch: int = 128
30
+ ch_mult: List[int] = field(default_factory=lambda: [1, 1, 2, 2, 4])
31
+ num_res_blocks: int = 2
32
+ attn_resolutions: List[int] = field(default_factory=lambda: [16])
33
+ pdrop: float = 0.0
34
+
35
+
36
+ @dataclass
37
+ class Stage2Hparams:
38
+ embed_dim: int = 1536
39
+ n_layers: int = 42
40
+ n_heads: int = 24
41
+ n_dense_layers: int = 42
42
+ ctx_len_img: int = 256
43
+ ctx_len_txt: int = 64
44
+ embd_pdrop: float = 0.0
45
+ resid_pdrop: float = 0.0
46
+ attn_pdrop: float = 0.0
47
+ mlp_bias: bool = True
48
+ attn_bias: bool = True
49
+ gelu_use_approx: bool = False
50
+ use_head_txt: bool = True
51
+ n_classes: Optional[int] = None
52
+
53
+
54
+ @dataclass
55
+ class Stage1Config:
56
+ type: str = 'vqgan'
57
+ embed_dim: int = 256
58
+ n_embed: int = 16384
59
+ hparams: Stage1Hparams = Stage1Hparams()
60
+
61
+
62
+ @dataclass
63
+ class Stage2Config:
64
+ type: str = 'transformer1d'
65
+ vocab_size_txt: int = 16384
66
+ vocab_size_img: int = 16384
67
+ use_cls_cond: Optional[bool] = None
68
+ hparams: Stage2Hparams = Stage2Hparams()
69
+
70
+
71
+ @dataclass
72
+ class WarmupConfig:
73
+ epoch: int = 1
74
+ multiplier: int = 1
75
+ buffer_epoch: int = 0
76
+ min_lr: float = 0.0
77
+ mode: str = 'fix'
78
+ peak_lr: float = 1e-4
79
+ start_from_zero: bool = True
80
+
81
+
82
+ @dataclass
83
+ class OptConfig:
84
+ opt_type: str = 'adamW'
85
+ base_lr: float = 1e-4
86
+ weight_decay: float = 1e-4
87
+ betas: List[float] = field(default_factory=lambda: [0.9, 0.99])
88
+ grad_clip_norm: float = 1.0
89
+
90
+ sched_type: str = 'cosine'
91
+ max_steps: int = 0
92
+ min_lr: float = 0.0
93
+
94
+
95
+ @dataclass
96
+ class ExpConfig:
97
+ local_batch_size: int = 4
98
+ total_batch_size: int = 512
99
+ valid_batch_size: int = 32
100
+ epochs: int = 10
101
+ save_ckpt_freq: int = 2
102
+ test_freq: int = 1
103
+ use_amp: bool = True
104
+
105
+
106
+ @dataclass
107
+ class DefaultConfig:
108
+ dataset: DataConfig = DataConfig()
109
+ stage1: Stage1Config = Stage1Config()
110
+ stage2: Stage2Config = Stage2Config()
111
+
112
+
113
+ @dataclass
114
+ class FineTuningConfig:
115
+ dataset: DataConfig = DataConfig()
116
+ stage1: Stage1Config = Stage1Config()
117
+ stage2: Stage2Config = Stage2Config()
118
+ optimizer: OptConfig = OptConfig()
119
+ experiment: ExpConfig = ExpConfig()
120
+
121
+
122
+ def get_base_config(use_default=True):
123
+ return OmegaConf.structured(DefaultConfig if use_default else FineTuningConfig)
dalle/utils/sampling.py ADDED
@@ -0,0 +1,152 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import torch
8
+ from typing import Optional
9
+ from tqdm import tqdm
10
+ from torch.nn import functional as F
11
+
12
+
13
+ def cutoff_topk_logits(logits: torch.FloatTensor, k: int) -> torch.FloatTensor:
14
+ if k is None:
15
+ return logits
16
+ else:
17
+ v, ix = torch.topk(logits, k)
18
+ out = logits.clone()
19
+ out[out < v[:, [-1]]] = -float('Inf')
20
+ return out
21
+
22
+
23
+ def cutoff_topp_probs(probs: torch.FloatTensor, p: float) -> torch.FloatTensor:
24
+ if p is None:
25
+ return probs
26
+ else:
27
+ sorted_probs, sorted_indices = torch.sort(probs, dim=-1, descending=True)
28
+ cum_probs = torch.cumsum(sorted_probs, dim=-1)
29
+
30
+ sorted_idx_remove_cond = cum_probs >= p
31
+
32
+ sorted_idx_remove_cond[..., 1:] = sorted_idx_remove_cond[..., :-1].clone()
33
+ sorted_idx_remove_cond[..., 0] = 0
34
+
35
+ indices_to_remove = sorted_idx_remove_cond.scatter(-1, sorted_indices, sorted_idx_remove_cond)
36
+ probs = probs.masked_fill(indices_to_remove, 0.0)
37
+ norm_probs = probs / torch.sum(probs, dim=-1, keepdim=True)
38
+ return norm_probs
39
+
40
+
41
+ def get_positional_encoding(inputs: torch.LongTensor, mode: str = '1d') -> torch.LongTensor:
42
+ device = inputs.device
43
+ if mode == '1d':
44
+ B, N = inputs.shape
45
+ xs_pos = torch.arange(N, device=device).repeat((B, 1))
46
+ elif mode == '2d':
47
+ B, H, W = inputs.shape
48
+ xs_pos_h = torch.arange(H, device=device).repeat(B, W, 1).transpose(1, 2)
49
+ xs_pos_w = torch.arange(W, device=device).repeat(B, H, 1)
50
+ xs_pos = (xs_pos_h, xs_pos_w)
51
+ else:
52
+ raise ValueError('%s positional encoding invalid' % mode)
53
+ return xs_pos
54
+
55
+
56
+ @torch.no_grad()
57
+ def sampling(model: torch.nn.Module,
58
+ tokens: torch.LongTensor,
59
+ top_k: Optional[float] = None,
60
+ top_p: Optional[float] = None,
61
+ softmax_temperature: float = 1.0,
62
+ is_tqdm: bool = True,
63
+ use_fp16: bool = True,
64
+ max_seq_len: int = 256) -> torch.LongTensor:
65
+ code = None
66
+ past = None
67
+
68
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
69
+ pos_enc_tokens = get_positional_encoding(tokens, mode='1d')
70
+
71
+ for cnt, h in enumerate(pbar):
72
+ if code is None:
73
+ code_ = None
74
+ pos_enc_code_ = None
75
+ else:
76
+ code_ = code.clone().detach()
77
+ pos_enc_code_ = get_positional_encoding(code_, mode='1d')
78
+ code_ = code_[:, cnt-1].unsqueeze(-1)
79
+ pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
80
+
81
+ logits, present = model.sampling(images=code_,
82
+ texts=tokens,
83
+ pos_images=pos_enc_code_,
84
+ pos_texts=pos_enc_tokens,
85
+ use_fp16=use_fp16,
86
+ past=past)
87
+ logits = logits.to(dtype=torch.float32)
88
+ logits = logits / softmax_temperature
89
+
90
+ present = torch.stack(present).clone().detach()
91
+ if past is None:
92
+ past = [present]
93
+ else:
94
+ past.append(present)
95
+
96
+ logits = cutoff_topk_logits(logits, top_k)
97
+ probs = F.softmax(logits, dim=-1)
98
+ probs = cutoff_topp_probs(probs, top_p)
99
+
100
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
101
+ code = idx if code is None else torch.cat([code, idx], axis=1)
102
+
103
+ del past
104
+ return code
105
+
106
+
107
+ @torch.no_grad()
108
+ def sampling_igpt(model: torch.nn.Module,
109
+ sos: torch.FloatTensor,
110
+ top_k: Optional[float] = None,
111
+ top_p: Optional[float] = None,
112
+ softmax_temperature: float = 1.0,
113
+ is_tqdm: bool = True,
114
+ use_fp16: bool = True,
115
+ max_seq_len: int = 256) -> torch.LongTensor:
116
+ code = None
117
+ past = None
118
+ pbar = tqdm(range(max_seq_len), total=max_seq_len) if is_tqdm else range(max_seq_len)
119
+
120
+ for cnt, h in enumerate(pbar):
121
+ if code is None:
122
+ code_ = None
123
+ pos_enc_code_ = None
124
+ else:
125
+ code_ = code.clone().detach()
126
+ pos_enc_code_ = get_positional_encoding(code_, mode='1d')
127
+ code_ = code_[:, cnt-1].unsqueeze(-1)
128
+ pos_enc_code_ = pos_enc_code_[:, cnt-1].unsqueeze(-1)
129
+
130
+ logits, present = model.sampling(sos=sos,
131
+ codes=code_,
132
+ pos_codes=pos_enc_code_,
133
+ use_fp16=use_fp16,
134
+ past=past)
135
+ logits = logits.to(dtype=torch.float32)
136
+ logits = logits / softmax_temperature
137
+
138
+ present = torch.stack(present).clone().detach()
139
+ if past is None:
140
+ past = [present]
141
+ else:
142
+ past.append(present)
143
+
144
+ logits = cutoff_topk_logits(logits, top_k)
145
+ probs = F.softmax(logits, dim=-1)
146
+ probs = cutoff_topp_probs(probs, top_p)
147
+
148
+ idx = torch.multinomial(probs, num_samples=1).clone().detach()
149
+ code = idx if code is None else torch.cat([code, idx], axis=1)
150
+
151
+ del past
152
+ return code
dalle/utils/utils.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ import random
9
+ import urllib
10
+ import hashlib
11
+ import tarfile
12
+ import torch
13
+ import clip
14
+ import numpy as np
15
+ from PIL import Image
16
+ from torch.nn import functional as F
17
+ from tqdm import tqdm
18
+
19
+
20
+ def set_seed(seed: int):
21
+ random.seed(seed)
22
+ np.random.seed(seed)
23
+ torch.manual_seed(seed)
24
+ torch.cuda.manual_seed_all(seed)
25
+
26
+
27
+ @torch.no_grad()
28
+ def clip_score(prompt: str,
29
+ images: np.ndarray,
30
+ model_clip: torch.nn.Module,
31
+ preprocess_clip,
32
+ device: str) -> np.ndarray:
33
+ images = [preprocess_clip(Image.fromarray((image*255).astype(np.uint8))) for image in images]
34
+ images = torch.stack(images, dim=0).to(device=device)
35
+ texts = clip.tokenize(prompt).to(device=device)
36
+ texts = torch.repeat_interleave(texts, images.shape[0], dim=0)
37
+
38
+ image_features = model_clip.encode_image(images)
39
+ text_features = model_clip.encode_text(texts)
40
+
41
+ scores = F.cosine_similarity(image_features, text_features).squeeze()
42
+ rank = torch.argsort(scores, descending=True).cpu().numpy()
43
+ return rank
44
+
45
+
46
+ def download(url: str, root: str) -> str:
47
+ os.makedirs(root, exist_ok=True)
48
+ filename = os.path.basename(url)
49
+ pathname = filename[:-len('.tar.gz')]
50
+
51
+ expected_md5 = url.split("/")[-2]
52
+ download_target = os.path.join(root, filename)
53
+ result_path = os.path.join(root, pathname)
54
+
55
+ if os.path.isfile(download_target) and (os.path.exists(result_path) and not os.path.isfile(result_path)):
56
+ return result_path
57
+
58
+ with urllib.request.urlopen(url) as source, open(download_target, 'wb') as output:
59
+ with tqdm(total=int(source.info().get('Content-Length')), ncols=80, unit='iB', unit_scale=True,
60
+ unit_divisor=1024) as loop:
61
+ while True:
62
+ buffer = source.read(8192)
63
+ if not buffer:
64
+ break
65
+
66
+ output.write(buffer)
67
+ loop.update(len(buffer))
68
+
69
+ if hashlib.md5(open(download_target, 'rb').read()).hexdigest() != expected_md5:
70
+ raise RuntimeError(f'Model has been downloaded but the md5 checksum does not not match')
71
+
72
+ with tarfile.open(download_target, 'r:gz') as f:
73
+ pbar = tqdm(f.getmembers(), total=len(f.getmembers()))
74
+ for member in pbar:
75
+ pbar.set_description(f'extracting: {member.name} (size:{member.size // (1024 * 1024)}MB)')
76
+ f.extract(member=member, path=root)
77
+
78
+ return result_path
79
+
80
+
81
+ def realpath_url_or_path(url_or_path: str, root: str = None) -> str:
82
+ if urllib.parse.urlparse(url_or_path).scheme in ('http', 'https'):
83
+ return download(url_or_path, root)
84
+ return url_or_path
examples/sampling_ex.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ import sys
9
+ import argparse
10
+ import clip
11
+ import numpy as np
12
+ from PIL import Image
13
+
14
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
15
+
16
+ from dalle.models import Dalle
17
+ from dalle.utils.utils import set_seed, clip_score
18
+
19
+
20
+ parser = argparse.ArgumentParser()
21
+ parser.add_argument('-n', '--num_candidates', type=int, default=96)
22
+ parser.add_argument('--prompt', type=str, default='A painting of a tree on the ocean')
23
+ parser.add_argument('--softmax-temperature', type=float, default=1.0)
24
+ parser.add_argument('--top-k', type=int, default=256)
25
+ parser.add_argument('--top-p', type=float, default=None, help='0.0 <= top-p <= 1.0')
26
+ parser.add_argument('--seed', type=int, default=0)
27
+
28
+ args = parser.parse_args()
29
+
30
+ # Setup
31
+ assert args.top_k <= 256, "It is recommended that top_k is set lower than 256."
32
+
33
+ set_seed(args.seed)
34
+ device = 'cuda:0'
35
+ model = Dalle.from_pretrained('minDALL-E/1.3B') # This will automatically download the pretrained model.
36
+ model.to(device=device)
37
+
38
+ # Sampling
39
+ images = model.sampling(prompt=args.prompt,
40
+ top_k=args.top_k,
41
+ top_p=args.top_p,
42
+ softmax_temperature=args.softmax_temperature,
43
+ num_candidates=args.num_candidates,
44
+ device=device).cpu().numpy()
45
+ images = np.transpose(images, (0, 2, 3, 1))
46
+
47
+ # CLIP Re-ranking
48
+ model_clip, preprocess_clip = clip.load("ViT-B/32", device=device)
49
+ model_clip.to(device=device)
50
+ rank = clip_score(prompt=args.prompt,
51
+ images=images,
52
+ model_clip=model_clip,
53
+ preprocess_clip=preprocess_clip,
54
+ device=device)
55
+
56
+ # Save images
57
+ images = images[rank]
58
+ print(rank, images.shape)
59
+ if not os.path.exists('./figures'):
60
+ os.makedirs('./figures')
61
+ for i in range(min(16, args.num_candidates)):
62
+ im = Image.fromarray((images[i]*255).astype(np.uint8))
63
+ im.save(f'./figures/{args.prompt}_{i}.png')
examples/sampling_interactive_demo.ipynb ADDED
@@ -0,0 +1,298 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": 1,
6
+ "id": "cdf36725-ec00-4027-95d6-374340c2264e",
7
+ "metadata": {},
8
+ "outputs": [
9
+ {
10
+ "name": "stderr",
11
+ "output_type": "stream",
12
+ "text": [
13
+ "100%|█████████████████████████████████████| 4.72G/4.72G [02:04<00:00, 40.7MiB/s]\n",
14
+ "extracting: ./1.3B/tokenizer/bpe-16k-vocab.json (size:0MB): 100%|██████████| 7/7 [00:59<00:00, 8.51s/it]\n"
15
+ ]
16
+ },
17
+ {
18
+ "name": "stdout",
19
+ "output_type": "stream",
20
+ "text": [
21
+ "/root/.cache/minDALL-E/1.3B/tokenizer successfully restored..\n",
22
+ "/root/.cache/minDALL-E/1.3B/stage1_last.ckpt successfully restored..\n"
23
+ ]
24
+ },
25
+ {
26
+ "name": "stderr",
27
+ "output_type": "stream",
28
+ "text": [
29
+ " 0%| | 0.00/338M [00:00<?, ?iB/s]"
30
+ ]
31
+ },
32
+ {
33
+ "name": "stdout",
34
+ "output_type": "stream",
35
+ "text": [
36
+ "/root/.cache/minDALL-E/1.3B/stage2_last.ckpt succesfully restored..\n"
37
+ ]
38
+ },
39
+ {
40
+ "name": "stderr",
41
+ "output_type": "stream",
42
+ "text": [
43
+ "100%|███████████████████████████████████████| 338M/338M [00:09<00:00, 38.5MiB/s]\n"
44
+ ]
45
+ }
46
+ ],
47
+ "source": [
48
+ "import os\n",
49
+ "import sys\n",
50
+ "import math\n",
51
+ "import argparse\n",
52
+ "import clip\n",
53
+ "import numpy as np\n",
54
+ "%matplotlib inline\n",
55
+ "import matplotlib.pyplot as plt\n",
56
+ "from PIL import Image\n",
57
+ "\n",
58
+ "sys.path.append(os.path.dirname(os.getcwd()))\n",
59
+ "\n",
60
+ "from dalle.models import Dalle\n",
61
+ "from dalle.utils.utils import set_seed, clip_score\n",
62
+ "\n",
63
+ "device = 'cuda:0'\n",
64
+ "model = Dalle.from_pretrained(\"minDALL-E/1.3B\")\n",
65
+ "model_clip, preprocess_clip = clip.load(\"ViT-B/32\", device=device)\n",
66
+ "\n",
67
+ "model_clip.to(device=device)\n",
68
+ "model.to(device=device)\n",
69
+ "\n",
70
+ "def sampling(prompt, top_k, softmax_temperature, seed, num_candidates=96, num_samples_for_display=36):\n",
71
+ " # Setup\n",
72
+ " n_row = int(math.sqrt(num_samples_for_display))\n",
73
+ " n_col = int(math.sqrt(num_samples_for_display))\n",
74
+ " set_seed(seed)\n",
75
+ " \n",
76
+ " # Sampling\n",
77
+ " images = model.sampling(prompt=prompt,\n",
78
+ " top_k=top_k,\n",
79
+ " top_p=None,\n",
80
+ " softmax_temperature=softmax_temperature,\n",
81
+ " num_candidates=num_candidates,\n",
82
+ " device=device).cpu().numpy()\n",
83
+ " images = np.transpose(images, (0, 2, 3, 1))\n",
84
+ "\n",
85
+ " # CLIP Re-ranking\n",
86
+ " rank = clip_score(prompt=prompt, images=images, model_clip=model_clip, preprocess_clip=preprocess_clip, device=device)\n",
87
+ " images = images[rank]\n",
88
+ " \n",
89
+ " images = images[:num_samples_for_display]\n",
90
+ " fig = plt.figure(figsize=(8*n_row, 8*n_col))\n",
91
+ "\n",
92
+ " for i in range(num_samples_for_display):\n",
93
+ " ax = fig.add_subplot(n_row, n_col, i+1)\n",
94
+ " ax.imshow(images[i])\n",
95
+ " ax.set_axis_off()\n",
96
+ "\n",
97
+ " plt.tight_layout()\n",
98
+ " plt.show()"
99
+ ]
100
+ },
101
+ {
102
+ "cell_type": "code",
103
+ "execution_count": 2,
104
+ "id": "619add15-073e-40f4-9a97-06b89d647c81",
105
+ "metadata": {},
106
+ "outputs": [
107
+ {
108
+ "data": {
109
+ "application/vnd.jupyter.widget-view+json": {
110
+ "model_id": "ee477531ea0e4b86b20d997f8cb83767",
111
+ "version_major": 2,
112
+ "version_minor": 0
113
+ },
114
+ "text/plain": [
115
+ "IntSlider(value=0, description='RND SEED: ', max=1024)"
116
+ ]
117
+ },
118
+ "metadata": {},
119
+ "output_type": "display_data"
120
+ },
121
+ {
122
+ "data": {
123
+ "application/vnd.jupyter.widget-view+json": {
124
+ "model_id": "d63edc4725ef4f4e8a6f03f7693a481d",
125
+ "version_major": 2,
126
+ "version_minor": 0
127
+ },
128
+ "text/plain": [
129
+ "FloatSlider(value=1.0, description='SOFTMAX TEMPERATURE:', max=5.0, step=0.2)"
130
+ ]
131
+ },
132
+ "metadata": {},
133
+ "output_type": "display_data"
134
+ },
135
+ {
136
+ "data": {
137
+ "application/vnd.jupyter.widget-view+json": {
138
+ "model_id": "5bb9170e9e8b4686a661799d8aff3901",
139
+ "version_major": 2,
140
+ "version_minor": 0
141
+ },
142
+ "text/plain": [
143
+ "IntSlider(value=256, description='TOP-K:', max=512, step=16)"
144
+ ]
145
+ },
146
+ "metadata": {},
147
+ "output_type": "display_data"
148
+ },
149
+ {
150
+ "data": {
151
+ "application/vnd.jupyter.widget-view+json": {
152
+ "model_id": "6b97b49debfc4f7ab002748e9fd89864",
153
+ "version_major": 2,
154
+ "version_minor": 0
155
+ },
156
+ "text/plain": [
157
+ "Text(value='A painting of a monkey with sunglasses in the frame', description='String:', placeholder='Text pro…"
158
+ ]
159
+ },
160
+ "metadata": {},
161
+ "output_type": "display_data"
162
+ },
163
+ {
164
+ "data": {
165
+ "application/vnd.jupyter.widget-view+json": {
166
+ "model_id": "a520b10d8c0b4dd0bb6db56dc37b4422",
167
+ "version_major": 2,
168
+ "version_minor": 0
169
+ },
170
+ "text/plain": [
171
+ "Button(description='Generate!', style=ButtonStyle())"
172
+ ]
173
+ },
174
+ "metadata": {},
175
+ "output_type": "display_data"
176
+ },
177
+ {
178
+ "data": {
179
+ "application/vnd.jupyter.widget-view+json": {
180
+ "model_id": "5a98437abf964636a467677dc4f816bb",
181
+ "version_major": 2,
182
+ "version_minor": 0
183
+ },
184
+ "text/plain": [
185
+ "Output()"
186
+ ]
187
+ },
188
+ "metadata": {},
189
+ "output_type": "display_data"
190
+ },
191
+ {
192
+ "data": {
193
+ "application/vnd.jupyter.widget-view+json": {
194
+ "model_id": "90d05006d50e4d88b8fb7c36095b12e7",
195
+ "version_major": 2,
196
+ "version_minor": 0
197
+ },
198
+ "text/plain": [
199
+ "Output()"
200
+ ]
201
+ },
202
+ "metadata": {},
203
+ "output_type": "display_data"
204
+ }
205
+ ],
206
+ "source": [
207
+ "import ipywidgets as widgets\n",
208
+ "from IPython.display import display\n",
209
+ "from IPython.display import clear_output\n",
210
+ "\n",
211
+ "output = widgets.Output()\n",
212
+ "plot_output = widgets.Output()\n",
213
+ "\n",
214
+ "def btn_eventhandler(obj):\n",
215
+ " output.clear_output()\n",
216
+ " plot_output.clear_output()\n",
217
+ " \n",
218
+ " with output:\n",
219
+ " print(f'SEED: {slider_seed.value}')\n",
220
+ " print(f'Softmax Temperature: {slider_temp.value}')\n",
221
+ " print(f'Top-K: {slider_topk.value}')\n",
222
+ " print(f'Text prompt: {wd_text.value}')\n",
223
+ " \n",
224
+ " with plot_output:\n",
225
+ " sampling(prompt=wd_text.value, top_k=slider_topk.value, softmax_temperature=slider_temp.value, seed=slider_seed.value)\n",
226
+ " \n",
227
+ "slider_seed = widgets.IntSlider(\n",
228
+ " min=0,\n",
229
+ " max=1024,\n",
230
+ " step=1,\n",
231
+ " description='RND SEED: ',\n",
232
+ " value=0\n",
233
+ ")\n",
234
+ "slider_topk = widgets.IntSlider(\n",
235
+ " min=0,\n",
236
+ " max=512,\n",
237
+ " step=16,\n",
238
+ " description='TOP-K:',\n",
239
+ " value=256\n",
240
+ ")\n",
241
+ "slider_temp = widgets.FloatSlider(\n",
242
+ " min=0.0,\n",
243
+ " max=5.0,\n",
244
+ " step=0.2,\n",
245
+ " description='SOFTMAX TEMPERATURE:',\n",
246
+ " value=1.0\n",
247
+ ")\n",
248
+ "wd_text = widgets.Text(\n",
249
+ " value='A painting of a monkey with sunglasses in the frame',\n",
250
+ " placeholder='Text prompt',\n",
251
+ " description='String:',\n",
252
+ " disabled=False\n",
253
+ ")\n",
254
+ "\n",
255
+ "display(slider_seed)\n",
256
+ "display(slider_temp)\n",
257
+ "display(slider_topk)\n",
258
+ "display(wd_text)\n",
259
+ "\n",
260
+ "btn = widgets.Button(description='Generate!')\n",
261
+ "display(btn)\n",
262
+ "btn.on_click(btn_eventhandler)\n",
263
+ "\n",
264
+ "display(output)\n",
265
+ "display(plot_output)"
266
+ ]
267
+ },
268
+ {
269
+ "cell_type": "code",
270
+ "execution_count": null,
271
+ "id": "20571236-3b9a-426e-ab29-96b643c8cbe1",
272
+ "metadata": {},
273
+ "outputs": [],
274
+ "source": []
275
+ }
276
+ ],
277
+ "metadata": {
278
+ "kernelspec": {
279
+ "display_name": "Python 3 (ipykernel)",
280
+ "language": "python",
281
+ "name": "python3"
282
+ },
283
+ "language_info": {
284
+ "codemirror_mode": {
285
+ "name": "ipython",
286
+ "version": 3
287
+ },
288
+ "file_extension": ".py",
289
+ "mimetype": "text/x-python",
290
+ "name": "python",
291
+ "nbconvert_exporter": "python",
292
+ "pygments_lexer": "ipython3",
293
+ "version": "3.7.7"
294
+ }
295
+ },
296
+ "nbformat": 4,
297
+ "nbformat_minor": 5
298
+ }
examples/transfer_learning_ex.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ------------------------------------------------------------------------------------
2
+ # Minimal DALL-E
3
+ # Copyright (c) 2021 KakaoBrain. All Rights Reserved.
4
+ # Licensed under the Apache License, Version 2.0 [see LICENSE for details]
5
+ # ------------------------------------------------------------------------------------
6
+
7
+ import os
8
+ import sys
9
+ import argparse
10
+ from typing import Optional
11
+ from datetime import datetime
12
+
13
+ import torch
14
+ from torch.utils.data import DataLoader
15
+ import torchvision
16
+ import torchvision.transforms as transforms
17
+ import pytorch_lightning as pl
18
+ from pytorch_lightning.callbacks import ModelCheckpoint, Callback
19
+ from pytorch_lightning.loggers import TensorBoardLogger
20
+ from pytorch_lightning.utilities.distributed import rank_zero_only
21
+
22
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
23
+
24
+ from dalle.models import ImageGPT
25
+
26
+
27
+ parser = argparse.ArgumentParser()
28
+
29
+ parser.add_argument('-d', '--config-downstream', type=str, default=None, required=True)
30
+ parser.add_argument('-u', '--path-upstream', type=str, default=None, required=True)
31
+ parser.add_argument('-r', '--result-path', type=str, default=None, required=True)
32
+ parser.add_argument('--imagenet-path', type=str, default=None, required=True)
33
+
34
+ parser.add_argument('--n-gpus', type=int, default=1)
35
+ parser.add_argument('--seed', type=int, default=0)
36
+
37
+
38
+ args = parser.parse_args()
39
+
40
+
41
+ class ImageLogger(Callback):
42
+ def __init__(self):
43
+ super().__init__()
44
+
45
+ @rank_zero_only
46
+ def log_img(self, pl_module, batch, current_epoch, split="train"):
47
+ with torch.no_grad():
48
+ images, labels = batch
49
+ recons = pl_module.stage1(images)
50
+ images = images.cpu()
51
+ recons = recons.cpu()
52
+
53
+ grid_org = (torchvision.utils.make_grid(images, nrow=8) + 1.0) / 2.0
54
+ grid_rec = (torchvision.utils.make_grid(recons, nrow=8) + 1.0) / 2.0
55
+ grid_rec = torch.clip(grid_rec, min=0, max=1)
56
+
57
+ pl_module.logger.experiment.add_image(f"images_org/{split}", grid_org, global_step=current_epoch)
58
+ pl_module.logger.experiment.add_image(f"images_rec/{split}", grid_rec, global_step=current_epoch)
59
+
60
+ def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
61
+ if batch_idx == 0 and trainer.current_epoch < 5:
62
+ self.log_img(pl_module, batch, current_epoch=trainer.current_epoch, split="train")
63
+
64
+ def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
65
+ if batch_idx == 0 and trainer.current_epoch < 5:
66
+ self.log_img(pl_module, batch, current_epoch=trainer.current_epoch, split="test")
67
+
68
+
69
+ class ImageNetDataModule(pl.LightningDataModule):
70
+ def __init__(self,
71
+ data_dir: Optional[str] = None,
72
+ image_resolution: int = 256,
73
+ train_batch_size: int = 2,
74
+ valid_batch_size: int = 32,
75
+ num_workers: int = 8):
76
+ super().__init__()
77
+
78
+ self.data_dir = data_dir
79
+ self.image_resolution = image_resolution
80
+ self.train_batch_size = train_batch_size
81
+ self.valid_batch_size = valid_batch_size
82
+ self.num_workers = num_workers
83
+
84
+ self.train_transform = transforms.Compose(
85
+ [transforms.Resize(image_resolution),
86
+ transforms.RandomCrop(image_resolution),
87
+ transforms.ToTensor(),
88
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
89
+ )
90
+ self.valid_transform = transforms.Compose(
91
+ [transforms.Resize(image_resolution),
92
+ transforms.CenterCrop(image_resolution),
93
+ transforms.ToTensor(),
94
+ transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5])]
95
+ )
96
+
97
+ def setup(self, stage=None):
98
+ self.trainset = torchvision.datasets.ImageNet(root=self.data_dir, split='train', transform=self.train_transform)
99
+ self.validset = torchvision.datasets.ImageNet(root=self.data_dir, split='val', transform=self.valid_transform)
100
+
101
+ def train_dataloader(self):
102
+ return DataLoader(self.trainset,
103
+ batch_size=self.train_batch_size,
104
+ num_workers=self.num_workers,
105
+ pin_memory=True)
106
+
107
+ def valid_dataloader(self):
108
+ return DataLoader(self.validset,
109
+ batch_size=self.valid_batch_size,
110
+ num_workers=self.num_workers,
111
+ pin_memory=True)
112
+
113
+
114
+ def setup_callbacks(config):
115
+ # Setup callbacks
116
+ now = datetime.now().strftime('%d%m%Y_%H%M%S')
117
+ result_path = os.path.join(args.result_path,
118
+ os.path.basename(args.config_downstream).split('.')[0],
119
+ now)
120
+ ckpt_path = os.path.join(result_path, 'ckpt')
121
+ log_path = os.path.join(result_path, 'log')
122
+
123
+ checkpoint_callback = ModelCheckpoint(
124
+ dirpath=ckpt_path,
125
+ filename="imagenet-clscond-gen-{epoch:02d}" if config.stage2.use_cls_cond else
126
+ "imagenet-uncond-gen-{epoch:02d}",
127
+ every_n_epochs=config.experiment.save_ckpt_freq,
128
+ save_weights_only=True,
129
+ save_last=True
130
+ )
131
+ logger = TensorBoardLogger(log_path, name="iGPT")
132
+ logger_img = ImageLogger()
133
+ return checkpoint_callback, logger, logger_img
134
+
135
+
136
+ if __name__ == '__main__':
137
+ pl.seed_everything(args.seed)
138
+
139
+ # Build iGPT
140
+ model, config = ImageGPT.from_pretrained(args.path_upstream, args.config_downstream)
141
+
142
+ # Setup callbacks
143
+ ckpt_callback, logger, logger_img = setup_callbacks(config)
144
+
145
+ # Build data modules
146
+ dataset = ImageNetDataModule(data_dir=args.imagenet_path,
147
+ image_resolution=config.dataset.image_resolution,
148
+ train_batch_size=config.experiment.local_batch_size,
149
+ valid_batch_size=config.experiment.valid_batch_size,
150
+ num_workers=16)
151
+ dataset.setup()
152
+ train_dataloader = dataset.train_dataloader()
153
+ valid_dataloader = dataset.valid_dataloader()
154
+ print(f"len(train_dataset) = {len(dataset.trainset)}")
155
+ print(f"len(valid_dataset) = {len(dataset.validset)}")
156
+
157
+ # Calculate how many batches are accumulated
158
+ assert config.experiment.total_batch_size % (config.experiment.local_batch_size * args.n_gpus) == 0
159
+ grad_accm_steps = config.experiment.total_batch_size // (config.experiment.local_batch_size * args.n_gpus)
160
+ config.optimizer.max_steps = len(dataset.trainset) // config.experiment.total_batch_size * config.experiment.epochs
161
+
162
+ # Build trainer
163
+ trainer = pl.Trainer(max_epochs=config.experiment.epochs,
164
+ accumulate_grad_batches=grad_accm_steps,
165
+ gradient_clip_val=config.optimizer.grad_clip_norm,
166
+ precision=16 if config.experiment.use_amp else 32,
167
+ callbacks=[ckpt_callback, logger_img],
168
+ accelerator="gpu",
169
+ devices=args.n_gpus,
170
+ strategy="ddp",
171
+ logger=logger)
172
+ trainer.fit(model, train_dataloader, valid_dataloader)
requirements.txt ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ torch==1.8.0
2
+ torchvision>=0.8.2
3
+ tokenizers>=0.10.2
4
+ pyflakes>=2.2.0
5
+ tqdm>=4.46.0
6
+ pytorch-lightning>=1.5
7
+ einops
8
+ omegaconf
9
+ git+https://github.com/openai/CLIP.git
10
+ matplotlib
setup.cfg ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ [flake8]
2
+ max-line-length = 120
3
+ ignore = E226, E402, W504