wangqinghehe commited on
Commit
3ab16a9
1 Parent(s): 726206e

0515_first_upload

Browse files
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,5 +1,5 @@
1
  ---
2
- title: CharacterFactory
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
@@ -7,7 +7,6 @@ sdk: gradio
7
  sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
10
- license: apache-2.0
11
  ---
12
 
13
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: 'CharacterFactory'
3
  emoji: 🖼
4
  colorFrom: purple
5
  colorTo: red
 
7
  sdk_version: 4.26.0
8
  app_file: app.py
9
  pinned: false
 
10
  ---
11
 
12
  Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
app.py CHANGED
@@ -1,146 +1,395 @@
 
 
 
 
 
 
 
1
  import gradio as gr
2
- import numpy as np
 
3
  import random
4
- from diffusers import DiffusionPipeline
5
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- device = "cuda" if torch.cuda.is_available() else "cpu"
 
8
 
9
- if torch.cuda.is_available():
10
- torch.cuda.max_memory_allocated(device=device)
11
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", torch_dtype=torch.float16, variant="fp16", use_safetensors=True)
12
- pipe.enable_xformers_memory_efficient_attention()
13
- pipe = pipe.to(device)
14
- else:
15
- pipe = DiffusionPipeline.from_pretrained("stabilityai/sdxl-turbo", use_safetensors=True)
16
- pipe = pipe.to(device)
17
 
18
- MAX_SEED = np.iinfo(np.int32).max
19
- MAX_IMAGE_SIZE = 1024
20
 
21
- def infer(prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps):
 
 
 
 
 
 
 
22
 
23
- if randomize_seed:
24
- seed = random.randint(0, MAX_SEED)
25
-
26
- generator = torch.Generator().manual_seed(seed)
27
-
28
- image = pipe(
29
- prompt = prompt,
30
- negative_prompt = negative_prompt,
31
- guidance_scale = guidance_scale,
32
- num_inference_steps = num_inference_steps,
33
- width = width,
34
- height = height,
35
- generator = generator
36
- ).images[0]
37
-
38
- return image
39
-
40
- examples = [
41
- "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k",
42
- "An astronaut riding a green horse",
43
- "A delicious ceviche cheesecake slice",
44
- ]
45
-
46
- css="""
47
- #col-container {
48
- margin: 0 auto;
49
- max-width: 520px;
50
  }
 
 
 
 
51
  """
52
 
53
- if torch.cuda.is_available():
54
- power_device = "GPU"
55
- else:
56
- power_device = "CPU"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
- with gr.Blocks(css=css) as demo:
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
- with gr.Column(elem_id="col-container"):
61
- gr.Markdown(f"""
62
- # Text-to-Image Gradio Template
63
- Currently running on {power_device}.
64
- """)
65
-
66
- with gr.Row():
67
-
68
- prompt = gr.Text(
69
- label="Prompt",
70
- show_label=False,
71
- max_lines=1,
72
- placeholder="Enter your prompt",
73
- container=False,
74
- )
75
-
76
- run_button = gr.Button("Run", scale=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
78
- result = gr.Image(label="Result", show_label=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
 
80
- with gr.Accordion("Advanced Settings", open=False):
81
-
82
- negative_prompt = gr.Text(
83
- label="Negative prompt",
84
- max_lines=1,
85
- placeholder="Enter a negative prompt",
86
- visible=False,
87
- )
88
-
89
- seed = gr.Slider(
90
- label="Seed",
91
- minimum=0,
92
- maximum=MAX_SEED,
93
- step=1,
94
- value=0,
95
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
96
 
97
- randomize_seed = gr.Checkbox(label="Randomize seed", value=True)
 
 
 
98
 
99
- with gr.Row():
100
-
101
- width = gr.Slider(
102
- label="Width",
103
- minimum=256,
104
- maximum=MAX_IMAGE_SIZE,
105
- step=32,
106
- value=512,
107
- )
108
-
109
- height = gr.Slider(
110
- label="Height",
111
- minimum=256,
112
- maximum=MAX_IMAGE_SIZE,
113
- step=32,
114
- value=512,
115
- )
116
 
117
- with gr.Row():
118
-
119
- guidance_scale = gr.Slider(
120
- label="Guidance scale",
121
- minimum=0.0,
122
- maximum=10.0,
123
- step=0.1,
124
- value=0.0,
125
- )
126
-
127
- num_inference_steps = gr.Slider(
128
- label="Number of inference steps",
129
- minimum=1,
130
- maximum=12,
131
- step=1,
132
- value=2,
133
- )
134
-
135
- gr.Examples(
136
- examples = examples,
137
- inputs = [prompt]
138
  )
 
139
 
140
- run_button.click(
141
- fn = infer,
142
- inputs = [prompt, negative_prompt, seed, randomize_seed, width, height, guidance_scale, num_inference_steps],
143
- outputs = [result]
 
 
144
  )
 
 
 
 
 
145
 
146
- demo.queue().launch()
 
1
+ import os
2
+ import io
3
+ import IPython.display
4
+ from PIL import Image
5
+ import base64
6
+ import io
7
+ from PIL import Image
8
  import gradio as gr
9
+ import requests
10
+ import time
11
  import random
12
+ import numpy as np
13
  import torch
14
+ import os
15
+ from transformers import ViTModel, ViTImageProcessor
16
+ from utils import text_encoder_forward
17
+ from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler
18
+ from utils import latents_to_images, downsampling, merge_and_save_images
19
+ from omegaconf import OmegaConf
20
+ from accelerate.utils import set_seed
21
+ from tqdm import tqdm
22
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
23
+ from PIL import Image
24
+ from models.celeb_embeddings import embedding_forward
25
+ import models.embedding_manager
26
+ import importlib
27
+ import time
28
 
29
+ import os
30
+ os.environ['GRADIO_TEMP_DIR'] = 'qinghewang/tmp'
31
 
32
+ title = r"""
33
+ <h1 align="center">CharacterFactory: Sampling Consistent Characters with GANs for Diffusion Models</h1>
34
+ """
 
 
 
 
 
35
 
36
+ description = r"""
37
+ <b>Official Gradio demo</b> for <a href='https://qinghew.github.io/CharacterFactory/' target='_blank'><b>CharacterFactory: Sampling Consistent Characters with GANs for Diffusion Models</b></a>.<br>
38
 
39
+ How to use:<br>
40
+ 1. Enter prompts (the character placeholder is "a person"), where each line will generate an image.
41
+ 2. You can choose to create a new character or continue to use the current one. We have provided some examples, click on the examples below to use.
42
+ 3. You can choose to use the Normal version (the gender is random), the Man version, and the Woman version.
43
+ 4. Click the <b>Generate</b> button to begin (Images are generated one by one).
44
+ 5. Our method can be applied to illustrating books and stories, creating brand ambassadors, developing presentations, art design, identity-consistent data construction and more. Looking forward to your explorations!😊
45
+ 6. If CharacterFactory is helpful, please help to ⭐ the <a href='https://github.com/qinghew/CharacterFactory' target='_blank'>Github Repo</a>. Thanks!
46
+ """
47
 
48
+ article = r"""
49
+ ---
50
+ 📝 **Citation**
51
+ <br>
52
+ If our work is helpful for your research or applications, please cite us via:
53
+ ```bibtex
54
+ @article{wang2024characterfactory,
55
+ title={CharacterFactory: Sampling Consistent Characters with GANs for Diffusion Models},
56
+ author={Wang, Qinghe and Li, Baolu and Li, Xiaomin and Cao, Bing and Ma, Liqian and Lu, Huchuan and Jia, Xu},
57
+ journal={arXiv preprint arXiv:2404.15677},
58
+ year={2024}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
  }
60
+ ```
61
+ 📧 **Contact**
62
+ <br>
63
+ If you have any questions, please feel free to open an issue or directly reach us out at <b>[email protected]</b>.
64
  """
65
 
66
+ css = '''
67
+ #color-bg{display:flex;justify-content: center;align-items: center;}
68
+ .color-bg-item{width: 100%; height: 32px}
69
+ #main_button{width:100%}
70
+ <style>
71
+ '''
72
+
73
+ model_id = "stabilityai/stable-diffusion-2-1-base"
74
+ # model_path = "/home/qinghewang/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6"
75
+ pipe = StableDiffusionPipeline.from_pretrained(model_id) # , torch_dtype=torch.float16
76
+ pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)
77
+ pipe = pipe.to("cuda")
78
+
79
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
80
+
81
+ vae = pipe.vae
82
+ unet = pipe.unet
83
+ text_encoder = pipe.text_encoder
84
+ tokenizer = pipe.tokenizer
85
+ scheduler = pipe.scheduler
86
+
87
+ input_dim = 64
88
+
89
+ original_forward = text_encoder.text_model.embeddings.forward
90
+ text_encoder.text_model.embeddings.forward = embedding_forward.__get__(text_encoder.text_model.embeddings)
91
+ embedding_manager_config = OmegaConf.load("datasets_face/identity_space.yaml")
92
+
93
+ normal_Embedding_Manager = models.embedding_manager.EmbeddingManagerId_adain(
94
+ tokenizer,
95
+ text_encoder,
96
+ device = device,
97
+ training = True,
98
+ experiment_name = "normal_GAN",
99
+ num_embeds_per_token = embedding_manager_config.model.personalization_config.params.num_embeds_per_token,
100
+ token_dim = embedding_manager_config.model.personalization_config.params.token_dim,
101
+ mlp_depth = embedding_manager_config.model.personalization_config.params.mlp_depth,
102
+ loss_type = embedding_manager_config.model.personalization_config.params.loss_type,
103
+ vit_out_dim = input_dim,
104
+ )
105
+
106
+ man_Embedding_Manager = models.embedding_manager.EmbeddingManagerId_adain(
107
+ tokenizer,
108
+ text_encoder,
109
+ device = device,
110
+ training = True,
111
+ experiment_name = "man_GAN",
112
+ num_embeds_per_token = embedding_manager_config.model.personalization_config.params.num_embeds_per_token,
113
+ token_dim = embedding_manager_config.model.personalization_config.params.token_dim,
114
+ mlp_depth = embedding_manager_config.model.personalization_config.params.mlp_depth,
115
+ loss_type = embedding_manager_config.model.personalization_config.params.loss_type,
116
+ vit_out_dim = input_dim,
117
+ )
118
+
119
+ woman_Embedding_Manager = models.embedding_manager.EmbeddingManagerId_adain(
120
+ tokenizer,
121
+ text_encoder,
122
+ device = device,
123
+ training = True,
124
+ experiment_name = "woman_GAN",
125
+ num_embeds_per_token = embedding_manager_config.model.personalization_config.params.num_embeds_per_token,
126
+ token_dim = embedding_manager_config.model.personalization_config.params.token_dim,
127
+ mlp_depth = embedding_manager_config.model.personalization_config.params.mlp_depth,
128
+ loss_type = embedding_manager_config.model.personalization_config.params.loss_type,
129
+ vit_out_dim = input_dim,
130
+ )
131
+
132
 
133
+ DEFAULT_STYLE_NAME = "Watercolor"
134
+ MAX_SEED = np.iinfo(np.int32).max
135
+
136
+ def remove_tips():
137
+ return gr.update(visible=False)
138
+
139
+ def response(choice, gender_GAN):
140
+ c = ""
141
+ e = ""
142
+ if choice == "Create a new character":
143
+ c = "create"
144
+ elif choice == "Still use this character":
145
+ c = "continue"
146
 
147
+ if gender_GAN == "Normal":
148
+ e = "normal_GAN"
149
+ elif gender_GAN == "Man":
150
+ e = "man_GAN"
151
+ elif gender_GAN == "Woman":
152
+ e = "woman_GAN"
153
+
154
+ return c, e
155
+
156
+ def replace_phrases(prompt):
157
+ replacements = {
158
+ "a person": "v1* v2*",
159
+ "a man": "v1* v2*",
160
+ "a woman": "v1* v2*",
161
+ "a boy": "v1* v2*",
162
+ "a girl": "v1* v2*"
163
+ }
164
+ for phrase, replacement in replacements.items():
165
+ prompt = prompt.replace(phrase, replacement)
166
+ return prompt
167
+
168
+
169
+ def handle_prompts(prompts_array):
170
+ prompts = prompts_array.splitlines()
171
+ prompts = [prompt + ', facing to camera, best quality, ultra high res' for prompt in prompts]
172
+ prompts = [replace_phrases(prompt) for prompt in prompts]
173
+ return prompts
174
+
175
+
176
+
177
+ def generate_image(experiment_name, label, prompts_array, chose_emb):
178
+ prompts = handle_prompts(prompts_array)
179
+
180
+ print("experiment_name:",experiment_name)
181
+
182
+ if experiment_name == "normal_GAN":
183
+ steps = 10000
184
+ Embedding_Manager = normal_Embedding_Manager
185
+ elif experiment_name == "man_GAN":
186
+ steps = 7000
187
+ Embedding_Manager = man_Embedding_Manager
188
+ elif experiment_name == "woman_GAN":
189
+ steps = 6000
190
+ Embedding_Manager = woman_Embedding_Manager
191
+ else:
192
+ print("Hello, please notice this ^_^")
193
+ assert 0
194
+
195
+ embedding_path = os.path.join("training_weight", experiment_name, "embeddings_manager-{}.pt".format(str(steps)))
196
+ Embedding_Manager.load(embedding_path)
197
+ print("embedding_path:",embedding_path)
198
+ print("label:",label)
199
+
200
+ index = "0"
201
+ save_dir = os.path.join("test_results/" + experiment_name, index)
202
+ os.makedirs(save_dir, exist_ok=True)
203
+ ran_emb_path = os.path.join(save_dir, "ran_embeddings.pt")
204
+ test_emb_path = os.path.join(save_dir, "id_embeddings.pt")
205
+
206
+ if label == "create":
207
+ print("new")
208
+ random_embedding = torch.randn(1, 1, input_dim).to(device)
209
+ torch.save(random_embedding, ran_emb_path)
210
+ _, emb_dict = Embedding_Manager(tokenized_text=None, embedded_text=None, name_batch=None, random_embeddings = random_embedding, timesteps = None,)
211
+ text_encoder.text_model.embeddings.forward = original_forward
212
+ test_emb = emb_dict["adained_total_embedding"].to(device)
213
+ torch.save(test_emb, test_emb_path)
214
+ elif label == "continue":
215
+ print("old")
216
+ test_emb = torch.load(chose_emb).cuda()
217
+ text_encoder.text_model.embeddings.forward = original_forward
218
+
219
+ v1_emb = test_emb[:, 0]
220
+ v2_emb = test_emb[:, 1]
221
+ embeddings = [v1_emb, v2_emb]
222
+
223
+ tokens = ["v1*", "v2*"]
224
+ tokenizer.add_tokens(tokens)
225
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
226
+
227
+ text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)
228
+ for token_id, embedding in zip(token_ids, embeddings):
229
+ text_encoder.get_input_embeddings().weight.data[token_id] = embedding
230
+
231
+ total_results = []
232
+ for prompt in prompts:
233
+ image = pipe(prompt, guidance_scale = 8.5).images
234
+ total_results = image + total_results
235
+ yield total_results, test_emb_path
236
+
237
+ def get_example():
238
+ case = [
239
+ [
240
+ 'demo_embeddings/example_1.pt',
241
+ "Normal",
242
+ "Still use this character",
243
+ "a photo of a person\na person as a small child\na person as a 20 years old person\na person as a 80 years old person\na person reading a book\na person in the sunset\n",
244
+ ],
245
+ [
246
+ 'demo_embeddings/example_2.pt',
247
+ "Man",
248
+ "Still use this character",
249
+ "a photo of a person\na person with a mustache and a hat\na person wearing headphoneswith red hair\na person with his dog\n",
250
+ ],
251
+ [
252
+ 'demo_embeddings/example_3.pt',
253
+ "Woman",
254
+ "Still use this character",
255
+ "a photo of a person\na person at a beach\na person as a police officer\na person wearing a birthday hat\n",
256
+ ],
257
+ [
258
+ 'demo_embeddings/example_4.pt',
259
+ "Man",
260
+ "Still use this character",
261
+ "a photo of a person\na person holding a bunch of flowers\na person in a lab coat\na person speaking at a podium\n",
262
+ ],
263
+ [
264
+ 'demo_embeddings/example_5.pt',
265
+ "Woman",
266
+ "Still use this character",
267
+ "a photo of a person\na person wearing a kimono\na person in Van Gogh style\nEthereal fantasy concept art of a person\n",
268
+ ],
269
+ [
270
+ 'demo_embeddings/example_6.pt',
271
+ "Man",
272
+ "Still use this character",
273
+ "a photo of a person\na person in the rain\na person meditating\na pencil sketch of a person\n",
274
+ ],
275
+ ]
276
+ return case
277
+
278
+ def run_for_examples(example_emb, gender_GAN, choice, prompts_array):
279
+ prompts = handle_prompts(prompts_array)
280
+ label, experiment_name = response(choice, gender_GAN)
281
+ if experiment_name == "normal_GAN":
282
+ steps = 10000
283
+ Embedding_Manager = normal_Embedding_Manager
284
+ elif experiment_name == "man_GAN":
285
+ steps = 7000
286
+ Embedding_Manager = man_Embedding_Manager
287
+ elif experiment_name == "woman_GAN":
288
+ steps = 6000
289
+ Embedding_Manager = woman_Embedding_Manager
290
+ else:
291
+ print("Hello, please notice this ^_^")
292
+ assert 0
293
 
294
+ embedding_path = os.path.join("training_weight", experiment_name, "embeddings_manager-{}.pt".format(str(steps)))
295
+ Embedding_Manager.load(embedding_path)
296
+ print("embedding_path:",embedding_path)
297
+ print("label:",label)
298
+
299
+ test_emb = torch.load(example_emb).cuda()
300
+ text_encoder.text_model.embeddings.forward = original_forward
301
+ v1_emb = test_emb[:, 0]
302
+ v2_emb = test_emb[:, 1]
303
+ embeddings = [v1_emb, v2_emb]
304
+
305
+ tokens = ["v1*", "v2*"]
306
+ tokenizer.add_tokens(tokens)
307
+ token_ids = tokenizer.convert_tokens_to_ids(tokens)
308
+
309
+ text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)
310
+ for token_id, embedding in zip(token_ids, embeddings):
311
+ text_encoder.get_input_embeddings().weight.data[token_id] = embedding
312
+
313
+ total_results = []
314
+ i = 0
315
+ for prompt in prompts:
316
+ image = pipe(prompt, guidance_scale = 8.5).images
317
+ total_results = image + total_results
318
+ i+=1
319
+ if i < len(prompts):
320
+ yield total_results, gr.update(visible=True, value="<h3>(Not Finished) Generating ···</h3>")
321
+ else:
322
+ yield total_results, gr.update(visible=True, value="<h3>Generation Finished</h3>")
323
 
324
+
325
+ def set_text_unfinished():
326
+ return gr.update(visible=True, value="<h3>(Not Finished) Generating ···</h3>")
327
+
328
+ def set_text_finished():
329
+ return gr.update(visible=True, value="<h3>Generation Finished</h3>")
330
+
331
+
332
+
333
+
334
+ with gr.Blocks(css=css) as demo: # css=css
335
+ # binary_matrixes = gr.State([])
336
+ # color_layout = gr.State([])
337
+
338
+ # gr.Markdown(logo)
339
+ gr.Markdown(title)
340
+ gr.Markdown(description)
341
+
342
+ with gr.Row():
343
+ with gr.Column():
344
+ prompts_array = gr.Textbox(lines = 3,
345
+ label="Prompts (each line corresponds to a frame).",
346
+ info="Give simple prompt is enough to achieve good face fidelity",
347
+ # placeholder="A photo of a person",
348
+ value="a photo of a person\na person in front of the Great Wall\na person reading a book\na person wearing a Christmas hat\n",
349
+ interactive=True)
350
+ choice = gr.Radio(choices=["Create a new character", "Still use this character"], label="Choose your action")
351
+
352
+ gender_GAN = gr.Radio(choices=["Normal", "Man", "Woman"], label="Choose your model version")
353
 
354
+ label = gr.Text(label="Select the action you want to take", visible=False)
355
+ experiment_name = gr.Text(label="Select the GAN you want to take", visible=False)
356
+ chose_emb = gr.File(label="Uploaded files", type="filepath", visible=False)
357
+ example_emb = gr.File(label="Uploaded files", type="filepath", visible=False)
358
 
359
+ generate = gr.Button("Generate!😊", variant="primary")
360
+
361
+ with gr.Column():
362
+ gallery = gr.Gallery(label="Generated Images", columns=2, height='auto')
363
+ generated_information = gr.Markdown(label="Generation Details", value="",visible=False)
 
 
 
 
 
 
 
 
 
 
 
 
364
 
365
+ generate.click(
366
+ fn=set_text_unfinished,
367
+ outputs=generated_information
368
+ ).then(
369
+ fn=response,
370
+ inputs=[choice, gender_GAN],
371
+ outputs=[label, experiment_name],
372
+ ).then(
373
+ fn=generate_image,
374
+ inputs=[experiment_name, label, prompts_array, chose_emb],
375
+ outputs=[gallery, chose_emb]
376
+ ).then(
377
+ fn=set_text_finished,
378
+ outputs=generated_information
 
 
 
 
 
 
 
379
  )
380
+
381
 
382
+ gr.Examples(
383
+ examples=get_example(),
384
+ inputs=[example_emb, gender_GAN, choice, prompts_array],
385
+ run_on_click=True,
386
+ fn=run_for_examples,
387
+ outputs=[gallery, generated_information],
388
  )
389
+
390
+ gr.Markdown(article)
391
+ # demo.launch(server_name="0.0.0.0", share = False)
392
+ # share_link = demo.launch(share=True)
393
+ # print("Share this link: ", share_link)
394
 
