ML-Motivators commited on
Commit
5b93708
1 Parent(s): c5c8fc9

uploads files

Browse files
ip_adapter/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .ip_adapter import IPAdapter, IPAdapterPlus, IPAdapterPlusXL, IPAdapterXL, IPAdapterFull,IPAdapterPlus_Lora,IPAdapterPlus_Lora_up
2
+
3
+ __all__ = [
4
+ "IPAdapter",
5
+ "IPAdapterPlus",
6
+ "IPAdapterPlusXL",
7
+ "IPAdapterXL",
8
+ "IPAdapterFull",
9
+ "IPAdapterPlus_Lora",
10
+ 'IPAdapterPlus_Lora_up',
11
+ ]
ip_adapter/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (362 Bytes). View file
 
ip_adapter/__pycache__/attention_processor.cpython-310.pyc ADDED
Binary file (30.8 kB). View file
 
ip_adapter/__pycache__/ip_adapter.cpython-310.pyc ADDED
Binary file (14.9 kB). View file
 
ip_adapter/__pycache__/resampler.cpython-310.pyc ADDED
Binary file (4.77 kB). View file
 
ip_adapter/__pycache__/utils.cpython-310.pyc ADDED
Binary file (360 Bytes). View file
 
ip_adapter/attention_processor.py ADDED
The diff for this file is too large to render. See raw diff
 
