File size: 11,985 Bytes
3ab16a9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import os\n",
    "from transformers import ViTModel, ViTImageProcessor\n",
    "from utils import text_encoder_forward\n",
    "from diffusers import StableDiffusionPipeline, DPMSolverMultistepScheduler\n",
    "from utils import latents_to_images, downsampling, merge_and_save_images\n",
    "from omegaconf import OmegaConf\n",
    "from accelerate.utils import set_seed\n",
    "from tqdm import tqdm\n",
    "from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion import StableDiffusionPipelineOutput\n",
    "from PIL import Image\n",
    "from models.celeb_embeddings import embedding_forward\n",
    "import models.embedding_manager\n",
    "import importlib\n",
    "\n",
    "# seed = 42\n",
    "# set_seed(seed)  \n",
    "# torch.cuda.set_device(0)\n",
    "\n",
    "# set your sd2.1 path\n",
    "model_path = \"/home/user/.cache/huggingface/hub/models--stabilityai--stable-diffusion-2-1/snapshots/5cae40e6a2745ae2b01ad92ae5043f95f23644d6\"\n",
    "pipe = StableDiffusionPipeline.from_pretrained(model_path)   \n",
    "pipe.scheduler = DPMSolverMultistepScheduler.from_config(pipe.scheduler.config)\n",
    "pipe = pipe.to(\"cuda\")\n",
    "\n",
    "device = torch.device(\"cuda\") if torch.cuda.is_available() else torch.device(\"cpu\")\n",
    "\n",
    "vae = pipe.vae\n",
    "unet = pipe.unet\n",
    "text_encoder = pipe.text_encoder\n",
    "tokenizer = pipe.tokenizer\n",
    "scheduler = pipe.scheduler\n",
    "\n",
    "input_dim = 64\n",
    "\n",
    "experiment_name = \"normal_GAN\"   # \"normal_GAN\", \"man_GAN\", \"woman_GAN\" , \n",
    "if experiment_name == \"normal_GAN\":\n",
    "    steps = 10000\n",
    "elif experiment_name == \"man_GAN\":\n",
    "    steps = 7000\n",
    "elif experiment_name == \"woman_GAN\":\n",
    "    steps = 6000\n",
    "else:\n",
    "    print(\"Hello, please notice this ^_^\")\n",
    "    assert 0\n",
    "\n",
    "\n",
    "original_forward = text_encoder.text_model.embeddings.forward\n",
    "text_encoder.text_model.embeddings.forward = embedding_forward.__get__(text_encoder.text_model.embeddings)\n",
    "embedding_manager_config = OmegaConf.load(\"datasets_face/identity_space.yaml\")\n",
    "Embedding_Manager = models.embedding_manager.EmbeddingManagerId_adain(  \n",
    "        tokenizer,\n",
    "        text_encoder,\n",
    "        device = device,\n",
    "        training = True,\n",
    "        experiment_name = experiment_name, \n",
    "        num_embeds_per_token = embedding_manager_config.model.personalization_config.params.num_embeds_per_token,            \n",
    "        token_dim = embedding_manager_config.model.personalization_config.params.token_dim,\n",
    "        mlp_depth = embedding_manager_config.model.personalization_config.params.mlp_depth,\n",
    "        loss_type = embedding_manager_config.model.personalization_config.params.loss_type,\n",
    "        vit_out_dim = input_dim,\n",
    ")\n",
    "embedding_path = os.path.join(\"training_weight\", experiment_name, \"embeddings_manager-{}.pt\".format(str(steps)))\n",
    "Embedding_Manager.load(embedding_path)\n",
    "text_encoder.text_model.embeddings.forward = original_forward\n",
    "\n",
    "print(\"finish init\")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "1. create a new character and test with prompts"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# sample a z\n",
    "random_embedding = torch.randn(1, 1, input_dim).to(device)\n",
    "\n",
    "# map z to pseudo identity embeddings\n",
    "_, emb_dict = Embedding_Manager(tokenized_text=None, embedded_text=None, name_batch=None, random_embeddings = random_embedding, timesteps = None,)\n",
    "\n",
    "test_emb = emb_dict[\"adained_total_embedding\"].to(device)\n",
    "\n",
    "v1_emb = test_emb[:, 0]\n",
    "v2_emb = test_emb[:, 1]\n",
    "embeddings = [v1_emb, v2_emb]\n",
    "\n",
    "index = \"0000\"\n",
    "save_dir = os.path.join(\"test_results/\" + experiment_name, index)\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "test_emb_path = os.path.join(save_dir, \"id_embeddings.pt\")\n",
    "torch.save(test_emb, test_emb_path)\n",
    "\n",
    "'''insert into tokenizer & embedding layer'''\n",
    "tokens = [\"v1*\", \"v2*\"]\n",
    "embeddings = [v1_emb, v2_emb]\n",
    "# add tokens and get ids\n",
    "tokenizer.add_tokens(tokens)\n",
    "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "\n",
    "# resize token embeddings and set new embeddings\n",
    "text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)\n",
    "for token_id, embedding in zip(token_ids, embeddings):\n",
    "    text_encoder.get_input_embeddings().weight.data[token_id] = embedding\n",
    "\n",
    "prompts_list = [\"a photo of v1* v2*, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a Superman outfit, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a red sweater, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a blue hoodie, facing to camera, best quality, ultra high res\",\n",
    "]\n",
    "\n",
    "for prompt in prompts_list:\n",
    "    image = pipe(prompt, guidance_scale = 8.5).images[0]\n",
    "    save_img_path = os.path.join(save_dir, prompt.replace(\"v1* v2*\", \"a person\") + '.png')\n",
    "    image.save(save_img_path)\n",
    "    print(save_img_path)\n"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "2. directly use a chosen generated pseudo identity embeddings"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# the path of your generated embeddings\n",
    "test_emb_path = \"demo_embeddings/856.pt\"  # \"test_results/normal_GAN/0000/id_embeddings.pt\"\n",
    "test_emb = torch.load(test_emb_path).cuda()\n",
    "v1_emb = test_emb[:, 0]\n",
    "v2_emb = test_emb[:, 1]\n",
    "\n",
    "\n",
    "index = \"chosen_index\"\n",
    "save_dir = os.path.join(\"test_results/\" + experiment_name, index)\n",
    "os.makedirs(save_dir, exist_ok=True)\n",
    "\n",
    "\n",
    "'''insert into tokenizer & embedding layer'''\n",
    "tokens = [\"v1*\", \"v2*\"]\n",
    "embeddings = [v1_emb, v2_emb]\n",
    "# add tokens and get ids\n",
    "tokenizer.add_tokens(tokens)\n",
    "token_ids = tokenizer.convert_tokens_to_ids(tokens)\n",
    "\n",
    "# resize token embeddings and set new embeddings\n",
    "text_encoder.resize_token_embeddings(len(tokenizer), pad_to_multiple_of = 8)\n",
    "for token_id, embedding in zip(token_ids, embeddings):\n",
    "    text_encoder.get_input_embeddings().weight.data[token_id] = embedding\n",
    "\n",
    "prompts_list = [\"a photo of v1* v2*, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a Superman outfit, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a red sweater, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a purple wizard outfit, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a blue hoodie, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing headphones, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* with red hair, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing headphones with red hair, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a Christmas hat, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing sunglasses, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing sunglasses and necklace, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a blue cap, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a doctoral cap, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* with white hair, wearing glasses, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* in a helmet and vest riding a motorcycle, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* holding a bottle of red wine, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* driving a bus in the desert, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* playing basketball, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* playing the violin, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* piloting a spaceship, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* riding a horse, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* coding in front of a computer, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* laughing on the lawn, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* frowning at the camera, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* happily smiling, looking at the camera, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* crying disappointedly, with tears flowing, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing sunglasses, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* playing the guitar in the view of left side, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* holding a bottle of red wine, upper body, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing sunglasses and necklace, close-up, in the view of right side, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* riding a horse, in the view of the top, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* wearing a doctoral cap, upper body, with the left side of the face facing the camera, best quality, ultra high res\",\n",
    "    \"v1* v2* crying disappointedly, with tears flowing, with left side of the face facing the camera, best quality, ultra high res\",\n",
    "    \"v1* v2* sitting in front of the camera, with a beautiful purple sunset at the beach in the background, best quality, ultra high res\",\n",
    "    \"v1* v2* swimming in the pool, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* climbing a mountain, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* skiing on the snowy mountain, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* in the snow, facing to camera, best quality, ultra high res\",\n",
    "    \"v1* v2* in space wearing a spacesuit, facing to camera, best quality, ultra high res\",\n",
    "]\n",
    "\n",
    "for prompt in prompts_list:\n",
    "    image = pipe(prompt, guidance_scale = 8.5).images[0]\n",
    "    save_img_path = os.path.join(save_dir, prompt.replace(\"v1* v2*\", \"a person\") + '.png')\n",
    "    image.save(save_img_path)\n",
    "    print(save_img_path)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "lbl",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.8.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}