395
+ demo.launch() # share=True
datasets_face/face_id.py ADDED
@@ -0,0 +1,1007 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import Dataset
4
+ from torchvision.transforms import transforms
5
+ import torch.nn.functional as F
6
+ import random
7
+
8
+ imagenet_templates_small = [
9
+ 'a photo of {}',
10
+ '{} is sitting at a desk, writing in a notebook',
11
+ '{} is standing in a kitchen, cooking a meal',
12
+ 'In a garden, there are flowers and trees. {} is walking around',
13
+ '{} is playing a piano in a music room',
14
+ 'In a garden, {} is watering plants',
15
+ 'In a gym, there are many machines. {} is lifting weights',
16
+ '{} is reading a book in a cozy armchair',
17
+ '{} is painting on a canvas in a studio',
18
+ 'In a museum, there are many paintings. {} is admiring them',
19
+ '{} is jogging on a trail in the woods',
20
+ 'In an office, {} is working on a computer',
21
+ '{} is playing with a dog in a backyard',
22
+ '{} is taking a photograph in a city street',
23
+ 'In a concert, there are many people. {} is enjoying the music',
24
+ '{} is playing chess in a park',
25
+ 'In a library, {} is browsing through books',
26
+ '{} is riding a bicycle on a city street',
27
+ '{} is watching a movie in a living room',
28
+ 'In a café, {} is drinking coffee and using a laptop',
29
+ '{} is hiking in the mountains',
30
+ '{} is playing a violin in a concert hall',
31
+ 'In a gym, a {} is lifting weights',
32
+ '{} is gardening in his backyard',
33
+ '{} is swimming in a pool',
34
+ '{} is shopping in a grocery stor',
35
+ 'In a museum, {} is admiring a painting',
36
+ 'In a studio, {} is recording a podcast',
37
+ '{} is doing yoga in a peaceful room',
38
+ '{} is cooking barbecue in a park',
39
+ 'In a laboratory, {} is conducting an experiment',
40
+ 'In an airport, {} is waiting for a flight',
41
+ '{} is sitting on a bench, smiling at the camera',
42
+ '{} is holding a book in his hands',
43
+ 'In a room, {}n is sitting on a chair',
44
+ '{} is standing in front of a window',
45
+ 'In a kitchen, {} is cooking food',
46
+ 'In a living room, {} is watching TV',
47
+ 'In a bedroom, {} is sleeping',
48
+ '{} is holding a cup of coffee',
49
+ 'In a classroom, {} is writing on a whiteboard',
50
+ 'In a gym, {} is lifting weights',
51
+ '{} is holding a microphone',
52
+ 'In a restaurant, {} is eating food',
53
+ '{} is holding a pen and writing on a paper',
54
+ 'In a store, {} is shopping for clothes',
55
+ 'In a museum, {} is looking at an exhibit',
56
+ '{} is holding a camera and taking a photo',
57
+ '{} is holding a baby in his arms',
58
+ 'In a laboratory, {} is conducting an experiment',
59
+ '{} is holding a guitar',
60
+ 'In a swimming pool, {} is swimming',
61
+ 'In a cafe, {} is drinking tea',
62
+ 'In a garden, {} is watering plants',
63
+ '{} is sitting on a bench in a park',
64
+ 'In a classroom, {} is writing on a whiteboard',
65
+ '{} is holding a pen and writing on a paper',
66
+ '{} is standing in front of a building',
67
+ 'In a museum, {} is looking at an exhibit',
68
+ 'In a theater, {} is watching a movie',
69
+ '{} is standing in front of a car',
70
+ '{} is standing in front of a tree',
71
+ 'In a meeting room, {} is giving a presentation',
72
+ 'In a stadium, {} is watching a game',
73
+ 'In a garage, {} is fixing a car',
74
+ '{} is holding a paintbrush and painting a picture',
75
+ 'In a classroom, {} is listening to a lecture',
76
+ '{} is standing in front of a mountain',
77
+ 'In a park, {} is flying a kite',
78
+ 'a rendering of a {}',
79
+ 'a cropped photo of the {}',
80
+ 'the photo of a {}',
81
+ 'a photo of a clean {}',
82
+ 'a photo of a dirty {}',
83
+ 'a dark photo of the {}',
84
+ 'a photo of my {}',
85
+ 'a photo of the cool {}',
86
+ 'a close-up photo of a {}',
87
+ 'a bright photo of the {}',
88
+ 'a cropped photo of a {}',
89
+ 'a photo of the {}',
90
+ 'a good photo of the {}',
91
+ 'a photo of one {}',
92
+ 'a close-up photo of the {}',
93
+ 'a rendition of the {}',
94
+ 'a photo of the clean {}',
95
+ 'a rendition of a {}',
96
+ 'a photo of a nice {}',
97
+ 'a good photo of a {}',
98
+ 'a photo of the nice {}',
99
+ 'a photo of the small {}',
100
+ 'a photo of the weird {}',
101
+ 'a photo of the large {}',
102
+ 'a photo of a cool {}',
103
+ 'a photo of a small {}',
104
+ 'an illustration of a {}',
105
+ 'a rendering of a {}',
106
+ 'a cropped photo of the {}',
107
+ 'the photo of a {}',
108
+ 'an illustration of a clean {}',
109
+ 'an illustration of a dirty {}',
110
+ 'a dark photo of the {}',
111
+ 'an illustration of my {}',
112
+ 'an illustration of the cool {}',
113
+ 'a close-up photo of a {}',
114
+ 'a bright photo of the {}',
115
+ 'a cropped photo of a {}',
116
+ 'an illustration of the {}',
117
+ 'a good photo of the {}',
118
+ 'an illustration of one {}',
119
+ 'a close-up photo of the {}',
120
+ 'a rendition of the {}',
121
+ 'an illustration of the clean {}',
122
+ 'a rendition of a {}',
123
+ 'an illustration of a nice {}',
124
+ 'a good photo of a {}',
125
+ 'an illustration of the nice {}',
126
+ 'an illustration of the small {}',
127
+ 'an illustration of the weird {}',
128
+ 'an illustration of the large {}',
129
+ 'an illustration of a cool {}',
130
+ 'an illustration of a small {}',
131
+ 'a depiction of a {}',
132
+ 'a rendering of a {}',
133
+ 'a cropped photo of the {}',
134
+ 'the photo of a {}',
135
+ 'a depiction of a clean {}',
136
+ 'a depiction of a dirty {}',
137
+ 'a dark photo of the {}',
138
+ 'a depiction of my {}',
139
+ 'a depiction of the cool {}',
140
+ 'a close-up photo of a {}',
141
+ 'a bright photo of the {}',
142
+ 'a cropped photo of a {}',
143
+ 'a depiction of the {}',
144
+ 'a good photo of the {}',
145
+ 'a depiction of one {}',
146
+ 'a close-up photo of the {}',
147
+ 'a rendition of the {}',
148
+ 'a depiction of the clean {}',
149
+ 'a rendition of a {}',
150
+ 'a depiction of a nice {}',
151
+ 'a good photo of a {}',
152
+ 'a depiction of the nice {}',
153
+ 'a depiction of the small {}',
154
+ 'a depiction of the weird {}',
155
+ 'a depiction of the large {}',
156
+ 'a depiction of a cool {}',
157
+ 'a depiction of a small {}',
158
+ '{} reads a newspaper in a cozy coffee shop.',
159
+ '{} jogs along a winding trail at sunrise.',
160
+ '{} takes a photograph of the bustling cityscape.',
161
+ '{} bakes cookies in a warm, inviting kitchen.',
162
+ '{} paints a mural on a large outdoor wall.',
163
+ '{} plants a tree in a sunny backyard.',
164
+ '{} repairs an old bicycle in the garage.',
165
+ '{} sketches a portrait with charcoal.',
166
+ '{} dances freely at a lively festival.',
167
+ '{} sows seeds in a flourishing vegetable garden.',
168
+ '{} plays a violin in a quiet room.',
169
+ '{} writes a poem under the shade of an oak tree.',
170
+ '{} feeds ducks at a peaceful lake.',
171
+ '{} practices yoga on a tranquil beach at dawn.',
172
+ '{} repairs a watch with skilled hands.',
173
+ '{} constructs a model airplane with precision.',
174
+ '{} decorates a cake with elaborate icing designs.',
175
+ '{} climbs a rock wall with determination.',
176
+ '{} meditates in a serene temple garden.',
177
+ '{} knits a colorful scarf by the fireside.',
178
+ '{} assembles a puzzle on a rainy afternoon.',
179
+ '{} examines artifacts at a history museum.',
180
+ '{} tends to a beehive in protective gear.',
181
+ '{} composes a new song on a keyboard.',
182
+ '{} stretches before starting a marathon.',
183
+ '{} recites lines for an upcoming play.',
184
+ '{} harvests apples in an orchard.',
185
+ '{} leads a tour group through ancient ruins.',
186
+ '{} creates a scrapbook filled with memories.',
187
+ '{} tutors a student in mathematics.',
188
+ '{} tries a new recipe from a gourmet cookbook.',
189
+ '{} rides a horse through an open field.',
190
+ '{} collects samples on a nature walk.',
191
+ '{} solves a complex mathematical equation.',
192
+ '{} fills the room with the sound of saxophone music.',
193
+ '{} arranges flowers in a beautiful bouquet.',
194
+ '{} hosts a podcast interview.',
195
+ '{} dives into a crystal-clear swimming pool.',
196
+ '{} studies a map before an expedition.',
197
+ '{} makes ceramic pottery on a spinning wheel.',
198
+ '{} trains a puppy to sit and stay.',
199
+ '{} rehearses for a ballet performance.',
200
+ '{} sails a boat across a calm lake.',
201
+ '{} browses through a second-hand bookstore.',
202
+ '{} explores a cave with a flashlight.',
203
+ '{} restores an old car in their home workshop.',
204
+ '{} conducts an orchestra with passion.',
205
+ '{} volunteers at a community food bank.',
206
+ '{} compiles a report in their office.',
207
+ '{} designs a website on their computer.',
208
+ '{} teaches a child to ride a bike.',
209
+ '{} performs a magic trick at a party.',
210
+ '{} packs a suitcase for a journey.',
211
+ '{} prunes roses in a sunlit garden.',
212
+ '{} crafts handmade jewelry from silver and gems.',
213
+ '{} inspects products for quality in a factory.',
214
+ '{} sculpts a figure from a block of marble.',
215
+ '{} organizes a community cleanup day.',
216
+ '{} swings a golf club on a green fairway.',
217
+ '{} develops photos in a darkroom.',
218
+ '{} directs a small indie film.',
219
+ '{} carves a wooden figure with intricate detail.',
220
+ '{} birdwatches with binoculars in the forest.',
221
+ '{} pilots a hot air balloon at dawn.',
222
+ '{} tutors peers in a university library.',
223
+ '{} rides a skateboard down a city street.',
224
+ '{} decorates a storefront for the holidays.',
225
+ '{} mixes cocktails at a busy bar.',
226
+ '{} cuts hair in a stylish salon.',
227
+ '{} researches genealogy to fill out a family tree.',
228
+ '{} writes calligraphy with elegant strokes.',
229
+ '{} edits a manuscript for publication.',
230
+ '{} lectures on environmental science.',
231
+ '{} designs a new board game.',
232
+ '{} hosts a charity auction.',
233
+ '{} laces up skates for an ice-skating lesson.',
234
+ '{} coordinates a wedding at a picturesque venue.',
235
+ '{} builds a sandcastle on the beach.',
236
+ '{} programs a robot for a competition.',
237
+ '{} captures wildlife photography in the jungle.',
238
+ '{} sets up a tent under the stars.',
239
+ '{} debuts a fashion collection on the runway.',
240
+ '{} curates pieces for an art exhibition.',
241
+ '{} conducts a science experiment in the laboratory.',
242
+ '{} provides a walking tour of a historic city.',
243
+ '{} partakes in a coffee cupping session.',
244
+ '{} negotiates a deal in a boardroom.',
245
+ '{} operates a forklift in a warehouse.',
246
+ '{} leads a yoga retreat in a mountain setting.',
247
+ '{} analyzes data on multiple computer screens.',
248
+ '{} paints a picket fence on a sunny day.',
249
+ '{} trains for gymnastics at the gym.',
250
+ '{} teaches a pottery class, guiding students.',
251
+ '{} cares for animals at a wildlife sanctuary.',
252
+ '{} crafts origami creations from colorful paper.',
253
+ '{} deejays a lively dance party.',
254
+ '{} writes code for a new software application.',
255
+ '{} grows an array of herbs in a window garden.',
256
+ '{} instructs a spin class with high energy.',
257
+ '{} navigates rapids in a whitewater raft.',
258
+ 'Quietly, {} sets the table for dinner.',
259
+ 'Suddenly, {} stops to pick up a fallen object.',
260
+ 'Calmly, {} navigates through the crowd.',
261
+ 'Gently, {} soothes a crying child.',
262
+ 'Quickly, {} dashes out in the rain.',
263
+ 'Joyfully, {} embraces a long-lost friend.',
264
+ 'Firmly, {} stands their ground in debate.',
265
+ 'Loudly, {} cheers on their favorite team.',
266
+ 'Patiently, {} waits for their turn.',
267
+ 'Anxiously, {} fidgets during an interview.',
268
+ 'Easily, {} solves a complex puzzle.',
269
+ 'Sadly, {} waves farewell at the departure gates.',
270
+ 'Meticulously, {} organizes their collection.',
271
+ 'Slyly, {} sneaks a cookie from the jar.',
272
+ 'Defiantly, {} marches for change.',
273
+ 'Warmly, {} greets newcomers.',
274
+ 'Hungrily, {} eyes the banquet table.',
275
+ 'Enthusiastically, {} starts their first day of work.',
276
+ 'Stealthily, {} moves in the game of hide and seek.',
277
+ 'Expertly, {} navigates the rapid waters.',
278
+ 'Seamlessly, {} transitions between tasks.',
279
+ 'Vigorously, {} cleans the cluttered garage.',
280
+ 'Devotedly, {} tends to their garden.',
281
+ 'Silently, {} admires the sunrise.',
282
+ 'Righteously, {} advocates for justice.',
283
+ 'Keenly, {} observes the wildlife.',
284
+ 'Desperately, {} searches for their lost item.',
285
+ 'Reverently, {} visits a historic monument.',
286
+ 'Wistfully, {} looks back on fond memories.',
287
+ 'Ambitiously, {} sets their career goals.',
288
+ 'Rapidly, {} types up an urgent report.',
289
+ 'Generously, {} shares their lunch.',
290
+ 'Skillfully, {} crafts a beautiful piece of pottery.',
291
+ 'Cautiously, {} approaches the unfamiliar dog.',
292
+ 'Inquisitively, {} examines the ancient artifact.',
293
+ 'Effortlessly, {} completes the challenging workout.',
294
+ 'Frantically, {} looks for the exit.',
295
+ 'Discreetly, {} passes a note during class.',
296
+ 'Pensively, {} contemplates their next move.',
297
+ 'Optimistically, {} plans for the future.',
298
+ 'Sorrowfully, {} attends a memorial service.',
299
+ 'Methodically, {} assembles the model airplane.',
300
+ 'Lazily, {} lounges on the hammock.',
301
+ 'Unsuccessfully, {} tries to hail a taxi.',
302
+ 'Faithfully, {} follows the recipe.',
303
+ 'Dramatically, {} reacts to the plot twist.',
304
+ 'Adventurously, {} explores the dense forest.',
305
+ 'Gracefully, {} accepts the award.',
306
+ 'Theatrically, {} recites lines on stage.',
307
+ 'Ardently, {} defends their thesis.',
308
+ 'Abstractedly, {} doodles during the meeting.',
309
+ 'Vivaciously, {} engages in the lively party.',
310
+ 'Stoically, {} endures the challenging ordeal.',
311
+ 'Resolutely, {} decides to change their life.',
312
+ 'Triumphantly, {} crosses the finish line.',
313
+ 'Suspiciously, {} glances over their shoulder.',
314
+ 'Fervently, {} prays for good news.',
315
+ 'Ruefully, {} acknowledges their mistake.',
316
+ 'Industriously, {} works on the project till dusk.',
317
+ 'Compassionately, {} comforts a stranger.',
318
+ 'Sheepishly, {} admits they forgot the appointment.',
319
+ 'Irately, {} disputes the incorrect charge.',
320
+ 'Protectively, {} shields the puppy from the rain.',
321
+ 'Serenely, {} meditates in the morning light.',
322
+ 'Comically, {} slips on the banana peel.',
323
+ 'Impressively, {} juggles multiple objects with ease.',
324
+ 'Apprehensively, {} approaches the spooky house.',
325
+ 'Unwaveringly, {} supports their friend.',
326
+ 'Blissfully, {} soaks in the hot spring.',
327
+ 'Compulsively, {} checks their notifications.',
328
+ 'Tactfully, {} navigates the awkward situation.',
329
+ 'Convincingly, {} sells their innovative idea.',
330
+ 'Dutifully, {} fulfills their obligations.',
331
+ 'Ingeniously, {} solves the critical problem.',
332
+ 'Haphazardly, {} packs their suitcase.',
333
+ 'Deftly, {} maneuvers the playing pieces.',
334
+ 'Intriguedly, {} listens to the mysterious tale.',
335
+ 'Ceremoniously, {} unveils the new sculpture.',
336
+ 'Sterily, {} organizes the lab equipment.',
337
+ 'Unintentionally, {} overhears a private conversation.',
338
+ 'Forever, {} holds dear the cherished memories.',
339
+ 'Nostalgically, {} revisits their old neighborhood.',
340
+ 'Predictably, {} always laughs at the same joke.',
341
+ 'Politely, {} inquires about the meeting agenda.',
342
+ 'Securely, {} fastens their seatbelt before takeoff.',
343
+ 'Casually, {} strolls through the park.',
344
+ 'Spontaneously, {} decides to go on a road trip.',
345
+ 'Clearly, {} expresses their feelings.',
346
+ 'Merrily, {} decorates for the festive season.',
347
+ 'Valiantly, {} stands up against injustice.',
348
+ 'Diligently, {} studies for the upcoming exam.',
349
+ 'Nonchalantly, {} brushes off the slight mishap.',
350
+ 'Intensely, {} focuses on the target.',
351
+ 'Subtly, {} hints at the surprise party.',
352
+ 'Mysteriously, {} vanishes into the foggy night.',
353
+ 'Decisively, {} makes their final choice.',
354
+ 'Lovingly, {} prepares a home-cooked meal.',
355
+ 'Immaculately, {} arranges the storefront display.',
356
+ 'Vibrantly, {} adds color to the canvas.',
357
+ 'The silhouette of {} casts a long shadow.',
358
+ 'Through the fog, {} emerges slowly.',
359
+ 'Over the hill, {} rides a bicycle.',
360
+ 'After the storm, {} surveys the damage.',
361
+ 'Around the bend, {} sails a boat.',
362
+ 'Under the tree, {} reads a book.',
363
+ 'Beside the fire, {} warms their hands.',
364
+ 'Below the surface, {} discovers coral reefs.',
365
+ 'Beyond the fence, {} tends to horses.',
366
+ 'Above the crowd, {} waves a flag.',
367
+ 'Among the flowers, {} finds peace.',
368
+ 'Across the field, {} flies a kite.',
369
+ 'Near the water, {} sketches the view.',
370
+ 'By the road, {} waits patiently.',
371
+ 'With careful precision, {} repairs a clock.',
372
+ 'In the spotlight, {} performs a solo.',
373
+ 'To the beat, {} dances joyfully.',
374
+ 'On the stage, {} delivers a monologue.',
375
+ 'Underneath the stars, {} makes a wish.',
376
+ 'Beside the window, {} sips morning tea.',
377
+ 'At the corner, {} hails a cab.',
378
+ 'Against the odds, {} triumphs victoriously.',
379
+ 'Beneath the waves, {} finds tranquility.',
380
+ 'Before the race, {} stretches carefully.',
381
+ 'Through the lens, {} captures the moment.',
382
+ 'From the bridge, {} observes the river.',
383
+ 'Since the dawn, {} has been fishing.',
384
+ 'Like a statue, {} stands immovable.',
385
+ 'Inside the house, {} feels safe.',
386
+ 'The smile of {} brightens the room.',
387
+ 'Upon the mountaintop, {} feels awe.',
388
+ 'Without a doubt, {} commits to the goal.',
389
+ 'Reflecting on life, {} contemplates deeply.',
390
+ 'Amidst the chaos, {} remains calm.',
391
+ 'Throughout the day, {} maintains focus.',
392
+ 'During the performance, {} takes the stage.',
393
+ 'Considering all options, {} makes a choice.',
394
+ 'Alongside the path, {} picks wildflowers.',
395
+ 'Toward the horizon, {} gazes expectantly.',
396
+ 'Wrapped in thought, {} ponders life’s mysteries.',
397
+ 'Accompanied by music, {} feels uplifted.',
398
+ 'Surrounded by books, {} indulges in knowledge.',
399
+ 'Guided by intuition, {} chooses a path.',
400
+ 'Entertaining guests, {} tells a tale.',
401
+ 'Admiring the artwork, {} gains inspiration.',
402
+ 'Standing at the crossroads, {} hesitates slightly.',
403
+ 'Lost in music, {} enjoys the concert.',
404
+ 'Besieged by deadlines, {} works diligently.',
405
+ 'Empowered by support, {} achieves greatness.',
406
+ 'Gazing into space, {} dreams of stars.',
407
+ 'Facing the challenge, {} exudes confidence.',
408
+ 'Approaching the podium, {} clears their throat.',
409
+ 'Enclosed in glass, {} admires the terrarium.',
410
+ 'The reflection of {} shimmers on water.',
411
+ 'Clutching the ticket, {} rushes to the gate.',
412
+ 'Heeding the warning, {} takes precaution.',
413
+ 'Observing the traditions, {} learns respect.',
414
+ 'At the museum, {} admires ancient artifacts.',
415
+ 'Following the recipe, {} bakes a cake.',
416
+ 'Adjusting the telescope, {} explores the heavens.',
417
+ 'In the garden, {} relaxes with nature.',
418
+ 'Clinging to hope, {} perseveres through trials.',
419
+ 'The laughter of {} fills the room.',
420
+ 'During the lecture, {} takes diligent notes.',
421
+ 'Sitting by the piano, {} composes a melody.',
422
+ 'The hands of {} shape the clay.',
423
+ 'The courage of {} inspires many others.',
424
+ 'Laid on the canvas, {} begins to paint.',
425
+ 'Carried by wind, {}’s kite ascends higher.',
426
+ 'In the workshop, {} builds a dream.',
427
+ 'Mingled with others, {} shares a story.',
428
+ 'Learning the ropes, {} adapts quickly.',
429
+ 'Fuelled by passion, {} pursues their dreams.',
430
+ 'In the office, {} meets a deadline.',
431
+ 'With each stride, {} closes the distance.',
432
+ 'Mastering the craft, {} excels in their art.',
433
+ 'The vision of {} leads to success.',
434
+ 'Striving for wellness, {} embraces a change.',
435
+ 'Buffeted by wind, {} adjusts their hat.',
436
+ 'Engulfed in aroma, {} enjoys the spices.',
437
+ 'Surrounded by laughter, {} feels joy.',
438
+ 'Avoiding the puddle, {} steps carefully.',
439
+ 'Reacting quickly, {} catches the falling vase.',
440
+ 'Marked by time, {}’s diary tells tales.',
441
+ 'Supported by friends, {} overcomes fear.',
442
+ 'Puzzled by clues, {} solves the riddle.',
443
+ 'Driving through night, {} reaches their destination.',
444
+ 'Splashed by waves, {} laughs heartily.',
445
+ 'Confronted with choices, {} deliberates wisely.',
446
+ 'Hidden by shadows, {} watches the scene.',
447
+ 'Inspired by nature, {} writes poetry.',
448
+ 'Guarded by mystery, {}’s past intrigues.',
449
+ 'Detouring the path, {} discovers new sights.',
450
+ 'Greeted by dawn, {} feels renewed.',
451
+ 'Warmed by sunlight, {} enjoys the afternoon.',
452
+ 'Answering the call, {} takes action.',
453
+ 'Sheltered by canopy, {} escapes the rain.',
454
+ 'Bound by duty, {} fulfills their role.',
455
+ 'Pulled by curiosity, {} enters the store.',
456
+ 'Motivated by change, {} advocates for causes.',
457
+ 'In silence, {} stares into space.',
458
+ 'Lost in thought, {} stands still.',
459
+ 'With excitement, {} opens a gift.',
460
+ 'Amid laughter, {} shares a joke.',
461
+ 'Surrounded by nature, {} takes a deep breath.',
462
+ 'Under the sun, {} stretches out.',
463
+ 'Against a backdrop of mountains, {} gazes afar.',
464
+ 'Among friends, {} enjoys a conversation.',
465
+ 'Before dinner, {} sets the table.',
466
+ 'Behind the counter, {} makes coffee.',
467
+ 'Below the surface, {} snorkels in clear water.',
468
+ 'Beneath the stars, {} lights a campfire.',
469
+ 'Beside a bicycle, {} takes a break.',
470
+ 'By the seaside, {} collects seashells.',
471
+ 'Near the horizon, {} sketches the view.',
472
+ 'On the bridge, {} watches the water flow.',
473
+ 'Through the window, {} waves goodbye.',
474
+ 'To the music, {} taps their feet.',
475
+ 'With a book, {} finds escape.',
476
+ 'Without a care, {} listens to music.',
477
+ 'Around the table, {} shares a story.',
478
+ 'Outside the house, {} does some gardening.',
479
+ 'From the stage, {} delivers a speech.',
480
+ 'After the rain, {} jumps in puddles.',
481
+ 'During the party, {} blows up balloons.',
482
+ 'Following the path, {} takes a stroll.',
483
+ 'Along the river, {} is fishing.',
484
+ 'Inside the room, {} practices yoga.',
485
+ 'Throughout the day, {} takes photos.',
486
+ 'Across the field, {} flies a kite.',
487
+ 'Between the lines, {} reads quietly.',
488
+ 'Behind the lens, {} captures the moment.',
489
+ 'Along the alley, {} walks their dog.',
490
+ 'Before the sunrise, {} enjoys the calm.',
491
+ 'Over the fence, {} talks to a neighbor.',
492
+ 'Under the tree, {} has a picnic.',
493
+ 'Beyond the gate, {} starts their journey.',
494
+ 'Around the fire, {} tells ghost stories.',
495
+ 'Above the clouds, {} skydives.',
496
+ 'Among the crowd, {} cheers loudly.',
497
+ 'Near the pond, {} feeds the ducks.',
498
+ 'On the couch, {} takes a nap.',
499
+ 'Before the show, {} checks their ticket.',
500
+ 'Under the sky, {} flies a drone.',
501
+ 'Behind the wheel, {} sings loudly.',
502
+ 'Above the waves, {} surfs with skill.',
503
+ 'Within the walls, {} paints their dream.',
504
+ 'Beyond the road, {} hikes up the hill.',
505
+ 'Beneath the quilt, {} reads at night.',
506
+ 'Against the odds, {} tries a new trick.',
507
+ 'During the trip, {} savors local cuisine.',
508
+ 'Amid the shelves, {} finds an old book.',
509
+ 'Across the room, {} waves to a friend.',
510
+ 'By the pool, {} basks in the sun.',
511
+ 'Beneath the lights, {} takes center stage.',
512
+ 'Above the city, {} marvels at the view.',
513
+ 'Behind the scenes, {} prepares diligently.',
514
+ 'Over the moon, {} celebrates good news.',
515
+ 'Under the arch, {} takes memorable photos.',
516
+ 'Before the dawn, {} prepares for the day.',
517
+ 'Throughout the match, {} cheers enthusiastically.',
518
+ 'Between workouts, {} hydrates and rests.',
519
+ 'Around the campfire, {} roasts marshmallows.',
520
+ 'By the window, {} enjoys the morning light.',
521
+ 'After the lecture, {} asks thoughtful questions.',
522
+ 'Within the garden, {} admires the flowers.',
523
+ 'Beneath the blanket, {} watches a movie.',
524
+ 'Beyond the wall, {} hears echoes of laughter.',
525
+ 'Behind the book, {} hides a surprise gift.',
526
+ 'Under the bridge, {} sketches the river scene.',
527
+ 'During the concert, {} loses themselves in the music.',
528
+ 'On the terrace, {} sips on iced tea.',
529
+ 'Before the alarm, {} wakes up naturally.',
530
+ 'Above the rooftops, {} spots a passing balloon.',
531
+ 'Across the street, {} helps an elderly neighbor.',
532
+ 'Beside the lamp, {} finishes their novel.',
533
+ 'With the crowd, {} dances to the festival music.',
534
+ 'By the lakeside, {} sets up a fishing rod.',
535
+ 'Before the exercise, {} stretches thoroughly.',
536
+ 'Near the finish line, {} sprints with determination.',
537
+ 'On the balcony, {} tends to potted plants.',
538
+ 'After the storm, {} clears the fallen branches.',
539
+ 'Under the covers, {} snoozes the alarm clock.',
540
+ 'Between the curtains, {} peeks at the sunrise.',
541
+ 'Around the corner, {} discovers a quaint café.',
542
+ 'By the artwork, {} contemplates the message of the painter.',
543
+ 'After the game, {} congratulates the players.',
544
+ 'Within the studio, {} edits a documentary film.',
545
+ 'Beneath the hat, {} grins at a private joke.',
546
+ 'Beyond the dunes, {} takes in the beach view.',
547
+ 'Behind the microphone, {} records a podcast.',
548
+ 'Under the eaves, {} shelters from the rain.',
549
+ 'During the hike, {} spots a rare bird.',
550
+ 'On the platform, {} awaits the next train.',
551
+ 'Before the meal, {} gives thanks.',
552
+ 'Above the fray, {} keeps a level head.',
553
+ 'Across the canvas, {} strokes colors with a brush.',
554
+ 'Beside the hearth, {} warms their hands.',
555
+ 'With affection, {} pets their sleepy cat.',
556
+ 'By the harbor, {} watches boats come and go.',
557
+ 'In a room, {} reads quietly.',
558
+ 'Near the shore, {} fishes calmly.',
559
+ 'Behind the counter, {} smiles warmly.',
560
+ 'Among the trees, {} jogs daily.',
561
+ 'On the bench, {} sits silently.',
562
+ 'With a pen, {} writes diligently.',
563
+ 'At dawn, {} stretches readily.',
564
+ 'Under the stars, {} dreams peacefully.',
565
+ 'With the dog, {} walks leisurely.',
566
+ 'Against the backdrop, {} stands proudly.',
567
+ 'On stage, {} speaks clearly.',
568
+ 'In the garden, {} works happily.',
569
+ 'At the table, {} eats slowly.',
570
+ 'Beside the window, {} gazes thoughtfully.',
571
+ 'Within the crowd, {} laughs loudly.',
572
+ 'By the painting, {} ponders intently.',
573
+ 'On the bridge, {} pauses reflectively.',
574
+ 'Under the umbrella, {} waits patiently.',
575
+ 'Before the game, {} practices routinely.',
576
+ 'Behind the lens, {} captures moments.',
577
+ 'In the cafe, {} sips coffee.',
578
+ 'With a map, {} explores curiously.',
579
+ 'On the couch, {} naps briefly.',
580
+ 'At the wheel, {} drives safely.',
581
+ 'Beside the fire, {} warms up.',
582
+ 'During the concert, {} claps excitedly.',
583
+ 'By the bookshelf, {} selects a novel.',
584
+ 'On the path, {} bikes steadily.',
585
+ 'Under the quilt, {} snoozes comfortably.',
586
+ 'Before the screen, {} types consistently.',
587
+ 'Within the room, {} dances joyfully.',
588
+ 'At the market, {} shops carefully.',
589
+ 'Beside the pool, {} sunbathes lazily.',
590
+ 'On the road, {} hitches northward.',
591
+ 'Against the clock, {} races swiftly.',
592
+ 'By the door, {} knocks promptly.',
593
+ 'In the silence, {} meditates profoundly.',
594
+ 'With a brush, {} paints a canvas.',
595
+ 'On a horse, {} rides boldly.',
596
+ 'At the concert, {} listens attentively.',
597
+ 'Beside the lamp, {} reads a letter.',
598
+ 'On the field, {} throws a ball.',
599
+ 'Under the sun, {} basks leisurely.',
600
+ 'Before the microphone, {} sings softly.',
601
+ 'Within the frame, {} looks stern.',
602
+ 'In the studio, {} records a podcast.',
603
+ 'By the seaside, {} collects shells.',
604
+ 'On the mattress, {} lies awake.',
605
+ 'Behind the bar, {} mixes drinks.',
606
+ 'During the meeting, {} takes notes.',
607
+ 'At the podium, {} delivers a speech.',
608
+ 'Beside the pond, {} feeds ducks.',
609
+ 'On the swing, {} rocks gently.',
610
+ 'Under the sky, {} dreams freely.',
611
+ 'Before the class, {} sets up.',
612
+ 'Within the pages, {} finds adventure.',
613
+ 'At the corner, {} waves hello.',
614
+ 'By the stove, {} cooks breakfast.',
615
+ 'On the terrace, {} breathes deeply.',
616
+ 'Against the wall, {} rests momentarily.',
617
+ 'In the lineup, {} waits calmly.',
618
+ 'With a joystick, {} plays a game.',
619
+ 'On the floor, {} stretches out.',
620
+ 'At the crossroads, {} chooses a path.',
621
+ 'Beside the bag, {} finds keys.',
622
+ 'On the track, {} runs laps.',
623
+ 'Under the tree, {} enjoys shade.',
624
+ 'Before the journey, {} packs essentials.',
625
+ 'Within the box, {} discovers treasures.',
626
+ 'In the mirror, {} sees reflection.',
627
+ 'By the lake, {} skips stones.',
628
+ 'On the steps, {} sits waiting.',
629
+ 'Against the flow, {} stands firm.',
630
+ 'Before the event, {} feels nervous.',
631
+ 'Within the heart, {} holds love.',
632
+ 'At the keyboard, {} composes music.',
633
+ 'By the fence, {} watches sunset.',
634
+ 'On the ledge, {} takes in views.',
635
+ 'Under the moon, {} makes wishes.',
636
+ 'Before the crowd, {} shows courage.',
637
+ 'Within the house, {} calls family.',
638
+ 'At the desk, {} solves puzzles.',
639
+ 'Beside the car, {} checks tires.',
640
+ 'On the peak, {} celebrates triumph.',
641
+ 'Against the odds, {} perseveres always.',
642
+ 'In the foyer, {} welcomes guests.',
643
+ 'With the team, {} collaborates effectively.',
644
+ 'On the grass, {} rolls playfully.',
645
+ 'At the junction, {} signals left.',
646
+ 'Beside the easel, {} studies the painting.',
647
+ 'On the quilt, {} patches holes.',
648
+ 'Under the coat, {} hides a gift.',
649
+ 'Before the dawn, {} dreams of success.',
650
+ 'Within the shadows, {} moves silently.',
651
+ 'At the beach, {} builds castles.',
652
+ 'By the gate, {} waits anxiously.',
653
+ 'On the island, {} finds peace.',
654
+ 'Against the breeze, {} flies a kite.',
655
+ 'Before the altar, {} takes a vow.',
656
+ 'Within the orchestra, {} tunes their instrument.',
657
+ 'An exciting magic trick is being performed by {}.',
658
+ 'A quiet library is being enjoyed by {}.',
659
+ 'A delicious meal is being cooked in the kitchen by {}.',
660
+ 'A challenging rock wall is being climbed by {}.',
661
+ 'A fast-paced basketball game is being played by {}.',
662
+ 'A beautiful melody is being played on a violin by {}.',
663
+ 'A serene lake is being fished by {}.',
664
+ 'An intense workout is being completed in the gym by {}.',
665
+ 'A mysterious book is being read under the tree by {}.',
666
+ 'A spirited dance is being performed on stage by {}.',
667
+ 'A serene afternoon picnic is being enjoyed by {}.',
668
+ 'A thrilling skateboarding trick is being attempted by {}.',
669
+ 'An intricate jigsaw puzzle is being solved by {}.',
670
+ 'A high note is being sung in a rehearsal room by {}.',
671
+ 'A new recipe is being tried out in the kitchen by {}.',
672
+ 'A bookshelf is being organized in the study by {}.',
673
+ 'A large canvas is being painted with bold colors by {}.',
674
+ 'An ancient ruin is being carefully explored by {}.',
675
+ 'A lengthy novel is being written at the desk by {}.',
676
+ 'A pottery wheel is being used to shape clay by {}.',
677
+ 'A soft melody is being played on a guitar by {}.',
678
+ 'A new language is being learned with enthusiasm by {}.',
679
+ 'An early morning jog is being taken along the beach by {}.',
680
+ 'A handmade quilt is being stitched with care by {}.',
681
+ 'A tropical fruit stand is being set up at the market by {}.',
682
+ 'A hot beverage is being brewed in a cozy cafe by {}.',
683
+ 'A winter bonfire is being lit to warm up the night by {}.',
684
+ 'A peaceful kayak trip is being embarked upon by {}.',
685
+ 'Bold graffiti is being sprayed on an urban wall by {}.',
686
+ 'A lively story is being told around the campfire by {}.',
687
+ 'A crafty sculpture is being created from recycled materials by {}.',
688
+ 'A vibrant mural is being painted on a downtown alley by {}.',
689
+ 'A dusty trail is being hiked at dawn by {}.',
690
+ 'A tricky crossword puzzle is being filled out by {}.',
691
+ 'A homemade pie is being baked for a special occasion by {}.',
692
+ 'An elaborate garden is being tended to by {}.',
693
+ 'A suspenseful movie is being watched with excitement by {}.',
694
+ 'A difficult yoga pose is being mastered in the studio by {}.',
695
+ 'A new skateboard is being ridden down a hill by {}.',
696
+ 'A savory soup is being stirred in a pot by {}.',
697
+ 'Cheerful holiday decorations are being hung around the house by {}.',
698
+ 'A thrilling novel is being devoured on a rainy afternoon by {}.',
699
+ 'A chess game is being thoughtfully played in the park by {}.',
700
+ 'A burst of laughter is being shared with friends by {}.',
701
+ 'Bright city lights are being admired from a rooftop by {}.',
702
+ 'An old family recipe is being followed in the kitchen by {}.',
703
+ 'A marshmallow is being roasted over a campfire by {}.',
704
+ 'Careful brush strokes are being applied to a model figurine by {}.',
705
+ 'A challenging video game is being played with focus by {}.',
706
+ 'An evening class is being attended with interest by {}.',
707
+ 'A delicate pastry is being decorated with icing by {}.',
708
+ 'An excited puppy is being trained in the backyard by {}.',
709
+ 'A basketball is being shot into a hoop by {}.',
710
+ 'A lively drumbeat is being played at a concert by {}.',
711
+ 'Colorful fall leaves are being photographed in the woods by {}.',
712
+ 'A new song is being composed on the piano by {}.',
713
+ 'A long-lost friend is being hugged in a warm embrace by {}.',
714
+ 'Bright fireworks are being watched in awe by {}.',
715
+ 'A favorite TV show is being binge-watched by {}.',
716
+ 'A new trail is being biked through the forest by {}.',
717
+ 'Freshly baked cookies are being taken out of the oven by {}.',
718
+ 'A difficult problem is being solved with satisfaction by {}.',
719
+ 'Colorful balloons are being blown up for a party by {}.',
720
+ 'A joyful tune is being whistled while walking by {}.',
721
+ 'An old film camera is being loaded with film by {}.',
722
+ 'An empty canvas is being gazed upon before painting by {}.',
723
+ 'An exciting soccer match is being watched with friends by {}.',
724
+ 'A warm cup of tea is being sipped quietly by {}.',
725
+ 'A good book is being enjoyed in a comfy armchair by {}.',
726
+ 'A gentle horse is being groomed in the stable by {}.',
727
+ 'A tense board game is being strategized over by {}.',
728
+ 'Fresh laundry is being folded neatly by {}.',
729
+ 'A thrilling roller coaster ride is being braved by {}.',
730
+ 'A favorite song is being sung in the shower by {}.',
731
+ 'A rainy day is being spent baking cookies by {}.',
732
+ 'Classic tunes are being listened to on vinyl by {}.',
733
+ 'An interesting documentary is being watched intently by {}.',
734
+ 'A busy day is being relaxed from with a bubble bath by {}.',
735
+ 'A sunflower field is being walked through by {}.',
736
+ 'A new plant is being potted with care by {}.',
737
+ 'A sunny terrace is being enjoyed with a cold drink by {}.',
738
+ 'Morning birds are being listened to at dawn by {}.',
739
+ 'A quiet museum hall is being wandered through by {}.',
740
+ 'An experimental recipe is being tested in the kitchen by {}.',
741
+ 'A homemade kite is being flown on a breezy day by {}.',
742
+ 'A colorful aquarium is being cleaned by {}.',
743
+ 'A new blog post is being composed on a laptop by {}.',
744
+ 'A wild trail is being trekked with enthusiasm by {}.',
745
+ 'An ice cream cone is being savored on a warm day by {}.',
746
+ 'A peaceful sunrise is being watched from a hilltop by {}.',
747
+ 'Freshly ground coffee is being brewed in the morning by {}.',
748
+ 'A comfortable hammock is being swayed in gently by {}.',
749
+ 'A nostalgic video game is being revisited with joy by {}.',
750
+ 'A challenging Sudoku puzzle is being completed by {}.',
751
+ 'A dusty attic is being explored for treasures by {}.',
752
+ 'A hefty stack of pancakes is being devoured for breakfast by {}.',
753
+ 'Delicate origami is being folded by {}.',
754
+ 'A peaceful moment is being cherished on a quiet porch by {}.',
755
+ 'On a quiet street, {} is jogging.',
756
+ 'With a gentle smile, {} offers help.',
757
+ 'Behind the old bookstore, {} reads quietly.',
758
+ 'Near a calm lake, {} sketches the scenery.',
759
+ 'By the bright window, {} sips coffee.',
760
+ 'Under the warm sun, {} relaxes.',
761
+ 'Around the bustling square, {} dances.',
762
+ 'Beside the campfire, {} tells stories.',
763
+ 'Above the city noise, {} daydreams.',
764
+ 'Through the crowded fair, {} navigates.',
765
+ 'Against the evening sky, {} takes photos.',
766
+ 'Among the tall trees, {} hikes.',
767
+ 'Before the morning rush, {} stretches.',
768
+ 'Amid the garden blooms, {} seeks peace.',
769
+ 'Across the open field, {} runs freely.',
770
+ 'During the lively party, {} laughs.',
771
+ 'Following the winding path, {} explores.',
772
+ 'Outside the cozy cottage, {} gazes at stars.',
773
+ 'Within the silent walls, {} contemplates.',
774
+ 'Beneath the ancient arch, {} pauses reflectively.',
775
+ 'Along the riverbank, {} fishes.',
776
+ 'Beside a bubbling brook, {} writes poetry.',
777
+ 'Underneath the vibrant mural, {} admires art.',
778
+ 'Beyond the bustling streets, {} finds quiet.',
779
+ 'Behind the heavy curtain, {} rehearses lines.',
780
+ 'Upon the windswept hill, {} flies a kite.',
781
+ 'Throughout the sunny day, {} tends the shop.',
782
+ 'Despite the hectic pace, {} stays calm.',
783
+ 'Behind the lens of a camera, {} captures moments.',
784
+ 'Inside the warm bakery, {} savors aromas.',
785
+ 'Beneath the star-filled sky, {} makes a wish.',
786
+ 'Beyond the garden gate, {} enters serenity.',
787
+ 'Between the bookshelves, {} finds adventure.',
788
+ 'Across the dance floor, {} moves gracefully.',
789
+ 'Around the festive decorations, {} feels joy.',
790
+ 'Amidst the quiet sanctuary, {} prays.',
791
+ 'Near the bustling café, {} watches the world.',
792
+ 'Under the shade of a tree, {} enjoys a picnic.',
793
+ 'By the glow of the fireplace, {} reads.',
794
+ 'After the long journey, {} rests.',
795
+ 'Outside the lively market, {} samples flavors.',
796
+ 'Upon the old wooden bench, {} sits.',
797
+ 'Around the warm campfire, {} sings.',
798
+ 'Through the busy terminal, {} travels.',
799
+ 'Within the walls of home, {} feels safe.',
800
+ 'Beside the flowing river, {} reflects.',
801
+ 'Against the cool breeze, {} wraps up warm.',
802
+ 'Across the silent library, {} seeks knowledge.',
803
+ 'Beneath the towering cliff, {} gazes up.',
804
+ 'Beyond the colorful horizon, {} dreams.',
805
+ 'Between the office cubicles, {} takes a breath.',
806
+ 'Behind the vibrant easel, {} paints.',
807
+ 'Upon the peaceful shore, {} collects shells.',
808
+ 'Throughout the old village, {} discovers history.',
809
+ 'Despite the falling rain, {} smiles.',
810
+ 'Inside the bustling diner, {} enjoys breakfast.',
811
+ 'By the edge of the fountain, {} tosses a coin.',
812
+ 'Outside the charming bookstore, {} chooses a novel.',
813
+ 'Upon the rooftop terrace, {} views the skyline.',
814
+ 'Through the frosty window, {} longs for spring.',
815
+ 'Within the hushed auditorium, {} listens intently.',
816
+ 'Beside the crackling bonfire, {} cozies up.',
817
+ 'Against the morning chill, {} jogs.',
818
+ 'Across the golden meadow, {} strolls.',
819
+ 'Amidst the echo of laughter, {} joins in.',
820
+ 'Beyond the realm of the city, {} seeks nature.',
821
+ 'Between the lush vines, {} harvests grapes.',
822
+ 'Behind the frosted glass, {} sips tea.',
823
+ 'Upon the creaky floorboards, {} tip-toes.',
824
+ 'Throughout the silent movie, {} is mesmerized.',
825
+ 'Despite the room’s clutter, {} finds order.',
826
+ 'Beneath the bright marquee, {} awaits the opening.',
827
+ 'By the light of the lanterns, {} feels warmth.',
828
+ 'After the rain has passed, {} splashes in puddles.',
829
+ 'Outside the local theater, {} buys a ticket.',
830
+ 'Upon the green expanse, {} practices yoga.',
831
+ 'Through the historic district, {} admires architecture.',
832
+ 'Within the quiet of dawn, {} takes a moment.',
833
+ 'Beside the ice-covered pond, {} feeds the ducks.',
834
+ 'Against the setting sun, {} cherishes the moment.',
835
+ 'Across the crowded room, {} finds a friend.',
836
+ 'Amidst the morning calm, {} sows seeds.',
837
+ 'Beneath the overcast sky, {} contemplates change.',
838
+ 'Beyond the busy crosswalk, {} finds solitude.',
839
+ 'Between two towering pines, {} hangs a hammock.',
840
+ 'Behind the cool shade, {} enjoys an ice cream.',
841
+ 'Upon the deserted path, {} embraces stillness.',
842
+ 'Throughout the lively tune, {} taps their foot.',
843
+ 'Despite the distance apart, {} feels connected.',
844
+ 'Inside the crowded bus, {} daydreams.',
845
+ 'Beneath the vast universe, {} feels wonder.',
846
+ 'By the vibrant mural, {} appreciates art.',
847
+ 'After the final curtain call, {} feels inspired.',
848
+ 'Outside the quaint café, {} inhales the fresh morning air.',
849
+ 'Sitting calmly with a book in their lap is {}.',
850
+ 'Holding the reins of a horse stands {}.',
851
+ 'Laughing at a joke just heard is {}.',
852
+ 'Taking a deep breath of fresh air on a hike is {}.',
853
+ 'Reaching for an apple on a tree is {}.',
854
+ 'Playing a violin with focused attention is {}.',
855
+ 'Taking a photo of the sunset is {}.',
856
+ 'Lying on the grass and looking at the clouds is {}.',
857
+ 'Standing with an umbrella in the rain is {}.',
858
+ 'Throwing a frisbee in the park is {}.',
859
+ 'Riding a skateboard down the sidewalk is {}.',
860
+ 'Juggling three balls skillfully is {}.',
861
+ 'Swinging on a swing with a smile is {}.',
862
+ 'Pulling a suitcase in an airport is {}.',
863
+ 'Dipping a paintbrush into paint before a canvas is {}.',
864
+ 'Stretching before a run along the track is {}.',
865
+ 'Pouring a cup of coffee in the morning is {}.',
866
+ 'Bouncing a basketball on the court is {}.',
867
+ 'Holding an ice cream cone upside down is {}.',
868
+ 'Standing at the podium about to speak is {}.',
869
+ 'Waiting for a train at the station is {}.',
870
+ 'Typing rapidly on a keyboard is {}.',
871
+ 'Riding a bicycle along the river path is {}.',
872
+ 'Blowing out candles on a birthday cake is {}.',
873
+ 'Feeding ducks by the pond is {}.',
874
+ 'Hiking with a backpack up a mountain trail is {}.',
875
+ 'Lifting weights in the gym is {}.',
876
+ 'Contemplating a piece of art in a gallery is {}.',
877
+ 'Sipping a milkshake through a straw is {}.',
878
+ 'Planting seedlings in a garden bed is {}.',
879
+ 'Wading through a stream with a fishing pole is {}.',
880
+ 'Assembling a model airplane with focus is {}.',
881
+ 'Whipping up a smoothie in a blender is {}.',
882
+ 'Rolling out dough for a pie is {}.',
883
+ 'Peering through a telescope at night is {}.',
884
+ 'Flying a kite in the open field is {}.',
885
+ 'Playing chess and contemplating the next move is {}.',
886
+ 'Brushing a horse in the stable is {}.',
887
+ 'Sitting on the pier with feet dangling over water is {}.',
888
+ 'Tuning a guitar before a performance is {}.',
889
+ 'Practicing yoga in a peaceful room is {}.',
890
+ 'Sculpting clay on a pottery wheel is {}.',
891
+ 'Skimming a stone across a lake is {}.',
892
+ 'Building a sandcastle at the beach is {}.',
893
+ 'Fishing at the crack of dawn on a boat is {}.',
894
+ 'Roasting marshmallows over a campfire is {}.',
895
+ 'Watching the horizon from the deck of a ship is {}.',
896
+ 'Admiring the view from the top of a ferris wheel is {}.',
897
+ 'Reading a map under the streetlight is {}.',
898
+ 'Twirling a pen thoughtfully while studying is {}.',
899
+ 'Writing in a journal quietly is {}.',
900
+ 'Inspecting a gadget with curiosity is {}.',
901
+ 'Balancing on a slackline between two trees is {}.',
902
+ 'Mixing ingredients for a recipe is {}.',
903
+ 'Waiting patiently for the crosswalk signal is {}.',
904
+ 'Riding an escalator up to the next floor is {}.',
905
+ 'Sitting on a bench feeding pigeons is {}.',
906
+ 'Standing at the edge of the diving board is {}.',
907
+ 'Looking at merchandise in a shop window is {}.',
908
+ 'Sitting on the floor wrapping gifts is {}.',
909
+ 'Climbing up a ladder to reach a high shelf is {}.',
910
+ 'Waiting for the bus at the bus stop is {}.',
911
+ 'Sipping tea while gazing out the window is {}.',
912
+ 'Swinging a tennis racquet on the court is {}.',
913
+ 'Watching a movie with 3D glasses on is {}.',
914
+ 'Carving a piece of wood into a sculpture is {}.',
915
+ 'Hula hooping in the backyard is {}.',
916
+ 'Rowing a boat down the river is {}.',
917
+ 'Bending down to tie a shoelace is {}.',
918
+ 'Playing the drums with enthusiasm is {}.',
919
+ 'Waiting in line at the grocery store checkout is {}.',
920
+ 'Blowing bubbles with gum is {}.',
921
+ 'Sketching a landscape on a notepad is {}.',
922
+ 'Jumping into a pile of autumn leaves is {}.',
923
+ 'Standing with hands on hips after a workout is {}.',
924
+ 'Conducting an orchestra with intensity is {}.',
925
+ 'Leaning against a fence watching the sunrise is {}.',
926
+ 'Tossing a salad in a bowl for dinner is {}.',
927
+ 'Crossing a footbridge over a stream is {}.',
928
+ 'Bobbing their head to music on headphones is {}.',
929
+ 'Attaching a lock to a bridge railing as a symbol of love is {}.',
930
+ 'Pumping air into a bicycle tire is {}.',
931
+ 'Repairing a computer with various tools is {}.',
932
+ 'Doodling in a notebook during a lecture is {}.',
933
+ 'Lining up a shot with a camera is {}.',
934
+ 'Kneading dough on a floured surface is {}.',
935
+ 'Waving goodbye at the train station is {}.',
936
+ 'Lying on the beach soaking up the sun is {}.',
937
+ 'Reading street signs in an unfamiliar city is {}.',
938
+ 'Casting a fishing line from the shore is {}.',
939
+ 'Blowing on a dandelion with seeds dispersing is {}.',
940
+ 'Dancing alone in the living room is {}.',
941
+ 'Watching the stars with a blanket wrapped around is {}.',
942
+ 'Peeling an orange in one long spiral is {}.',
943
+ 'Picking flowers from a field is {}.',
944
+ 'Studying a museum exhibit with interest is {}.',
945
+ 'Hanging laundry out to dry on a sunny day is {}.',
946
+ 'Cuddling a pet cat on the couch is {}.',
947
+ 'Arranging books on a shelf by color is {}.',
948
+ 'Standing silent in a moment of gratitude is {}.'
949
+ ]
950
+
951
+
952
+ random.shuffle(imagenet_templates_small)
953
+
954
+ per_img_token_list = ['*']
955
+
956
+
957
+ class FaceIdDataset(Dataset):
958
+ def __init__(self, experiment_name, **kwargs):
959
+ super(FaceIdDataset, self).__init__()
960
+
961
+ self.experiment_name = experiment_name
962
+ if self.experiment_name == "normal_GAN":
963
+ name_path = "datasets_face/good_names.txt"
964
+ elif self.experiment_name == "man_GAN":
965
+ name_path = "datasets_face/good_names_man.txt"
966
+ elif self.experiment_name == "woman_GAN":
967
+ name_path = "datasets_face/good_names_woman.txt"
968
+ else:
969
+ print("Hello, please notice this ^_^")
970
+ assert 0
971
+ print("now experiment_name:", self.experiment_name)
972
+
973
+ with open(name_path, "r") as f:
974
+ self.names = f.read().splitlines()
975
+
976
+ if self.experiment_name == "normal_GAN":
977
+ with open("datasets_face/good_names_man.txt", "r") as f_man, open("datasets_face/good_names_woman.txt", "r") as f_woman:
978
+ self.man_names = f_man.read().splitlines()
979
+ self.woman_names = f_woman.read().splitlines()
980
+
981
+ self._length = len(self.names)
982
+
983
+
984
+
985
+ def __len__(self):
986
+ return self._length
987
+
988
+ def __getitem__(self, i):
989
+ example = {}
990
+
991
+ name = self.names[i]
992
+
993
+ # if normal_GAN, this trick will be used for gender balance.
994
+ if self.experiment_name == "normal_GAN":
995
+ if random.random() < 0.5:
996
+ name = random.choice(self.man_names)
997
+ else:
998
+ name = random.choice(self.woman_names)
999
+
1000
+ ''' text '''
1001
+ placeholder_string = per_img_token_list[0]
1002
+ text = random.choice(imagenet_templates_small).format('%s person' % placeholder_string)
1003
+
1004
+ example["caption"] = text
1005
+ example["name"] = name
1006
+
1007
+ return example
datasets_face/good_names.txt ADDED
@@ -0,0 +1,326 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adam Savage
2
+ Adam Scott
3
+ Alan Alda
4
+ Alan Hale
5
+ Albert Brooks
6
+ Alec Baldwin
7
+ Alec Guinness
8
+ Alice Cooper
9
+ Alicia Alonso
10
+ Amy Adams
11
+ Amy Schumer
12
+ Anderson Cooper
13
+ Andrea Martin
14
+ Andy Richter
15
+ Angelina Jolie
16
+ Ann Curry
17
+ Ann Miller
18
+ Anne Hathaway
19
+ Anne Murray
20
+ Aubrey Plaza
21
+ Audrey Hepburn
22
+ Aziz Ansari
23
+ BD Wong
24
+ Barbara Walters
25
+ Ben Affleck
26
+ Ben Kingsley
27
+ Ben Miller
28
+ Ben Schwartz
29
+ Benedict Cumberbatch
30
+ Bill Burr
31
+ Bill Cosby
32
+ Bill Irwin
33
+ Bill Maher
34
+ Bill Murray
35
+ Bill Nye
36
+ Billy Chow
37
+ Billy Connolly
38
+ Billy Crystal
39
+ Billy Joel
40
+ Billy Porter
41
+ Billy Wilder
42
+ Bob Hope
43
+ Bob Marley
44
+ Bonnie Hunt
45
+ Brad Pitt
46
+ Brandon Lee
47
+ Brian Cox
48
+ Brian Tee
49
+ Britney Spears
50
+ Bron James
51
+ Bruce Springsteen
52
+ Bruce Willis
53
+ Bryan Cranston
54
+ Buck Henry
55
+ Burt Lancaster
56
+ Burt Reynolds
57
+ Cameron Diaz
58
+ Carol Burnett
59
+ Carol Channing
60
+ Carol Kane
61
+ Carrie Underwood
62
+ Cary Grant
63
+ Cate Blanchett
64
+ Celia Cruz
65
+ Celine Dion
66
+ Charlie Sheen
67
+ Cheryl Hines
68
+ Chris Pratt
69
+ Christina Hendricks
70
+ Christopher Guest
71
+ Cindy Williams
72
+ Claire Danes
73
+ Craig Ferguson
74
+ Craig Robinson
75
+ Cristiano Ronaldo
76
+ Crystal Gayle
77
+ Dan Harmon
78
+ Dan Levy
79
+ Dan Rather
80
+ Dana Gould
81
+ Daniel Radcliffe
82
+ Danny Thomas
83
+ Daryl Hall
84
+ Dave Bautista
85
+ Dave Matthews
86
+ David Beckham
87
+ David Bowie
88
+ David Butler
89
+ David Spade
90
+ Dax Shepard
91
+ Dean Martin
92
+ Debra Messing
93
+ Dennis Chan
94
+ Dennis Franz
95
+ Dennis Hopper
96
+ Dennis Quaid
97
+ Dev Patel
98
+ Devon Aoki
99
+ Diana Ross
100
+ Diane Sawyer
101
+ Dizzy Gillespie
102
+ Donald Crisp
103
+ Donald Glover
104
+ Donna Reed
105
+ Donnie Yen
106
+ Doris Roberts
107
+ Drew Barrymore
108
+ Drew Carey
109
+ Dudley Moore
110
+ Dwayne Johnson
111
+ Ed Sheeran
112
+ Eddie Murphy
113
+ Edgar Wright
114
+ Edward Norton
115
+ Elaine May
116
+ Eleanor Powell
117
+ Eli Roth
118
+ Elizabeth Banks
119
+ Ellen Pompeo
120
+ Elon Musk
121
+ Elton John
122
+ Emma Thompson
123
+ Eric Idle
124
+ Ernie Reyes
125
+ Floyd Mayweather
126
+ Forest Whitaker
127
+ Fred Savage
128
+ Garry Marshall
129
+ Gene Lockhart
130
+ George Benson
131
+ George Burns
132
+ George Clooney
133
+ George Lopez
134
+ George Lucas
135
+ George Marshall
136
+ George Miller
137
+ Gillian Murphy
138
+ Ginger Rogers
139
+ Gregory Hines
140
+ Gregory Peck
141
+ Halle Berry
142
+ Harold Lloyd
143
+ Harrison Ford
144
+ Harry Carey
145
+ Helen Mirren
146
+ Helen Reddy
147
+ Howard Stern
148
+ Hugh Jackman
149
+ Hugh Laurie
150
+ Ira Glass
151
+ Isabel Sanford
152
+ Jack Conway
153
+ Jack Nicholson
154
+ Jackie Chan
155
+ Jackie Mason
156
+ James Burrows
157
+ James Cameron
158
+ James Franco
159
+ James Patterson
160
+ Jamie Foxx
161
+ Jane Lynch
162
+ Janet Jackson
163
+ Jason Alexander
164
+ Jason Bateman
165
+ Jason Biggs
166
+ Jason Nash
167
+ Jay Leno
168
+ Jay Pharoah
169
+ Jeff Gordon
170
+ Jennifer Aniston
171
+ Jennifer Garner
172
+ Jennifer Hudson
173
+ Jennifer Lopez
174
+ Jennifer Saunders
175
+ Jenny Slate
176
+ Jerome Robbins
177
+ Jerry Lewis
178
+ Jerry Seinfeld
179
+ Jim Parsons
180
+ Jodie Foster
181
+ Joe Cornish
182
+ John Cho
183
+ John Legend
184
+ John Ritter
185
+ Johnny Depp
186
+ Jon Hamm
187
+ Joseph Gordon
188
+ Josh Gad
189
+ Julia Roberts
190
+ Julie Bowen
191
+ Julie Kent
192
+ Julie Walters
193
+ Justin Bieber
194
+ Kanye West
195
+ Katy Perry
196
+ Kay Cannon
197
+ Keanu Reeves
198
+ Kelly Clarkson
199
+ Kelly Hu
200
+ Ken Dodd
201
+ Ken Jeong
202
+ Kenny Ortega
203
+ Kerry Washington
204
+ Kevin Dillon
205
+ Kevin Hart
206
+ Kevin James
207
+ Kevin Kline
208
+ Kevin Spacey
209
+ Kiefer Sutherland
210
+ Kim Coles
211
+ Kim Kardashian
212
+ Kobe Bryant
213
+ Kristen Bell
214
+ Kylie Jenner
215
+ Lady Gaga
216
+ Larry King
217
+ LeBron James
218
+ Lee Daniels
219
+ Lena Dunham
220
+ Leonardo DiCaprio
221
+ Leslie Mann
222
+ Leslie Nielsen
223
+ Lillian Hurst
224
+ Lilly Singh
225
+ Lily Tomlin
226
+ Lionel Messi
227
+ Loretta Lynn
228
+ Lucy Liu
229
+ Mackenzie Crook
230
+ Madeline Kahn
231
+ Marcia Wallace
232
+ Margaret Cho
233
+ Mariah Carey
234
+ Mark Wahlberg
235
+ Martin Scorsese
236
+ Mel Brooks
237
+ Mel Gibson
238
+ Michael Cera
239
+ Michael Jackson
240
+ Michael Jordan
241
+ Michael Landon
242
+ Michael Palin
243
+ Mike Myers
244
+ Molly Shannon
245
+ Morgan Freeman
246
+ Naomi Watts
247
+ Natalie Morales
248
+ Natalie Portman
249
+ Nathan Fielder
250
+ Nathan Lane
251
+ Nick Park
252
+ Nicolas Cage
253
+ Nicole Kidman
254
+ Norman Lear
255
+ Patrick Stewart
256
+ Paul McCartney
257
+ Paul Rudd
258
+ Paula Abdul
259
+ Penny Marshall
260
+ Pete Holmes
261
+ Peter Jackson
262
+ Phil McGraw
263
+ Piers Morgan
264
+ Quentin Tarantino
265
+ Randy Jackson
266
+ Randy Travis
267
+ Ray Romano
268
+ Rich Sommer
269
+ Richard Attenborough
270
+ Ricky Gervais
271
+ Ridley Scott
272
+ Rita Moreno
273
+ Rob Lowe
274
+ Robert Downey
275
+ Robin Williams
276
+ Roger Federer
277
+ Roger Moore
278
+ Ron Howard
279
+ Rose Marie
280
+ Russell Brand
281
+ Ryan Murphy
282
+ Ryan Reynolds
283
+ Sally Field
284
+ Sandra Bullock
285
+ Sarah Shahi
286
+ Seth Rogen
287
+ Shirley Jones
288
+ Sidney Franklin
289
+ Simon Cowell
290
+ Snoop Dogg
291
+ Spike Lee
292
+ Stan Lee
293
+ Stephen Curry
294
+ Stephen Fry
295
+ Stephen King
296
+ Stephen Merchant
297
+ Steven Spielberg
298
+ Sung Kang
299
+ Susan Egan
300
+ Taylor Swift
301
+ Terrence Howard
302
+ Terry Bradshaw
303
+ Terry Jones
304
+ Tim Conway
305
+ Tim Robbins
306
+ Tina Fey
307
+ Tom Cruise
308
+ Tom Hanks
309
+ Tom Hiddleston
310
+ Tom Jones
311
+ Tommy Chong
312
+ Tony Bennett
313
+ Tracy Morgan
314
+ Trey Parker
315
+ Tyler Perry
316
+ Valerie Harper
317
+ Vanessa Bayer
318
+ Vanessa Williams
319
+ Viola Davis
320
+ Walt Disney
321
+ Wanda Sykes
322
+ Wayne Brady
323
+ Wendy Whelan
324
+ Will Ferrell
325
+ Will Smith
326
+ Zachary Levi
datasets_face/good_names_man.txt ADDED
@@ -0,0 +1,226 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Adam Savage
2
+ Adam Scott
3
+ Alan Alda
4
+ Alan Hale
5
+ Albert Brooks
6
+ Alec Baldwin
7
+ Alec Guinness
8
+ Alice Cooper
9
+ Amy Adams
10
+ Anderson Cooper
11
+ Andy Richter
12
+ Aziz Ansari
13
+ BD Wong
14
+ Ben Affleck
15
+ Ben Kingsley
16
+ Ben Miller
17
+ Ben Schwartz
18
+ Benedict Cumberbatch
19
+ Bill Burr
20
+ Bill Cosby
21
+ Bill Irwin
22
+ Bill Maher
23
+ Bill Murray
24
+ Bill Nye
25
+ Billy Chow
26
+ Billy Connolly
27
+ Billy Crystal
28
+ Billy Joel
29
+ Billy Porter
30
+ Billy Wilder
31
+ Bob Hope
32
+ Bob Marley
33
+ Brad Pitt
34
+ Brandon Lee
35
+ Brian Cox
36
+ Brian Tee
37
+ Britney Spears
38
+ Bron James
39
+ Bruce Springsteen
40
+ Bruce Willis
41
+ Bryan Cranston
42
+ Buck Henry
43
+ Burt Lancaster
44
+ Burt Reynolds
45
+ Cary Grant
46
+ Charlie Sheen
47
+ Chris Pratt
48
+ Christopher Guest
49
+ Craig Ferguson
50
+ Craig Robinson
51
+ Cristiano Ronaldo
52
+ Dan Harmon
53
+ Dan Levy
54
+ Dan Rather
55
+ Dana Gould
56
+ Daniel Radcliffe
57
+ Danny Thomas
58
+ Daryl Hall
59
+ Dave Bautista
60
+ Dave Matthews
61
+ David Beckham
62
+ David Bowie
63
+ David Butler
64
+ David Spade
65
+ Dax Shepard
66
+ Dean Martin
67
+ Dennis Chan
68
+ Dennis Franz
69
+ Dennis Hopper
70
+ Dennis Quaid
71
+ Dev Patel
72
+ Dizzy Gillespie
73
+ Donald Crisp
74
+ Donald Glover
75
+ Donnie Yen
76
+ Drew Carey
77
+ Dudley Moore
78
+ Dwayne Johnson
79
+ Ed Sheeran
80
+ Eddie Murphy
81
+ Edgar Wright
82
+ Edward Norton
83
+ Eli Roth
84
+ Elon Musk
85
+ Elton John
86
+ Eric Idle
87
+ Ernie Reyes
88
+ Floyd Mayweather
89
+ Forest Whitaker
90
+ Fred Savage
91
+ Garry Marshall
92
+ Gene Lockhart
93
+ George Benson
94
+ George Burns
95
+ George Clooney
96
+ George Lopez
97
+ George Lucas
98
+ George Marshall
99
+ George Miller
100
+ Gregory Hines
101
+ Gregory Peck
102
+ Harold Lloyd
103
+ Harrison Ford
104
+ Harry Carey
105
+ Howard Stern
106
+ Hugh Jackman
107
+ Hugh Laurie
108
+ Ira Glass
109
+ Jack Conway
110
+ Jack Nicholson
111
+ Jackie Chan
112
+ Jackie Mason
113
+ James Burrows
114
+ James Cameron
115
+ James Franco
116
+ James Patterson
117
+ Jamie Foxx
118
+ Jason Alexander
119
+ Jason Bateman
120
+ Jason Biggs
121
+ Jason Nash
122
+ Jay Leno
123
+ Jay Pharoah
124
+ Jeff Gordon
125
+ Jerome Robbins
126
+ Jerry Lewis
127
+ Jerry Seinfeld
128
+ Jim Parsons
129
+ Joe Cornish
130
+ John Cho
131
+ John Legend
132
+ John Ritter
133
+ Johnny Depp
134
+ Jon Hamm
135
+ Joseph Gordon
136
+ Josh Gad
137
+ Justin Bieber
138
+ Kanye West
139
+ Keanu Reeves
140
+ Ken Dodd
141
+ Ken Jeong
142
+ Kenny Ortega
143
+ Kevin Dillon
144
+ Kevin Hart
145
+ Kevin James
146
+ Kevin Kline
147
+ Kevin Spacey
148
+ Kiefer Sutherland
149
+ Kobe Bryant
150
+ Larry King
151
+ LeBron James
152
+ Lee Daniels
153
+ Leonardo DiCaprio
154
+ Lionel Messi
155
+ Mackenzie Crook
156
+ Mark Wahlberg
157
+ Martin Scorsese
158
+ Mel Brooks
159
+ Mel Gibson
160
+ Michael Cera
161
+ Michael Jackson
162
+ Michael Jordan
163
+ Michael Landon
164
+ Michael Palin
165
+ Mike Myers
166
+ Morgan Freeman
167
+ Nathan Fielder
168
+ Nathan Lane
169
+ Nick Park
170
+ Nicolas Cage
171
+ Norman Lear
172
+ Patrick Stewart
173
+ Paul McCartney
174
+ Paul Rudd
175
+ Pete Holmes
176
+ Peter Jackson
177
+ Phil McGraw
178
+ Piers Morgan
179
+ Quentin Tarantino
180
+ Randy Jackson
181
+ Randy Travis
182
+ Ray Romano
183
+ Rich Sommer
184
+ Richard Attenborough
185
+ Ricky Gervais
186
+ Ridley Scott
187
+ Rob Lowe
188
+ Robert Downey
189
+ Robin Williams
190
+ Roger Federer
191
+ Roger Moore
192
+ Ron Howard
193
+ Russell Brand
194
+ Ryan Murphy
195
+ Ryan Reynolds
196
+ Seth Rogen
197
+ Sidney Franklin
198
+ Simon Cowell
199
+ Snoop Dogg
200
+ Spike Lee
201
+ Stan Lee
202
+ Stephen Curry
203
+ Stephen Fry
204
+ Stephen King
205
+ Stephen Merchant
206
+ Steven Spielberg
207
+ Sung Kang
208
+ Terrence Howard
209
+ Terry Bradshaw
210
+ Terry Jones
211
+ Tim Conway
212
+ Tim Robbins
213
+ Tom Cruise
214
+ Tom Hanks
215
+ Tom Hiddleston
216
+ Tom Jones
217
+ Tommy Chong
218
+ Tony Bennett
219
+ Tracy Morgan
220
+ Trey Parker
221
+ Tyler Perry
222
+ Walt Disney
223
+ Wayne Brady
224
+ Will Ferrell
225
+ Will Smith
226
+ Zachary Levi
datasets_face/good_names_woman.txt ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Alicia Alonso
2
+ Amy Schumer
3
+ Andrea Martin
4
+ Angelina Jolie
5
+ Ann Curry
6
+ Ann Miller
7
+ Anne Hathaway
8
+ Anne Murray
9
+ Aubrey Plaza
10
+ Audrey Hepburn
11
+ Barbara Walters
12
+ Bonnie Hunt
13
+ Cameron Diaz
14
+ Carol Burnett
15
+ Carol Channing
16
+ Carol Kane
17
+ Carrie Underwood
18
+ Cate Blanchett
19
+ Celia Cruz
20
+ Celine Dion
21
+ Cheryl Hines
22
+ Christina Hendricks
23
+ Cindy Williams
24
+ Claire Danes
25
+ Crystal Gayle
26
+ Debra Messing
27
+ Devon Aoki
28
+ Diana Ross
29
+ Diane Sawyer
30
+ Donna Reed
31
+ Doris Roberts
32
+ Drew Barrymore
33
+ Elaine May
34
+ Eleanor Powell
35
+ Elizabeth Banks
36
+ Ellen Pompeo
37
+ Emma Thompson
38
+ Gillian Murphy
39
+ Ginger Rogers
40
+ Halle Berry
41
+ Helen Mirren
42
+ Helen Reddy
43
+ Isabel Sanford
44
+ Jane Lynch
45
+ Janet Jackson
46
+ Jennifer Aniston
47
+ Jennifer Garner
48
+ Jennifer Hudson
49
+ Jennifer Lopez
50
+ Jennifer Saunders
51
+ Jenny Slate
52
+ Jodie Foster
53
+ Julia Roberts
54
+ Julie Bowen
55
+ Julie Kent
56
+ Julie Walters
57
+ Katy Perry
58
+ Kay Cannon
59
+ Kelly Clarkson
60
+ Kelly Hu
61
+ Kerry Washington
62
+ Kim Coles
63
+ Kim Kardashian
64
+ Kristen Bell
65
+ Kylie Jenner
66
+ Lady Gaga
67
+ Lena Dunham
68
+ Leslie Mann
69
+ Leslie Nielsen
70
+ Lillian Hurst
71
+ Lilly Singh
72
+ Lily Tomlin
73
+ Loretta Lynn
74
+ Lucy Liu
75
+ Madeline Kahn
76
+ Marcia Wallace
77
+ Margaret Cho
78
+ Mariah Carey
79
+ Molly Shannon
80
+ Naomi Watts
81
+ Natalie Morales
82
+ Natalie Portman
83
+ Nicole Kidman
84
+ Paula Abdul
85
+ Penny Marshall
86
+ Rita Moreno
87
+ Rose Marie
88
+ Sally Field
89
+ Sandra Bullock
90
+ Sarah Shahi
91
+ Shirley Jones
92
+ Susan Egan
93
+ Taylor Swift
94
+ Tina Fey
95
+ Valerie Harper
96
+ Vanessa Bayer
97
+ Vanessa Williams
98
+ Viola Davis
99
+ Wanda Sykes
100
+ Wendy Whelan
datasets_face/identity_space.yaml ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model:
2
+ use_celeb: True
3
+ use_svd: True
4
+ rm_repeats: True
5
+ n_components: 512 # consistent with meta_inner_dim, should be <= n_samples-1
6
+ use_sample_reduce: False
7
+ n_samples: 513
8
+ use_flatten: False
9
+ num_embeds_per_token: 2 # consistent with personalization_config
10
+ target: models.embedding_manager.EmbeddingManagerId
11
+ params:
12
+ linear_start: 0.00085
13
+ linear_end: 0.0120
14
+ num_timesteps_cond: 1
15
+ log_every_t: 200
16
+ timesteps: 1000
17
+ first_stage_key: image
18
+ cond_stage_key: caption
19
+ image_size: 64
20
+ channels: 4
21
+ cond_stage_trainable: true # Note: different from the one we trained before
22
+ conditioning_key: crossattn
23
+ monitor: val/loss_simple_ema
24
+ scale_factor: 0.18215
25
+ use_ema: False
26
+ embedding_reg_weight: 0.0
27
+ unfreeze_model: False
28
+ model_lr: 0.0
29
+
30
+ personalization_config:
31
+ params:
32
+ num_embeds_per_token: 2 # consistent with cond_stage_config
33
+ mlp_depth: 2
34
+ input_dim: 64
35
+ token_dim: 1024
36
+ loss_type: 'none'
37
+
38
+
demo_embeddings/example_1.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d77090f8d1c6cb049c491dd0ffc74a05c1df9e272d9a5788f358b4073f63b75
3
+ size 9288
demo_embeddings/example_2.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:523e07d8ea4af7f74a4bc4e40ea3bd246562389b6f52aa81ad79ba900eeef040
3
+ size 9288
demo_embeddings/example_3.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a09e71305dd6a1a6ad2f8e387012c64f7f38728f33667a776e40456435e4d5a
3
+ size 9288
demo_embeddings/example_4.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:03deeeb1438237f2e69467a8afc0841c644c336c74f7703b39c78cbf211983f7
3
+ size 9288
demo_embeddings/example_5.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b4b3a8e990d1b374748d514a031e617e5230215e9619d97441855bf734b05439
3
+ size 9288
demo_embeddings/example_6.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5888c106b5b574a0f015e9d294f86cc62c77f2d3afb197cdc3f71fbabbceec5d
3
+ size 9283
models/celeb_embeddings.py ADDED
@@ -0,0 +1,74 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+ import clip
5
+ from einops import rearrange, repeat
6
+ from transformers import CLIPTokenizer, CLIPTextModel
7
+ import kornia
8
+ import numpy as np
9
+ import os
10
+
11
+ def embedding_forward(
12
+ self,
13
+ input_ids = None,
14
+ position_ids = None,
15
+ name_batch = None,
16
+ inputs_embeds = None,
17
+ embedding_manager = None,
18
+ only_embedding=True,
19
+ random_embeddings = None,
20
+ timesteps = None,
21
+ ) -> torch.Tensor:
22
+
23
+ seq_length = input_ids.shape[-1] if input_ids is not None else inputs_embeds.shape[-2]
24
+
25
+ if inputs_embeds is None:
26
+ inputs_embeds = self.token_embedding(input_ids)
27
+ if only_embedding:
28
+ return inputs_embeds
29
+
30
+ if embedding_manager is not None:
31
+ inputs_embeds, other_return_dict = embedding_manager(input_ids, inputs_embeds, name_batch, random_embeddings, timesteps)
32
+
33
+ if position_ids is None:
34
+ position_ids = self.position_ids[:, :seq_length]
35
+
36
+ position_embeddings = self.position_embedding(position_ids)
37
+ embeddings = inputs_embeds + position_embeddings
38
+
39
+ return embeddings, other_return_dict
40
+
41
+
42
+ @torch.no_grad()
43
+ def _get_celeb_embeddings_basis(tokenizer, text_encoder, good_names_txt):
44
+
45
+ device = text_encoder.device
46
+ max_length = 77
47
+
48
+ with open(good_names_txt, "r") as f:
49
+ celeb_names = f.read().splitlines()
50
+
51
+ ''' get tokens and embeddings '''
52
+ all_embeddings = []
53
+ for name in celeb_names:
54
+ batch_encoding = tokenizer(name, truncation=True, return_tensors="pt")
55
+ tokens = batch_encoding["input_ids"].to(device)[:, 1:3]
56
+ embeddings = text_encoder.text_model.embeddings(input_ids=tokens, only_embedding=True)
57
+ all_embeddings.append(embeddings)
58
+
59
+ all_embeddings: torch.Tensor = torch.cat(all_embeddings, dim=0)
60
+
61
+
62
+ print('[all_embeddings loaded] shape =', all_embeddings.shape,
63
+ 'max:', all_embeddings.max(),
64
+ 'min={}', all_embeddings.min())
65
+
66
+ name_emb_mean = all_embeddings.mean(0)
67
+ name_emb_std = all_embeddings.std(0)
68
+
69
+ print('[name_emb_mean loaded] shape =', name_emb_mean.shape,
70
+ 'max:', name_emb_mean.max(),
71
+ 'min={}', name_emb_mean.min())
72
+
73
+ return name_emb_mean, name_emb_std
74
+
models/embedding_manager.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+ from einops import rearrange
4
+ import numpy as np
5
+ from typing import List
6
+ from models.id_embedding.helpers import get_rep_pos, shift_tensor_dim0
7
+ from models.id_embedding.meta_net import StyleVectorizer
8
+ from models.celeb_embeddings import _get_celeb_embeddings_basis
9
+
10
+ from functools import partial
11
+ import torch.nn.functional as F
12
+ import torch.nn as nn
13
+ import torch.nn.init as init
14
+
15
+
16
+ DEFAULT_PLACEHOLDER_TOKEN = ["*"]
17
+
18
+ PROGRESSIVE_SCALE = 2000
19
+
20
+ def get_clip_token_for_string(tokenizer, string):
21
+ batch_encoding = tokenizer(string, return_length=True, padding=True, truncation=True, return_overflowing_tokens=False, return_tensors="pt")
22
+ tokens = batch_encoding["input_ids"]
23
+
24
+ return tokens
25
+
26
+
27
+ def get_embedding_for_clip_token(embedder, token):
28
+ return embedder(token.unsqueeze(0))
29
+
30
+
31
+ class EmbeddingManagerId_adain(nn.Module):
32
+ def __init__(
33
+ self,
34
+ tokenizer,
35
+ text_encoder,
36
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu"),
37
+ experiment_name = "normal_GAN",
38
+ num_embeds_per_token: int = 2,
39
+ loss_type: str = None,
40
+ mlp_depth: int = 2,
41
+ token_dim: int = 1024,
42
+ input_dim: int = 1024,
43
+ **kwargs
44
+ ):
45
+ super().__init__()
46
+ self.device = device
47
+ self.num_es = num_embeds_per_token
48
+
49
+ self.get_token_for_string = partial(get_clip_token_for_string, tokenizer)
50
+ self.get_embedding_for_tkn = partial(get_embedding_for_clip_token, text_encoder.text_model.embeddings)
51
+
52
+
53
+ self.token_dim = token_dim
54
+
55
+ ''' 1. Placeholder mapping dicts '''
56
+ self.placeholder_token = self.get_token_for_string("*")[0][1]
57
+
58
+ if experiment_name == "normal_GAN":
59
+ self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names.txt")
60
+ elif experiment_name == "man_GAN":
61
+ self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names_man.txt")
62
+ elif experiment_name == "woman_GAN":
63
+ self.celeb_embeddings_mean, self.celeb_embeddings_std = _get_celeb_embeddings_basis(tokenizer, text_encoder, "datasets_face/good_names_woman.txt")
64
+ else:
65
+ print("Hello, please notice this ^_^")
66
+ assert 0
67
+ print("now experiment_name:", experiment_name)
68
+
69
+ self.celeb_embeddings_mean = self.celeb_embeddings_mean.to(device)
70
+ self.celeb_embeddings_std = self.celeb_embeddings_std.to(device)
71
+
72
+ self.name_projection_layer = StyleVectorizer(input_dim, self.token_dim * self.num_es, depth=mlp_depth, lr_mul=0.1)
73
+ self.embedding_discriminator = Embedding_discriminator(self.token_dim * self.num_es, dropout_rate = 0.2)
74
+
75
+ self.adain_mode = 0
76
+
77
+ def forward(
78
+ self,
79
+ tokenized_text,
80
+ embedded_text,
81
+ name_batch,
82
+ random_embeddings = None,
83
+ timesteps = None,
84
+ ):
85
+
86
+ if tokenized_text is not None:
87
+ batch_size, n, device = *tokenized_text.shape, tokenized_text.device
88
+ other_return_dict = {}
89
+
90
+ if random_embeddings is not None:
91
+ mlp_output_embedding = self.name_projection_layer(random_embeddings)
92
+ total_embedding = mlp_output_embedding.view(mlp_output_embedding.shape[0], 2, 1024)
93
+
94
+ if self.adain_mode == 0:
95
+ adained_total_embedding = total_embedding * self.celeb_embeddings_std + self.celeb_embeddings_mean
96
+ else:
97
+ adained_total_embedding = total_embedding
98
+
99
+ other_return_dict["total_embedding"] = total_embedding
100
+ other_return_dict["adained_total_embedding"] = adained_total_embedding
101
+
102
+ if name_batch is not None:
103
+ if isinstance(name_batch, list):
104
+ name_tokens = self.get_token_for_string(name_batch)[:, 1:3]
105
+ name_embeddings = self.get_embedding_for_tkn(name_tokens.to(random_embeddings.device))[0]
106
+
107
+ other_return_dict["name_embeddings"] = name_embeddings
108
+ else:
109
+ assert 0
110
+
111
+ if tokenized_text is not None:
112
+ placeholder_pos = get_rep_pos(tokenized_text,
113
+ [self.placeholder_token])
114
+ placeholder_pos = np.array(placeholder_pos)
115
+ if len(placeholder_pos) != 0:
116
+ batch_size = adained_total_embedding.shape[0]
117
+ end_index = min(batch_size, placeholder_pos.shape[0])
118
+ embedded_text[placeholder_pos[:, 0], placeholder_pos[:, 1]] = adained_total_embedding[:end_index,0,:]
119
+ embedded_text[placeholder_pos[:, 0], placeholder_pos[:, 1] + 1] = adained_total_embedding[:end_index,1,:]
120
+
121
+ return embedded_text, other_return_dict
122
+
123
+
124
+
125
+ def load(self, ckpt_path):
126
+ ckpt = torch.load(ckpt_path, map_location='cuda')
127
+ if ckpt.get("name_projection_layer") is not None:
128
+ self.name_projection_layer = ckpt.get("name_projection_layer").float()
129
+
130
+ print('[Embedding Manager] weights loaded.')
131
+
132
+
133
+
134
+ def save(self, ckpt_path):
135
+ save_dict = {}
136
+ save_dict["name_projection_layer"] = self.name_projection_layer
137
+
138
+ torch.save(save_dict, ckpt_path)
139
+
140
+
141
+ def trainable_projection_parameters(self):
142
+ trainable_list = []
143
+ trainable_list.extend(list(self.name_projection_layer.parameters()))
144
+
145
+ return trainable_list
146
+
147
+
148
+
149
+ class Embedding_discriminator(nn.Module):
150
+ def __init__(self, input_size, dropout_rate):
151
+ super(Embedding_discriminator, self).__init__()
152
+ self.input_size = input_size
153
+
154
+ self.fc1 = nn.Linear(input_size, 512)
155
+ self.fc2 = nn.Linear(512, 256)
156
+ self.fc3 = nn.Linear(256, 1)
157
+
158
+ self.LayerNorm1 = nn.LayerNorm(512)
159
+ self.LayerNorm2 = nn.LayerNorm(256)
160
+
161
+ self.leaky_relu = nn.LeakyReLU(0.2)
162
+
163
+ self.dropout_rate = dropout_rate
164
+ if self.dropout_rate > 0:
165
+ self.dropout1 = nn.Dropout(dropout_rate)
166
+ self.dropout2 = nn.Dropout(dropout_rate)
167
+
168
+ def forward(self, input):
169
+ x = input.view(-1, self.input_size)
170
+
171
+ if self.dropout_rate > 0:
172
+ x = self.leaky_relu(self.dropout1(self.fc1(x)))
173
+ else:
174
+ x = self.leaky_relu(self.fc1(x))
175
+
176
+ if self.dropout_rate > 0:
177
+ x = self.leaky_relu(self.dropout2(self.fc2(x)))
178
+ else:
179
+ x = self.leaky_relu(self.fc2(x))
180
+
181
+ x = self.fc3(x)
182
+
183
+ return x
184
+
185
+
186
+ def save(self, ckpt_path):
187
+ save_dict = {}
188
+
189
+ save_dict["fc1"] = self.fc1
190
+ save_dict["fc2"] = self.fc2
191
+ save_dict["fc3"] = self.fc3
192
+ save_dict["LayerNorm1"] = self.LayerNorm1
193
+ save_dict["LayerNorm2"] = self.LayerNorm2
194
+ save_dict["leaky_relu"] = self.leaky_relu
195
+ save_dict["dropout1"] = self.dropout1
196
+ save_dict["dropout2"] = self.dropout2
197
+
198
+ torch.save(save_dict, ckpt_path)
199
+
200
+ def load(self, ckpt_path):
201
+ ckpt = torch.load(ckpt_path, map_location='cuda')
202
+
203
+ if ckpt.get("first_name_proj_layer") is not None:
204
+ self.fc1 = ckpt.get("fc1").float()
205
+ self.fc2 = ckpt.get("fc2").float()
206
+ self.fc3 = ckpt.get("fc3").float()
207
+ self.LayerNorm1 = ckpt.get("LayerNorm1").float()
208
+ self.LayerNorm2 = ckpt.get("LayerNorm2").float()
209
+ self.leaky_relu = ckpt.get("leaky_relu").float()
210
+ self.dropout1 = ckpt.get("dropout1").float()
211
+ self.dropout2 = ckpt.get("dropout2").float()
212
+
213
+ print('[Embedding D] weights loaded.')
214
+
215
+
216
+
217
+
models/id_embedding/__init__.py ADDED
File without changes
models/id_embedding/helpers.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # -*-coding:utf-8-*-
2
+ import torch
3
+ import numpy as np
4
+ from typing import List
5
+
6
+
7
+ def get_rep_pos(tokenized: torch.Tensor, rep_tokens: list):
8
+ pos_list = []
9
+ for token in rep_tokens:
10
+ pos_list = torch.stack(torch.where(tokenized == token)).T.tolist()
11
+ return pos_list
12
+
13
+
14
+ def shift_tensor_dim0(ori: torch.Tensor, r_pos: List[np.ndarray], reps: int):
15
+ assert reps >= 1
16
+ device = ori.device
17
+ d = ori.shape[0]
18
+ offset = np.zeros(d, dtype=np.int64)
19
+ r_pos_cat = np.concatenate(r_pos)
20
+ for p in r_pos_cat:
21
+ offset[p + 1:] += (reps - 1)
22
+
23
+ r_cnt = r_pos_cat.shape[0]
24
+ target_pos = (np.arange(d) + offset)[:d - r_cnt * (reps - 1)]
25
+ ori[target_pos] = ori[np.arange(target_pos.shape[0])]
26
+
27
+ rep_final_pos: np.ndarray = target_pos[r_pos_cat].repeat(reps) + np.tile(np.arange(reps), r_cnt)
28
+ ori[rep_final_pos] = ori[target_pos[r_pos_cat].repeat(reps)]
29
+
30
+ rep_final_pos_list = []
31
+ lo = 0
32
+ for i in range(len(r_pos)):
33
+ r_one_times = r_pos[i].shape[0]
34
+ r_one_nums = r_one_times * reps
35
+ rep_final_pos_list.append(rep_final_pos[lo: lo + r_one_nums].reshape(r_one_times, reps))
36
+ lo += r_one_nums
37
+ return ori, rep_final_pos_list
38
+
39
+
40
+ def _test_get_rep_pos():
41
+ tokenized = torch.LongTensor([0, 1, 2, 2, 3, 4, 5, 6, 7, 99] + [99] * 20)
42
+ print('[from]:', tokenized)
43
+ rep_tokens = [2, 6]
44
+ rep_times = 2
45
+
46
+ rep_pos = get_rep_pos(tokenized, rep_tokens)
47
+ print('[rep_pos]:', rep_pos)
48
+ res, rep_pos_final = shift_tensor_dim0(tokenized, rep_pos, rep_times)
49
+ print('[to]:', res)
50
+ print('[final pos]:', rep_pos_final)
51
+
52
+
53
+ def _test_shift_tensor_dim0():
54
+ embedded = torch.arange(20)
55
+ print(embedded)
56
+ pos = np.array([3, 6, 8])
57
+ times = 1
58
+ output = shift_tensor_dim0(embedded, pos, times)
59
+ print(output)
60
+
61
+
62
+ if __name__ == "__main__":
63
+ _test_get_rep_pos()
models/id_embedding/meta_net.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import kornia
5
+ from einops import rearrange
6
+ import torch.nn.init as init
7
+
8
+ def leaky_relu(p=0.2):
9
+ return nn.LeakyReLU(p, inplace=True)
10
+
11
+ class Residual(nn.Module):
12
+ def __init__(self,
13
+ fn):
14
+ super().__init__()
15
+ self.fn = fn
16
+
17
+ def forward(self, x, **kwargs):
18
+ return x + self.fn(x, **kwargs)
19
+
20
+
21
+ class EqualLinear(nn.Module):
22
+ def __init__(self, in_dim, out_dim, lr_mul=1, bias=True, pre_norm=False, activate = False):
23
+ super().__init__()
24
+ self.weight = nn.Parameter(torch.randn(out_dim, in_dim))
25
+ if bias:
26
+ self.bias = nn.Parameter(torch.zeros(out_dim))
27
+
28
+ self.lr_mul = lr_mul
29
+
30
+ self.pre_norm = pre_norm
31
+ if pre_norm:
32
+ self.norm = nn.LayerNorm(in_dim, eps=1e-5)
33
+ self.activate = activate
34
+ if self.activate == True:
35
+ self.non_linear = leaky_relu()
36
+
37
+ def forward(self, input):
38
+ if hasattr(self, 'pre_norm') and self.pre_norm:
39
+ out = self.norm(input)
40
+ out = F.linear(out, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
41
+ else:
42
+ out = F.linear(input, self.weight * self.lr_mul, bias=self.bias * self.lr_mul)
43
+
44
+ if self.activate == True:
45
+ out = self.non_linear(out)
46
+ return out
47
+
48
+
49
+ class StyleVectorizer(nn.Module):
50
+ def __init__(self, dim_in, dim_out, depth, lr_mul = 0.1):
51
+ super().__init__()
52
+
53
+ layers = []
54
+ for i in range(depth):
55
+ if i == 0:
56
+ layers.extend([EqualLinear(dim_in, dim_out, lr_mul, pre_norm=False, activate = True)])
57
+ elif i == depth - 1:
58
+ layers.extend([EqualLinear(dim_out, dim_out, lr_mul, pre_norm=True, activate = False)])
59
+ else:
60
+ layers.extend([Residual(EqualLinear(dim_out, dim_out, lr_mul, pre_norm=True, activate = True))])
61
+
62
+ self.net = nn.Sequential(*layers)
63
+ self.norm = nn.LayerNorm(dim_out, eps=1e-5)
64
+
65
+ def forward(self, x):
66
+ return self.norm(self.net(x))
67
+
requirements.txt CHANGED
@@ -1,6 +1,15 @@
1
- accelerate
2
- diffusers
3
- invisible_watermark
4
- torch
5
- transformers
6
- xformers
 
 
 
 
 
 
 
 
 
 
1
+ torch==2.0.1
2
+ torchvision==0.15.2
3
+ diffusers==0.23.0
4
+ transformers==4.33.2
5
+ xformers==0.0.20
6
+ accelerate==0.23.0
7
+ omegaconf
8
+ clip==0.2.0
9
+ einops
10
+ kornia==0.6.12
11
+ opencv-python
12
+ opencv-contrib-python
13
+ gradio
14
+ huggingface_hub==0.22.2
15
+ IPython
test.ipynb ADDED
@@ -0,0 +1,243 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import os\n",
11
+ "from transformers import ViTModel, ViTImageProcessor\n",
12
+ "from utils import text_encoder_forward\n",
13
+ "from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n",
14
+ "from utils import latents_to_images, downsampling, merge_and_save_images\n",
15
+ "from omegaconf import OmegaConf\n",
16
+ "from accelerate.utils import set_seed\n",
17
+ "from tqdm import tqdm\n",
18
+ "from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput\n",
19
+ "from PIL import Image\n",
20
+ "from models.celeb_embeddings import embedding_forward\n",
21
+ "import models.embedding_manager\n",
22
+ "import importlib\n",
23
+ "\n",
24
+ "# seed = 42\n",
25
+ "# set_seed(seed) \n",
26
+ "# torch.cuda.set_device(0)\n",
27
+ "\n",
28
+ "# set your sd2.1 path\n",
29
+ "model_path = \"/home/user/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6\"\n",
30
+ "pipe = StableDiffusionPipeline.from_pretrained(model_path) \n",
31
+ "pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n",
32
+ "pipe = pipe.to(\"cuda\")\n",
33
+ "\n",
34
+ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
35
+ "\n",
36
+ "vae = pipe.vae\n",
37
+ "unet = pipe.unet\n",
38
+ "text_encoder = pipe.text_encoder\n",
39
+ "tokenizer = pipe.tokenizer\n",
40
+ "scheduler = pipe.scheduler\n",
41
+ "\n",
42
+ "input_dim = 64\n",
43
+ "\n",
44
+ "experiment_name = \"normal_GAN\" # \"normal_GAN\", \"man_GAN\", \"woman_GAN\" , \n",
45
+ "if experiment_name == \"normal_GAN\":\n",
46
+ " steps = 10000\n",
47
+ "elif experiment_name == \"man_GAN\":\n",
48
+ " steps = 7000\n",
49
+ "elif experiment_name == \"woman_GAN\":\n",
50
+ " steps = 6000\n",
51
+ "else:\n",
52
+ " print(\"Hello, please notice this ^_^\")\n",
53
+ " assert 0\n",
54
+ "\n",
55
+ "\n",
56
+ "original_forward = text_encoder.text_model.embeddings.forward\n",
57
+ "text_encoder.text_model.embeddings.forward = embedding_forward.__get__(text_encoder.text_model.embeddings)\n",
58
+ "embedding_manager_config = OmegaConf.load(\"datasets_face/identity_space.yaml\")\n",
59
+ "Embedding_Manager = models.embedding_manager.EmbeddingManagerId_adain( \n",
60
+ " tokenizer,\n",
61
+ " text_encoder,\n",
62
+ " device = device,\n",
63
+ " training = True,\n",
64
+ " experiment_name = experiment_name, \n",
65
+ " num_embeds_per_token = embedding_manager_config.model.personalization_config.params.num_embeds_per_token, \n",
66
+ " token_dim = embedding_manager_config.model.personalization_config.params.token_dim,\n",
67
+ " mlp_depth = embedding_manager_config.model.personalization_config.params.mlp_depth,\n",
68
+ " loss_type = embedding_manager_config.model.personalization_config.params.loss_type,\n",
69
+ " vit_out_dim = input_dim,\n",
70
+ ")\n",
71
+ "embedding_path = os.path.join(\"training_weight\", experiment_name, \"embeddings_manager-{}.pt\".format(str(steps)))\n",
72
+ "Embedding_Manager.load(embedding_path)\n",
73
+ "text_encoder.text_model.embeddings.forward = original_forward\n",
74
+ "\n",
75
+ "print(\"finish init\")"
76
+ ]
77
+ },
78
+ {
79
+ "cell_type": "markdown",
80
+ "metadata": {},
81
+ "source": [
82
+ "1. create a new character and test with prompts"
83
+ ]
84
+ },
85
+ {
86
+ "cell_type": "code",
87
+ "execution_count": null,
88
+ "metadata": {},
89
+ "outputs": [],
90
+ "source": [
91
+ "# sample a z\n",
92
+ "random_embedding = torch.randn(1, 1, input_dim).to(device)\n",
93
+ "\n",
94
+ "# map z to pseudo identity embeddings\n",
95
+ "_, emb_dict = Embedding_Manager(tokenized_text=None, embedded_text=None, name_batch=None, random_embeddings = random_embedding, timesteps = None,)\n",
96
+ "\n",
97
+ "test_emb = emb_dict[\"adained_total_embedding\"].to(device)\n",
98
+ "\n",
99
+ "v1_emb = test_emb[:, 0]\n",
100
+ "v2_emb = test_emb[:, 1]\n",
101
+ "embeddings = [v1_emb, v2_emb]\n",
102
+ "\n",
103
+ "index = \"0000\"\n",
104
+ "save_dir = os.path.join(\"test_results/\" + experiment_name, index)\n",
105
+ "os.makedirs(save_dir, exist_ok=True)\n",
106
+ "test_emb_path = os.path.join(save_dir, \"id_embeddings.pt\")\n",
107
+ "torch.save(test_emb, test_emb_path)\n",
108
+ "\n",
109
+ "'''insert into tokenizer & embedding layer'''\n",
110
+ "tokens = [\"v1*\", \"v2*\"]\n",
111
+ "embeddings = [v1_emb, v2_emb]\n",
112
+ "# add tokens and get ids\n",
113
+ "tokenizer.add_tokens(tokens)\n",
114
+ "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
115
+ "\n",
116
+ "# resize token embeddings and set new embeddings\n",
117
+ "text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)\n",
118
+ "for token_id, embedding in zip(token_ids, embeddings):\n",
119
+ " text_encoder.get_input_embeddings().weight.data[token_id] = embedding\n",
120
+ "\n",
121
+ "prompts_list = [\"a photo of v1* v2*, facing to camera, best quality, ultra high res\",\n",
122
+ " \"v1* v2* wearing a Superman outfit, facing to camera, best quality, ultra high res\",\n",
123
+ " \"v1* v2* wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
124
+ " \"v1* v2* wearing a red sweater, facing to camera, best quality, ultra high res\",\n",
125
+ " \"v1* v2* wearing a blue hoodie, facing to camera, best quality, ultra high res\",\n",
126
+ "]\n",
127
+ "\n",
128
+ "for prompt in prompts_list:\n",
129
+ " image = pipe(prompt, guidance_scale = 8.5).images[0]\n",
130
+ " save_img_path = os.path.join(save_dir, prompt.replace(\"v1* v2*\", \"a person\") + '.png')\n",
131
+ " image.save(save_img_path)\n",
132
+ " print(save_img_path)\n"
133
+ ]
134
+ },
135
+ {
136
+ "cell_type": "markdown",
137
+ "metadata": {},
138
+ "source": [
139
+ "2. directly use a chosen generated pseudo identity embeddings"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "metadata": {},
146
+ "outputs": [],
147
+ "source": [
148
+ "# the path of your generated embeddings\n",
149
+ "test_emb_path = \"demo_embeddings/856.pt\" # \"test_results/normal_GAN/0000/id_embeddings.pt\"\n",
150
+ "test_emb = torch.load(test_emb_path).cuda()\n",
151
+ "v1_emb = test_emb[:, 0]\n",
152
+ "v2_emb = test_emb[:, 1]\n",
153
+ "\n",
154
+ "\n",
155
+ "index = \"chosen_index\"\n",
156
+ "save_dir = os.path.join(\"test_results/\" + experiment_name, index)\n",
157
+ "os.makedirs(save_dir, exist_ok=True)\n",
158
+ "\n",
159
+ "\n",
160
+ "'''insert into tokenizer & embedding layer'''\n",
161
+ "tokens = [\"v1*\", \"v2*\"]\n",
162
+ "embeddings = [v1_emb, v2_emb]\n",
163
+ "# add tokens and get ids\n",
164
+ "tokenizer.add_tokens(tokens)\n",
165
+ "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
166
+ "\n",
167
+ "# resize token embeddings and set new embeddings\n",
168
+ "text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)\n",
169
+ "for token_id, embedding in zip(token_ids, embeddings):\n",
170
+ " text_encoder.get_input_embeddings().weight.data[token_id] = embedding\n",
171
+ "\n",
172
+ "prompts_list = [\"a photo of v1* v2*, facing to camera, best quality, ultra high res\",\n",
173
+ " \"v1* v2* wearing a Superman outfit, facing to camera, best quality, ultra high res\",\n",
174
+ " \"v1* v2* wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
175
+ " \"v1* v2* wearing a red sweater, facing to camera, best quality, ultra high res\",\n",
176
+ " \"v1* v2* wearing a purple wizard outfit, facing to camera, best quality, ultra high res\",\n",
177
+ " \"v1* v2* wearing a blue hoodie, facing to camera, best quality, ultra high res\",\n",
178
+ " \"v1* v2* wearing headphones, facing to camera, best quality, ultra high res\",\n",
179
+ " \"v1* v2* with red hair, facing to camera, best quality, ultra high res\",\n",
180
+ " \"v1* v2* wearing headphones with red hair, facing to camera, best quality, ultra high res\",\n",
181
+ " \"v1* v2* wearing a Christmas hat, facing to camera, best quality, ultra high res\",\n",
182
+ " \"v1* v2* wearing sunglasses, facing to camera, best quality, ultra high res\",\n",
183
+ " \"v1* v2* wearing sunglasses and necklace, facing to camera, best quality, ultra high res\",\n",
184
+ " \"v1* v2* wearing a blue cap, facing to camera, best quality, ultra high res\",\n",
185
+ " \"v1* v2* wearing a doctoral cap, facing to camera, best quality, ultra high res\",\n",
186
+ " \"v1* v2* with white hair, wearing glasses, facing to camera, best quality, ultra high res\",\n",
187
+ " \"v1* v2* in a helmet and vest riding a motorcycle, facing to camera, best quality, ultra high res\",\n",
188
+ " \"v1* v2* holding a bottle of red wine, facing to camera, best quality, ultra high res\",\n",
189
+ " \"v1* v2* driving a bus in the desert, facing to camera, best quality, ultra high res\",\n",
190
+ " \"v1* v2* playing basketball, facing to camera, best quality, ultra high res\",\n",
191
+ " \"v1* v2* playing the violin, facing to camera, best quality, ultra high res\",\n",
192
+ " \"v1* v2* piloting a spaceship, facing to camera, best quality, ultra high res\",\n",
193
+ " \"v1* v2* riding a horse, facing to camera, best quality, ultra high res\",\n",
194
+ " \"v1* v2* coding in front of a computer, facing to camera, best quality, ultra high res\",\n",
195
+ " \"v1* v2* laughing on the lawn, facing to camera, best quality, ultra high res\",\n",
196
+ " \"v1* v2* frowning at the camera, facing to camera, best quality, ultra high res\",\n",
197
+ " \"v1* v2* happily smiling, looking at the camera, facing to camera, best quality, ultra high res\",\n",
198
+ " \"v1* v2* crying disappointedly, with tears flowing, facing to camera, best quality, ultra high res\",\n",
199
+ " \"v1* v2* wearing sunglasses, facing to camera, best quality, ultra high res\",\n",
200
+ " \"v1* v2* playing the guitar in the view of left side, facing to camera, best quality, ultra high res\",\n",
201
+ " \"v1* v2* holding a bottle of red wine, upper body, facing to camera, best quality, ultra high res\",\n",
202
+ " \"v1* v2* wearing sunglasses and necklace, close-up, in the view of right side, facing to camera, best quality, ultra high res\",\n",
203
+ " \"v1* v2* riding a horse, in the view of the top, facing to camera, best quality, ultra high res\",\n",
204
+ " \"v1* v2* wearing a doctoral cap, upper body, with the left side of the face facing the camera, best quality, ultra high res\",\n",
205
+ " \"v1* v2* crying disappointedly, with tears flowing, with left side of the face facing the camera, best quality, ultra high res\",\n",
206
+ " \"v1* v2* sitting in front of the camera, with a beautiful purple sunset at the beach in the background, best quality, ultra high res\",\n",
207
+ " \"v1* v2* swimming in the pool, facing to camera, best quality, ultra high res\",\n",
208
+ " \"v1* v2* climbing a mountain, facing to camera, best quality, ultra high res\",\n",
209
+ " \"v1* v2* skiing on the snowy mountain, facing to camera, best quality, ultra high res\",\n",
210
+ " \"v1* v2* in the snow, facing to camera, best quality, ultra high res\",\n",
211
+ " \"v1* v2* in space wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
212
+ "]\n",
213
+ "\n",
214
+ "for prompt in prompts_list:\n",
215
+ " image = pipe(prompt, guidance_scale = 8.5).images[0]\n",
216
+ " save_img_path = os.path.join(save_dir, prompt.replace(\"v1* v2*\", \"a person\") + '.png')\n",
217
+ " image.save(save_img_path)\n",
218
+ " print(save_img_path)"
219
+ ]
220
+ }
221
+ ],
222
+ "metadata": {
223
+ "kernelspec": {
224
+ "display_name": "lbl",
225
+ "language": "python",
226
+ "name": "python3"
227
+ },
228
+ "language_info": {
229
+ "codemirror_mode": {
230
+ "name": "ipython",
231
+ "version": 3
232
+ },
233
+ "file_extension": ".py",
234
+ "mimetype": "text/x-python",
235
+ "name": "python",
236
+ "nbconvert_exporter": "python",
237
+ "pygments_lexer": "ipython3",
238
+ "version": "3.8.5"
239
+ }
240
+ },
241
+ "nbformat": 4,
242
+ "nbformat_minor": 2
243
+ }
test_create_many_characters.ipynb ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "metadata": {},
7
+ "outputs": [],
8
+ "source": [
9
+ "import torch\n",
10
+ "import os\n",
11
+ "from transformers import ViTModel, ViTImageProcessor\n",
12
+ "from utils import text_encoder_forward\n",
13
+ "from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n",
14
+ "from utils import latents_to_images, downsampling, merge_and_save_images\n",
15
+ "from omegaconf import OmegaConf\n",
16
+ "from accelerate.utils import set_seed\n",
17
+ "from tqdm import tqdm\n",
18
+ "from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput\n",
19
+ "from PIL import Image\n",
20
+ "from models.celeb_embeddings import embedding_forward\n",
21
+ "import models.embedding_manager\n",
22
+ "import importlib\n",
23
+ "\n",
24
+ "# seed = 42\n",
25
+ "# set_seed(seed) \n",
26
+ "# torch.cuda.set_device(0)\n",
27
+ "\n",
28
+ "# set your sd2.1 path\n",
29
+ "model_path = \"/home/user/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6\"\n",
30
+ "pipe = StableDiffusionPipeline.from_pretrained(model_path) \n",
31
+ "pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n",
32
+ "pipe = pipe.to(\"cuda\")\n",
33
+ "\n",
34
+ "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
35
+ "\n",
36
+ "vae = pipe.vae\n",
37
+ "unet = pipe.unet\n",
38
+ "text_encoder = pipe.text_encoder\n",
39
+ "tokenizer = pipe.tokenizer\n",
40
+ "scheduler = pipe.scheduler\n",
41
+ "\n",
42
+ "input_dim = 64\n",
43
+ "\n",
44
+ "experiment_name = \"normal_GAN\" # \"normal_GAN\", \"man_GAN\", \"woman_GAN\" , \n",
45
+ "if experiment_name == \"normal_GAN\":\n",
46
+ " steps = 10000\n",
47
+ "elif experiment_name == \"man_GAN\":\n",
48
+ " steps = 7000\n",
49
+ "elif experiment_name == \"woman_GAN\":\n",
50
+ " steps = 6000\n",
51
+ "else:\n",
52
+ " print(\"Hello, please notice this ^_^\")\n",
53
+ " assert 0\n",
54
+ "\n",
55
+ "\n",
56
+ "original_forward = text_encoder.text_model.embeddings.forward\n",
57
+ "text_encoder.text_model.embeddings.forward = embedding_forward.__get__(text_encoder.text_model.embeddings)\n",
58
+ "embedding_manager_config = OmegaConf.load(\"datasets_face/identity_space.yaml\")\n",
59
+ "Embedding_Manager = models.embedding_manager.EmbeddingManagerId_adain( \n",
60
+ " tokenizer,\n",
61
+ " text_encoder,\n",
62
+ " device = device,\n",
63
+ " training = True,\n",
64
+ " experiment_name = experiment_name, \n",
65
+ " num_embeds_per_token = embedding_manager_config.model.personalization_config.params.num_embeds_per_token, \n",
66
+ " token_dim = embedding_manager_config.model.personalization_config.params.token_dim,\n",
67
+ " mlp_depth = embedding_manager_config.model.personalization_config.params.mlp_depth,\n",
68
+ " loss_type = embedding_manager_config.model.personalization_config.params.loss_type,\n",
69
+ " vit_out_dim = input_dim,\n",
70
+ ")\n",
71
+ "\n",
72
+ "\n",
73
+ "embedding_path = os.path.join(\"training_weight\", experiment_name, \"embeddings_manager-{}.pt\".format(str(steps)))\n",
74
+ "Embedding_Manager.load(embedding_path)\n",
75
+ "text_encoder.text_model.embeddings.forward = original_forward\n",
76
+ "\n",
77
+ "print(\"finish init\")"
78
+ ]
79
+ },
80
+ {
81
+ "cell_type": "markdown",
82
+ "metadata": {},
83
+ "source": [
84
+ "1. create a new character and test with prompts"
85
+ ]
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "execution_count": null,
90
+ "metadata": {},
91
+ "outputs": [],
92
+ "source": [
93
+ "# sample a z\n",
94
+ "for index in range(100):\n",
95
+ "\n",
96
+ " random_embedding = torch.randn(1, 1, input_dim).to(device)\n",
97
+ "\n",
98
+ " # map z to pseudo identity embeddings\n",
99
+ " _, emb_dict = Embedding_Manager(tokenized_text=None, embedded_text=None, name_batch=None, random_embeddings = random_embedding, timesteps = None,)\n",
100
+ "\n",
101
+ " test_emb = emb_dict[\"adained_total_embedding\"].to(device)\n",
102
+ "\n",
103
+ " v1_emb = test_emb[:, 0]\n",
104
+ " v2_emb = test_emb[:, 1]\n",
105
+ " embeddings = [v1_emb, v2_emb]\n",
106
+ "\n",
107
+ " save_dir = os.path.join(\"test_results/\" + experiment_name, str(index))\n",
108
+ " os.makedirs(save_dir, exist_ok=True) \n",
109
+ " test_emb_path = os.path.join(save_dir, \"id_embeddings.pt\")\n",
110
+ " torch.save(test_emb, test_emb_path)\n",
111
+ "\n",
112
+ "\n",
113
+ "\n",
114
+ " '''insert into tokenizer & embedding layer'''\n",
115
+ " tokens = [\"v1*\", \"v2*\"]\n",
116
+ " embeddings = [v1_emb, v2_emb]\n",
117
+ " # add tokens and get ids\n",
118
+ " tokenizer.add_tokens(tokens)\n",
119
+ " token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
120
+ "\n",
121
+ " # resize token embeddings and set new embeddings\n",
122
+ " text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)\n",
123
+ " for token_id, embedding in zip(token_ids, embeddings):\n",
124
+ " text_encoder.get_input_embeddings().weight.data[token_id] = embedding\n",
125
+ "\n",
126
+ " prompts_list = [\"a photo of v1* v2*, facing to camera, best quality, ultra high res\",\n",
127
+ " \"v1* v2* wearing a Superman outfit, facing to camera, best quality, ultra high res\",\n",
128
+ " \"v1* v2* wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
129
+ " \"v1* v2* wearing a red sweater, facing to camera, best quality, ultra high res\",\n",
130
+ " \"v1* v2* wearing a blue hoodie, facing to camera, best quality, ultra high res\",\n",
131
+ " ]\n",
132
+ "\n",
133
+ " for prompt in prompts_list:\n",
134
+ " image = pipe(prompt, guidance_scale = 8.5).images[0]\n",
135
+ " save_img_path = os.path.join(save_dir, prompt.replace(\"v1* v2*\", \"a person\") + '.png')\n",
136
+ " image.save(save_img_path)\n",
137
+ " print(save_img_path)\n"
138
+ ]
139
+ },
140
+ {
141
+ "cell_type": "markdown",
142
+ "metadata": {},
143
+ "source": [
144
+ "2. directly use a chosen generated pseudo identity embeddings"
145
+ ]
146
+ },
147
+ {
148
+ "cell_type": "code",
149
+ "execution_count": null,
150
+ "metadata": {},
151
+ "outputs": [],
152
+ "source": [
153
+ "# the path of your generated embeddings\n",
154
+ "test_emb_path = \"test_results/normal_GAN/0000/id_embeddings.pt\"\n",
155
+ "test_emb = torch.load(test_emb_path).cuda()\n",
156
+ "v1_emb = test_emb[:, 0]\n",
157
+ "v2_emb = test_emb[:, 1]\n",
158
+ "\n",
159
+ "\n",
160
+ "index = \"chosen_index\"\n",
161
+ "save_dir = os.path.join(\"test_results/\" + experiment_name, index)\n",
162
+ "os.makedirs(save_dir, exist_ok=True)\n",
163
+ "\n",
164
+ "\n",
165
+ "'''insert into tokenizer & embedding layer'''\n",
166
+ "tokens = [\"v1*\", \"v2*\"]\n",
167
+ "embeddings = [v1_emb, v2_emb]\n",
168
+ "# add tokens and get ids\n",
169
+ "tokenizer.add_tokens(tokens)\n",
170
+ "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
171
+ "\n",
172
+ "# resize token embeddings and set new embeddings\n",
173
+ "text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)\n",
174
+ "for token_id, embedding in zip(token_ids, embeddings):\n",
175
+ " text_encoder.get_input_embeddings().weight.data[token_id] = embedding\n",
176
+ "\n",
177
+ "prompts_list = [\"a photo of v1* v2*, facing to camera, best quality, ultra high res\",\n",
178
+ " \"v1* v2* wearing a Superman outfit, facing to camera, best quality, ultra high res\",\n",
179
+ " \"v1* v2* wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
180
+ " \"v1* v2* wearing a red sweater, facing to camera, best quality, ultra high res\",\n",
181
+ " \"v1* v2* wearing a purple wizard outfit, facing to camera, best quality, ultra high res\",\n",
182
+ " \"v1* v2* wearing a blue hoodie, facing to camera, best quality, ultra high res\",\n",
183
+ " \"v1* v2* wearing headphones, facing to camera, best quality, ultra high res\",\n",
184
+ " \"v1* v2* with red hair, facing to camera, best quality, ultra high res\",\n",
185
+ " \"v1* v2* wearing headphones with red hair, facing to camera, best quality, ultra high res\",\n",
186
+ " \"v1* v2* wearing a Christmas hat, facing to camera, best quality, ultra high res\",\n",
187
+ " \"v1* v2* wearing sunglasses, facing to camera, best quality, ultra high res\",\n",
188
+ " \"v1* v2* wearing sunglasses and necklace, facing to camera, best quality, ultra high res\",\n",
189
+ " \"v1* v2* wearing a blue cap, facing to camera, best quality, ultra high res\",\n",
190
+ " \"v1* v2* wearing a doctoral cap, facing to camera, best quality, ultra high res\",\n",
191
+ " \"v1* v2* with white hair, wearing glasses, facing to camera, best quality, ultra high res\",\n",
192
+ " \"v1* v2* in a helmet and vest riding a motorcycle, facing to camera, best quality, ultra high res\",\n",
193
+ " \"v1* v2* holding a bottle of red wine, facing to camera, best quality, ultra high res\",\n",
194
+ " \"v1* v2* driving a bus in the desert, facing to camera, best quality, ultra high res\",\n",
195
+ " \"v1* v2* playing basketball, facing to camera, best quality, ultra high res\",\n",
196
+ " \"v1* v2* playing the violin, facing to camera, best quality, ultra high res\",\n",
197
+ " \"v1* v2* piloting a spaceship, facing to camera, best quality, ultra high res\",\n",
198
+ " \"v1* v2* riding a horse, facing to camera, best quality, ultra high res\",\n",
199
+ " \"v1* v2* coding in front of a computer, facing to camera, best quality, ultra high res\",\n",
200
+ " \"v1* v2* laughing on the lawn, facing to camera, best quality, ultra high res\",\n",
201
+ " \"v1* v2* frowning at the camera, facing to camera, best quality, ultra high res\",\n",
202
+ " \"v1* v2* happily smiling, looking at the camera, facing to camera, best quality, ultra high res\",\n",
203
+ " \"v1* v2* crying disappointedly, with tears flowing, facing to camera, best quality, ultra high res\",\n",
204
+ " \"v1* v2* wearing sunglasses, facing to camera, best quality, ultra high res\",\n",
205
+ " \"v1* v2* playing the guitar in the view of left side, facing to camera, best quality, ultra high res\",\n",
206
+ " \"v1* v2* holding a bottle of red wine, upper body, facing to camera, best quality, ultra high res\",\n",
207
+ " \"v1* v2* wearing sunglasses and necklace, close-up, in the view of right side, facing to camera, best quality, ultra high res\",\n",
208
+ " \"v1* v2* riding a horse, in the view of the top, facing to camera, best quality, ultra high res\",\n",
209
+ " \"v1* v2* wearing a doctoral cap, upper body, with the left side of the face facing the camera, best quality, ultra high res\",\n",
210
+ " \"v1* v2* crying disappointedly, with tears flowing, with left side of the face facing the camera, best quality, ultra high res\",\n",
211
+ " \"v1* v2* sitting in front of the camera, with a beautiful purple sunset at the beach in the background, best quality, ultra high res\",\n",
212
+ " \"v1* v2* swimming in the pool, facing to camera, best quality, ultra high res\",\n",
213
+ " \"v1* v2* climbing a mountain, facing to camera, best quality, ultra high res\",\n",
214
+ " \"v1* v2* skiing on the snowy mountain, facing to camera, best quality, ultra high res\",\n",
215
+ " \"v1* v2* in the snow, facing to camera, best quality, ultra high res\",\n",
216
+ " \"v1* v2* in space wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
217
+ "]\n",
218
+ "\n",
219
+ "for prompt in prompts_list:\n",
220
+ " image = pipe(prompt, guidance_scale = 8.5).images[0]\n",
221
+ " save_img_path = os.path.join(save_dir, prompt.replace(\"v1* v2*\", \"a person\") + '.png')\n",
222
+ " image.save(save_img_path)\n",
223
+ " print(save_img_path)"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "metadata": {},
230
+ "outputs": [],
231
+ "source": []
232
+ }
233
+ ],
234
+ "metadata": {
235
+ "kernelspec": {
236
+ "display_name": "lbl",
237
+ "language": "python",
238
+ "name": "python3"
239
+ },
240
+ "language_info": {
241
+ "codemirror_mode": {
242
+ "name": "ipython",
243
+ "version": 3
244
+ },
245
+ "file_extension": ".py",
246
+ "mimetype": "text/x-python",
247
+ "name": "python",
248
+ "nbconvert_exporter": "python",
249
+ "pygments_lexer": "ipython3",
250
+ "version": "3.8.5"
251
+ }
252
+ },
253
+ "nbformat": 4,
254
+ "nbformat_minor": 2
255
+ }
train.py ADDED
@@ -0,0 +1,767 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import itertools
3
+ import logging
4
+ import math
5
+ import os
6
+ from pathlib import Path
7
+ import accelerate
8
+ import torch
9
+ import torch.nn.functional as F
10
+ import torch.utils.checkpoint
11
+ import transformers
12
+ from accelerate import Accelerator
13
+ from accelerate.logging import get_logger
14
+ from accelerate.utils import ProjectConfiguration, set_seed
15
+ from packaging import version
16
+ from PIL import Image
17
+ from torch.utils.data import Dataset
18
+ from torchvision import transforms
19
+ from tqdm.auto import tqdm
20
+ from transformers import AutoTokenizer, PretrainedConfig
21
+ import diffusers
22
+ from diffusers import (
23
+ AutoencoderKL,
24
+ DDPMScheduler,
25
+ DiffusionPipeline,
26
+ UNet2DConditionModel,
27
+ StableDiffusionPipeline,
28
+ DPMSolverMultistepScheduler,
29
+ )
30
+ from diffusers.optimization import get_scheduler
31
+ from diffusers.utils.import_utils import is_xformers_available
32
+ import numpy as np
33
+ from omegaconf import OmegaConf
34
+ import random
35
+ from transformers import ViTModel, ViTImageProcessor
36
+ from models.celeb_embeddings import embedding_forward
37
+ from models.embedding_manager import EmbeddingManagerId_adain, Embedding_discriminator
38
+ from datasets_face.face_id import FaceIdDataset
39
+ from utils import text_encoder_forward, set_requires_grad, add_noise_return_paras, latents_to_images, discriminator_r1_loss, discriminator_r1_loss_accelerator, downsampling, GANLoss
40
+ import types
41
+ import torch.nn as nn
42
+ from tqdm import tqdm
43
+ from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput
44
+ import importlib
45
+
46
+ logger = get_logger(__name__)
47
+
48
+ def import_model_class_from_model_name_or_path(pretrained_model_name_or_path: str, revision: str):
49
+ text_encoder_config = PretrainedConfig.from_pretrained(
50
+ pretrained_model_name_or_path,
51
+ subfolder="text_encoder",
52
+ revision=revision,
53
+ )
54
+ model_class = text_encoder_config.architectures[0]
55
+
56
+ if model_class == "CLIPTextModel":
57
+ from transformers import CLIPTextModel
58
+
59
+ return CLIPTextModel
60
+ elif model_class == "RobertaSeriesModelWithTransformation":
61
+ from diffusers.pipelines.alt_diffusion.modeling_roberta_series import RobertaSeriesModelWithTransformation
62
+
63
+ return RobertaSeriesModelWithTransformation
64
+ elif model_class == "T5EncoderModel":
65
+ from transformers import T5EncoderModel
66
+
67
+ return T5EncoderModel
68
+ else:
69
+ raise ValueError(f"{model_class} is not supported.")
70
+
71
+ def parse_args(input_args=None):
72
+ parser = argparse.ArgumentParser(description="Simple example of a script for training Cones 2.")
73
+ parser.add_argument(
74
+ "--embedding_manager_config",
75
+ type=str,
76
+ default="datasets_face/identity_space.yaml",
77
+ help=('config to load the train model and dataset'),
78
+ )
79
+ parser.add_argument(
80
+ "--d_reg_every",
81
+ type=int,
82
+ default=16,
83
+ help="interval for applying r1 regularization"
84
+ )
85
+ parser.add_argument(
86
+ "--r1",
87
+ type=float,
88
+ default=1,
89
+ help="weight of the r1 regularization"
90
+ )
91
+ parser.add_argument(
92
+ "--l_gan_lambda",
93
+ type=float,
94
+ default=1,
95
+ help="Initial learning rate (after the potential warmup period) to use.",
96
+ )
97
+ parser.add_argument(
98
+ "--l_consis_lambda",
99
+ type=float,
100
+ default=8,
101
+ help="Initial learning rate (after the potential warmup period) to use.",
102
+ )
103
+ parser.add_argument(
104
+ "--pretrained_model_name_or_path",
105
+ type=str,
106
+ default="/home/user/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6",
107
+ help="Path to pretrained model or model identifier from huggingface.co/models.",
108
+ )
109
+ parser.add_argument(
110
+ "--pretrained_embedding_manager_path",
111
+ type=str,
112
+ default=None,
113
+ help="pretrained_embedding_manager_path",
114
+ )
115
+ parser.add_argument(
116
+ "--pretrained_embedding_manager_epoch",
117
+ type=str,
118
+ default=800,
119
+ help="pretrained_embedding_manager_epoch",
120
+ )
121
+ parser.add_argument(
122
+ "--revision",
123
+ type=str,
124
+ default=None,
125
+ required=False,
126
+ help=(
127
+ "Revision of pretrained model identifier from huggingface.co/models. Trainable model components should be"
128
+ " float32 precision."
129
+ ),
130
+ )
131
+ parser.add_argument(
132
+ "--tokenizer_name",
133
+ type=str,
134
+ default=None,
135
+ help="Pretrained tokenizer name or path if not the same as model_name",
136
+ )
137
+ parser.add_argument(
138
+ "--output_dir",
139
+ type=str,
140
+ default="training_weight/normal_GAN", # training_weight/woman_GAN training_weight/man_GAN
141
+ help="The output directory where the model predictions and checkpoints will be written.",
142
+ )
143
+ parser.add_argument("--seed", type=int, default= None, help="A seed for reproducible training.")
144
+ parser.add_argument(
145
+ "--resolution",
146
+ type=int,
147
+ default=512,
148
+ help=(
149
+ "The resolution for input images, all the images in the train/validation dataset will be resized to this"
150
+ " resolution"
151
+ ),
152
+ )
153
+ parser.add_argument(
154
+ "--center_crop",
155
+ default=False,
156
+ action="store_true",
157
+ help=(
158
+ "Whether to center crop the input images to the resolution. If not set, the images will be randomly"
159
+ " cropped. The images will be resized to the resolution first before cropping."
160
+ ),
161
+ )
162
+ parser.add_argument(
163
+ "--train_batch_size",
164
+ type=int, default=8,
165
+ help="Batch size (per device) for the training dataloader."
166
+ )
167
+ parser.add_argument(
168
+ "--num_train_epochs",
169
+ type=int,
170
+ default=None
171
+ )
172
+ parser.add_argument(
173
+ "--max_train_steps",
174
+ type=int,
175
+ # default=None,
176
+ default=10001,
177
+ help="Total number of training steps to perform. If provided, overrides num_train_epochs.",
178
+ )
179
+ parser.add_argument(
180
+ "--checkpointing_steps",
181
+ type=int,
182
+ default=1000,
183
+ help=(
184
+ "Save a checkpoint of the training state every X updates. Checkpoints can be used for resuming training via"
185
+ " `--resume_from_checkpoint`. In the case that the checkpoint is better than the final trained model, the"
186
+ " checkpoint can also be used for inference. Using a checkpoint for inference requires separate loading of"
187
+ " the original pipeline and the individual checkpointed model components."
188
+ ),
189
+ )
190
+ parser.add_argument(
191
+ "--resume_from_checkpoint",
192
+ type=str,
193
+ default=None,
194
+ help=(
195
+ "Whether training should be resumed from a previous checkpoint. Use a path saved by"
196
+ ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
197
+ ),
198
+ )
199
+ parser.add_argument(
200
+ "--gradient_accumulation_steps",
201
+ type=int,
202
+ default=1,
203
+ help="Number of updates steps to accumulate before performing a backward/update pass.",
204
+ )
205
+ parser.add_argument(
206
+ "--gradient_checkpointing",
207
+ action="store_true",
208
+ help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
209
+ )
210
+ parser.add_argument(
211
+ "--learning_rate",
212
+ type=float,
213
+ default=5e-5,
214
+ help="Initial learning rate (after the potential warmup period) to use.",
215
+ )
216
+ parser.add_argument(
217
+ "--scale_lr",
218
+ action="store_true",
219
+ default=False,
220
+ help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
221
+ )
222
+ parser.add_argument(
223
+ "--lr_scheduler",
224
+ type=str,
225
+ default="constant",
226
+ help=(
227
+ 'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
228
+ ' "constant", "constant_with_warmup"]'
229
+ ),
230
+ )
231
+ parser.add_argument(
232
+ "--lr_warmup_steps", type=int, default=0, help="Number of steps for the warmup in the lr scheduler."
233
+ )
234
+ parser.add_argument(
235
+ "--lr_num_cycles",
236
+ type=int,
237
+ default=1,
238
+ help="Number of hard resets of the lr in cosine_with_restarts scheduler.",
239
+ )
240
+ parser.add_argument("--lr_power", type=float, default=1.0, help="Power factor of the polynomial scheduler.")
241
+ parser.add_argument(
242
+ "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
243
+ )
244
+ parser.add_argument(
245
+ "--dataloader_num_workers",
246
+ type=int,
247
+ default=2,
248
+ help=(
249
+ "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process."
250
+ ),
251
+ )
252
+ parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
253
+ parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
254
+ parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
255
+ parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
256
+ parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
257
+ parser.add_argument(
258
+ "--logging_dir",
259
+ type=str,
260
+ default="logs",
261
+ help=(
262
+ "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
263
+ " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
264
+ ),
265
+ )
266
+ parser.add_argument(
267
+ "--allow_tf32",
268
+ action="store_true",
269
+ help=(
270
+ "Whether or not to allow TF32 on Ampere GPUs. Can be used to speed up training. For more information, see"
271
+ " https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices"
272
+ ),
273
+ )
274
+ parser.add_argument(
275
+ "--report_to",
276
+ type=str,
277
+ default="tensorboard",
278
+ help=(
279
+ 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
280
+ ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
281
+ ),
282
+ )
283
+ parser.add_argument(
284
+ "--mixed_precision",
285
+ type=str,
286
+ default=None,
287
+ choices=["no", "fp16", "bf16"],
288
+ help=(
289
+ "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
290
+ " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the"
291
+ " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
292
+ ),
293
+ )
294
+ parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank")
295
+ parser.add_argument(
296
+ "--enable_xformers_memory_efficient_attention", action="store_true", help="Whether or not to use xformers."
297
+ )
298
+ parser.add_argument(
299
+ "--set_grads_to_none",
300
+ action="store_true",
301
+ help=(
302
+ "Save more memory by using setting grads to None instead of zero. Be aware, that this changes certain"
303
+ " behaviors, so disable this argument if it causes any problems. More info:"
304
+ " https://pytorch.org/docs/stable/generated/torch.optim.Optimizer.zero_grad.html"
305
+ ),
306
+ )
307
+ parser.add_argument(
308
+ "--input_dim",
309
+ type=int,
310
+ default=64,
311
+ help="randomly sampled vectors and dimensions of MLP input"
312
+ )
313
+ parser.add_argument(
314
+ "--experiment_name",
315
+ type=str,
316
+ default="normal_GAN", # "man_GAN" "woman_GAN"
317
+ help="randomly sampled vectors and dimensions of MLP input"
318
+ )
319
+
320
+
321
+ if input_args is not None:
322
+ args = parser.parse_args(input_args)
323
+ else:
324
+ args = parser.parse_args()
325
+
326
+ env_local_rank = int(os.environ.get("LOCAL_RANK", -1))
327
+ if env_local_rank != -1 and env_local_rank != args.local_rank:
328
+ args.local_rank = env_local_rank
329
+
330
+ return args
331
+
332
+ def encode_prompt(prompt_batch, name_batch, text_encoder, tokenizer, embedding_manager, is_train=True,
333
+ random_embeddings = None, timesteps = None):
334
+ captions = []
335
+ proportion_empty_prompts = 0
336
+
337
+ for caption in prompt_batch:
338
+ if random.random() < proportion_empty_prompts:
339
+ captions.append("")
340
+ elif isinstance(caption, str):
341
+ captions.append(caption)
342
+ elif isinstance(caption, (list, np.ndarray)):
343
+ captions.append(random.choice(caption) if is_train else caption[0])
344
+
345
+ text_inputs = tokenizer(
346
+ captions,
347
+ padding="max_length",
348
+ max_length=tokenizer.model_max_length,
349
+ truncation=True,
350
+ return_tensors="pt",
351
+ )
352
+ text_input_ids = text_inputs.input_ids.to(text_encoder.device)
353
+
354
+ positions_list = []
355
+ for prompt_ids in text_input_ids:
356
+ position = int(torch.where(prompt_ids == 265)[0][0])
357
+ positions_list.append(position)
358
+
359
+ prompt_embeds, other_return_dict = text_encoder_forward(
360
+ text_encoder = text_encoder,
361
+ input_ids = text_input_ids,
362
+ name_batch = name_batch,
363
+ output_hidden_states=True,
364
+ embedding_manager = embedding_manager,
365
+ random_embeddings = random_embeddings,
366
+ timesteps = timesteps)
367
+
368
+ return prompt_embeds, other_return_dict, positions_list
369
+
370
+
371
+ def weights_init_normal(m):
372
+ classname = m.__class__.__name__
373
+ if classname.find("Linear") != -1:
374
+ torch.nn.init.normal_(m.weight.data, 0.0, 0.02)
375
+ torch.nn.init.constant_(m.bias.data, 0.0)
376
+
377
+
378
+ def main(args):
379
+ args.output_dir = os.path.join(args.output_dir, args.experiment_name)
380
+ print("output_dir", args.output_dir)
381
+ logging_dir = Path(args.output_dir, args.logging_dir)
382
+ accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir)
383
+ accelerator = Accelerator(
384
+ gradient_accumulation_steps=args.gradient_accumulation_steps,
385
+ mixed_precision=args.mixed_precision,
386
+ log_with=args.report_to,
387
+ project_config=accelerator_project_config,
388
+ )
389
+
390
+ # Currently, it's not possible to do gradient accumulation when training two models with accelerate.accumulate
391
+ # This will be enabled soon in accelerate. For now, we don't allow gradient accumulation when training two models.
392
+ if args.gradient_accumulation_steps > 1 and accelerator.num_processes > 1:
393
+ raise ValueError(
394
+ "Gradient accumulation is not supported when training the text encoder in distributed training. "
395
+ "Please set gradient_accumulation_steps to 1. This feature will be supported in the future."
396
+ )
397
+
398
+ # Make one log on every process with the configuration for debugging.
399
+ logging.basicConfig(
400
+ format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
401
+ datefmt="%m/%d/%Y %H:%M:%S",
402
+ level=logging.INFO,
403
+ )
404
+ logger.info(accelerator.state, main_process_only=False)
405
+ if accelerator.is_local_main_process:
406
+ transformers.utils.logging.set_verbosity_warning()
407
+ diffusers.utils.logging.set_verbosity_info()
408
+ else:
409
+ transformers.utils.logging.set_verbosity_error()
410
+ diffusers.utils.logging.set_verbosity_error()
411
+
412
+ if args.seed is not None:
413
+ set_seed(args.seed)
414
+
415
+ if accelerator.is_main_process:
416
+ if args.output_dir is not None:
417
+ os.makedirs(args.output_dir, exist_ok=True)
418
+
419
+ # Load the tokenizer
420
+ if args.tokenizer_name:
421
+ tokenizer = AutoTokenizer.from_pretrained(args.tokenizer_name, revision=args.revision, use_fast=False)
422
+ elif args.pretrained_model_name_or_path:
423
+ tokenizer = AutoTokenizer.from_pretrained(
424
+ args.pretrained_model_name_or_path,
425
+ subfolder="tokenizer",
426
+ revision=args.revision,
427
+ use_fast=False,
428
+ )
429
+ # import correct text encoder class
430
+ text_encoder_cls = import_model_class_from_model_name_or_path(args.pretrained_model_name_or_path, args.revision)
431
+
432
+ # Load scheduler and models
433
+ noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler")
434
+ noise_scheduler.add_noise = types.MethodType(add_noise_return_paras, noise_scheduler)
435
+
436
+ text_encoder = text_encoder_cls.from_pretrained(
437
+ args.pretrained_model_name_or_path, subfolder="text_encoder", revision=args.revision
438
+ )
439
+
440
+ text_encoder.text_model.embeddings.forward = embedding_forward.__get__(text_encoder.text_model.embeddings)
441
+
442
+
443
+ embedding_manager_config = OmegaConf.load(args.embedding_manager_config)
444
+ experiment_name = args.experiment_name
445
+
446
+ Embedding_Manager = EmbeddingManagerId_adain(
447
+ tokenizer,
448
+ text_encoder,
449
+ device = accelerator.device,
450
+ training = True,
451
+ num_embeds_per_token = embedding_manager_config.model.personalization_config.params.num_embeds_per_token,
452
+ token_dim = embedding_manager_config.model.personalization_config.params.token_dim,
453
+ mlp_depth = embedding_manager_config.model.personalization_config.params.mlp_depth,
454
+ loss_type = embedding_manager_config.model.personalization_config.params.loss_type,
455
+ input_dim = embedding_manager_config.model.personalization_config.params.input_dim,
456
+ experiment_name = experiment_name,
457
+ )
458
+
459
+ Embedding_Manager.name_projection_layer.apply(weights_init_normal)
460
+
461
+ Embedding_D = Embedding_discriminator(embedding_manager_config.model.personalization_config.params.token_dim * 2, dropout_rate = 0.2)
462
+ Embedding_D.apply(weights_init_normal)
463
+
464
+ if args.pretrained_embedding_manager_path is not None:
465
+ epoch = args.pretrained_embedding_manager_epoch
466
+ embedding_manager_path = os.path.join(args.pretrained_embedding_manager_path, "embeddings_manager-{}.pt".format(epoch))
467
+ Embedding_Manager.load(embedding_manager_path)
468
+ embedding_D_path = os.path.join(args.pretrained_embedding_manager_path, "embedding_D-{}.pt".format(epoch))
469
+ Embedding_D = torch.load(embedding_D_path)
470
+
471
+ for param in Embedding_Manager.trainable_projection_parameters():
472
+ param.requires_grad = True
473
+ Embedding_D.requires_grad = True
474
+
475
+ text_encoder.requires_grad_(False)
476
+
477
+
478
+ # Check that all trainable models are in full precision
479
+ low_precision_error_string = (
480
+ "Please make sure to always have all model weights in full float32 precision when starting training - even if"
481
+ " doing mixed precision training. copy of the weights should still be float32."
482
+ )
483
+
484
+ if accelerator.unwrap_model(text_encoder).dtype != torch.float32:
485
+ raise ValueError(
486
+ f"Text encoder loaded as datatype {accelerator.unwrap_model(text_encoder).dtype}."
487
+ f" {low_precision_error_string}"
488
+ )
489
+
490
+ # Enable TF32 for faster training on Ampere GPUs,
491
+ # cf https://pytorch.org/docs/stable/notes/cuda.html#tensorfloat-32-tf32-on-ampere-devices
492
+ if args.allow_tf32:
493
+ torch.backends.cuda.matmul.allow_tf32 = True
494
+
495
+ if args.scale_lr:
496
+ args.learning_rate = (
497
+ args.learning_rate * args.gradient_accumulation_steps * args.train_batch_size * accelerator.num_processes
498
+ )
499
+
500
+ # Use 8-bit Adam for lower memory usage or to fine-tune the model in 16GB GPUs
501
+ if args.use_8bit_adam:
502
+ try:
503
+ import bitsandbytes as bnb
504
+ except ImportError:
505
+ raise ImportError(
506
+ "To use 8-bit Adam, please install the bitsandbytes library: `pip install bitsandbytes`."
507
+ )
508
+
509
+ optimizer_class = bnb.optim.AdamW8bit
510
+ else:
511
+ optimizer_class = torch.optim.AdamW
512
+
513
+
514
+ projection_params_to_optimize = Embedding_Manager.trainable_projection_parameters()
515
+ optimizer_projection = optimizer_class(
516
+ projection_params_to_optimize,
517
+ lr=args.learning_rate,
518
+ betas=(args.adam_beta1, args.adam_beta2),
519
+ weight_decay=args.adam_weight_decay,
520
+ eps=args.adam_epsilon,
521
+ )
522
+
523
+ discriminator_params_to_optimize = list(Embedding_D.parameters())
524
+ optimizer_discriminator = optimizer_class(
525
+ discriminator_params_to_optimize,
526
+ lr=args.learning_rate,
527
+ betas=(args.adam_beta1, args.adam_beta2),
528
+ weight_decay=args.adam_weight_decay,
529
+ eps=args.adam_epsilon,
530
+ )
531
+
532
+
533
+ train_dataset = FaceIdDataset(
534
+ experiment_name = experiment_name
535
+ )
536
+
537
+ print("dataset_length", train_dataset._length)
538
+ train_dataloader = torch.utils.data.DataLoader(
539
+ train_dataset,
540
+ batch_size=args.train_batch_size,
541
+ shuffle=True,
542
+ num_workers=accelerator.num_processes,
543
+ )
544
+
545
+ # Scheduler and math around the number of training steps.
546
+ overrode_max_train_steps = False
547
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
548
+ if args.max_train_steps is None:
549
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
550
+ overrode_max_train_steps = True
551
+
552
+ lr_scheduler_proj = get_scheduler(
553
+ args.lr_scheduler,
554
+ optimizer=optimizer_projection,
555
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
556
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
557
+ num_cycles=args.lr_num_cycles,
558
+ power=args.lr_power,
559
+ )
560
+
561
+ lr_scheduler_disc = get_scheduler(
562
+ args.lr_scheduler,
563
+ optimizer=optimizer_discriminator,
564
+ num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
565
+ num_training_steps=args.max_train_steps * args.gradient_accumulation_steps,
566
+ num_cycles=args.lr_num_cycles,
567
+ power=args.lr_power,
568
+ )
569
+
570
+ Embedding_Manager, optimizer_projection, optimizer_discriminator, train_dataloader, lr_scheduler_proj, lr_scheduler_disc = accelerator.prepare(
571
+ Embedding_Manager, optimizer_projection, optimizer_discriminator, train_dataloader, lr_scheduler_proj, lr_scheduler_disc
572
+ )
573
+
574
+
575
+ # For mixed precision training we cast the text_encoder and vae weights to half-precision
576
+ # as these models are only used for inference, keeping weights in full precision is not required.
577
+ weight_dtype = torch.float32
578
+ if accelerator.mixed_precision == "fp16":
579
+ weight_dtype = torch.float16
580
+ elif accelerator.mixed_precision == "bf16":
581
+ weight_dtype = torch.bfloat16
582
+
583
+ # Move vae and unet to device and cast to weight_dtype
584
+ text_encoder.to(accelerator.device, dtype=weight_dtype)
585
+ Embedding_Manager.to(accelerator.device, dtype=weight_dtype)
586
+ Embedding_D.to(accelerator.device, dtype=weight_dtype)
587
+
588
+ # We need to recalculate our total training steps as the size of the training dataloader may have changed.
589
+ num_update_steps_per_epoch = math.ceil(len(train_dataloader) / args.gradient_accumulation_steps)
590
+ if overrode_max_train_steps:
591
+ args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch
592
+
593
+ # Afterwards we recalculate our number of training epochs
594
+ args.num_train_epochs = math.ceil(args.max_train_steps / num_update_steps_per_epoch)
595
+
596
+ # We need to initialize the trackers we use, and also store our configuration.
597
+ # The trackers initializes automatically on the main process.
598
+ if accelerator.is_main_process:
599
+ accelerator.init_trackers("identity_space", config=vars(args))
600
+
601
+
602
+ # Train!
603
+ total_batch_size = args.train_batch_size * accelerator.num_processes * args.gradient_accumulation_steps
604
+
605
+ logger.info("***** Running training *****")
606
+ logger.info(f" Num examples = {len(train_dataset)}")
607
+ logger.info(f" Num batches each epoch = {len(train_dataloader)}")
608
+ logger.info(f" Num Epochs = {args.num_train_epochs}")
609
+ logger.info(f" Instantaneous batch size per device = {args.train_batch_size}")
610
+ logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}")
611
+ logger.info(f" Gradient Accumulation steps = {args.gradient_accumulation_steps}")
612
+ logger.info(f" Total optimization steps = {args.max_train_steps}")
613
+ global_step = 0
614
+ first_epoch = 0
615
+
616
+ # Potentially load in the weights and states from a previous save
617
+ if args.resume_from_checkpoint:
618
+ if args.resume_from_checkpoint != "latest":
619
+ path = os.path.basename(args.resume_from_checkpoint)
620
+ else:
621
+ # Get the mos recent checkpoint
622
+ dirs = os.listdir(args.output_dir)
623
+ dirs = [d for d in dirs if d.startswith("checkpoint")]
624
+ dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
625
+ path = dirs[-1] if len(dirs) > 0 else None
626
+
627
+ if path is None:
628
+ accelerator.print(
629
+ f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
630
+ )
631
+ args.resume_from_checkpoint = None
632
+ else:
633
+ accelerator.print(f"Resuming from checkpoint {path}")
634
+ accelerator.load_state(os.path.join(args.output_dir, path))
635
+ global_step = int(path.split("-")[1])
636
+
637
+ resume_global_step = global_step * args.gradient_accumulation_steps
638
+ first_epoch = global_step // num_update_steps_per_epoch
639
+ resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)
640
+
641
+ # Only show the progress bar once on each machine.
642
+ progress_bar = tqdm(range(global_step, args.max_train_steps), disable=not accelerator.is_local_main_process)
643
+ progress_bar.set_description("Steps")
644
+
645
+ num_iter = 0
646
+ # trained_images_num = 0
647
+ for epoch in range(first_epoch, args.num_train_epochs):
648
+ print("=====================================")
649
+ print("epoch:", epoch)
650
+ print("=====================================")
651
+ Embedding_Manager.train()
652
+ for step, batch in enumerate(train_dataloader):
653
+
654
+ # Skip steps until we reach the resumed step
655
+ if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
656
+ if step % args.gradient_accumulation_steps == 0:
657
+ progress_bar.update(1)
658
+ continue
659
+
660
+ random_embeddings = torch.randn(1, 1, args.input_dim).to(accelerator.device)
661
+ random_embeddings = random_embeddings.repeat(args.train_batch_size, 1, 1)
662
+
663
+ encoder_hidden_states, other_return_dict, positions_list = encode_prompt(batch["caption"],
664
+ batch["name"],
665
+ text_encoder, tokenizer,
666
+ Embedding_Manager,
667
+ is_train=True,
668
+ random_embeddings = random_embeddings,
669
+ timesteps = 0)
670
+
671
+ name_embeddings = other_return_dict["name_embeddings"]
672
+ adained_total_embedding = other_return_dict["adained_total_embedding"]
673
+ fake_emb = adained_total_embedding
674
+
675
+ criterionGAN = GANLoss().to(accelerator.device)
676
+
677
+ set_requires_grad(Embedding_D, True)
678
+ optimizer_discriminator.zero_grad(set_to_none=args.set_grads_to_none)
679
+ # fake
680
+ pred_fake = Embedding_D(fake_emb.detach())
681
+ loss_D_fake = criterionGAN(pred_fake[0], False)
682
+
683
+ # Real
684
+ random_noise = torch.rand_like(name_embeddings) * 0.005
685
+ real_name_embeddings = random_noise + name_embeddings
686
+ pred_real = Embedding_D(real_name_embeddings)
687
+ loss_D_real = criterionGAN(pred_real[0], True)
688
+
689
+ loss_D = (loss_D_fake + loss_D_real) * 0.5
690
+ accelerator.backward(loss_D)
691
+ if accelerator.sync_gradients:
692
+ accelerator.clip_grad_norm_(discriminator_params_to_optimize, args.max_grad_norm)
693
+ optimizer_discriminator.step()
694
+
695
+ set_requires_grad(Embedding_D, False)
696
+ optimizer_projection.zero_grad(set_to_none=args.set_grads_to_none)
697
+ pred_fake = Embedding_D(fake_emb)
698
+
699
+ loss_G_GAN = criterionGAN(pred_fake[0], True)
700
+
701
+ num_embeddings = encoder_hidden_states.size(0)
702
+ loss_consistency = 0.0
703
+ for i in range(num_embeddings):
704
+ position1 = positions_list[i]
705
+ name_embedding1 = torch.cat([encoder_hidden_states[i][position1], encoder_hidden_states[i][position1 + 1]], dim=0)
706
+ for j in range(i + 1, num_embeddings):
707
+ position2 = positions_list[j]
708
+ name_embedding2 = torch.cat([encoder_hidden_states[j][position2], encoder_hidden_states[j][position2 + 1]], dim=0)
709
+ loss_consistency += F.mse_loss(name_embedding1, name_embedding2)
710
+
711
+ loss_consistency /= (num_embeddings * (num_embeddings - 1)) / 2
712
+
713
+ loss = loss_G_GAN * args.l_gan_lambda + loss_consistency * args.l_consis_lambda
714
+
715
+ accelerator.backward(loss)
716
+
717
+ if accelerator.sync_gradients:
718
+ accelerator.clip_grad_norm_(projection_params_to_optimize, args.max_grad_norm)
719
+ optimizer_projection.step()
720
+ lr_scheduler_proj.step()
721
+ lr_scheduler_disc.step()
722
+
723
+ num_iter += 1
724
+
725
+ # Checks if the accelerator has performed an optimization step behind the scenes
726
+ if accelerator.sync_gradients:
727
+ progress_bar.update(1)
728
+ if global_step % args.checkpointing_steps == 0:
729
+ if accelerator.is_main_process:
730
+ save_path = os.path.join(args.output_dir, f"embeddings_manager-{global_step}.pt")
731
+ # accelerator.save_state(save_path)
732
+ try:
733
+ Embedding_Manager.save(save_path)
734
+ except:
735
+ Embedding_Manager.module.save(save_path)
736
+
737
+ save_path_d = os.path.join(args.output_dir, f"embedding_D-{global_step}.pt")
738
+ Embedding_D.save(save_path_d)
739
+
740
+ logger.info(f"Saved state to {save_path}")
741
+
742
+ global_step += 1
743
+
744
+ adained_total_embeddings_max_min = (round(adained_total_embedding.max().detach().item(), 4),
745
+ round(adained_total_embedding.min().detach().item(), 4))
746
+
747
+ logs = {"m1": adained_total_embeddings_max_min,
748
+ "l_G_GAN": loss_G_GAN.detach().item(),
749
+ "l_consistency": loss_consistency.detach().item(),
750
+ "l_D_real": loss_D_real.detach().item(),
751
+ "l_D_fake": loss_D_fake.detach().item(),
752
+ "loss": loss.detach().item(),
753
+ }
754
+ progress_bar.set_postfix(**logs)
755
+ accelerator.log(logs, step=global_step)
756
+
757
+ if global_step >= args.max_train_steps:
758
+ break
759
+
760
+ # Create the pipeline using the trained modules and save it.
761
+ accelerator.wait_for_everyone()
762
+ accelerator.end_training()
763
+
764
+
765
+ if __name__ == "__main__":
766
+ args = parse_args()
767
+ main(args)
training_weight/man_GAN/embeddings_manager-7000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23ad2b0e562fac58f51b6deb1fb129b1317b76aa98bdf5b70e870c7a5ed38862
3
+ size 17356032
training_weight/normal_GAN/embeddings_manager-10000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2c0051eccaa1435d0ec678dc8d8e7130d09849b48aa7b67f51aee9aa4bad71cb
3
+ size 17356044
training_weight/woman_GAN/embeddings_manager-6000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23fc0dc612520375801ec596354479c194baee2f090a50f7ae7cc9e41479eb3f
3
+ size 17356032
utils.py ADDED
@@ -0,0 +1,237 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from PIL import Image
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from transformers.modeling_outputs import BaseModelOutputWithPooling
7
+ from transformers.models.clip.modeling_clip import _make_causal_mask, _expand_mask
8
+ from torch import autograd
9
+ import accelerate
10
+ import torch.nn as nn
11
+
12
+ from PIL import Image
13
+ import numpy as np
14
+
15
+ def set_requires_grad(nets, requires_grad=False):
16
+ """Set requies_grad=Fasle for all the networks to avoid unnecessary computations
17
+ Parameters:
18
+ nets (network list) -- a list of networks
19
+ requires_grad (bool) -- whether the networks require gradients or not
20
+ """
21
+ if not isinstance(nets, list):
22
+ nets = [nets]
23
+ for net in nets:
24
+ if net is not None:
25
+ for param in net.parameters():
26
+ param.requires_grad = requires_grad
27
+
28
+ def discriminator_r1_loss_accelerator(accelerator, real_pred, real_w):
29
+ grad_real, = accelerate.gradient(
30
+ outputs=real_pred.sum(), inputs=real_w, create_graph=True #, only_inputs=True
31
+ )
32
+ grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
33
+
34
+ return grad_penalty
35
+
36
+ class GANLoss(nn.Module):
37
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
38
+ super(GANLoss, self).__init__()
39
+ self.register_buffer('real_label', torch.tensor(target_real_label))
40
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
41
+ if use_lsgan:
42
+ self.loss = nn.MSELoss()
43
+ else:
44
+ self.loss = nn.BCEWithLogitsLoss()
45
+
46
+ def get_target_tensor(self, input, target_is_real):
47
+ if target_is_real:
48
+ target_tensor = self.real_label
49
+ else:
50
+ target_tensor = self.fake_label
51
+ return target_tensor.expand_as(input)
52
+
53
+ def __call__(self, input, target_is_real):
54
+ target_tensor = self.get_target_tensor(input, target_is_real)
55
+ return self.loss(input, target_tensor)
56
+
57
+
58
+
59
+ def discriminator_r1_loss(real_pred, real_w):
60
+ grad_real, = autograd.grad(
61
+ outputs=real_pred.sum(), inputs=real_w, create_graph=True #, only_inputs=True
62
+ )
63
+ grad_penalty = grad_real.pow(2).reshape(grad_real.shape[0], -1).sum(1).mean()
64
+
65
+ return grad_penalty
66
+
67
+ def add_noise_return_paras(
68
+ self,
69
+ original_samples: torch.FloatTensor,
70
+ noise: torch.FloatTensor,
71
+ timesteps: torch.IntTensor,
72
+ ) -> torch.FloatTensor:
73
+ # Make sure alphas_cumprod and timestep have same device and dtype as original_samples
74
+ alphas_cumprod = self.alphas_cumprod.to(device=original_samples.device, dtype=original_samples.dtype)
75
+ timesteps = timesteps.to(original_samples.device)
76
+
77
+ sqrt_alpha_prod = alphas_cumprod[timesteps] ** 0.5
78
+ sqrt_alpha_prod = sqrt_alpha_prod.flatten()
79
+ while len(sqrt_alpha_prod.shape) < len(original_samples.shape):
80
+ sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
81
+
82
+ sqrt_one_minus_alpha_prod = (1 - alphas_cumprod[timesteps]) ** 0.5
83
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
84
+ while len(sqrt_one_minus_alpha_prod.shape) < len(original_samples.shape):
85
+ sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
86
+
87
+ noisy_samples = sqrt_alpha_prod * original_samples + sqrt_one_minus_alpha_prod * noise
88
+ return noisy_samples, sqrt_alpha_prod, sqrt_one_minus_alpha_prod
89
+
90
+
91
+ def text_encoder_forward(
92
+ text_encoder = None,
93
+ input_ids = None,
94
+ name_batch = None,
95
+ attention_mask = None,
96
+ position_ids = None,
97
+ output_attentions = None,
98
+ output_hidden_states = None,
99
+ return_dict = None,
100
+ embedding_manager = None,
101
+ only_embedding=False,
102
+ random_embeddings = None,
103
+ timesteps = None,
104
+ ):
105
+ output_attentions = output_attentions if output_attentions is not None else text_encoder.config.output_attentions
106
+ output_hidden_states = (
107
+ output_hidden_states if output_hidden_states is not None else text_encoder.config.output_hidden_states
108
+ )
109
+ return_dict = return_dict if return_dict is not None else text_encoder.config.use_return_dict
110
+
111
+ if input_ids is None:
112
+ raise ValueError("You have to specify either input_ids")
113
+
114
+ input_shape = input_ids.size()
115
+ input_ids = input_ids.view(-1, input_shape[-1])
116
+
117
+ hidden_states, other_return_dict = text_encoder.text_model.embeddings(input_ids=input_ids,
118
+ position_ids=position_ids,
119
+ name_batch = name_batch,
120
+ embedding_manager=embedding_manager,
121
+ only_embedding=only_embedding,
122
+ random_embeddings = random_embeddings,
123
+ timesteps = timesteps,
124
+ )
125
+ if only_embedding:
126
+ return hidden_states
127
+
128
+
129
+ causal_attention_mask = _make_causal_mask(input_shape, hidden_states.dtype, device=hidden_states.device)
130
+ if attention_mask is not None:
131
+ attention_mask = _expand_mask(attention_mask, hidden_states.dtype)
132
+
133
+ encoder_outputs = text_encoder.text_model.encoder(
134
+ inputs_embeds=hidden_states,
135
+ attention_mask=attention_mask,
136
+ causal_attention_mask=causal_attention_mask,
137
+ output_attentions=output_attentions,
138
+ output_hidden_states=output_hidden_states,
139
+ return_dict=return_dict,
140
+ )
141
+
142
+ last_hidden_state = encoder_outputs[0]
143
+ last_hidden_state = text_encoder.text_model.final_layer_norm(last_hidden_state)
144
+
145
+ if text_encoder.text_model.eos_token_id == 2:
146
+ pooled_output = last_hidden_state[
147
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
148
+ input_ids.to(dtype=torch.int, device=last_hidden_state.device).argmax(dim=-1),
149
+ ]
150
+ else:
151
+ pooled_output = last_hidden_state[
152
+ torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
153
+ (input_ids.to(dtype=torch.int, device=last_hidden_state.device) == text_encoder.text_model.eos_token_id)
154
+ .int()
155
+ .argmax(dim=-1),
156
+ ]
157
+
158
+ if not return_dict:
159
+ return (last_hidden_state, pooled_output) + encoder_outputs[1:]
160
+
161
+ return BaseModelOutputWithPooling(
162
+ last_hidden_state=last_hidden_state,
163
+ pooler_output=pooled_output,
164
+ hidden_states=encoder_outputs.hidden_states,
165
+ attentions=encoder_outputs.attentions,
166
+ )[0], other_return_dict
167
+
168
+
169
+ def downsampling(img: torch.tensor, w: int, h: int) -> torch.tensor:
170
+ return F.interpolate(
171
+ img.unsqueeze(0).unsqueeze(1),
172
+ size=(w, h),
173
+ mode="bilinear",
174
+ align_corners=True,
175
+ ).squeeze()
176
+
177
+
178
+ def image_grid(images, rows=2, cols=2):
179
+ w, h = images[0].size
180
+ grid = Image.new('RGB', size=(cols * w, rows * h))
181
+
182
+ for i, img in enumerate(images):
183
+ grid.paste(img, box=(i % cols * w, i // cols * h))
184
+ return grid
185
+
186
+
187
+ def latents_to_images(vae, latents, scale_factor=0.18215):
188
+ """
189
+ Decode latents to PIL images.
190
+ """
191
+ scaled_latents = 1.0 / scale_factor * latents.clone()
192
+ images = vae.decode(scaled_latents).sample
193
+ images = (images / 2 + 0.5).clamp(0, 1)
194
+ images = images.detach().cpu().permute(0, 2, 3, 1).numpy()
195
+
196
+ if images.ndim == 3:
197
+ images = images[None, ...]
198
+ images = (images * 255).round().astype("uint8")
199
+ pil_images = [Image.fromarray(image) for image in images]
200
+
201
+ return pil_images
202
+
203
+
204
+ def merge_and_save_images(output_images):
205
+ image_size = output_images[0].size
206
+
207
+ merged_width = len(output_images) * image_size[0]
208
+ merged_height = image_size[1]
209
+
210
+ merged_image = Image.new('RGB', (merged_width, merged_height), (255, 255, 255))
211
+
212
+ for i, image in enumerate(output_images):
213
+ merged_image.paste(image, (i * image_size[0], 0))
214
+
215
+ return merged_image
216
+
217
+ class GANLoss(nn.Module):
218
+ def __init__(self, use_lsgan=True, target_real_label=1.0, target_fake_label=0.0):
219
+ super(GANLoss, self).__init__()
220
+ self.register_buffer('real_label', torch.tensor(target_real_label))
221
+ self.register_buffer('fake_label', torch.tensor(target_fake_label))
222
+
223
+ if use_lsgan:
224
+ self.loss = nn.MSELoss()
225
+ else:
226
+ self.loss = nn.BCELoss()
227
+
228
+ def get_target_tensor(self, input, target_is_real):
229
+ if target_is_real:
230
+ target_tensor = self.real_label
231
+ else:
232
+ target_tensor = self.fake_label
233
+ return target_tensor.expand_as(input)
234
+
235
+ def __call__(self, input, target_is_real):
236
+ target_tensor = self.get_target_tensor(input, target_is_real)
237
+ return self.loss(input, target_tensor)