ip_adapter/ip_adapter.py ADDED
@@ -0,0 +1,907 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import List
3
+
4
+ import torch
5
+ from diffusers import StableDiffusionPipeline
6
+ from diffusers.pipelines.controlnet import MultiControlNetModel
7
+ from PIL import Image
8
+ from safetensors import safe_open
9
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
10
+
11
+ from .utils import is_torch2_available
12
+
13
+ if is_torch2_available():
14
+ from .attention_processor import (
15
+ AttnProcessor2_0 as AttnProcessor,
16
+ )
17
+ from .attention_processor import (
18
+ CNAttnProcessor2_0 as CNAttnProcessor,
19
+ )
20
+ from .attention_processor import (
21
+ IPAttnProcessor2_0 as IPAttnProcessor,
22
+ )
23
+ from .attention_processor import IPAttnProcessor2_0_Lora
24
+ # else:
25
+ # from .attention_processor import AttnProcessor, CNAttnProcessor, IPAttnProcessor
26
+ from .resampler import Resampler
27
+ from diffusers.models.lora import LoRALinearLayer
28
+
29
+
30
+ class ImageProjModel(torch.nn.Module):
31
+ """Projection Model"""
32
+
33
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024, clip_extra_context_tokens=4):
34
+ super().__init__()
35
+
36
+ self.cross_attention_dim = cross_attention_dim
37
+ self.clip_extra_context_tokens = clip_extra_context_tokens
38
+ self.proj = torch.nn.Linear(clip_embeddings_dim, self.clip_extra_context_tokens * cross_attention_dim)
39
+ self.norm = torch.nn.LayerNorm(cross_attention_dim)
40
+
41
+ def forward(self, image_embeds):
42
+ embeds = image_embeds
43
+ clip_extra_context_tokens = self.proj(embeds).reshape(
44
+ -1, self.clip_extra_context_tokens, self.cross_attention_dim
45
+ )
46
+ clip_extra_context_tokens = self.norm(clip_extra_context_tokens)
47
+ return clip_extra_context_tokens
48
+
49
+
50
+ class MLPProjModel(torch.nn.Module):
51
+ """SD model with image prompt"""
52
+ def __init__(self, cross_attention_dim=1024, clip_embeddings_dim=1024):
53
+ super().__init__()
54
+
55
+ self.proj = torch.nn.Sequential(
56
+ torch.nn.Linear(clip_embeddings_dim, clip_embeddings_dim),
57
+ torch.nn.GELU(),
58
+ torch.nn.Linear(clip_embeddings_dim, cross_attention_dim),
59
+ torch.nn.LayerNorm(cross_attention_dim)
60
+ )
61
+
62
+ def forward(self, image_embeds):
63
+ clip_extra_context_tokens = self.proj(image_embeds)
64
+ return clip_extra_context_tokens
65
+
66
+
67
+ class IPAdapter:
68
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4):
69
+ self.device = device
70
+ self.image_encoder_path = image_encoder_path
71
+ self.ip_ckpt = ip_ckpt
72
+ self.num_tokens = num_tokens
73
+
74
+ self.pipe = sd_pipe.to(self.device)
75
+ self.set_ip_adapter()
76
+
77
+ # load image encoder
78
+ self.image_encoder = CLIPVisionModelWithProjection.from_pretrained(self.image_encoder_path).to(
79
+ self.device, dtype=torch.float16
80
+ )
81
+ self.clip_image_processor = CLIPImageProcessor()
82
+ # image proj model
83
+ self.image_proj_model = self.init_proj()
84
+
85
+ self.load_ip_adapter()
86
+
87
+ def init_proj(self):
88
+ image_proj_model = ImageProjModel(
89
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
90
+ clip_embeddings_dim=self.image_encoder.config.projection_dim,
91
+ clip_extra_context_tokens=self.num_tokens,
92
+ ).to(self.device, dtype=torch.float16)
93
+ return image_proj_model
94
+
95
+ def set_ip_adapter(self):
96
+ unet = self.pipe.unet
97
+ attn_procs = {}
98
+ for name in unet.attn_processors.keys():
99
+ cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim
100
+ if name.startswith("mid_block"):
101
+ hidden_size = unet.config.block_out_channels[-1]
102
+ elif name.startswith("up_blocks"):
103
+ block_id = int(name[len("up_blocks.")])
104
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
105
+ elif name.startswith("down_blocks"):
106
+ block_id = int(name[len("down_blocks.")])
107
+ hidden_size = unet.config.block_out_channels[block_id]
108
+ if cross_attention_dim is None:
109
+ attn_procs[name] = AttnProcessor()
110
+ else:
111
+ attn_procs[name] = IPAttnProcessor(
112
+ hidden_size=hidden_size,
113
+ cross_attention_dim=cross_attention_dim,
114
+ scale=1.0,
115
+ num_tokens=self.num_tokens,
116
+ ).to(self.device, dtype=torch.float16)
117
+ unet.set_attn_processor(attn_procs)
118
+ if hasattr(self.pipe, "controlnet"):
119
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
120
+ for controlnet in self.pipe.controlnet.nets:
121
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
122
+ else:
123
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
124
+
125
+ def load_ip_adapter(self):
126
+ if self.ip_ckpt is not None:
127
+ if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
128
+ state_dict = {"image_proj": {}, "ip_adapter": {}}
129
+ with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
130
+ for key in f.keys():
131
+ if key.startswith("image_proj."):
132
+ state_dict["image_proj"][key.replace("image_proj.", "")] = f.get_tensor(key)
133
+ elif key.startswith("ip_adapter."):
134
+ state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
135
+ else:
136
+ state_dict = torch.load(self.ip_ckpt, map_location="cpu")
137
+ self.image_proj_model.load_state_dict(state_dict["image_proj"])
138
+ ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
139
+ ip_layers.load_state_dict(state_dict["ip_adapter"])
140
+
141
+
142
+ # def load_ip_adapter(self):
143
+ # if self.ip_ckpt is not None:
144
+ # if os.path.splitext(self.ip_ckpt)[-1] == ".safetensors":
145
+ # state_dict = {"image_proj_model": {}, "ip_adapter": {}}
146
+ # with safe_open(self.ip_ckpt, framework="pt", device="cpu") as f:
147
+ # for key in f.keys():
148
+ # if key.startswith("image_proj_model."):
149
+ # state_dict["image_proj_model"][key.replace("image_proj_model.", "")] = f.get_tensor(key)
150
+ # elif key.startswith("ip_adapter."):
151
+ # state_dict["ip_adapter"][key.replace("ip_adapter.", "")] = f.get_tensor(key)
152
+ # else:
153
+ # state_dict = torch.load(self.ip_ckpt, map_location="cpu")
154
+
155
+ # tmp1 = {}
156
+ # for k,v in state_dict.items():
157
+ # if 'image_proj_model' in k:
158
+ # tmp1[k.replace('image_proj_model.','')] = v
159
+ # self.image_proj_model.load_state_dict(tmp1, strict=True)
160
+ # # ip_layers = torch.nn.ModuleList(self.pipe.unet.attn_processors.values())
161
+ # tmp2 = {}
162
+ # for k,v in state_dict.ites():
163
+ # if 'adapter_mode' in k:
164
+ # tmp1[k] = v
165
+
166
+ # print(ip_layers.state_dict())
167
+ # ip_layers.load_state_dict(state_dict,strict=False)
168
+
169
+
170
+ @torch.inference_mode()
171
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
172
+ if pil_image is not None:
173
+ if isinstance(pil_image, Image.Image):
174
+ pil_image = [pil_image]
175
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
176
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float16)).image_embeds
177
+ else:
178
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
179
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
180
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
181
+ return image_prompt_embeds, uncond_image_prompt_embeds
182
+
183
+ def get_image_embeds_train(self, pil_image=None, clip_image_embeds=None):
184
+ if pil_image is not None:
185
+ if isinstance(pil_image, Image.Image):
186
+ pil_image = [pil_image]
187
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
188
+ clip_image_embeds = self.image_encoder(clip_image.to(self.device, dtype=torch.float32)).image_embeds
189
+ else:
190
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float32)
191
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
192
+ uncond_image_prompt_embeds = self.image_proj_model(torch.zeros_like(clip_image_embeds))
193
+ return image_prompt_embeds, uncond_image_prompt_embeds
194
+
195
+
196
+ def set_scale(self, scale):
197
+ for attn_processor in self.pipe.unet.attn_processors.values():
198
+ if isinstance(attn_processor, IPAttnProcessor):
199
+ attn_processor.scale = scale
200
+
201
+ def generate(
202
+ self,
203
+ pil_image=None,
204
+ clip_image_embeds=None,
205
+ prompt=None,
206
+ negative_prompt=None,
207
+ scale=1.0,
208
+ num_samples=4,
209
+ seed=None,
210
+ guidance_scale=7.5,
211
+ num_inference_steps=50,
212
+ **kwargs,
213
+ ):
214
+ self.set_scale(scale)
215
+
216
+ if pil_image is not None:
217
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
218
+ else:
219
+ num_prompts = clip_image_embeds.size(0)
220
+
221
+ if prompt is None:
222
+ prompt = "best quality, high quality"
223
+ if negative_prompt is None:
224
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
225
+
226
+ if not isinstance(prompt, List):
227
+ prompt = [prompt] * num_prompts
228
+ if not isinstance(negative_prompt, List):
229
+ negative_prompt = [negative_prompt] * num_prompts
230
+
231
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
232
+ pil_image=pil_image, clip_image_embeds=clip_image_embeds
233
+ )
234
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
235
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
236
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
237
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
238
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
239
+
240
+ with torch.inference_mode():
241
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
242
+ prompt,
243
+ device=self.device,
244
+ num_images_per_prompt=num_samples,
245
+ do_classifier_free_guidance=True,
246
+ negative_prompt=negative_prompt,
247
+ )
248
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
249
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
250
+
251
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
252
+ images = self.pipe(
253
+ prompt_embeds=prompt_embeds,
254
+ negative_prompt_embeds=negative_prompt_embeds,
255
+ guidance_scale=guidance_scale,
256
+ num_inference_steps=num_inference_steps,
257
+ generator=generator,
258
+ **kwargs,
259
+ ).images
260
+
261
+ return images
262
+
263
+
264
+ class IPAdapterXL(IPAdapter):
265
+ """SDXL"""
266
+
267
+ def generate_test(
268
+ self,
269
+ pil_image,
270
+ prompt=None,
271
+ negative_prompt=None,
272
+ scale=1.0,
273
+ num_samples=4,
274
+ seed=None,
275
+ num_inference_steps=30,
276
+ **kwargs,
277
+ ):
278
+ self.set_scale(scale)
279
+
280
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
281
+
282
+ if prompt is None:
283
+ prompt = "best quality, high quality"
284
+ if negative_prompt is None:
285
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
286
+
287
+ if not isinstance(prompt, List):
288
+ prompt = [prompt] * num_prompts
289
+ if not isinstance(negative_prompt, List):
290
+ negative_prompt = [negative_prompt] * num_prompts
291
+
292
+
293
+ with torch.inference_mode():
294
+ (
295
+ prompt_embeds,
296
+ negative_prompt_embeds,
297
+ pooled_prompt_embeds,
298
+ negative_pooled_prompt_embeds,
299
+ ) = self.pipe.encode_prompt(
300
+ prompt,
301
+ num_images_per_prompt=num_samples,
302
+ do_classifier_free_guidance=True,
303
+ negative_prompt=negative_prompt,
304
+ )
305
+
306
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
307
+ images = self.pipe(
308
+ prompt_embeds=prompt_embeds,
309
+ negative_prompt_embeds=negative_prompt_embeds,
310
+ pooled_prompt_embeds=pooled_prompt_embeds,
311
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
312
+ num_inference_steps=num_inference_steps,
313
+ generator=generator,
314
+ **kwargs,
315
+ ).images
316
+
317
+
318
+ # with torch.autocast("cuda"):
319
+ # images = self.pipe(
320
+ # prompt_embeds=prompt_embeds,
321
+ # negative_prompt_embeds=negative_prompt_embeds,
322
+ # pooled_prompt_embeds=pooled_prompt_embeds,
323
+ # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
324
+ # num_inference_steps=num_inference_steps,
325
+ # generator=generator,
326
+ # **kwargs,
327
+ # ).images
328
+
329
+ return images
330
+
331
+
332
+ def generate(
333
+ self,
334
+ pil_image,
335
+ prompt=None,
336
+ negative_prompt=None,
337
+ scale=1.0,
338
+ num_samples=4,
339
+ seed=None,
340
+ num_inference_steps=30,
341
+ **kwargs,
342
+ ):
343
+ self.set_scale(scale)
344
+
345
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
346
+
347
+ if prompt is None:
348
+ prompt = "best quality, high quality"
349
+ if negative_prompt is None:
350
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
351
+
352
+ if not isinstance(prompt, List):
353
+ prompt = [prompt] * num_prompts
354
+ if not isinstance(negative_prompt, List):
355
+ negative_prompt = [negative_prompt] * num_prompts
356
+
357
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
358
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
359
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
360
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
361
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
362
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
363
+
364
+ with torch.inference_mode():
365
+ (
366
+ prompt_embeds,
367
+ negative_prompt_embeds,
368
+ pooled_prompt_embeds,
369
+ negative_pooled_prompt_embeds,
370
+ ) = self.pipe.encode_prompt(
371
+ prompt,
372
+ num_images_per_prompt=num_samples,
373
+ do_classifier_free_guidance=True,
374
+ negative_prompt=negative_prompt,
375
+ )
376
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
377
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
378
+
379
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
380
+ images = self.pipe(
381
+ prompt_embeds=prompt_embeds,
382
+ negative_prompt_embeds=negative_prompt_embeds,
383
+ pooled_prompt_embeds=pooled_prompt_embeds,
384
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
385
+ num_inference_steps=num_inference_steps,
386
+ generator=generator,
387
+ **kwargs,
388
+ ).images
389
+
390
+
391
+ # with torch.autocast("cuda"):
392
+ # images = self.pipe(
393
+ # prompt_embeds=prompt_embeds,
394
+ # negative_prompt_embeds=negative_prompt_embeds,
395
+ # pooled_prompt_embeds=pooled_prompt_embeds,
396
+ # negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
397
+ # num_inference_steps=num_inference_steps,
398
+ # generator=generator,
399
+ # **kwargs,
400
+ # ).images
401
+
402
+ return images
403
+
404
+
405
+ class IPAdapterPlus(IPAdapter):
406
+ """IP-Adapter with fine-grained features"""
407
+
408
+ def generate(
409
+ self,
410
+ pil_image=None,
411
+ clip_image_embeds=None,
412
+ prompt=None,
413
+ negative_prompt=None,
414
+ scale=1.0,
415
+ num_samples=4,
416
+ seed=None,
417
+ guidance_scale=7.5,
418
+ num_inference_steps=50,
419
+ **kwargs,
420
+ ):
421
+ self.set_scale(scale)
422
+
423
+ if pil_image is not None:
424
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
425
+ else:
426
+ num_prompts = clip_image_embeds.size(0)
427
+
428
+ if prompt is None:
429
+ prompt = "best quality, high quality"
430
+ if negative_prompt is None:
431
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
432
+
433
+ if not isinstance(prompt, List):
434
+ prompt = [prompt] * num_prompts
435
+ if not isinstance(negative_prompt, List):
436
+ negative_prompt = [negative_prompt] * num_prompts
437
+
438
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
439
+ pil_image=pil_image, clip_image=clip_image_embeds
440
+ )
441
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
442
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
443
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
444
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
445
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
446
+
447
+ with torch.inference_mode():
448
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
449
+ prompt,
450
+ device=self.device,
451
+ num_images_per_prompt=num_samples,
452
+ do_classifier_free_guidance=True,
453
+ negative_prompt=negative_prompt,
454
+ )
455
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
456
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
457
+
458
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
459
+ images = self.pipe(
460
+ prompt_embeds=prompt_embeds,
461
+ negative_prompt_embeds=negative_prompt_embeds,
462
+ guidance_scale=guidance_scale,
463
+ num_inference_steps=num_inference_steps,
464
+ generator=generator,
465
+ **kwargs,
466
+ ).images
467
+
468
+ return images
469
+
470
+
471
+ def init_proj(self):
472
+ image_proj_model = Resampler(
473
+ dim=self.pipe.unet.config.cross_attention_dim,
474
+ depth=4,
475
+ dim_head=64,
476
+ heads=12,
477
+ num_queries=self.num_tokens,
478
+ embedding_dim=self.image_encoder.config.hidden_size,
479
+ output_dim=self.pipe.unet.config.cross_attention_dim,
480
+ ff_mult=4,
481
+ ).to(self.device, dtype=torch.float16)
482
+ return image_proj_model
483
+
484
+ @torch.inference_mode()
485
+ def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None):
486
+ if pil_image is not None:
487
+ if isinstance(pil_image, Image.Image):
488
+ pil_image = [pil_image]
489
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
490
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
491
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
492
+ else:
493
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
494
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
495
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
496
+ uncond_clip_image_embeds = self.image_encoder(
497
+ torch.zeros_like(clip_image), output_hidden_states=True
498
+ ).hidden_states[-2]
499
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
500
+ return image_prompt_embeds, uncond_image_prompt_embeds
501
+
502
+
503
+
504
+
505
+ class IPAdapterPlus_Lora(IPAdapter):
506
+ """IP-Adapter with fine-grained features"""
507
+
508
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, rank=32):
509
+ self.rank = rank
510
+ super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens)
511
+
512
+
513
+ def generate(
514
+ self,
515
+ pil_image=None,
516
+ clip_image_embeds=None,
517
+ prompt=None,
518
+ negative_prompt=None,
519
+ scale=1.0,
520
+ num_samples=4,
521
+ seed=None,
522
+ guidance_scale=7.5,
523
+ num_inference_steps=50,
524
+ **kwargs,
525
+ ):
526
+ self.set_scale(scale)
527
+
528
+ if pil_image is not None:
529
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
530
+ else:
531
+ num_prompts = clip_image_embeds.size(0)
532
+
533
+ if prompt is None:
534
+ prompt = "best quality, high quality"
535
+ if negative_prompt is None:
536
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
537
+
538
+ if not isinstance(prompt, List):
539
+ prompt = [prompt] * num_prompts
540
+ if not isinstance(negative_prompt, List):
541
+ negative_prompt = [negative_prompt] * num_prompts
542
+
543
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
544
+ pil_image=pil_image, clip_image=clip_image_embeds
545
+ )
546
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
547
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
548
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
549
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
550
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
551
+
552
+ with torch.inference_mode():
553
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
554
+ prompt,
555
+ device=self.device,
556
+ num_images_per_prompt=num_samples,
557
+ do_classifier_free_guidance=True,
558
+ negative_prompt=negative_prompt,
559
+ )
560
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
561
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
562
+
563
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
564
+ images = self.pipe(
565
+ prompt_embeds=prompt_embeds,
566
+ negative_prompt_embeds=negative_prompt_embeds,
567
+ guidance_scale=guidance_scale,
568
+ num_inference_steps=num_inference_steps,
569
+ generator=generator,
570
+ **kwargs,
571
+ ).images
572
+
573
+ return images
574
+
575
+
576
+ def init_proj(self):
577
+ image_proj_model = Resampler(
578
+ dim=self.pipe.unet.config.cross_attention_dim,
579
+ depth=4,
580
+ dim_head=64,
581
+ heads=12,
582
+ num_queries=self.num_tokens,
583
+ embedding_dim=self.image_encoder.config.hidden_size,
584
+ output_dim=self.pipe.unet.config.cross_attention_dim,
585
+ ff_mult=4,
586
+ ).to(self.device, dtype=torch.float16)
587
+ return image_proj_model
588
+
589
+ @torch.inference_mode()
590
+ def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None):
591
+ if pil_image is not None:
592
+ if isinstance(pil_image, Image.Image):
593
+ pil_image = [pil_image]
594
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
595
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
596
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
597
+ else:
598
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
599
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
600
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
601
+ uncond_clip_image_embeds = self.image_encoder(
602
+ torch.zeros_like(clip_image), output_hidden_states=True
603
+ ).hidden_states[-2]
604
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
605
+ return image_prompt_embeds, uncond_image_prompt_embeds
606
+
607
+ def set_ip_adapter(self):
608
+ unet = self.pipe.unet
609
+ attn_procs = {}
610
+ unet_sd = unet.state_dict()
611
+
612
+ for attn_processor_name, attn_processor in unet.attn_processors.items():
613
+ # Parse the attention module.
614
+ cross_attention_dim = None if attn_processor_name.endswith("attn1.processor") else unet.config.cross_attention_dim
615
+ if attn_processor_name.startswith("mid_block"):
616
+ hidden_size = unet.config.block_out_channels[-1]
617
+ elif attn_processor_name.startswith("up_blocks"):
618
+ block_id = int(attn_processor_name[len("up_blocks.")])
619
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
620
+ elif attn_processor_name.startswith("down_blocks"):
621
+ block_id = int(attn_processor_name[len("down_blocks.")])
622
+ hidden_size = unet.config.block_out_channels[block_id]
623
+ if cross_attention_dim is None:
624
+ attn_procs[attn_processor_name] = AttnProcessor()
625
+ else:
626
+ layer_name = attn_processor_name.split(".processor")[0]
627
+ weights = {
628
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
629
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
630
+ }
631
+ attn_procs[attn_processor_name] = IPAttnProcessor2_0_Lora(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens)
632
+ attn_procs[attn_processor_name].load_state_dict(weights,strict=False)
633
+
634
+ attn_module = unet
635
+ for n in attn_processor_name.split(".")[:-1]:
636
+ attn_module = getattr(attn_module, n)
637
+
638
+ attn_module.q_lora = LoRALinearLayer(in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=self.rank)
639
+ attn_module.k_lora = LoRALinearLayer(in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=self.rank)
640
+ attn_module.v_lora = LoRALinearLayer(in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=self.rank)
641
+ attn_module.out_lora = LoRALinearLayer(in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, rank=self.rank)
642
+
643
+ unet.set_attn_processor(attn_procs)
644
+ if hasattr(self.pipe, "controlnet"):
645
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
646
+ for controlnet in self.pipe.controlnet.nets:
647
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
648
+ else:
649
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
650
+
651
+
652
+
653
+ class IPAdapterPlus_Lora_up(IPAdapter):
654
+ """IP-Adapter with fine-grained features"""
655
+
656
+ def __init__(self, sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens=4, rank=32):
657
+ self.rank = rank
658
+ super().__init__(sd_pipe, image_encoder_path, ip_ckpt, device, num_tokens)
659
+
660
+
661
+ def generate(
662
+ self,
663
+ pil_image=None,
664
+ clip_image_embeds=None,
665
+ prompt=None,
666
+ negative_prompt=None,
667
+ scale=1.0,
668
+ num_samples=4,
669
+ seed=None,
670
+ guidance_scale=7.5,
671
+ num_inference_steps=50,
672
+ **kwargs,
673
+ ):
674
+ self.set_scale(scale)
675
+
676
+ if pil_image is not None:
677
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
678
+ else:
679
+ num_prompts = clip_image_embeds.size(0)
680
+
681
+ if prompt is None:
682
+ prompt = "best quality, high quality"
683
+ if negative_prompt is None:
684
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
685
+
686
+ if not isinstance(prompt, List):
687
+ prompt = [prompt] * num_prompts
688
+ if not isinstance(negative_prompt, List):
689
+ negative_prompt = [negative_prompt] * num_prompts
690
+
691
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(
692
+ pil_image=pil_image, clip_image=clip_image_embeds
693
+ )
694
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
695
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
696
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
697
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
698
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
699
+
700
+ with torch.inference_mode():
701
+ prompt_embeds_, negative_prompt_embeds_ = self.pipe.encode_prompt(
702
+ prompt,
703
+ device=self.device,
704
+ num_images_per_prompt=num_samples,
705
+ do_classifier_free_guidance=True,
706
+ negative_prompt=negative_prompt,
707
+ )
708
+ prompt_embeds = torch.cat([prompt_embeds_, image_prompt_embeds], dim=1)
709
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds_, uncond_image_prompt_embeds], dim=1)
710
+
711
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
712
+ images = self.pipe(
713
+ prompt_embeds=prompt_embeds,
714
+ negative_prompt_embeds=negative_prompt_embeds,
715
+ guidance_scale=guidance_scale,
716
+ num_inference_steps=num_inference_steps,
717
+ generator=generator,
718
+ **kwargs,
719
+ ).images
720
+
721
+ return images
722
+
723
+
724
+ def init_proj(self):
725
+ image_proj_model = Resampler(
726
+ dim=self.pipe.unet.config.cross_attention_dim,
727
+ depth=4,
728
+ dim_head=64,
729
+ heads=12,
730
+ num_queries=self.num_tokens,
731
+ embedding_dim=self.image_encoder.config.hidden_size,
732
+ output_dim=self.pipe.unet.config.cross_attention_dim,
733
+ ff_mult=4,
734
+ ).to(self.device, dtype=torch.float16)
735
+ return image_proj_model
736
+
737
+ @torch.inference_mode()
738
+ def get_image_embeds(self, pil_image=None, clip_image=None, uncond= None):
739
+ if pil_image is not None:
740
+ if isinstance(pil_image, Image.Image):
741
+ pil_image = [pil_image]
742
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
743
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
744
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
745
+ else:
746
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
747
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
748
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
749
+ uncond_clip_image_embeds = self.image_encoder(
750
+ torch.zeros_like(clip_image), output_hidden_states=True
751
+ ).hidden_states[-2]
752
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
753
+ return image_prompt_embeds, uncond_image_prompt_embeds
754
+
755
+ def set_ip_adapter(self):
756
+ unet = self.pipe.unet
757
+ attn_procs = {}
758
+ unet_sd = unet.state_dict()
759
+
760
+ for attn_processor_name, attn_processor in unet.attn_processors.items():
761
+ # Parse the attention module.
762
+ cross_attention_dim = None if attn_processor_name.endswith("attn1.processor") else unet.config.cross_attention_dim
763
+ if attn_processor_name.startswith("mid_block"):
764
+ hidden_size = unet.config.block_out_channels[-1]
765
+ elif attn_processor_name.startswith("up_blocks"):
766
+ block_id = int(attn_processor_name[len("up_blocks.")])
767
+ hidden_size = list(reversed(unet.config.block_out_channels))[block_id]
768
+ elif attn_processor_name.startswith("down_blocks"):
769
+ block_id = int(attn_processor_name[len("down_blocks.")])
770
+ hidden_size = unet.config.block_out_channels[block_id]
771
+ if cross_attention_dim is None:
772
+ attn_procs[attn_processor_name] = AttnProcessor()
773
+ else:
774
+ layer_name = attn_processor_name.split(".processor")[0]
775
+ weights = {
776
+ "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"],
777
+ "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"],
778
+ }
779
+ attn_procs[attn_processor_name] = IPAttnProcessor2_0_Lora(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=self.num_tokens)
780
+ attn_procs[attn_processor_name].load_state_dict(weights,strict=False)
781
+
782
+ attn_module = unet
783
+ for n in attn_processor_name.split(".")[:-1]:
784
+ attn_module = getattr(attn_module, n)
785
+
786
+
787
+ if "up_blocks" in attn_processor_name:
788
+ attn_module.q_lora = LoRALinearLayer(in_features=attn_module.to_q.in_features, out_features=attn_module.to_q.out_features, rank=self.rank)
789
+ attn_module.k_lora = LoRALinearLayer(in_features=attn_module.to_k.in_features, out_features=attn_module.to_k.out_features, rank=self.rank)
790
+ attn_module.v_lora = LoRALinearLayer(in_features=attn_module.to_v.in_features, out_features=attn_module.to_v.out_features, rank=self.rank)
791
+ attn_module.out_lora = LoRALinearLayer(in_features=attn_module.to_out[0].in_features, out_features=attn_module.to_out[0].out_features, rank=self.rank)
792
+
793
+
794
+
795
+ unet.set_attn_processor(attn_procs)
796
+ if hasattr(self.pipe, "controlnet"):
797
+ if isinstance(self.pipe.controlnet, MultiControlNetModel):
798
+ for controlnet in self.pipe.controlnet.nets:
799
+ controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
800
+ else:
801
+ self.pipe.controlnet.set_attn_processor(CNAttnProcessor(num_tokens=self.num_tokens))
802
+
803
+
804
+
805
+ class IPAdapterFull(IPAdapterPlus):
806
+ """IP-Adapter with full features"""
807
+
808
+ def init_proj(self):
809
+ image_proj_model = MLPProjModel(
810
+ cross_attention_dim=self.pipe.unet.config.cross_attention_dim,
811
+ clip_embeddings_dim=self.image_encoder.config.hidden_size,
812
+ ).to(self.device, dtype=torch.float16)
813
+ return image_proj_model
814
+
815
+
816
+ class IPAdapterPlusXL(IPAdapter):
817
+ """SDXL"""
818
+
819
+ def init_proj(self):
820
+ image_proj_model = Resampler(
821
+ dim=1280,
822
+ depth=4,
823
+ dim_head=64,
824
+ heads=20,
825
+ num_queries=self.num_tokens,
826
+ embedding_dim=self.image_encoder.config.hidden_size,
827
+ output_dim=self.pipe.unet.config.cross_attention_dim,
828
+ ff_mult=4,
829
+ ).to(self.device, dtype=torch.float16)
830
+ return image_proj_model
831
+
832
+ @torch.inference_mode()
833
+ def get_image_embeds(self, pil_image=None, clip_image_embeds=None):
834
+ if pil_image is not None:
835
+ if isinstance(pil_image, Image.Image):
836
+ pil_image = [pil_image]
837
+ clip_image = self.clip_image_processor(images=pil_image, return_tensors="pt").pixel_values
838
+ clip_image = clip_image.to(self.device, dtype=torch.float16)
839
+ clip_image_embeds = self.image_encoder(clip_image, output_hidden_states=True).hidden_states[-2]
840
+ else:
841
+ clip_image_embeds = clip_image_embeds.to(self.device, dtype=torch.float16)
842
+ image_prompt_embeds = self.image_proj_model(clip_image_embeds)
843
+ uncond_clip_image_embeds = self.image_encoder(
844
+ torch.zeros_like(clip_image), output_hidden_states=True
845
+ ).hidden_states[-2]
846
+ uncond_image_prompt_embeds = self.image_proj_model(uncond_clip_image_embeds)
847
+ return image_prompt_embeds, uncond_image_prompt_embeds
848
+
849
+ def generate(
850
+ self,
851
+ pil_image,
852
+ prompt=None,
853
+ negative_prompt=None,
854
+ scale=1.0,
855
+ num_samples=4,
856
+ seed=None,
857
+ num_inference_steps=30,
858
+ **kwargs,
859
+ ):
860
+ self.set_scale(scale)
861
+
862
+ num_prompts = 1 if isinstance(pil_image, Image.Image) else len(pil_image)
863
+
864
+ if prompt is None:
865
+ prompt = "best quality, high quality"
866
+ if negative_prompt is None:
867
+ negative_prompt = "monochrome, lowres, bad anatomy, worst quality, low quality"
868
+
869
+ if not isinstance(prompt, List):
870
+ prompt = [prompt] * num_prompts
871
+ if not isinstance(negative_prompt, List):
872
+ negative_prompt = [negative_prompt] * num_prompts
873
+
874
+ image_prompt_embeds, uncond_image_prompt_embeds = self.get_image_embeds(pil_image)
875
+ bs_embed, seq_len, _ = image_prompt_embeds.shape
876
+ image_prompt_embeds = image_prompt_embeds.repeat(1, num_samples, 1)
877
+ image_prompt_embeds = image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
878
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.repeat(1, num_samples, 1)
879
+ uncond_image_prompt_embeds = uncond_image_prompt_embeds.view(bs_embed * num_samples, seq_len, -1)
880
+
881
+ with torch.inference_mode():
882
+ (
883
+ prompt_embeds,
884
+ negative_prompt_embeds,
885
+ pooled_prompt_embeds,
886
+ negative_pooled_prompt_embeds,
887
+ ) = self.pipe.encode_prompt(
888
+ prompt,
889
+ num_images_per_prompt=num_samples,
890
+ do_classifier_free_guidance=True,
891
+ negative_prompt=negative_prompt,
892
+ )
893
+ prompt_embeds = torch.cat([prompt_embeds, image_prompt_embeds], dim=1)
894
+ negative_prompt_embeds = torch.cat([negative_prompt_embeds, uncond_image_prompt_embeds], dim=1)
895
+
896
+ generator = torch.Generator(self.device).manual_seed(seed) if seed is not None else None
897
+ images = self.pipe(
898
+ prompt_embeds=prompt_embeds,
899
+ negative_prompt_embeds=negative_prompt_embeds,
900
+ pooled_prompt_embeds=pooled_prompt_embeds,
901
+ negative_pooled_prompt_embeds=negative_pooled_prompt_embeds,
902
+ num_inference_steps=num_inference_steps,
903
+ generator=generator,
904
+ **kwargs,
905
+ ).images
906
+
907
+ return images
ip_adapter/resampler.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modified from https://github.com/mlfoundations/open_flamingo/blob/main/open_flamingo/src/helpers.py
2
+ # and https://github.com/lucidrains/imagen-pytorch/blob/main/imagen_pytorch/imagen_pytorch.py
3
+
4
+ import math
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from einops import rearrange
9
+ from einops.layers.torch import Rearrange
10
+
11
+
12
+ # FFN
13
+ def FeedForward(dim, mult=4):
14
+ inner_dim = int(dim * mult)
15
+ return nn.Sequential(
16
+ nn.LayerNorm(dim),
17
+ nn.Linear(dim, inner_dim, bias=False),
18
+ nn.GELU(),
19
+ nn.Linear(inner_dim, dim, bias=False),
20
+ )
21
+
22
+
23
+ def reshape_tensor(x, heads):
24
+ bs, length, width = x.shape
25
+ # (bs, length, width) --> (bs, length, n_heads, dim_per_head)
26
+ x = x.view(bs, length, heads, -1)
27
+ # (bs, length, n_heads, dim_per_head) --> (bs, n_heads, length, dim_per_head)
28
+ x = x.transpose(1, 2)
29
+ # (bs, n_heads, length, dim_per_head) --> (bs*n_heads, length, dim_per_head)
30
+ x = x.reshape(bs, heads, length, -1)
31
+ return x
32
+
33
+
34
+ class PerceiverAttention(nn.Module):
35
+ def __init__(self, *, dim, dim_head=64, heads=8):
36
+ super().__init__()
37
+ self.scale = dim_head**-0.5
38
+ self.dim_head = dim_head
39
+ self.heads = heads
40
+ inner_dim = dim_head * heads
41
+
42
+ self.norm1 = nn.LayerNorm(dim)
43
+ self.norm2 = nn.LayerNorm(dim)
44
+
45
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
46
+ self.to_kv = nn.Linear(dim, inner_dim * 2, bias=False)
47
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
48
+
49
+ def forward(self, x, latents):
50
+ """
51
+ Args:
52
+ x (torch.Tensor): image features
53
+ shape (b, n1, D)
54
+ latent (torch.Tensor): latent features
55
+ shape (b, n2, D)
56
+ """
57
+ x = self.norm1(x)
58
+ latents = self.norm2(latents)
59
+
60
+ b, l, _ = latents.shape
61
+
62
+ q = self.to_q(latents)
63
+ kv_input = torch.cat((x, latents), dim=-2)
64
+ k, v = self.to_kv(kv_input).chunk(2, dim=-1)
65
+
66
+ q = reshape_tensor(q, self.heads)
67
+ k = reshape_tensor(k, self.heads)
68
+ v = reshape_tensor(v, self.heads)
69
+
70
+ # attention
71
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
72
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
73
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
74
+ out = weight @ v
75
+
76
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
77
+
78
+ return self.to_out(out)
79
+
80
+
81
+ class CrossAttention(nn.Module):
82
+ def __init__(self, *, dim, dim_head=64, heads=8):
83
+ super().__init__()
84
+ self.scale = dim_head**-0.5
85
+ self.dim_head = dim_head
86
+ self.heads = heads
87
+ inner_dim = dim_head * heads
88
+
89
+ self.norm1 = nn.LayerNorm(dim)
90
+ self.norm2 = nn.LayerNorm(dim)
91
+
92
+ self.to_q = nn.Linear(dim, inner_dim, bias=False)
93
+ self.to_k = nn.Linear(dim, inner_dim, bias=False)
94
+ self.to_v = nn.Linear(dim, inner_dim, bias=False)
95
+ self.to_out = nn.Linear(inner_dim, dim, bias=False)
96
+
97
+
98
+ def forward(self, x, x2):
99
+ """
100
+ Args:
101
+ x (torch.Tensor): image features
102
+ shape (b, n1, D)
103
+ latent (torch.Tensor): latent features
104
+ shape (b, n2, D)
105
+ """
106
+ x = self.norm1(x)
107
+ x2 = self.norm2(x2)
108
+
109
+ b, l, _ = x2.shape
110
+
111
+ q = self.to_q(x)
112
+ k = self.to_k(x2)
113
+ v = self.to_v(x2)
114
+
115
+ q = reshape_tensor(q, self.heads)
116
+ k = reshape_tensor(k, self.heads)
117
+ v = reshape_tensor(v, self.heads)
118
+
119
+ # attention
120
+ scale = 1 / math.sqrt(math.sqrt(self.dim_head))
121
+ weight = (q * scale) @ (k * scale).transpose(-2, -1) # More stable with f16 than dividing afterwards
122
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
123
+ out = weight @ v
124
+
125
+ out = out.permute(0, 2, 1, 3).reshape(b, l, -1)
126
+ return self.to_out(out)
127
+
128
+
129
+ class Resampler(nn.Module):
130
+ def __init__(
131
+ self,
132
+ dim=1024,
133
+ depth=8,
134
+ dim_head=64,
135
+ heads=16,
136
+ num_queries=8,
137
+ embedding_dim=768,
138
+ output_dim=1024,
139
+ ff_mult=4,
140
+ max_seq_len: int = 257, # CLIP tokens + CLS token
141
+ apply_pos_emb: bool = False,
142
+ num_latents_mean_pooled: int = 0, # number of latents derived from mean pooled representation of the sequence
143
+ ):
144
+ super().__init__()
145
+
146
+ self.latents = nn.Parameter(torch.randn(1, num_queries, dim) / dim**0.5)
147
+
148
+ self.proj_in = nn.Linear(embedding_dim, dim)
149
+
150
+ self.proj_out = nn.Linear(dim, output_dim)
151
+ self.norm_out = nn.LayerNorm(output_dim)
152
+
153
+ self.layers = nn.ModuleList([])
154
+ for _ in range(depth):
155
+ self.layers.append(
156
+ nn.ModuleList(
157
+ [
158
+ PerceiverAttention(dim=dim, dim_head=dim_head, heads=heads),
159
+ FeedForward(dim=dim, mult=ff_mult),
160
+ ]
161
+ )
162
+ )
163
+
164
+ def forward(self, x):
165
+
166
+ latents = self.latents.repeat(x.size(0), 1, 1)
167
+
168
+ x = self.proj_in(x)
169
+
170
+
171
+ for attn, ff in self.layers:
172
+ latents = attn(x, latents) + latents
173
+ latents = ff(latents) + latents
174
+
175
+ latents = self.proj_out(latents)
176
+ return self.norm_out(latents)
177
+
178
+
179
+
180
+ def masked_mean(t, *, dim, mask=None):
181
+ if mask is None:
182
+ return t.mean(dim=dim)
183
+
184
+ denom = mask.sum(dim=dim, keepdim=True)
185
+ mask = rearrange(mask, "b n -> b n 1")
186
+ masked_t = t.masked_fill(~mask, 0.0)
187
+
188
+ return masked_t.sum(dim=dim) / denom.clamp(min=1e-5)
ip_adapter/test_resampler.py ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from resampler import Resampler
3
+ from transformers import CLIPVisionModel
4
+
5
+ BATCH_SIZE = 2
6
+ OUTPUT_DIM = 1280
7
+ NUM_QUERIES = 8
8
+ NUM_LATENTS_MEAN_POOLED = 4 # 0 for no mean pooling (previous behavior)
9
+ APPLY_POS_EMB = True # False for no positional embeddings (previous behavior)
10
+ IMAGE_ENCODER_NAME_OR_PATH = "laion/CLIP-ViT-H-14-laion2B-s32B-b79K"
11
+
12
+
13
+ def main():
14
+ image_encoder = CLIPVisionModel.from_pretrained(IMAGE_ENCODER_NAME_OR_PATH)
15
+ embedding_dim = image_encoder.config.hidden_size
16
+ print(f"image_encoder hidden size: ", embedding_dim)
17
+
18
+ image_proj_model = Resampler(
19
+ dim=1024,
20
+ depth=2,
21
+ dim_head=64,
22
+ heads=16,
23
+ num_queries=NUM_QUERIES,
24
+ embedding_dim=embedding_dim,
25
+ output_dim=OUTPUT_DIM,
26
+ ff_mult=2,
27
+ max_seq_len=257,
28
+ apply_pos_emb=APPLY_POS_EMB,
29
+ num_latents_mean_pooled=NUM_LATENTS_MEAN_POOLED,
30
+ )
31
+
32
+ dummy_images = torch.randn(BATCH_SIZE, 3, 224, 224)
33
+ with torch.no_grad():
34
+ image_embeds = image_encoder(dummy_images, output_hidden_states=True).hidden_states[-2]
35
+ print("image_embds shape: ", image_embeds.shape)
36
+
37
+ with torch.no_grad():
38
+ ip_tokens = image_proj_model(image_embeds)
39
+ print("ip_tokens shape:", ip_tokens.shape)
40
+ assert ip_tokens.shape == (BATCH_SIZE, NUM_QUERIES + NUM_LATENTS_MEAN_POOLED, OUTPUT_DIM)
41
+
42
+
43
+ if __name__ == "__main__":
44
+ main()
ip_adapter/utils.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ import torch.nn.functional as F
2
+
3
+
4
+ def is_torch2_available():
5
+ return hasattr(F, "scaled_dot_product_attention")