ashawkey commited on
Commit
3a2ea0a
1 Parent(s): af95a32

both worked... but imagedream quality is unmatched

Browse files
.gitignore CHANGED
@@ -3,6 +3,6 @@
3
  **/__pycache__
4
  *.pyc
5
 
6
- weights
7
  models
8
  sd-v2*
 
3
  **/__pycache__
4
  *.pyc
5
 
6
+ weights*
7
  models
8
  sd-v2*
README.md CHANGED
@@ -3,6 +3,8 @@
3
  modified from https://github.com/KokeCacao/mvdream-hf.
4
 
5
  ### convert weights
 
 
6
  ```bash
7
  # dependency
8
  pip install -U omegaconf diffusers safetensors huggingface_hub transformers accelerate
@@ -17,33 +19,20 @@ cd ..
17
  python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view.pt --dump_path ./weights_mvdream --original_config_file models/sd-v2-base.yaml --half --to_safetensors --test
18
  ```
19
 
 
20
  ```bash
21
  # download original ckpt
22
  wget https://huggingface.co/Peng-Wang/ImageDream/resolve/main/sd-v2.1-base-4view-ipmv-local.pt
23
  wget https://raw.githubusercontent.com/bytedance/ImageDream/main/extern/ImageDream/imagedream/configs/sd_v2_base_ipmv_local.yaml
24
 
25
  # convert
26
- python convert_imagedream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view-ipmv-local.pt --dump_path ./weights_imagedream --original_config_file models/sd-v2-base_ipmv_local.yaml --half --to_safetensors --test
27
  ```
28
 
29
  ### usage
30
 
31
  example:
32
  ```bash
33
- python main.py "a cute owl"
34
- ```
35
-
36
- detailed usage:
37
- ```python
38
- import torch
39
- import kiui
40
- from mvdream.pipeline_mvdream import MVDreamPipeline
41
-
42
- pipe = MVDreamPipeline.from_pretrained('./weights', torch_dtype=torch.float16)
43
- pipe = pipe.to("cuda")
44
-
45
- prompt = "a photo of an astronaut riding a horse on mars"
46
- image = pipe(prompt) # np.ndarray [4, 256, 256, 3]
47
-
48
- kiui.vis.plot_image(image)
49
  ```
 
3
  modified from https://github.com/KokeCacao/mvdream-hf.
4
 
5
  ### convert weights
6
+
7
+ MVDream:
8
  ```bash
9
  # dependency
10
  pip install -U omegaconf diffusers safetensors huggingface_hub transformers accelerate
 
19
  python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view.pt --dump_path ./weights_mvdream --original_config_file models/sd-v2-base.yaml --half --to_safetensors --test
20
  ```
21
 
22
+ ImageDream:
23
  ```bash
24
  # download original ckpt
25
  wget https://huggingface.co/Peng-Wang/ImageDream/resolve/main/sd-v2.1-base-4view-ipmv-local.pt
26
  wget https://raw.githubusercontent.com/bytedance/ImageDream/main/extern/ImageDream/imagedream/configs/sd_v2_base_ipmv_local.yaml
27
 
28
  # convert
29
+ python convert_mvdream_to_diffusers.py --checkpoint_path models/sd-v2.1-base-4view-ipmv-local.pt --dump_path ./weights_imagedream --original_config_file models/sd_v2_base_ipmv_local.yaml --half --to_safetensors --test
30
  ```
31
 
32
  ### usage
33
 
34
  example:
35
  ```bash
36
+ python run_mvdream.py "a cute owl"
37
+ python run_imagedream.py data/anya_rgba.png
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  ```
convert_imagedream_to_diffusers.py DELETED
@@ -1,561 +0,0 @@
1
- # Modified from https://github.com/huggingface/diffusers/blob/bc691231360a4cbc7d19a58742ebb8ed0f05e027/scripts/convert_original_stable_diffusion_to_diffusers.py
2
-
3
- import argparse
4
- import torch
5
- import sys
6
-
7
- sys.path.insert(0, ".")
8
-
9
- from diffusers.models import (
10
- AutoencoderKL,
11
- )
12
- from omegaconf import OmegaConf
13
- from diffusers.schedulers import DDIMScheduler
14
- from diffusers.utils import logging
15
- from typing import Any
16
- from accelerate import init_empty_weights
17
- from accelerate.utils import set_module_tensor_to_device
18
- from imagedream.models import MultiViewUNetModel
19
- from imagedream.pipeline_imagedream import ImageDreamPipeline
20
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPFeatureExtractor
21
-
22
- logger = logging.get_logger(__name__)
23
-
24
-
25
- def assign_to_checkpoint(
26
- paths,
27
- checkpoint,
28
- old_checkpoint,
29
- attention_paths_to_split=None,
30
- additional_replacements=None,
31
- config=None,
32
- ):
33
- """
34
- This does the final conversion step: take locally converted weights and apply a global renaming to them. It splits
35
- attention layers, and takes into account additional replacements that may arise.
36
- Assigns the weights to the new checkpoint.
37
- """
38
- assert isinstance(
39
- paths, list
40
- ), "Paths should be a list of dicts containing 'old' and 'new' keys."
41
-
42
- # Splits the attention layers into three variables.
43
- if attention_paths_to_split is not None:
44
- for path, path_map in attention_paths_to_split.items():
45
- old_tensor = old_checkpoint[path]
46
- channels = old_tensor.shape[0] // 3
47
-
48
- target_shape = (-1, channels) if len(old_tensor.shape) == 3 else (-1)
49
-
50
- assert config is not None
51
- num_heads = old_tensor.shape[0] // config["num_head_channels"] // 3
52
-
53
- old_tensor = old_tensor.reshape(
54
- (num_heads, 3 * channels // num_heads) + old_tensor.shape[1:]
55
- )
56
- query, key, value = old_tensor.split(channels // num_heads, dim=1)
57
-
58
- checkpoint[path_map["query"]] = query.reshape(target_shape)
59
- checkpoint[path_map["key"]] = key.reshape(target_shape)
60
- checkpoint[path_map["value"]] = value.reshape(target_shape)
61
-
62
- for path in paths:
63
- new_path = path["new"]
64
-
65
- # These have already been assigned
66
- if (
67
- attention_paths_to_split is not None
68
- and new_path in attention_paths_to_split
69
- ):
70
- continue
71
-
72
- # Global renaming happens here
73
- new_path = new_path.replace("middle_block.0", "mid_block.resnets.0")
74
- new_path = new_path.replace("middle_block.1", "mid_block.attentions.0")
75
- new_path = new_path.replace("middle_block.2", "mid_block.resnets.1")
76
-
77
- if additional_replacements is not None:
78
- for replacement in additional_replacements:
79
- new_path = new_path.replace(replacement["old"], replacement["new"])
80
-
81
- # proj_attn.weight has to be converted from conv 1D to linear
82
- is_attn_weight = "proj_attn.weight" in new_path or (
83
- "attentions" in new_path and "to_" in new_path
84
- )
85
- shape = old_checkpoint[path["old"]].shape
86
- if is_attn_weight and len(shape) == 3:
87
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
88
- elif is_attn_weight and len(shape) == 4:
89
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0, 0]
90
- else:
91
- checkpoint[new_path] = old_checkpoint[path["old"]]
92
-
93
-
94
- def shave_segments(path, n_shave_prefix_segments=1):
95
- """
96
- Removes segments. Positive values shave the first segments, negative shave the last segments.
97
- """
98
- if n_shave_prefix_segments >= 0:
99
- return ".".join(path.split(".")[n_shave_prefix_segments:])
100
- else:
101
- return ".".join(path.split(".")[:n_shave_prefix_segments])
102
-
103
-
104
- def create_vae_diffusers_config(original_config, image_size: int):
105
- """
106
- Creates a config for the diffusers based on the config of the LDM model.
107
- """
108
- vae_params = original_config.model.params.first_stage_config.params.ddconfig
109
- _ = original_config.model.params.first_stage_config.params.embed_dim
110
-
111
- block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
112
- down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
113
- up_block_types = ["UpDecoderBlock2D"] * len(block_out_channels)
114
-
115
- config = {
116
- "sample_size": image_size,
117
- "in_channels": vae_params.in_channels,
118
- "out_channels": vae_params.out_ch,
119
- "down_block_types": tuple(down_block_types),
120
- "up_block_types": tuple(up_block_types),
121
- "block_out_channels": tuple(block_out_channels),
122
- "latent_channels": vae_params.z_channels,
123
- "layers_per_block": vae_params.num_res_blocks,
124
- }
125
- return config
126
-
127
-
128
- def convert_ldm_vae_checkpoint(checkpoint, config):
129
- # extract state dict for VAE
130
- vae_state_dict = {}
131
- vae_key = "first_stage_model."
132
- keys = list(checkpoint.keys())
133
- for key in keys:
134
- if key.startswith(vae_key):
135
- vae_state_dict[key.replace(vae_key, "")] = checkpoint.get(key)
136
-
137
- new_checkpoint = {}
138
-
139
- new_checkpoint["encoder.conv_in.weight"] = vae_state_dict["encoder.conv_in.weight"]
140
- new_checkpoint["encoder.conv_in.bias"] = vae_state_dict["encoder.conv_in.bias"]
141
- new_checkpoint["encoder.conv_out.weight"] = vae_state_dict[
142
- "encoder.conv_out.weight"
143
- ]
144
- new_checkpoint["encoder.conv_out.bias"] = vae_state_dict["encoder.conv_out.bias"]
145
- new_checkpoint["encoder.conv_norm_out.weight"] = vae_state_dict[
146
- "encoder.norm_out.weight"
147
- ]
148
- new_checkpoint["encoder.conv_norm_out.bias"] = vae_state_dict[
149
- "encoder.norm_out.bias"
150
- ]
151
-
152
- new_checkpoint["decoder.conv_in.weight"] = vae_state_dict["decoder.conv_in.weight"]
153
- new_checkpoint["decoder.conv_in.bias"] = vae_state_dict["decoder.conv_in.bias"]
154
- new_checkpoint["decoder.conv_out.weight"] = vae_state_dict[
155
- "decoder.conv_out.weight"
156
- ]
157
- new_checkpoint["decoder.conv_out.bias"] = vae_state_dict["decoder.conv_out.bias"]
158
- new_checkpoint["decoder.conv_norm_out.weight"] = vae_state_dict[
159
- "decoder.norm_out.weight"
160
- ]
161
- new_checkpoint["decoder.conv_norm_out.bias"] = vae_state_dict[
162
- "decoder.norm_out.bias"
163
- ]
164
-
165
- new_checkpoint["quant_conv.weight"] = vae_state_dict["quant_conv.weight"]
166
- new_checkpoint["quant_conv.bias"] = vae_state_dict["quant_conv.bias"]
167
- new_checkpoint["post_quant_conv.weight"] = vae_state_dict["post_quant_conv.weight"]
168
- new_checkpoint["post_quant_conv.bias"] = vae_state_dict["post_quant_conv.bias"]
169
-
170
- # Retrieves the keys for the encoder down blocks only
171
- num_down_blocks = len(
172
- {
173
- ".".join(layer.split(".")[:3])
174
- for layer in vae_state_dict
175
- if "encoder.down" in layer
176
- }
177
- )
178
- down_blocks = {
179
- layer_id: [key for key in vae_state_dict if f"down.{layer_id}" in key]
180
- for layer_id in range(num_down_blocks)
181
- }
182
-
183
- # Retrieves the keys for the decoder up blocks only
184
- num_up_blocks = len(
185
- {
186
- ".".join(layer.split(".")[:3])
187
- for layer in vae_state_dict
188
- if "decoder.up" in layer
189
- }
190
- )
191
- up_blocks = {
192
- layer_id: [key for key in vae_state_dict if f"up.{layer_id}" in key]
193
- for layer_id in range(num_up_blocks)
194
- }
195
-
196
- for i in range(num_down_blocks):
197
- resnets = [
198
- key
199
- for key in down_blocks[i]
200
- if f"down.{i}" in key and f"down.{i}.downsample" not in key
201
- ]
202
-
203
- if f"encoder.down.{i}.downsample.conv.weight" in vae_state_dict:
204
- new_checkpoint[
205
- f"encoder.down_blocks.{i}.downsamplers.0.conv.weight"
206
- ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.weight")
207
- new_checkpoint[
208
- f"encoder.down_blocks.{i}.downsamplers.0.conv.bias"
209
- ] = vae_state_dict.pop(f"encoder.down.{i}.downsample.conv.bias")
210
-
211
- paths = renew_vae_resnet_paths(resnets)
212
- meta_path = {"old": f"down.{i}.block", "new": f"down_blocks.{i}.resnets"}
213
- assign_to_checkpoint(
214
- paths,
215
- new_checkpoint,
216
- vae_state_dict,
217
- additional_replacements=[meta_path],
218
- config=config,
219
- )
220
-
221
- mid_resnets = [key for key in vae_state_dict if "encoder.mid.block" in key]
222
- num_mid_res_blocks = 2
223
- for i in range(1, num_mid_res_blocks + 1):
224
- resnets = [key for key in mid_resnets if f"encoder.mid.block_{i}" in key]
225
-
226
- paths = renew_vae_resnet_paths(resnets)
227
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
228
- assign_to_checkpoint(
229
- paths,
230
- new_checkpoint,
231
- vae_state_dict,
232
- additional_replacements=[meta_path],
233
- config=config,
234
- )
235
-
236
- mid_attentions = [key for key in vae_state_dict if "encoder.mid.attn" in key]
237
- paths = renew_vae_attention_paths(mid_attentions)
238
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
239
- assign_to_checkpoint(
240
- paths,
241
- new_checkpoint,
242
- vae_state_dict,
243
- additional_replacements=[meta_path],
244
- config=config,
245
- )
246
- conv_attn_to_linear(new_checkpoint)
247
-
248
- for i in range(num_up_blocks):
249
- block_id = num_up_blocks - 1 - i
250
- resnets = [
251
- key
252
- for key in up_blocks[block_id]
253
- if f"up.{block_id}" in key and f"up.{block_id}.upsample" not in key
254
- ]
255
-
256
- if f"decoder.up.{block_id}.upsample.conv.weight" in vae_state_dict:
257
- new_checkpoint[
258
- f"decoder.up_blocks.{i}.upsamplers.0.conv.weight"
259
- ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.weight"]
260
- new_checkpoint[
261
- f"decoder.up_blocks.{i}.upsamplers.0.conv.bias"
262
- ] = vae_state_dict[f"decoder.up.{block_id}.upsample.conv.bias"]
263
-
264
- paths = renew_vae_resnet_paths(resnets)
265
- meta_path = {"old": f"up.{block_id}.block", "new": f"up_blocks.{i}.resnets"}
266
- assign_to_checkpoint(
267
- paths,
268
- new_checkpoint,
269
- vae_state_dict,
270
- additional_replacements=[meta_path],
271
- config=config,
272
- )
273
-
274
- mid_resnets = [key for key in vae_state_dict if "decoder.mid.block" in key]
275
- num_mid_res_blocks = 2
276
- for i in range(1, num_mid_res_blocks + 1):
277
- resnets = [key for key in mid_resnets if f"decoder.mid.block_{i}" in key]
278
-
279
- paths = renew_vae_resnet_paths(resnets)
280
- meta_path = {"old": f"mid.block_{i}", "new": f"mid_block.resnets.{i - 1}"}
281
- assign_to_checkpoint(
282
- paths,
283
- new_checkpoint,
284
- vae_state_dict,
285
- additional_replacements=[meta_path],
286
- config=config,
287
- )
288
-
289
- mid_attentions = [key for key in vae_state_dict if "decoder.mid.attn" in key]
290
- paths = renew_vae_attention_paths(mid_attentions)
291
- meta_path = {"old": "mid.attn_1", "new": "mid_block.attentions.0"}
292
- assign_to_checkpoint(
293
- paths,
294
- new_checkpoint,
295
- vae_state_dict,
296
- additional_replacements=[meta_path],
297
- config=config,
298
- )
299
- conv_attn_to_linear(new_checkpoint)
300
- return new_checkpoint
301
-
302
-
303
- def renew_vae_resnet_paths(old_list, n_shave_prefix_segments=0):
304
- """
305
- Updates paths inside resnets to the new naming scheme (local renaming)
306
- """
307
- mapping = []
308
- for old_item in old_list:
309
- new_item = old_item
310
-
311
- new_item = new_item.replace("nin_shortcut", "conv_shortcut")
312
- new_item = shave_segments(
313
- new_item, n_shave_prefix_segments=n_shave_prefix_segments
314
- )
315
-
316
- mapping.append({"old": old_item, "new": new_item})
317
-
318
- return mapping
319
-
320
-
321
- def renew_vae_attention_paths(old_list, n_shave_prefix_segments=0):
322
- """
323
- Updates paths inside attentions to the new naming scheme (local renaming)
324
- """
325
- mapping = []
326
- for old_item in old_list:
327
- new_item = old_item
328
-
329
- new_item = new_item.replace("norm.weight", "group_norm.weight")
330
- new_item = new_item.replace("norm.bias", "group_norm.bias")
331
-
332
- new_item = new_item.replace("q.weight", "to_q.weight")
333
- new_item = new_item.replace("q.bias", "to_q.bias")
334
-
335
- new_item = new_item.replace("k.weight", "to_k.weight")
336
- new_item = new_item.replace("k.bias", "to_k.bias")
337
-
338
- new_item = new_item.replace("v.weight", "to_v.weight")
339
- new_item = new_item.replace("v.bias", "to_v.bias")
340
-
341
- new_item = new_item.replace("proj_out.weight", "to_out.0.weight")
342
- new_item = new_item.replace("proj_out.bias", "to_out.0.bias")
343
-
344
- new_item = shave_segments(
345
- new_item, n_shave_prefix_segments=n_shave_prefix_segments
346
- )
347
-
348
- mapping.append({"old": old_item, "new": new_item})
349
-
350
- return mapping
351
-
352
-
353
- def conv_attn_to_linear(checkpoint):
354
- keys = list(checkpoint.keys())
355
- attn_keys = ["query.weight", "key.weight", "value.weight"]
356
- for key in keys:
357
- if ".".join(key.split(".")[-2:]) in attn_keys:
358
- if checkpoint[key].ndim > 2:
359
- checkpoint[key] = checkpoint[key][:, :, 0, 0]
360
- elif "proj_attn.weight" in key:
361
- if checkpoint[key].ndim > 2:
362
- checkpoint[key] = checkpoint[key][:, :, 0]
363
-
364
-
365
- def create_unet_config(original_config) -> Any:
366
- return OmegaConf.to_container(
367
- original_config.model.params.unet_config.params, resolve=True
368
- )
369
-
370
-
371
- def convert_from_original_imagedream_ckpt(checkpoint_path, original_config_file, device):
372
- checkpoint = torch.load(checkpoint_path, map_location=device)
373
- # print(f"Checkpoint: {checkpoint.keys()}")
374
- torch.cuda.empty_cache()
375
-
376
- original_config = OmegaConf.load(original_config_file)
377
- # print(f"Original Config: {original_config}")
378
- prediction_type = "epsilon"
379
- image_size = 256
380
- num_train_timesteps = (
381
- getattr(original_config.model.params, "timesteps", None) or 1000
382
- )
383
- beta_start = getattr(original_config.model.params, "linear_start", None) or 0.02
384
- beta_end = getattr(original_config.model.params, "linear_end", None) or 0.085
385
- scheduler = DDIMScheduler(
386
- beta_end=beta_end,
387
- beta_schedule="scaled_linear",
388
- beta_start=beta_start,
389
- num_train_timesteps=num_train_timesteps,
390
- steps_offset=1,
391
- clip_sample=False,
392
- set_alpha_to_one=False,
393
- prediction_type=prediction_type,
394
- )
395
- scheduler.register_to_config(clip_sample=False)
396
-
397
- # Convert the UNet2DConditionModel model.
398
- # upcast_attention = None
399
- # unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
400
- # unet_config["upcast_attention"] = upcast_attention
401
- # with init_empty_weights():
402
- # unet = UNet2DConditionModel(**unet_config)
403
- # converted_unet_checkpoint = convert_ldm_unet_checkpoint(
404
- # checkpoint, unet_config, path=None, extract_ema=extract_ema
405
- # )
406
- # print(f"Unet Config: {original_config.model.params.unet_config.params}")
407
- unet_config = create_unet_config(original_config)
408
-
409
- # remove unused configs
410
- del unet_config['legacy']
411
- del unet_config['use_linear_in_transformer']
412
- del unet_config['use_spatial_transformer']
413
- del unet_config['ip_mode']
414
-
415
- unet = MultiViewUNetModel(**unet_config)
416
- unet.register_to_config(**unet_config)
417
- # print(f"Unet State Dict: {unet.state_dict().keys()}")
418
- unet.load_state_dict(
419
- {
420
- key.replace("model.diffusion_model.", ""): value
421
- for key, value in checkpoint.items()
422
- if key.replace("model.diffusion_model.", "") in unet.state_dict()
423
- }
424
- )
425
- for param_name, param in unet.state_dict().items():
426
- set_module_tensor_to_device(unet, param_name, device=device, value=param)
427
-
428
- # Convert the VAE model.
429
- vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
430
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
431
-
432
- if (
433
- "model" in original_config
434
- and "params" in original_config.model
435
- and "scale_factor" in original_config.model.params
436
- ):
437
- vae_scaling_factor = original_config.model.params.scale_factor
438
- else:
439
- vae_scaling_factor = 0.18215 # default SD scaling factor
440
-
441
- vae_config["scaling_factor"] = vae_scaling_factor
442
-
443
- with init_empty_weights():
444
- vae = AutoencoderKL(**vae_config)
445
-
446
- for param_name, param in converted_vae_checkpoint.items():
447
- set_module_tensor_to_device(vae, param_name, device=device, value=param)
448
-
449
-
450
- tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
451
- text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=device) # type: ignore
452
-
453
- # this is the clip used by sd2.1
454
- feature_extractor: CLIPFeatureExtractor = CLIPFeatureExtractor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
455
- image_encoder: CLIPVisionModel = CLIPVisionModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
456
-
457
- pipe = ImageDreamPipeline(
458
- vae=vae,
459
- unet=unet,
460
- tokenizer=tokenizer,
461
- text_encoder=text_encoder,
462
- scheduler=scheduler,
463
- feature_extractor=feature_extractor,
464
- image_encoder=image_encoder,
465
- )
466
-
467
- return pipe
468
-
469
-
470
- if __name__ == "__main__":
471
- parser = argparse.ArgumentParser()
472
-
473
- parser.add_argument(
474
- "--checkpoint_path",
475
- default=None,
476
- type=str,
477
- required=True,
478
- help="Path to the checkpoint to convert.",
479
- )
480
- parser.add_argument(
481
- "--original_config_file",
482
- default=None,
483
- type=str,
484
- help="The YAML config file corresponding to the original architecture.",
485
- )
486
- parser.add_argument(
487
- "--to_safetensors",
488
- action="store_true",
489
- help="Whether to store pipeline in safetensors format or not.",
490
- )
491
- parser.add_argument(
492
- "--half", action="store_true", help="Save weights in half precision."
493
- )
494
- parser.add_argument(
495
- "--test",
496
- action="store_true",
497
- help="Whether to test inference after convertion.",
498
- )
499
- parser.add_argument(
500
- "--dump_path",
501
- default=None,
502
- type=str,
503
- required=True,
504
- help="Path to the output model.",
505
- )
506
- parser.add_argument(
507
- "--device", type=str, help="Device to use (e.g. cpu, cuda:0, cuda:1, etc.)"
508
- )
509
- args = parser.parse_args()
510
-
511
- args.device = torch.device(
512
- args.device
513
- if args.device is not None
514
- else "cuda"
515
- if torch.cuda.is_available()
516
- else "cpu"
517
- )
518
-
519
- pipe = convert_from_original_imagedream_ckpt(
520
- checkpoint_path=args.checkpoint_path,
521
- original_config_file=args.original_config_file,
522
- device=args.device,
523
- )
524
-
525
- if args.half:
526
- pipe.to(torch_dtype=torch.float16)
527
-
528
- print(f"Saving pipeline to {args.dump_path}...")
529
- pipe.save_pretrained(args.dump_path, safe_serialization=args.to_safetensors)
530
-
531
- # TODO: input image...
532
- if args.test:
533
- try:
534
- print(f"Testing each subcomponent of the pipeline...")
535
- images = pipe(
536
- prompt="Head of Hatsune Miku",
537
- negative_prompt="painting, bad quality, flat",
538
- output_type="pil",
539
- guidance_scale=7.5,
540
- num_inference_steps=50,
541
- device=args.device,
542
- )
543
- for i, image in enumerate(images):
544
- image.save(f"image_{i}.png") # type: ignore
545
-
546
- print(f"Testing entire pipeline...")
547
- loaded_pipe = ImageDreamPipeline.from_pretrained(args.dump_path, safe_serialization=args.to_safetensors) # type: ignore
548
- images = loaded_pipe(
549
- prompt="Head of Hatsune Miku",
550
- negative_prompt="painting, bad quality, flat",
551
- output_type="pil",
552
- guidance_scale=7.5,
553
- num_inference_steps=50,
554
- device=args.device,
555
- )
556
- for i, image in enumerate(images):
557
- image.save(f"image_{i}.png") # type: ignore
558
- except Exception as e:
559
- print(f"Failed to test inference: {e}")
560
- raise e from e
561
- print("Inference test passed!")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
convert_mvdream_to_diffusers.py CHANGED
@@ -17,7 +17,9 @@ from accelerate import init_empty_weights
17
  from accelerate.utils import set_module_tensor_to_device
18
  from mvdream.models import MultiViewUNetModel
19
  from mvdream.pipeline_mvdream import MVDreamPipeline
20
- from transformers import CLIPTokenizer, CLIPTextModel
 
 
21
 
22
  logger = logging.get_logger(__name__)
23
 
@@ -101,12 +103,20 @@ def shave_segments(path, n_shave_prefix_segments=1):
101
  return ".".join(path.split(".")[:n_shave_prefix_segments])
102
 
103
 
104
- def create_vae_diffusers_config(original_config, image_size: int):
105
  """
106
  Creates a config for the diffusers based on the config of the LDM model.
107
  """
108
- vae_params = original_config.model.params.first_stage_config.params.ddconfig
109
- _ = original_config.model.params.first_stage_config.params.embed_dim
 
 
 
 
 
 
 
 
110
 
111
  block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
112
  down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
@@ -122,13 +132,12 @@ def create_vae_diffusers_config(original_config, image_size: int):
122
  "latent_channels": vae_params.z_channels,
123
  "layers_per_block": vae_params.num_res_blocks,
124
  }
125
- return config
126
 
127
 
128
- def convert_ldm_vae_checkpoint(checkpoint, config):
129
  # extract state dict for VAE
130
  vae_state_dict = {}
131
- vae_key = "first_stage_model."
132
  keys = list(checkpoint.keys())
133
  for key in keys:
134
  if key.startswith(vae_key):
@@ -394,22 +403,15 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
394
  )
395
  scheduler.register_to_config(clip_sample=False)
396
 
397
- # Convert the UNet2DConditionModel model.
398
- # upcast_attention = None
399
- # unet_config = create_unet_diffusers_config(original_config, image_size=image_size)
400
- # unet_config["upcast_attention"] = upcast_attention
401
- # with init_empty_weights():
402
- # unet = UNet2DConditionModel(**unet_config)
403
- # converted_unet_checkpoint = convert_ldm_unet_checkpoint(
404
- # checkpoint, unet_config, path=None, extract_ema=extract_ema
405
- # )
406
- # print(f"Unet Config: {original_config.model.params.unet_config.params}")
407
  unet_config = create_unet_config(original_config)
408
 
409
  # remove unused configs
410
- del unet_config['legacy']
411
- del unet_config['use_linear_in_transformer']
412
- del unet_config['use_spatial_transformer']
 
 
 
413
 
414
  unet = MultiViewUNetModel(**unet_config)
415
  unet.register_to_config(**unet_config)
@@ -425,8 +427,8 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
425
  set_module_tensor_to_device(unet, param_name, device=device, value=param)
426
 
427
  # Convert the VAE model.
428
- vae_config = create_vae_diffusers_config(original_config, image_size=image_size)
429
- converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config)
430
 
431
  if (
432
  "model" in original_config
@@ -445,20 +447,17 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
445
  for param_name, param in converted_vae_checkpoint.items():
446
  set_module_tensor_to_device(vae, param_name, device=device, value=param)
447
 
448
- if original_config.model.params.unet_config.params.context_dim == 768:
449
- tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
450
- "openai/clip-vit-large-patch14"
451
- )
452
- text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14").to(device=device) # type: ignore
453
- elif original_config.model.params.unet_config.params.context_dim == 1024:
454
- tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(
455
- "stabilityai/stable-diffusion-2-1", subfolder="tokenizer"
456
- )
457
- text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=device) # type: ignore
458
  else:
459
- raise ValueError(
460
- f"Unknown context_dim: {original_config.model.paams.unet_config.params.context_dim}"
461
- )
462
 
463
  pipe = MVDreamPipeline(
464
  vae=vae,
@@ -466,6 +465,8 @@ def convert_from_original_mvdream_ckpt(checkpoint_path, original_config_file, de
466
  tokenizer=tokenizer,
467
  text_encoder=text_encoder,
468
  scheduler=scheduler,
 
 
469
  )
470
 
471
  return pipe
@@ -534,31 +535,63 @@ if __name__ == "__main__":
534
 
535
  if args.test:
536
  try:
537
- print(f"Testing each subcomponent of the pipeline...")
538
- images = pipe(
539
- prompt="Head of Hatsune Miku",
540
- negative_prompt="painting, bad quality, flat",
541
- output_type="pil",
542
- guidance_scale=7.5,
543
- num_inference_steps=50,
544
- device=args.device,
545
- )
546
- for i, image in enumerate(images):
547
- image.save(f"image_{i}.png") # type: ignore
548
-
549
- print(f"Testing entire pipeline...")
550
- loaded_pipe: MVDreamPipeline = MVDreamPipeline.from_pretrained(args.dump_path, safe_serialization=args.to_safetensors) # type: ignore
551
- images = loaded_pipe(
552
- prompt="Head of Hatsune Miku",
553
- negative_prompt="painting, bad quality, flat",
554
- output_type="pil",
555
- guidance_scale=7.5,
556
- num_inference_steps=50,
557
- device=args.device,
558
- )
559
- for i, image in enumerate(images):
560
- image.save(f"image_{i}.png") # type: ignore
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561
  except Exception as e:
562
  print(f"Failed to test inference: {e}")
563
- raise e from e
564
- print("Inference test passed!")
 
17
  from accelerate.utils import set_module_tensor_to_device
18
  from mvdream.models import MultiViewUNetModel
19
  from mvdream.pipeline_mvdream import MVDreamPipeline
20
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
21
+
22
+ import kiui
23
 
24
  logger = logging.get_logger(__name__)
25
 
 
103
  return ".".join(path.split(".")[:n_shave_prefix_segments])
104
 
105
 
106
+ def create_vae_diffusers_config(original_config, image_size):
107
  """
108
  Creates a config for the diffusers based on the config of the LDM model.
109
  """
110
+
111
+
112
+ if 'imagedream' in original_config.model.target:
113
+ vae_params = original_config.model.params.vae_config.params.ddconfig
114
+ _ = original_config.model.params.vae_config.params.embed_dim
115
+ vae_key = "vae_model."
116
+ else:
117
+ vae_params = original_config.model.params.first_stage_config.params.ddconfig
118
+ _ = original_config.model.params.first_stage_config.params.embed_dim
119
+ vae_key = "first_stage_model."
120
 
121
  block_out_channels = [vae_params.ch * mult for mult in vae_params.ch_mult]
122
  down_block_types = ["DownEncoderBlock2D"] * len(block_out_channels)
 
132
  "latent_channels": vae_params.z_channels,
133
  "layers_per_block": vae_params.num_res_blocks,
134
  }
135
+ return config, vae_key
136
 
137
 
138
+ def convert_ldm_vae_checkpoint(checkpoint, config, vae_key):
139
  # extract state dict for VAE
140
  vae_state_dict = {}
 
141
  keys = list(checkpoint.keys())
142
  for key in keys:
143
  if key.startswith(vae_key):
 
403
  )
404
  scheduler.register_to_config(clip_sample=False)
405
 
 
 
 
 
 
 
 
 
 
 
406
  unet_config = create_unet_config(original_config)
407
 
408
  # remove unused configs
409
+ unet_config.pop('legacy', None)
410
+ unet_config.pop('use_linear_in_transformer', None)
411
+ unet_config.pop('use_spatial_transformer', None)
412
+
413
+ unet_config.pop('ip_mode', None)
414
+ unet_config.pop('with_ip', None)
415
 
416
  unet = MultiViewUNetModel(**unet_config)
417
  unet.register_to_config(**unet_config)
 
427
  set_module_tensor_to_device(unet, param_name, device=device, value=param)
428
 
429
  # Convert the VAE model.
430
+ vae_config, vae_key = create_vae_diffusers_config(original_config, image_size=image_size)
431
+ converted_vae_checkpoint = convert_ldm_vae_checkpoint(checkpoint, vae_config, vae_key)
432
 
433
  if (
434
  "model" in original_config
 
447
  for param_name, param in converted_vae_checkpoint.items():
448
  set_module_tensor_to_device(vae, param_name, device=device, value=param)
449
 
450
+ # we only supports SD 2.1 based model
451
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="tokenizer")
452
+ text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained("stabilityai/stable-diffusion-2-1", subfolder="text_encoder").to(device=device) # type: ignore
453
+
454
+ # imagedream variant
455
+ if unet.ip_dim > 0:
456
+ feature_extractor: CLIPImageProcessor = CLIPImageProcessor.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
457
+ image_encoder: CLIPVisionModel = CLIPVisionModel.from_pretrained("laion/CLIP-ViT-H-14-laion2B-s32B-b79K")
 
 
458
  else:
459
+ feature_extractor = None
460
+ image_encoder = None
 
461
 
462
  pipe = MVDreamPipeline(
463
  vae=vae,
 
465
  tokenizer=tokenizer,
466
  text_encoder=text_encoder,
467
  scheduler=scheduler,
468
+ feature_extractor=feature_extractor,
469
+ image_encoder=image_encoder,
470
  )
471
 
472
  return pipe
 
535
 
536
  if args.test:
537
  try:
538
+ # mvdream
539
+ if pipe.unet.ip_dim == 0:
540
+ print(f"Testing each subcomponent of the pipeline...")
541
+ images = pipe(
542
+ prompt="Head of Hatsune Miku",
543
+ negative_prompt="painting, bad quality, flat",
544
+ output_type="pil",
545
+ guidance_scale=7.5,
546
+ num_inference_steps=50,
547
+ device=args.device,
548
+ )
549
+ for i, image in enumerate(images):
550
+ image.save(f"test_image_{i}.png") # type: ignore
551
+
552
+ print(f"Testing entire pipeline...")
553
+ loaded_pipe = MVDreamPipeline.from_pretrained(args.dump_path) # type: ignore
554
+ images = loaded_pipe(
555
+ prompt="Head of Hatsune Miku",
556
+ negative_prompt="painting, bad quality, flat",
557
+ output_type="pil",
558
+ guidance_scale=7.5,
559
+ num_inference_steps=50,
560
+ device=args.device,
561
+ )
562
+ for i, image in enumerate(images):
563
+ image.save(f"test_image_{i}.png") # type: ignore
564
+ # imagedream
565
+ else:
566
+ input_image = kiui.read_image('data/anya_rgba.png', mode='float')
567
+ print(f"Testing each subcomponent of the pipeline...")
568
+ images = pipe(
569
+ image=input_image,
570
+ prompt="",
571
+ negative_prompt="painting, bad quality, flat",
572
+ output_type="pil",
573
+ guidance_scale=5.0,
574
+ num_inference_steps=50,
575
+ device=args.device,
576
+ )
577
+ for i, image in enumerate(images):
578
+ image.save(f"test_image_{i}.png") # type: ignore
579
+
580
+ print(f"Testing entire pipeline...")
581
+ loaded_pipe = MVDreamPipeline.from_pretrained(args.dump_path) # type: ignore
582
+ images = loaded_pipe(
583
+ image=input_image,
584
+ prompt="",
585
+ negative_prompt="painting, bad quality, flat",
586
+ output_type="pil",
587
+ guidance_scale=5.0,
588
+ num_inference_steps=50,
589
+ device=args.device,
590
+ )
591
+ for i, image in enumerate(images):
592
+ image.save(f"test_image_{i}.png") # type: ignore
593
+
594
+
595
+ print("Inference test passed!")
596
  except Exception as e:
597
  print(f"Failed to test inference: {e}")
 
 
data/anya_rgba.png ADDED
imagedream/attention.py DELETED
@@ -1,259 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
-
5
- from inspect import isfunction
6
- from einops import rearrange, repeat
7
- from typing import Optional, Any
8
-
9
- # require xformers
10
- import xformers # type: ignore
11
- import xformers.ops # type: ignore
12
-
13
- from .util import checkpoint, zero_module
14
-
15
- def default(val, d):
16
- if val is not None:
17
- return val
18
- return d() if isfunction(d) else d
19
-
20
-
21
- class GEGLU(nn.Module):
22
- def __init__(self, dim_in, dim_out):
23
- super().__init__()
24
- self.proj = nn.Linear(dim_in, dim_out * 2)
25
-
26
- def forward(self, x):
27
- x, gate = self.proj(x).chunk(2, dim=-1)
28
- return x * F.gelu(gate)
29
-
30
-
31
- class FeedForward(nn.Module):
32
- def __init__(self, dim, dim_out=None, mult=4, glu=False, dropout=0.0):
33
- super().__init__()
34
- inner_dim = int(dim * mult)
35
- dim_out = default(dim_out, dim)
36
- project_in = (
37
- nn.Sequential(nn.Linear(dim, inner_dim), nn.GELU())
38
- if not glu
39
- else GEGLU(dim, inner_dim)
40
- )
41
-
42
- self.net = nn.Sequential(
43
- project_in, nn.Dropout(dropout), nn.Linear(inner_dim, dim_out)
44
- )
45
-
46
- def forward(self, x):
47
- return self.net(x)
48
-
49
-
50
- class MemoryEfficientCrossAttention(nn.Module):
51
- # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
52
- def __init__(
53
- self,
54
- query_dim,
55
- context_dim=None,
56
- heads=8,
57
- dim_head=64,
58
- dropout=0.0,
59
- with_ip=False,
60
- ip_dim=16,
61
- ip_weight=1,
62
- ):
63
- super().__init__()
64
-
65
- inner_dim = dim_head * heads
66
- context_dim = default(context_dim, query_dim)
67
-
68
- self.heads = heads
69
- self.dim_head = dim_head
70
-
71
- self.with_ip = with_ip and (context_dim is not None)
72
- self.ip_dim = ip_dim
73
- self.ip_weight = ip_weight
74
-
75
- if self.with_ip:
76
- self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
77
- self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
78
-
79
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
80
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
81
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
82
-
83
- self.to_out = nn.Sequential(
84
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
85
- )
86
- self.attention_op: Optional[Any] = None
87
-
88
- def forward(self, x, context=None):
89
- q = self.to_q(x)
90
- context = default(context, x)
91
-
92
- if self.with_ip:
93
- # context dim [(b frame_num), (77 + img_token), 1024]
94
- token_len = context.shape[1]
95
- context_ip = context[:, -self.ip_dim :, :]
96
- k_ip = self.to_k_ip(context_ip)
97
- v_ip = self.to_v_ip(context_ip)
98
- context = context[:, : (token_len - self.ip_dim), :]
99
-
100
- k = self.to_k(context)
101
- v = self.to_v(context)
102
-
103
- b, _, _ = q.shape
104
- q, k, v = map(
105
- lambda t: t.unsqueeze(3)
106
- .reshape(b, t.shape[1], self.heads, self.dim_head)
107
- .permute(0, 2, 1, 3)
108
- .reshape(b * self.heads, t.shape[1], self.dim_head)
109
- .contiguous(),
110
- (q, k, v),
111
- )
112
-
113
- # actually compute the attention, what we cannot get enough of
114
- out = xformers.ops.memory_efficient_attention(
115
- q, k, v, attn_bias=None, op=self.attention_op
116
- )
117
-
118
- if self.with_ip:
119
- k_ip, v_ip = map(
120
- lambda t: t.unsqueeze(3)
121
- .reshape(b, t.shape[1], self.heads, self.dim_head)
122
- .permute(0, 2, 1, 3)
123
- .reshape(b * self.heads, t.shape[1], self.dim_head)
124
- .contiguous(),
125
- (k_ip, v_ip),
126
- )
127
- # actually compute the attention, what we cannot get enough of
128
- out_ip = xformers.ops.memory_efficient_attention(
129
- q, k_ip, v_ip, attn_bias=None, op=self.attention_op
130
- )
131
- out = out + self.ip_weight * out_ip
132
-
133
- out = (
134
- out.unsqueeze(0)
135
- .reshape(b, self.heads, out.shape[1], self.dim_head)
136
- .permute(0, 2, 1, 3)
137
- .reshape(b, out.shape[1], self.heads * self.dim_head)
138
- )
139
- return self.to_out(out)
140
-
141
-
142
- class BasicTransformerBlock3D(nn.Module):
143
-
144
- def __init__(
145
- self,
146
- dim,
147
- context_dim,
148
- n_heads,
149
- d_head,
150
- dropout=0.0,
151
- gated_ff=True,
152
- checkpoint=True,
153
- with_ip=False,
154
- ip_dim=16,
155
- ip_weight=1,
156
- ):
157
- super().__init__()
158
-
159
- self.attn1 = MemoryEfficientCrossAttention(
160
- query_dim=dim,
161
- context_dim=None, # self-attention
162
- heads=n_heads,
163
- dim_head=d_head,
164
- dropout=dropout,
165
- )
166
- self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
167
- self.attn2 = MemoryEfficientCrossAttention(
168
- query_dim=dim,
169
- context_dim=context_dim,
170
- heads=n_heads,
171
- dim_head=d_head,
172
- dropout=dropout,
173
- # ip only applies to cross-attention
174
- with_ip=with_ip,
175
- ip_dim=ip_dim,
176
- ip_weight=ip_weight,
177
- )
178
- self.norm1 = nn.LayerNorm(dim)
179
- self.norm2 = nn.LayerNorm(dim)
180
- self.norm3 = nn.LayerNorm(dim)
181
- self.checkpoint = checkpoint
182
-
183
- def forward(self, x, context=None, num_frames=1):
184
- return checkpoint(
185
- self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
186
- )
187
-
188
- def _forward(self, x, context=None, num_frames=1):
189
- x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
190
- x = self.attn1(self.norm1(x), context=None) + x
191
- x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
192
- x = self.attn2(self.norm2(x), context=context) + x
193
- x = self.ff(self.norm3(x)) + x
194
- return x
195
-
196
-
197
- class SpatialTransformer3D(nn.Module):
198
-
199
- def __init__(
200
- self,
201
- in_channels,
202
- n_heads,
203
- d_head,
204
- context_dim, # cross attention input dim
205
- depth=1,
206
- dropout=0.0,
207
- with_ip=False,
208
- ip_dim=16,
209
- ip_weight=1,
210
- use_checkpoint=True,
211
- ):
212
- super().__init__()
213
-
214
- if not isinstance(context_dim, list):
215
- context_dim = [context_dim]
216
-
217
- self.in_channels = in_channels
218
-
219
- inner_dim = n_heads * d_head
220
- self.norm = nn.GroupNorm(
221
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
222
- )
223
- self.proj_in = nn.Linear(in_channels, inner_dim)
224
-
225
- self.transformer_blocks = nn.ModuleList(
226
- [
227
- BasicTransformerBlock3D(
228
- inner_dim,
229
- n_heads,
230
- d_head,
231
- context_dim=context_dim[d],
232
- dropout=dropout,
233
- checkpoint=use_checkpoint,
234
- with_ip=with_ip,
235
- ip_dim=ip_dim,
236
- ip_weight=ip_weight,
237
- )
238
- for d in range(depth)
239
- ]
240
- )
241
-
242
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
243
-
244
-
245
- def forward(self, x, context=None, num_frames=1):
246
- # note: if no context is given, cross-attention defaults to self-attention
247
- if not isinstance(context, list):
248
- context = [context]
249
- b, c, h, w = x.shape
250
- x_in = x
251
- x = self.norm(x)
252
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
253
- x = self.proj_in(x)
254
- for i, block in enumerate(self.transformer_blocks):
255
- x = block(x, context=context[i], num_frames=num_frames)
256
- x = self.proj_out(x)
257
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
258
-
259
- return x + x_in
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
imagedream/models.py DELETED
@@ -1,627 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from diffusers.configuration_utils import ConfigMixin
5
- from diffusers.models.modeling_utils import ModelMixin
6
- from typing import Any, List, Optional
7
- from torch import Tensor
8
-
9
- from .util import (
10
- checkpoint,
11
- conv_nd,
12
- avg_pool_nd,
13
- zero_module,
14
- timestep_embedding,
15
- )
16
- from .attention import SpatialTransformer3D
17
- from .adaptor import Resampler, ImageProjModel
18
-
19
- class CondSequential(nn.Sequential):
20
- """
21
- A sequential module that passes timestep embeddings to the children that
22
- support it as an extra input.
23
- """
24
-
25
- def forward(self, x, emb, context=None, num_frames=1):
26
- for layer in self:
27
- if isinstance(layer, ResBlock):
28
- x = layer(x, emb)
29
- elif isinstance(layer, SpatialTransformer3D):
30
- x = layer(x, context, num_frames=num_frames)
31
- else:
32
- x = layer(x)
33
- return x
34
-
35
-
36
- class Upsample(nn.Module):
37
- """
38
- An upsampling layer with an optional convolution.
39
- :param channels: channels in the inputs and outputs.
40
- :param use_conv: a bool determining if a convolution is applied.
41
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
42
- upsampling occurs in the inner-two dimensions.
43
- """
44
-
45
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
46
- super().__init__()
47
- self.channels = channels
48
- self.out_channels = out_channels or channels
49
- self.use_conv = use_conv
50
- self.dims = dims
51
- if use_conv:
52
- self.conv = conv_nd(
53
- dims, self.channels, self.out_channels, 3, padding=padding
54
- )
55
-
56
- def forward(self, x):
57
- assert x.shape[1] == self.channels
58
- if self.dims == 3:
59
- x = F.interpolate(
60
- x, (x.shape[2], x.shape[3] * 2, x.shape[4] * 2), mode="nearest"
61
- )
62
- else:
63
- x = F.interpolate(x, scale_factor=2, mode="nearest")
64
- if self.use_conv:
65
- x = self.conv(x)
66
- return x
67
-
68
-
69
- class Downsample(nn.Module):
70
- """
71
- A downsampling layer with an optional convolution.
72
- :param channels: channels in the inputs and outputs.
73
- :param use_conv: a bool determining if a convolution is applied.
74
- :param dims: determines if the signal is 1D, 2D, or 3D. If 3D, then
75
- downsampling occurs in the inner-two dimensions.
76
- """
77
-
78
- def __init__(self, channels, use_conv, dims=2, out_channels=None, padding=1):
79
- super().__init__()
80
- self.channels = channels
81
- self.out_channels = out_channels or channels
82
- self.use_conv = use_conv
83
- self.dims = dims
84
- stride = 2 if dims != 3 else (1, 2, 2)
85
- if use_conv:
86
- self.op = conv_nd(
87
- dims,
88
- self.channels,
89
- self.out_channels,
90
- 3,
91
- stride=stride,
92
- padding=padding,
93
- )
94
- else:
95
- assert self.channels == self.out_channels
96
- self.op = avg_pool_nd(dims, kernel_size=stride, stride=stride)
97
-
98
- def forward(self, x):
99
- assert x.shape[1] == self.channels
100
- return self.op(x)
101
-
102
-
103
- class ResBlock(nn.Module):
104
- """
105
- A residual block that can optionally change the number of channels.
106
- :param channels: the number of input channels.
107
- :param emb_channels: the number of timestep embedding channels.
108
- :param dropout: the rate of dropout.
109
- :param out_channels: if specified, the number of out channels.
110
- :param use_conv: if True and out_channels is specified, use a spatial
111
- convolution instead of a smaller 1x1 convolution to change the
112
- channels in the skip connection.
113
- :param dims: determines if the signal is 1D, 2D, or 3D.
114
- :param use_checkpoint: if True, use gradient checkpointing on this module.
115
- :param up: if True, use this block for upsampling.
116
- :param down: if True, use this block for downsampling.
117
- """
118
-
119
- def __init__(
120
- self,
121
- channels,
122
- emb_channels,
123
- dropout,
124
- out_channels=None,
125
- use_conv=False,
126
- use_scale_shift_norm=False,
127
- dims=2,
128
- use_checkpoint=False,
129
- up=False,
130
- down=False,
131
- ):
132
- super().__init__()
133
- self.channels = channels
134
- self.emb_channels = emb_channels
135
- self.dropout = dropout
136
- self.out_channels = out_channels or channels
137
- self.use_conv = use_conv
138
- self.use_checkpoint = use_checkpoint
139
- self.use_scale_shift_norm = use_scale_shift_norm
140
-
141
- self.in_layers = nn.Sequential(
142
- nn.GroupNorm(32, channels),
143
- nn.SiLU(),
144
- conv_nd(dims, channels, self.out_channels, 3, padding=1),
145
- )
146
-
147
- self.updown = up or down
148
-
149
- if up:
150
- self.h_upd = Upsample(channels, False, dims)
151
- self.x_upd = Upsample(channels, False, dims)
152
- elif down:
153
- self.h_upd = Downsample(channels, False, dims)
154
- self.x_upd = Downsample(channels, False, dims)
155
- else:
156
- self.h_upd = self.x_upd = nn.Identity()
157
-
158
- self.emb_layers = nn.Sequential(
159
- nn.SiLU(),
160
- nn.Linear(
161
- emb_channels,
162
- 2 * self.out_channels if use_scale_shift_norm else self.out_channels,
163
- ),
164
- )
165
- self.out_layers = nn.Sequential(
166
- nn.GroupNorm(32, self.out_channels),
167
- nn.SiLU(),
168
- nn.Dropout(p=dropout),
169
- zero_module(
170
- conv_nd(dims, self.out_channels, self.out_channels, 3, padding=1)
171
- ),
172
- )
173
-
174
- if self.out_channels == channels:
175
- self.skip_connection = nn.Identity()
176
- elif use_conv:
177
- self.skip_connection = conv_nd(
178
- dims, channels, self.out_channels, 3, padding=1
179
- )
180
- else:
181
- self.skip_connection = conv_nd(dims, channels, self.out_channels, 1)
182
-
183
- def forward(self, x, emb):
184
- """
185
- Apply the block to a Tensor, conditioned on a timestep embedding.
186
- :param x: an [N x C x ...] Tensor of features.
187
- :param emb: an [N x emb_channels] Tensor of timestep embeddings.
188
- :return: an [N x C x ...] Tensor of outputs.
189
- """
190
- return checkpoint(
191
- self._forward, (x, emb), self.parameters(), self.use_checkpoint
192
- )
193
-
194
- def _forward(self, x, emb):
195
- if self.updown:
196
- in_rest, in_conv = self.in_layers[:-1], self.in_layers[-1]
197
- h = in_rest(x)
198
- h = self.h_upd(h)
199
- x = self.x_upd(x)
200
- h = in_conv(h)
201
- else:
202
- h = self.in_layers(x)
203
- emb_out = self.emb_layers(emb).type(h.dtype)
204
- while len(emb_out.shape) < len(h.shape):
205
- emb_out = emb_out[..., None]
206
- if self.use_scale_shift_norm:
207
- out_norm, out_rest = self.out_layers[0], self.out_layers[1:]
208
- scale, shift = torch.chunk(emb_out, 2, dim=1)
209
- h = out_norm(h) * (1 + scale) + shift
210
- h = out_rest(h)
211
- else:
212
- h = h + emb_out
213
- h = self.out_layers(h)
214
- return self.skip_connection(x) + h
215
-
216
-
217
- class MultiViewUNetModel(ModelMixin, ConfigMixin):
218
- """
219
- The full multi-view UNet model with attention, timestep embedding and camera embedding.
220
- :param in_channels: channels in the input Tensor.
221
- :param model_channels: base channel count for the model.
222
- :param out_channels: channels in the output Tensor.
223
- :param num_res_blocks: number of residual blocks per downsample.
224
- :param attention_resolutions: a collection of downsample rates at which
225
- attention will take place. May be a set, list, or tuple.
226
- For example, if this contains 4, then at 4x downsampling, attention
227
- will be used.
228
- :param dropout: the dropout probability.
229
- :param channel_mult: channel multiplier for each level of the UNet.
230
- :param conv_resample: if True, use learned convolutions for upsampling and
231
- downsampling.
232
- :param dims: determines if the signal is 1D, 2D, or 3D.
233
- :param num_classes: if specified (as an int), then this model will be
234
- class-conditional with `num_classes` classes.
235
- :param use_checkpoint: use gradient checkpointing to reduce memory usage.
236
- :param num_heads: the number of attention heads in each attention layer.
237
- :param num_heads_channels: if specified, ignore num_heads and instead use
238
- a fixed channel width per attention head.
239
- :param num_heads_upsample: works with num_heads to set a different number
240
- of heads for upsampling. Deprecated.
241
- :param use_scale_shift_norm: use a FiLM-like conditioning mechanism.
242
- :param resblock_updown: use residual blocks for up/downsampling.
243
- :param use_new_attention_order: use a different attention pattern for potentially
244
- increased efficiency.
245
- :param camera_dim: dimensionality of camera input.
246
- """
247
-
248
- def __init__(
249
- self,
250
- image_size,
251
- in_channels,
252
- model_channels,
253
- out_channels,
254
- num_res_blocks,
255
- attention_resolutions,
256
- dropout=0,
257
- channel_mult=(1, 2, 4, 8),
258
- conv_resample=True,
259
- dims=2,
260
- num_classes=None,
261
- use_checkpoint=False,
262
- num_heads=-1,
263
- num_head_channels=-1,
264
- num_heads_upsample=-1,
265
- use_scale_shift_norm=False,
266
- resblock_updown=False,
267
- transformer_depth=1, # custom transformer support
268
- context_dim=None, # custom transformer support
269
- n_embed=None, # custom support for prediction of discrete ids into codebook of first stage vq model
270
- disable_self_attentions=None,
271
- num_attention_blocks=None,
272
- disable_middle_self_attn=False,
273
- adm_in_channels=None,
274
- camera_dim=None,
275
- with_ip=True,
276
- ip_dim=16,
277
- ip_weight=1.0,
278
- **kwargs,
279
- ):
280
- super().__init__()
281
- assert context_dim is not None
282
-
283
- if num_heads_upsample == -1:
284
- num_heads_upsample = num_heads
285
-
286
- if num_heads == -1:
287
- assert (
288
- num_head_channels != -1
289
- ), "Either num_heads or num_head_channels has to be set"
290
-
291
- if num_head_channels == -1:
292
- assert (
293
- num_heads != -1
294
- ), "Either num_heads or num_head_channels has to be set"
295
-
296
- self.image_size = image_size
297
- self.in_channels = in_channels
298
- self.model_channels = model_channels
299
- self.out_channels = out_channels
300
- if isinstance(num_res_blocks, int):
301
- self.num_res_blocks = len(channel_mult) * [num_res_blocks]
302
- else:
303
- if len(num_res_blocks) != len(channel_mult):
304
- raise ValueError(
305
- "provide num_res_blocks either as an int (globally constant) or "
306
- "as a list/tuple (per-level) with the same length as channel_mult"
307
- )
308
- self.num_res_blocks = num_res_blocks
309
-
310
- if num_attention_blocks is not None:
311
- assert len(num_attention_blocks) == len(self.num_res_blocks)
312
- assert all(
313
- map(
314
- lambda i: self.num_res_blocks[i] >= num_attention_blocks[i],
315
- range(len(num_attention_blocks)),
316
- )
317
- )
318
- print(
319
- f"Constructor of UNetModel received num_attention_blocks={num_attention_blocks}. "
320
- f"This option has LESS priority than attention_resolutions {attention_resolutions}, "
321
- f"i.e., in cases where num_attention_blocks[i] > 0 but 2**i not in attention_resolutions, "
322
- f"attention will still not be set."
323
- )
324
-
325
- self.attention_resolutions = attention_resolutions
326
- self.dropout = dropout
327
- self.channel_mult = channel_mult
328
- self.conv_resample = conv_resample
329
- self.num_classes = num_classes
330
- self.use_checkpoint = use_checkpoint
331
- self.num_heads = num_heads
332
- self.num_head_channels = num_head_channels
333
- self.num_heads_upsample = num_heads_upsample
334
- self.predict_codebook_ids = n_embed is not None
335
-
336
- self.with_ip = with_ip
337
- self.ip_dim = ip_dim
338
- self.ip_weight = ip_weight
339
-
340
- if self.with_ip and self.ip_dim > 0:
341
- self.image_embed = Resampler(
342
- dim=context_dim,
343
- depth=4,
344
- dim_head=64,
345
- heads=12,
346
- num_queries=ip_dim, # num token
347
- embedding_dim=1280,
348
- output_dim=context_dim,
349
- ff_mult=4,
350
- )
351
-
352
- time_embed_dim = model_channels * 4
353
- self.time_embed = nn.Sequential(
354
- nn.Linear(model_channels, time_embed_dim),
355
- nn.SiLU(),
356
- nn.Linear(time_embed_dim, time_embed_dim),
357
- )
358
-
359
- if camera_dim is not None:
360
- time_embed_dim = model_channels * 4
361
- self.camera_embed = nn.Sequential(
362
- nn.Linear(camera_dim, time_embed_dim),
363
- nn.SiLU(),
364
- nn.Linear(time_embed_dim, time_embed_dim),
365
- )
366
-
367
- if self.num_classes is not None:
368
- if isinstance(self.num_classes, int):
369
- self.label_emb = nn.Embedding(self.num_classes, time_embed_dim)
370
- elif self.num_classes == "continuous":
371
- # print("setting up linear c_adm embedding layer")
372
- self.label_emb = nn.Linear(1, time_embed_dim)
373
- elif self.num_classes == "sequential":
374
- assert adm_in_channels is not None
375
- self.label_emb = nn.Sequential(
376
- nn.Sequential(
377
- nn.Linear(adm_in_channels, time_embed_dim),
378
- nn.SiLU(),
379
- nn.Linear(time_embed_dim, time_embed_dim),
380
- )
381
- )
382
- else:
383
- raise ValueError()
384
-
385
- self.input_blocks = nn.ModuleList(
386
- [
387
- CondSequential(
388
- conv_nd(dims, in_channels, model_channels, 3, padding=1)
389
- )
390
- ]
391
- )
392
- self._feature_size = model_channels
393
- input_block_chans = [model_channels]
394
- ch = model_channels
395
- ds = 1
396
- for level, mult in enumerate(channel_mult):
397
- for nr in range(self.num_res_blocks[level]):
398
- layers: List[Any] = [
399
- ResBlock(
400
- ch,
401
- time_embed_dim,
402
- dropout,
403
- out_channels=mult * model_channels,
404
- dims=dims,
405
- use_checkpoint=use_checkpoint,
406
- use_scale_shift_norm=use_scale_shift_norm,
407
- )
408
- ]
409
- ch = mult * model_channels
410
- if ds in attention_resolutions:
411
- if num_head_channels == -1:
412
- dim_head = ch // num_heads
413
- else:
414
- num_heads = ch // num_head_channels
415
- dim_head = num_head_channels
416
-
417
- if num_attention_blocks is None or nr < num_attention_blocks[level]:
418
- layers.append(
419
- SpatialTransformer3D(
420
- ch,
421
- num_heads,
422
- dim_head,
423
- context_dim=context_dim,
424
- depth=transformer_depth,
425
- use_checkpoint=use_checkpoint,
426
- with_ip=self.with_ip,
427
- ip_dim=self.ip_dim,
428
- ip_weight=self.ip_weight,
429
- )
430
- )
431
- self.input_blocks.append(CondSequential(*layers))
432
- self._feature_size += ch
433
- input_block_chans.append(ch)
434
- if level != len(channel_mult) - 1:
435
- out_ch = ch
436
- self.input_blocks.append(
437
- CondSequential(
438
- ResBlock(
439
- ch,
440
- time_embed_dim,
441
- dropout,
442
- out_channels=out_ch,
443
- dims=dims,
444
- use_checkpoint=use_checkpoint,
445
- use_scale_shift_norm=use_scale_shift_norm,
446
- down=True,
447
- )
448
- if resblock_updown
449
- else Downsample(
450
- ch, conv_resample, dims=dims, out_channels=out_ch
451
- )
452
- )
453
- )
454
- ch = out_ch
455
- input_block_chans.append(ch)
456
- ds *= 2
457
- self._feature_size += ch
458
-
459
- if num_head_channels == -1:
460
- dim_head = ch // num_heads
461
- else:
462
- num_heads = ch // num_head_channels
463
- dim_head = num_head_channels
464
-
465
- self.middle_block = CondSequential(
466
- ResBlock(
467
- ch,
468
- time_embed_dim,
469
- dropout,
470
- dims=dims,
471
- use_checkpoint=use_checkpoint,
472
- use_scale_shift_norm=use_scale_shift_norm,
473
- ),
474
- SpatialTransformer3D(
475
- ch,
476
- num_heads,
477
- dim_head,
478
- context_dim=context_dim,
479
- depth=transformer_depth,
480
- use_checkpoint=use_checkpoint,
481
- with_ip=self.with_ip,
482
- ip_dim=self.ip_dim,
483
- ip_weight=self.ip_weight,
484
- ),
485
- ResBlock(
486
- ch,
487
- time_embed_dim,
488
- dropout,
489
- dims=dims,
490
- use_checkpoint=use_checkpoint,
491
- use_scale_shift_norm=use_scale_shift_norm,
492
- ),
493
- )
494
- self._feature_size += ch
495
-
496
- self.output_blocks = nn.ModuleList([])
497
- for level, mult in list(enumerate(channel_mult))[::-1]:
498
- for i in range(self.num_res_blocks[level] + 1):
499
- ich = input_block_chans.pop()
500
- layers = [
501
- ResBlock(
502
- ch + ich,
503
- time_embed_dim,
504
- dropout,
505
- out_channels=model_channels * mult,
506
- dims=dims,
507
- use_checkpoint=use_checkpoint,
508
- use_scale_shift_norm=use_scale_shift_norm,
509
- )
510
- ]
511
- ch = model_channels * mult
512
- if ds in attention_resolutions:
513
- if num_head_channels == -1:
514
- dim_head = ch // num_heads
515
- else:
516
- num_heads = ch // num_head_channels
517
- dim_head = num_head_channels
518
-
519
- if num_attention_blocks is None or i < num_attention_blocks[level]:
520
- layers.append(
521
- SpatialTransformer3D(
522
- ch,
523
- num_heads,
524
- dim_head,
525
- context_dim=context_dim,
526
- depth=transformer_depth,
527
- use_checkpoint=use_checkpoint,
528
- with_ip=self.with_ip,
529
- ip_dim=self.ip_dim,
530
- ip_weight=self.ip_weight,
531
- )
532
- )
533
- if level and i == self.num_res_blocks[level]:
534
- out_ch = ch
535
- layers.append(
536
- ResBlock(
537
- ch,
538
- time_embed_dim,
539
- dropout,
540
- out_channels=out_ch,
541
- dims=dims,
542
- use_checkpoint=use_checkpoint,
543
- use_scale_shift_norm=use_scale_shift_norm,
544
- up=True,
545
- )
546
- if resblock_updown
547
- else Upsample(ch, conv_resample, dims=dims, out_channels=out_ch)
548
- )
549
- ds //= 2
550
- self.output_blocks.append(CondSequential(*layers))
551
- self._feature_size += ch
552
-
553
- self.out = nn.Sequential(
554
- nn.GroupNorm(32, ch),
555
- nn.SiLU(),
556
- zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
557
- )
558
- if self.predict_codebook_ids:
559
- self.id_predictor = nn.Sequential(
560
- nn.GroupNorm(32, ch),
561
- conv_nd(dims, model_channels, n_embed, 1),
562
- # nn.LogSoftmax(dim=1) # change to cross_entropy and produce non-normalized logits
563
- )
564
-
565
- def forward(
566
- self,
567
- x,
568
- timesteps=None,
569
- context=None,
570
- y: Optional[Tensor] = None,
571
- camera=None,
572
- num_frames=1,
573
- # should be provided if with_ip
574
- ip = None,
575
- ip_img = None,
576
- **kwargs,
577
- ):
578
- """
579
- Apply the model to an input batch.
580
- :param x: an [(N x F) x C x ...] Tensor of inputs. F is the number of frames (views).
581
- :param timesteps: a 1-D batch of timesteps.
582
- :param context: conditioning plugged in via crossattn
583
- :param y: an [N] Tensor of labels, if class-conditional.
584
- :param num_frames: a integer indicating number of frames for tensor reshaping.
585
- :return: an [(N x F) x C x ...] Tensor of outputs. F is the number of frames (views).
586
- """
587
- assert (
588
- x.shape[0] % num_frames == 0
589
- ), "[UNet] input batch size must be dividable by num_frames!"
590
- assert (y is not None) == (
591
- self.num_classes is not None
592
- ), "must specify y if and only if the model is class-conditional"
593
- hs = []
594
- t_emb = timestep_embedding(
595
- timesteps, self.model_channels, repeat_only=False
596
- ).to(x.dtype)
597
-
598
- emb = self.time_embed(t_emb)
599
-
600
- if self.num_classes is not None:
601
- assert y is not None
602
- assert y.shape[0] == x.shape[0]
603
- emb = emb + self.label_emb(y)
604
-
605
- # Add camera embeddings
606
- if camera is not None:
607
- assert camera.shape[0] == emb.shape[0]
608
- emb = emb + self.camera_embed(camera)
609
-
610
- if self.with_ip:
611
- x[(num_frames - 1) :: num_frames, :, :, :] = ip_img
612
- ip_emb = self.image_embed(ip)
613
- context = torch.cat((context, ip_emb), 1)
614
-
615
- h = x
616
- for module in self.input_blocks:
617
- h = module(h, emb, context, num_frames=num_frames)
618
- hs.append(h)
619
- h = self.middle_block(h, emb, context, num_frames=num_frames)
620
- for module in self.output_blocks:
621
- h = torch.cat([h, hs.pop()], dim=1)
622
- h = module(h, emb, context, num_frames=num_frames)
623
- h = h.type(x.dtype)
624
- if self.predict_codebook_ids:
625
- return self.id_predictor(h)
626
- else:
627
- return self.out(h)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
imagedream/pipeline_imagedream.py DELETED
@@ -1,620 +0,0 @@
1
- import torch
2
- import inspect
3
- import numpy as np
4
- from typing import Callable, List, Optional, Union
5
- from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPFeatureExtractor
6
- from diffusers import AutoencoderKL, DiffusionPipeline
7
- from diffusers.utils import (
8
- deprecate,
9
- is_accelerate_available,
10
- is_accelerate_version,
11
- logging,
12
- )
13
- from diffusers.configuration_utils import FrozenDict
14
- from diffusers.schedulers import DDIMScheduler
15
- from diffusers.utils.torch_utils import randn_tensor
16
-
17
- from .models import MultiViewUNetModel
18
-
19
- import kiui
20
-
21
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
22
-
23
-
24
- def create_camera_to_world_matrix(elevation, azimuth):
25
- elevation = np.radians(elevation)
26
- azimuth = np.radians(azimuth)
27
- # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
28
- x = np.cos(elevation) * np.sin(azimuth)
29
- y = np.sin(elevation)
30
- z = np.cos(elevation) * np.cos(azimuth)
31
-
32
- # Calculate camera position, target, and up vectors
33
- camera_pos = np.array([x, y, z])
34
- target = np.array([0, 0, 0])
35
- up = np.array([0, 1, 0])
36
-
37
- # Construct view matrix
38
- forward = target - camera_pos
39
- forward /= np.linalg.norm(forward)
40
- right = np.cross(forward, up)
41
- right /= np.linalg.norm(right)
42
- new_up = np.cross(right, forward)
43
- new_up /= np.linalg.norm(new_up)
44
- cam2world = np.eye(4)
45
- cam2world[:3, :3] = np.array([right, new_up, -forward]).T
46
- cam2world[:3, 3] = camera_pos
47
- return cam2world
48
-
49
-
50
- def convert_opengl_to_blender(camera_matrix):
51
- if isinstance(camera_matrix, np.ndarray):
52
- # Construct transformation matrix to convert from OpenGL space to Blender space
53
- flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
54
- camera_matrix_blender = np.dot(flip_yz, camera_matrix)
55
- else:
56
- # Construct transformation matrix to convert from OpenGL space to Blender space
57
- flip_yz = torch.tensor(
58
- [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
59
- )
60
- if camera_matrix.ndim == 3:
61
- flip_yz = flip_yz.unsqueeze(0)
62
- camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
63
- return camera_matrix_blender
64
-
65
-
66
- def get_camera(
67
- num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
68
- ):
69
- angle_gap = azimuth_span / num_frames
70
- cameras = []
71
- for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
72
- camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
73
- if blender_coord:
74
- camera_matrix = convert_opengl_to_blender(camera_matrix)
75
- cameras.append(camera_matrix.flatten())
76
- if extra_view:
77
- dim = len(cameras[0])
78
- cameras.append(np.zeros(dim))
79
- return torch.tensor(np.stack(cameras, 0)).float()
80
-
81
-
82
- class ImageDreamPipeline(DiffusionPipeline):
83
- def __init__(
84
- self,
85
- vae: AutoencoderKL,
86
- unet: MultiViewUNetModel,
87
- tokenizer: CLIPTokenizer,
88
- text_encoder: CLIPTextModel,
89
- scheduler: DDIMScheduler,
90
- feature_extractor: CLIPFeatureExtractor = None,
91
- image_encoder: CLIPVisionModel = None,
92
- requires_safety_checker: bool = False,
93
- ):
94
- super().__init__()
95
-
96
- if hasattr(scheduler.config, "steps_offset") and scheduler.config.steps_offset != 1: # type: ignore
97
- deprecation_message = (
98
- f"The configuration file of this scheduler: {scheduler} is outdated. `steps_offset`"
99
- f" should be set to 1 instead of {scheduler.config.steps_offset}. Please make sure " # type: ignore
100
- "to update the config accordingly as leaving `steps_offset` might led to incorrect results"
101
- " in future versions. If you have downloaded this checkpoint from the Hugging Face Hub,"
102
- " it would be very nice if you could open a Pull request for the `scheduler/scheduler_config.json`"
103
- " file"
104
- )
105
- deprecate(
106
- "steps_offset!=1", "1.0.0", deprecation_message, standard_warn=False
107
- )
108
- new_config = dict(scheduler.config)
109
- new_config["steps_offset"] = 1
110
- scheduler._internal_dict = FrozenDict(new_config)
111
-
112
- if hasattr(scheduler.config, "clip_sample") and scheduler.config.clip_sample is True: # type: ignore
113
- deprecation_message = (
114
- f"The configuration file of this scheduler: {scheduler} has not set the configuration `clip_sample`."
115
- " `clip_sample` should be set to False in the configuration file. Please make sure to update the"
116
- " config accordingly as not setting `clip_sample` in the config might lead to incorrect results in"
117
- " future versions. If you have downloaded this checkpoint from the Hugging Face Hub, it would be very"
118
- " nice if you could open a Pull request for the `scheduler/scheduler_config.json` file"
119
- )
120
- deprecate(
121
- "clip_sample not set", "1.0.0", deprecation_message, standard_warn=False
122
- )
123
- new_config = dict(scheduler.config)
124
- new_config["clip_sample"] = False
125
- scheduler._internal_dict = FrozenDict(new_config)
126
-
127
- self.register_modules(
128
- vae=vae,
129
- unet=unet,
130
- scheduler=scheduler,
131
- tokenizer=tokenizer,
132
- text_encoder=text_encoder,
133
- feature_extractor=feature_extractor,
134
- image_encoder=image_encoder,
135
- )
136
- self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
137
- self.register_to_config(requires_safety_checker=requires_safety_checker)
138
-
139
- def enable_vae_slicing(self):
140
- r"""
141
- Enable sliced VAE decoding.
142
-
143
- When this option is enabled, the VAE will split the input tensor in slices to compute decoding in several
144
- steps. This is useful to save some memory and allow larger batch sizes.
145
- """
146
- self.vae.enable_slicing()
147
-
148
- def disable_vae_slicing(self):
149
- r"""
150
- Disable sliced VAE decoding. If `enable_vae_slicing` was previously invoked, this method will go back to
151
- computing decoding in one step.
152
- """
153
- self.vae.disable_slicing()
154
-
155
- def enable_vae_tiling(self):
156
- r"""
157
- Enable tiled VAE decoding.
158
-
159
- When this option is enabled, the VAE will split the input tensor into tiles to compute decoding and encoding in
160
- several steps. This is useful to save a large amount of memory and to allow the processing of larger images.
161
- """
162
- self.vae.enable_tiling()
163
-
164
- def disable_vae_tiling(self):
165
- r"""
166
- Disable tiled VAE decoding. If `enable_vae_tiling` was previously invoked, this method will go back to
167
- computing decoding in one step.
168
- """
169
- self.vae.disable_tiling()
170
-
171
- def enable_sequential_cpu_offload(self, gpu_id=0):
172
- r"""
173
- Offloads all models to CPU using accelerate, significantly reducing memory usage. When called, unet,
174
- text_encoder, vae and safety checker have their state dicts saved to CPU and then are moved to a
175
- `torch.device('meta') and loaded to GPU only when their specific submodule has its `forward` method called.
176
- Note that offloading happens on a submodule basis. Memory savings are higher than with
177
- `enable_model_cpu_offload`, but performance is lower.
178
- """
179
- if is_accelerate_available() and is_accelerate_version(">=", "0.14.0"):
180
- from accelerate import cpu_offload
181
- else:
182
- raise ImportError(
183
- "`enable_sequential_cpu_offload` requires `accelerate v0.14.0` or higher"
184
- )
185
-
186
- device = torch.device(f"cuda:{gpu_id}")
187
-
188
- if self.device.type != "cpu":
189
- self.to("cpu", silence_dtype_warnings=True)
190
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
191
-
192
- for cpu_offloaded_model in [self.unet, self.text_encoder, self.vae]:
193
- cpu_offload(cpu_offloaded_model, device)
194
-
195
- def enable_model_cpu_offload(self, gpu_id=0):
196
- r"""
197
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
198
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
199
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
200
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
201
- """
202
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
203
- from accelerate import cpu_offload_with_hook
204
- else:
205
- raise ImportError(
206
- "`enable_model_offload` requires `accelerate v0.17.0` or higher."
207
- )
208
-
209
- device = torch.device(f"cuda:{gpu_id}")
210
-
211
- if self.device.type != "cpu":
212
- self.to("cpu", silence_dtype_warnings=True)
213
- torch.cuda.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
214
-
215
- hook = None
216
- for cpu_offloaded_model in [self.text_encoder, self.unet, self.vae]:
217
- _, hook = cpu_offload_with_hook(
218
- cpu_offloaded_model, device, prev_module_hook=hook
219
- )
220
-
221
- # We'll offload the last model manually.
222
- self.final_offload_hook = hook
223
-
224
- @property
225
- def _execution_device(self):
226
- r"""
227
- Returns the device on which the pipeline's models will be executed. After calling
228
- `pipeline.enable_sequential_cpu_offload()` the execution device can only be inferred from Accelerate's module
229
- hooks.
230
- """
231
- if not hasattr(self.unet, "_hf_hook"):
232
- return self.device
233
- for module in self.unet.modules():
234
- if (
235
- hasattr(module, "_hf_hook")
236
- and hasattr(module._hf_hook, "execution_device")
237
- and module._hf_hook.execution_device is not None
238
- ):
239
- return torch.device(module._hf_hook.execution_device)
240
- return self.device
241
-
242
- def _encode_prompt(
243
- self,
244
- prompt,
245
- device,
246
- num_images_per_prompt,
247
- do_classifier_free_guidance: bool,
248
- negative_prompt=None,
249
- ):
250
- r"""
251
- Encodes the prompt into text encoder hidden states.
252
-
253
- Args:
254
- prompt (`str` or `List[str]`, *optional*):
255
- prompt to be encoded
256
- device: (`torch.device`):
257
- torch device
258
- num_images_per_prompt (`int`):
259
- number of images that should be generated per prompt
260
- do_classifier_free_guidance (`bool`):
261
- whether to use classifier free guidance or not
262
- negative_prompt (`str` or `List[str]`, *optional*):
263
- The prompt or prompts not to guide the image generation. If not defined, one has to pass
264
- `negative_prompt_embeds`. instead. If not defined, one has to pass `negative_prompt_embeds`. instead.
265
- Ignored when not using guidance (i.e., ignored if `guidance_scale` is less than `1`).
266
- prompt_embeds (`torch.FloatTensor`, *optional*):
267
- Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
268
- provided, text embeddings will be generated from `prompt` input argument.
269
- negative_prompt_embeds (`torch.FloatTensor`, *optional*):
270
- Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
271
- weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
272
- argument.
273
- """
274
- if prompt is not None and isinstance(prompt, str):
275
- batch_size = 1
276
- elif prompt is not None and isinstance(prompt, list):
277
- batch_size = len(prompt)
278
- else:
279
- raise ValueError(
280
- f"`prompt` should be either a string or a list of strings, but got {type(prompt)}."
281
- )
282
-
283
- text_inputs = self.tokenizer(
284
- prompt,
285
- padding="max_length",
286
- max_length=self.tokenizer.model_max_length,
287
- truncation=True,
288
- return_tensors="pt",
289
- )
290
- text_input_ids = text_inputs.input_ids
291
- untruncated_ids = self.tokenizer(
292
- prompt, padding="longest", return_tensors="pt"
293
- ).input_ids
294
-
295
- if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(
296
- text_input_ids, untruncated_ids
297
- ):
298
- removed_text = self.tokenizer.batch_decode(
299
- untruncated_ids[:, self.tokenizer.model_max_length - 1 : -1]
300
- )
301
- logger.warning(
302
- "The following part of your input was truncated because CLIP can only handle sequences up to"
303
- f" {self.tokenizer.model_max_length} tokens: {removed_text}"
304
- )
305
-
306
- if (
307
- hasattr(self.text_encoder.config, "use_attention_mask")
308
- and self.text_encoder.config.use_attention_mask
309
- ):
310
- attention_mask = text_inputs.attention_mask.to(device)
311
- else:
312
- attention_mask = None
313
-
314
- prompt_embeds = self.text_encoder(
315
- text_input_ids.to(device),
316
- attention_mask=attention_mask,
317
- )
318
- prompt_embeds = prompt_embeds[0]
319
-
320
- prompt_embeds = prompt_embeds.to(dtype=self.text_encoder.dtype, device=device)
321
-
322
- bs_embed, seq_len, _ = prompt_embeds.shape
323
- # duplicate text embeddings for each generation per prompt, using mps friendly method
324
- prompt_embeds = prompt_embeds.repeat(1, num_images_per_prompt, 1)
325
- prompt_embeds = prompt_embeds.view(
326
- bs_embed * num_images_per_prompt, seq_len, -1
327
- )
328
-
329
- # get unconditional embeddings for classifier free guidance
330
- if do_classifier_free_guidance:
331
- uncond_tokens: List[str]
332
- if negative_prompt is None:
333
- uncond_tokens = [""] * batch_size
334
- elif type(prompt) is not type(negative_prompt):
335
- raise TypeError(
336
- f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !="
337
- f" {type(prompt)}."
338
- )
339
- elif isinstance(negative_prompt, str):
340
- uncond_tokens = [negative_prompt]
341
- elif batch_size != len(negative_prompt):
342
- raise ValueError(
343
- f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:"
344
- f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches"
345
- " the batch size of `prompt`."
346
- )
347
- else:
348
- uncond_tokens = negative_prompt
349
-
350
- max_length = prompt_embeds.shape[1]
351
- uncond_input = self.tokenizer(
352
- uncond_tokens,
353
- padding="max_length",
354
- max_length=max_length,
355
- truncation=True,
356
- return_tensors="pt",
357
- )
358
-
359
- if (
360
- hasattr(self.text_encoder.config, "use_attention_mask")
361
- and self.text_encoder.config.use_attention_mask
362
- ):
363
- attention_mask = uncond_input.attention_mask.to(device)
364
- else:
365
- attention_mask = None
366
-
367
- negative_prompt_embeds = self.text_encoder(
368
- uncond_input.input_ids.to(device),
369
- attention_mask=attention_mask,
370
- )
371
- negative_prompt_embeds = negative_prompt_embeds[0]
372
-
373
- # duplicate unconditional embeddings for each generation per prompt, using mps friendly method
374
- seq_len = negative_prompt_embeds.shape[1]
375
-
376
- negative_prompt_embeds = negative_prompt_embeds.to(
377
- dtype=self.text_encoder.dtype, device=device
378
- )
379
-
380
- negative_prompt_embeds = negative_prompt_embeds.repeat(
381
- 1, num_images_per_prompt, 1
382
- )
383
- negative_prompt_embeds = negative_prompt_embeds.view(
384
- batch_size * num_images_per_prompt, seq_len, -1
385
- )
386
-
387
- # For classifier free guidance, we need to do two forward passes.
388
- # Here we concatenate the unconditional and text embeddings into a single batch
389
- # to avoid doing two forward passes
390
- prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
391
-
392
- return prompt_embeds
393
-
394
- def decode_latents(self, latents):
395
- latents = 1 / self.vae.config.scaling_factor * latents
396
- image = self.vae.decode(latents).sample
397
- image = (image / 2 + 0.5).clamp(0, 1)
398
- # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
399
- image = image.cpu().permute(0, 2, 3, 1).float().numpy()
400
- return image
401
-
402
- def prepare_extra_step_kwargs(self, generator, eta):
403
- # prepare extra kwargs for the scheduler step, since not all schedulers have the same signature
404
- # eta (η) is only used with the DDIMScheduler, it will be ignored for other schedulers.
405
- # eta corresponds to η in DDIM paper: https://arxiv.org/abs/2010.02502
406
- # and should be between [0, 1]
407
-
408
- accepts_eta = "eta" in set(
409
- inspect.signature(self.scheduler.step).parameters.keys()
410
- )
411
- extra_step_kwargs = {}
412
- if accepts_eta:
413
- extra_step_kwargs["eta"] = eta
414
-
415
- # check if the scheduler accepts generator
416
- accepts_generator = "generator" in set(
417
- inspect.signature(self.scheduler.step).parameters.keys()
418
- )
419
- if accepts_generator:
420
- extra_step_kwargs["generator"] = generator
421
- return extra_step_kwargs
422
-
423
- def prepare_latents(
424
- self,
425
- batch_size,
426
- num_channels_latents,
427
- height,
428
- width,
429
- dtype,
430
- device,
431
- generator,
432
- latents=None,
433
- ):
434
- shape = (
435
- batch_size,
436
- num_channels_latents,
437
- height // self.vae_scale_factor,
438
- width // self.vae_scale_factor,
439
- )
440
- if isinstance(generator, list) and len(generator) != batch_size:
441
- raise ValueError(
442
- f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
443
- f" size of {batch_size}. Make sure the batch size matches the length of the generators."
444
- )
445
-
446
- if latents is None:
447
- latents = randn_tensor(
448
- shape, generator=generator, device=device, dtype=dtype
449
- )
450
- else:
451
- latents = latents.to(device)
452
-
453
- # scale the initial noise by the standard deviation required by the scheduler
454
- latents = latents * self.scheduler.init_noise_sigma
455
- return latents
456
-
457
- def encode_image(self, image, device, num_images_per_prompt):
458
- dtype = next(self.image_encoder.parameters()).dtype
459
-
460
- image = (image * 255).astype(np.uint8)
461
- image = self.feature_extractor(image, return_tensors="pt").pixel_values
462
-
463
- image = image.to(device=device, dtype=dtype)
464
-
465
- image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
466
- image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
467
-
468
- # imagedream directly use zero as uncond image embeddings
469
- uncond_image_enc_hidden_states = torch.zeros_like(image_enc_hidden_states)
470
-
471
- return uncond_image_enc_hidden_states, image_enc_hidden_states
472
-
473
- def encode_image_latents(self, image, device, num_images_per_prompt):
474
-
475
- image = torch.from_numpy(image).to(device)
476
- posterior = self.vae.encode(image).latent_dist
477
-
478
- latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
479
- latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
480
-
481
- return torch.zeros_like(latents), latents
482
-
483
- @torch.no_grad()
484
- def __call__(
485
- self,
486
- image, # input image, np.ndarray float32!
487
- prompt: str = "a car",
488
- height: int = 256,
489
- width: int = 256,
490
- num_inference_steps: int = 50,
491
- guidance_scale: float = 7.0,
492
- negative_prompt: str = "bad quality",
493
- num_images_per_prompt: int = 1,
494
- eta: float = 0.0,
495
- generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
496
- output_type: Optional[str] = "image",
497
- callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
498
- callback_steps: int = 1,
499
- num_frames: int = 4,
500
- device=torch.device("cuda:0"),
501
- ):
502
- self.unet = self.unet.to(device=device)
503
- self.vae = self.vae.to(device=device)
504
-
505
- self.text_encoder = self.text_encoder.to(device=device)
506
-
507
- # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
508
- # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
509
- # corresponds to doing no classifier free guidance.
510
- do_classifier_free_guidance = guidance_scale > 1.0
511
-
512
- # Prepare timesteps
513
- self.scheduler.set_timesteps(num_inference_steps, device=device)
514
- timesteps = self.scheduler.timesteps
515
-
516
- # encode image
517
- assert isinstance(image, np.ndarray) and image.dtype == np.float32
518
-
519
- self.image_encoder = self.image_encoder.to(device=device)
520
- image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
521
- kiui.lo(image_embeds_pos) # should be [1, 257, 1280]?
522
-
523
- image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
524
- kiui.lo(image_latents_pos)
525
-
526
- # encode text
527
- _prompt_embeds = self._encode_prompt(
528
- prompt=prompt,
529
- device=device,
530
- num_images_per_prompt=num_images_per_prompt,
531
- do_classifier_free_guidance=do_classifier_free_guidance,
532
- negative_prompt=negative_prompt,
533
- ) # type: ignore
534
- prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
535
-
536
- # Prepare latent variables
537
- latents: torch.Tensor = self.prepare_latents(
538
- (num_frames + 1) * num_images_per_prompt,
539
- 4, # channel
540
- height,
541
- width,
542
- prompt_embeds_pos.dtype,
543
- device,
544
- generator,
545
- None,
546
- )
547
-
548
- camera = get_camera(num_frames, extra_view=True).to(dtype=latents.dtype, device=device)
549
- camera = camera.repeat(num_images_per_prompt, 1).to(self.device)
550
-
551
- # Prepare extra step kwargs.
552
- extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
553
-
554
- # Denoising loop
555
- num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
556
- with self.progress_bar(total=num_inference_steps) as progress_bar:
557
- for i, t in enumerate(timesteps):
558
- # expand the latents if we are doing classifier free guidance
559
- multiplier = 2 if do_classifier_free_guidance else 1
560
- latent_model_input = torch.cat([latents] * multiplier)
561
- latent_model_input = self.scheduler.scale_model_input(
562
- latent_model_input, t
563
- )
564
-
565
- # predict the noise residual
566
- noise_pred = self.unet.forward(
567
- x=latent_model_input,
568
- timesteps=torch.tensor(
569
- [t] * (num_frames + 1) * multiplier,
570
- dtype=latent_model_input.dtype,
571
- device=device,
572
- ),
573
- context=torch.cat(
574
- [prompt_embeds_neg] * (num_frames + 1) + [prompt_embeds_pos] * (num_frames + 1)
575
- ),
576
- num_frames=num_frames + 1,
577
- camera=torch.cat([camera] * multiplier),
578
- # for with_ip
579
- ip=torch.cat(
580
- [image_embeds_neg] * (num_frames + 1) + [image_embeds_pos] * (num_frames + 1)
581
- ),
582
- ip_img=torch.cat(
583
- [image_latents_neg] * (num_frames + 1) + [image_latents_pos] * (num_frames + 1)
584
- ),
585
- )
586
-
587
- # perform guidance
588
- if do_classifier_free_guidance:
589
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
590
- noise_pred = noise_pred_uncond + guidance_scale * (
591
- noise_pred_text - noise_pred_uncond
592
- )
593
-
594
- # compute the previous noisy sample x_t -> x_t-1
595
- latents: torch.Tensor = self.scheduler.step(
596
- noise_pred, t, latents, **extra_step_kwargs, return_dict=False
597
- )[0]
598
-
599
- # call the callback, if provided
600
- if i == len(timesteps) - 1 or (
601
- (i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0
602
- ):
603
- progress_bar.update()
604
- if callback is not None and i % callback_steps == 0:
605
- callback(i, t, latents) # type: ignore
606
-
607
- # Post-processing
608
- if output_type == "latent":
609
- image = latents
610
- elif output_type == "pil":
611
- image = self.decode_latents(latents)
612
- image = self.numpy_to_pil(image)
613
- else:
614
- image = self.decode_latents(latents)
615
-
616
- # Offload last model to CPU
617
- if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
618
- self.final_offload_hook.offload()
619
-
620
- return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
imagedream/util.py DELETED
@@ -1,116 +0,0 @@
1
- import math
2
- import torch
3
- import torch.nn as nn
4
- from einops import repeat
5
-
6
-
7
- def checkpoint(func, inputs, params, flag):
8
- """
9
- Evaluate a function without caching intermediate activations, allowing for
10
- reduced memory at the expense of extra compute in the backward pass.
11
- :param func: the function to evaluate.
12
- :param inputs: the argument sequence to pass to `func`.
13
- :param params: a sequence of parameters `func` depends on but does not
14
- explicitly take as arguments.
15
- :param flag: if False, disable gradient checkpointing.
16
- """
17
- if flag:
18
- args = tuple(inputs) + tuple(params)
19
- return CheckpointFunction.apply(func, len(inputs), *args)
20
- else:
21
- return func(*inputs)
22
-
23
-
24
- class CheckpointFunction(torch.autograd.Function):
25
- @staticmethod
26
- def forward(ctx, run_function, length, *args):
27
- ctx.run_function = run_function
28
- ctx.input_tensors = list(args[:length])
29
- ctx.input_params = list(args[length:])
30
-
31
- with torch.no_grad():
32
- output_tensors = ctx.run_function(*ctx.input_tensors)
33
- return output_tensors
34
-
35
- @staticmethod
36
- def backward(ctx, *output_grads):
37
- ctx.input_tensors = [x.detach().requires_grad_(True) for x in ctx.input_tensors]
38
- with torch.enable_grad():
39
- # Fixes a bug where the first op in run_function modifies the
40
- # Tensor storage in place, which is not allowed for detach()'d
41
- # Tensors.
42
- shallow_copies = [x.view_as(x) for x in ctx.input_tensors]
43
- output_tensors = ctx.run_function(*shallow_copies)
44
- input_grads = torch.autograd.grad(
45
- output_tensors,
46
- ctx.input_tensors + ctx.input_params,
47
- output_grads,
48
- allow_unused=True,
49
- )
50
- del ctx.input_tensors
51
- del ctx.input_params
52
- del output_tensors
53
- return (None, None) + input_grads
54
-
55
-
56
- def timestep_embedding(timesteps, dim, max_period=10000, repeat_only=False):
57
- """
58
- Create sinusoidal timestep embeddings.
59
- :param timesteps: a 1-D Tensor of N indices, one per batch element.
60
- These may be fractional.
61
- :param dim: the dimension of the output.
62
- :param max_period: controls the minimum frequency of the embeddings.
63
- :return: an [N x dim] Tensor of positional embeddings.
64
- """
65
- if not repeat_only:
66
- half = dim // 2
67
- freqs = torch.exp(
68
- -math.log(max_period)
69
- * torch.arange(start=0, end=half, dtype=torch.float32)
70
- / half
71
- ).to(device=timesteps.device)
72
- args = timesteps[:, None] * freqs[None]
73
- embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
74
- if dim % 2:
75
- embedding = torch.cat(
76
- [embedding, torch.zeros_like(embedding[:, :1])], dim=-1
77
- )
78
- else:
79
- embedding = repeat(timesteps, "b -> b d", d=dim)
80
- # import pdb; pdb.set_trace()
81
- return embedding
82
-
83
-
84
- def zero_module(module):
85
- """
86
- Zero out the parameters of a module and return it.
87
- """
88
- for p in module.parameters():
89
- p.detach().zero_()
90
- return module
91
-
92
-
93
- def conv_nd(dims, *args, **kwargs):
94
- """
95
- Create a 1D, 2D, or 3D convolution module.
96
- """
97
- if dims == 1:
98
- return nn.Conv1d(*args, **kwargs)
99
- elif dims == 2:
100
- return nn.Conv2d(*args, **kwargs)
101
- elif dims == 3:
102
- return nn.Conv3d(*args, **kwargs)
103
- raise ValueError(f"unsupported dimensions: {dims}")
104
-
105
-
106
- def avg_pool_nd(dims, *args, **kwargs):
107
- """
108
- Create a 1D, 2D, or 3D average pooling module.
109
- """
110
- if dims == 1:
111
- return nn.AvgPool1d(*args, **kwargs)
112
- elif dims == 2:
113
- return nn.AvgPool2d(*args, **kwargs)
114
- elif dims == 3:
115
- return nn.AvgPool3d(*args, **kwargs)
116
- raise ValueError(f"unsupported dimensions: {dims}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
{imagedream → mvdream}/adaptor.py RENAMED
File without changes
mvdream/attention.py CHANGED
@@ -1,26 +1,16 @@
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
4
- from torch.amp.autocast_mode import autocast
5
 
6
  from inspect import isfunction
7
  from einops import rearrange, repeat
8
  from typing import Optional, Any
9
- from .util import checkpoint, zero_module
10
-
11
- try:
12
- import xformers # type: ignore
13
- import xformers.ops # type: ignore
14
- XFORMERS_IS_AVAILBLE = True
15
- except:
16
- print(f'[WARN] xformers is unavailable!')
17
- XFORMERS_IS_AVAILBLE = False
18
 
19
- # CrossAttn precision handling
20
- import os
21
-
22
- _ATTN_PRECISION = os.environ.get("ATTN_PRECISION", "fp32")
23
 
 
24
 
25
  def default(val, d):
26
  if val is not None:
@@ -57,68 +47,33 @@ class FeedForward(nn.Module):
57
  return self.net(x)
58
 
59
 
60
- class CrossAttention(nn.Module):
61
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
62
- super().__init__()
63
- inner_dim = dim_head * heads
64
- context_dim = default(context_dim, query_dim)
65
-
66
- self.scale = dim_head**-0.5
67
- self.heads = heads
68
-
69
- self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
70
- self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
71
- self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
72
-
73
- self.to_out = nn.Sequential(
74
- nn.Linear(inner_dim, query_dim), nn.Dropout(dropout)
75
- )
76
-
77
- def forward(self, x, context=None, mask=None):
78
- h = self.heads
79
-
80
- q = self.to_q(x)
81
- context = default(context, x)
82
- k = self.to_k(context)
83
- v = self.to_v(context)
84
-
85
- q, k, v = map(lambda t: rearrange(t, "b n (h d) -> (b h) n d", h=h), (q, k, v))
86
-
87
- # force cast to fp32 to avoid overflowing
88
- if _ATTN_PRECISION == "fp32":
89
- with autocast(enabled=False, device_type="cuda"):
90
- q, k = q.float(), k.float()
91
- sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
92
- else:
93
- sim = torch.einsum("b i d, b j d -> b i j", q, k) * self.scale
94
-
95
- del q, k
96
-
97
- if mask is not None:
98
- mask = rearrange(mask, "b ... -> b (...)")
99
- max_neg_value = -torch.finfo(sim.dtype).max
100
- mask = repeat(mask, "b j -> (b h) () j", h=h)
101
- sim.masked_fill_(~mask, max_neg_value)
102
-
103
- # attention, what we cannot get enough of
104
- sim = sim.softmax(dim=-1)
105
-
106
- out = torch.einsum("b i j, b j d -> b i d", sim, v)
107
- out = rearrange(out, "(b h) n d -> b n (h d)", h=h)
108
- return self.to_out(out)
109
-
110
-
111
  class MemoryEfficientCrossAttention(nn.Module):
112
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
113
- def __init__(self, query_dim, context_dim=None, heads=8, dim_head=64, dropout=0.0):
 
 
 
 
 
 
 
 
 
114
  super().__init__()
115
- # print(f"Setting up {self.__class__.__name__}. Query dim is {query_dim}, context_dim is {context_dim} and using {heads} heads.")
116
  inner_dim = dim_head * heads
117
  context_dim = default(context_dim, query_dim)
118
 
119
  self.heads = heads
120
  self.dim_head = dim_head
121
 
 
 
 
 
 
 
 
122
  self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
123
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
124
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
@@ -128,9 +83,18 @@ class MemoryEfficientCrossAttention(nn.Module):
128
  )
129
  self.attention_op: Optional[Any] = None
130
 
131
- def forward(self, x, context=None, mask=None):
132
  q = self.to_q(x)
133
  context = default(context, x)
 
 
 
 
 
 
 
 
 
134
  k = self.to_k(context)
135
  v = self.to_v(context)
136
 
@@ -149,8 +113,21 @@ class MemoryEfficientCrossAttention(nn.Module):
149
  q, k, v, attn_bias=None, op=self.attention_op
150
  )
151
 
152
- if mask is not None:
153
- raise NotImplementedError
 
 
 
 
 
 
 
 
 
 
 
 
 
154
  out = (
155
  out.unsqueeze(0)
156
  .reshape(b, self.heads, out.shape[1], self.dim_head)
@@ -160,148 +137,45 @@ class MemoryEfficientCrossAttention(nn.Module):
160
  return self.to_out(out)
161
 
162
 
163
- class BasicTransformerBlock(nn.Module):
164
- ATTENTION_MODES = {
165
- "softmax": CrossAttention,
166
- "softmax-xformers": MemoryEfficientCrossAttention,
167
- } # vanilla attention
168
-
169
  def __init__(
170
  self,
171
  dim,
172
  n_heads,
173
  d_head,
 
174
  dropout=0.0,
175
- context_dim=None,
176
  gated_ff=True,
177
  checkpoint=True,
178
- disable_self_attn=False,
 
179
  ):
180
  super().__init__()
181
- attn_mode = "softmax-xformers" if XFORMERS_IS_AVAILBLE else "softmax"
182
- assert attn_mode in self.ATTENTION_MODES
183
- attn_cls = self.ATTENTION_MODES[attn_mode]
184
- self.disable_self_attn = disable_self_attn
185
- self.attn1 = attn_cls(
186
  query_dim=dim,
 
187
  heads=n_heads,
188
  dim_head=d_head,
189
  dropout=dropout,
190
- context_dim=context_dim if self.disable_self_attn else None,
191
- ) # is a self-attention if not self.disable_self_attn
192
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
193
- self.attn2 = attn_cls(
194
  query_dim=dim,
195
  context_dim=context_dim,
196
  heads=n_heads,
197
  dim_head=d_head,
198
  dropout=dropout,
199
- ) # is self-attn if context is none
 
 
 
200
  self.norm1 = nn.LayerNorm(dim)
201
  self.norm2 = nn.LayerNorm(dim)
202
  self.norm3 = nn.LayerNorm(dim)
203
  self.checkpoint = checkpoint
204
 
205
- def forward(self, x, context=None):
206
- return checkpoint(
207
- self._forward, (x, context), self.parameters(), self.checkpoint
208
- )
209
-
210
- def _forward(self, x, context=None):
211
- x = (
212
- self.attn1(
213
- self.norm1(x), context=context if self.disable_self_attn else None
214
- )
215
- + x
216
- )
217
- x = self.attn2(self.norm2(x), context=context) + x
218
- x = self.ff(self.norm3(x)) + x
219
- return x
220
-
221
-
222
- class SpatialTransformer(nn.Module):
223
- """
224
- Transformer block for image-like data.
225
- First, project the input (aka embedding)
226
- and reshape to b, t, d.
227
- Then apply standard transformer action.
228
- Finally, reshape to image
229
- NEW: use_linear for more efficiency instead of the 1x1 convs
230
- """
231
-
232
- def __init__(
233
- self,
234
- in_channels,
235
- n_heads,
236
- d_head,
237
- depth=1,
238
- dropout=0.0,
239
- context_dim=None,
240
- disable_self_attn=False,
241
- use_linear=False,
242
- use_checkpoint=True,
243
- ):
244
- super().__init__()
245
- assert context_dim is not None
246
- if not isinstance(context_dim, list):
247
- context_dim = [context_dim]
248
- self.in_channels = in_channels
249
- inner_dim = n_heads * d_head
250
- self.norm = nn.GroupNorm(
251
- num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
252
- )
253
- if not use_linear:
254
- self.proj_in = nn.Conv2d(
255
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
256
- )
257
- else:
258
- self.proj_in = nn.Linear(in_channels, inner_dim)
259
-
260
- self.transformer_blocks = nn.ModuleList(
261
- [
262
- BasicTransformerBlock(
263
- inner_dim,
264
- n_heads,
265
- d_head,
266
- dropout=dropout,
267
- context_dim=context_dim[d],
268
- disable_self_attn=disable_self_attn,
269
- checkpoint=use_checkpoint,
270
- )
271
- for d in range(depth)
272
- ]
273
- )
274
- if not use_linear:
275
- self.proj_out = zero_module(
276
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
277
- )
278
- else:
279
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
280
- self.use_linear = use_linear
281
-
282
- def forward(self, x, context=None):
283
- # note: if no context is given, cross-attention defaults to self-attention
284
- if not isinstance(context, list):
285
- context = [context]
286
- b, c, h, w = x.shape
287
- x_in = x
288
- x = self.norm(x)
289
- if not self.use_linear:
290
- x = self.proj_in(x)
291
- x = rearrange(x, "b c h w -> b (h w) c").contiguous()
292
- if self.use_linear:
293
- x = self.proj_in(x)
294
- for i, block in enumerate(self.transformer_blocks):
295
- x = block(x, context=context[i])
296
- if self.use_linear:
297
- x = self.proj_out(x)
298
- x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
299
- if not self.use_linear:
300
- x = self.proj_out(x)
301
- return x + x_in
302
-
303
-
304
- class BasicTransformerBlock3D(BasicTransformerBlock):
305
  def forward(self, x, context=None, num_frames=1):
306
  return checkpoint(
307
  self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
@@ -309,12 +183,7 @@ class BasicTransformerBlock3D(BasicTransformerBlock):
309
 
310
  def _forward(self, x, context=None, num_frames=1):
311
  x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
312
- x = (
313
- self.attn1(
314
- self.norm1(x), context=context if self.disable_self_attn else None
315
- )
316
- + x
317
- )
318
  x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
319
  x = self.attn2(self.norm2(x), context=context) + x
320
  x = self.ff(self.norm3(x)) + x
@@ -322,35 +191,31 @@ class BasicTransformerBlock3D(BasicTransformerBlock):
322
 
323
 
324
  class SpatialTransformer3D(nn.Module):
325
- """3D self-attention"""
326
 
327
  def __init__(
328
  self,
329
  in_channels,
330
  n_heads,
331
  d_head,
 
332
  depth=1,
333
  dropout=0.0,
334
- context_dim=None,
335
- disable_self_attn=False,
336
- use_linear=True,
337
  use_checkpoint=True,
338
  ):
339
  super().__init__()
340
- assert context_dim is not None
341
  if not isinstance(context_dim, list):
342
  context_dim = [context_dim]
 
343
  self.in_channels = in_channels
 
344
  inner_dim = n_heads * d_head
345
  self.norm = nn.GroupNorm(
346
  num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
347
  )
348
- if not use_linear:
349
- self.proj_in = nn.Conv2d(
350
- in_channels, inner_dim, kernel_size=1, stride=1, padding=0
351
- )
352
- else:
353
- self.proj_in = nn.Linear(in_channels, inner_dim)
354
 
355
  self.transformer_blocks = nn.ModuleList(
356
  [
@@ -358,21 +223,18 @@ class SpatialTransformer3D(nn.Module):
358
  inner_dim,
359
  n_heads,
360
  d_head,
361
- dropout=dropout,
362
  context_dim=context_dim[d],
363
- disable_self_attn=disable_self_attn,
364
  checkpoint=use_checkpoint,
 
 
365
  )
366
  for d in range(depth)
367
  ]
368
  )
369
- if not use_linear:
370
- self.proj_out = zero_module(
371
- nn.Conv2d(inner_dim, in_channels, kernel_size=1, stride=1, padding=0)
372
- )
373
- else:
374
- self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
375
- self.use_linear = use_linear
376
 
377
  def forward(self, x, context=None, num_frames=1):
378
  # note: if no context is given, cross-attention defaults to self-attention
@@ -381,16 +243,11 @@ class SpatialTransformer3D(nn.Module):
381
  b, c, h, w = x.shape
382
  x_in = x
383
  x = self.norm(x)
384
- if not self.use_linear:
385
- x = self.proj_in(x)
386
  x = rearrange(x, "b c h w -> b (h w) c").contiguous()
387
- if self.use_linear:
388
- x = self.proj_in(x)
389
  for i, block in enumerate(self.transformer_blocks):
390
  x = block(x, context=context[i], num_frames=num_frames)
391
- if self.use_linear:
392
- x = self.proj_out(x)
393
  x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
394
- if not self.use_linear:
395
- x = self.proj_out(x)
396
  return x + x_in
 
1
  import torch
2
  import torch.nn as nn
3
  import torch.nn.functional as F
 
4
 
5
  from inspect import isfunction
6
  from einops import rearrange, repeat
7
  from typing import Optional, Any
 
 
 
 
 
 
 
 
 
8
 
9
+ # require xformers
10
+ import xformers # type: ignore
11
+ import xformers.ops # type: ignore
 
12
 
13
+ from .util import checkpoint, zero_module
14
 
15
  def default(val, d):
16
  if val is not None:
 
47
  return self.net(x)
48
 
49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
  class MemoryEfficientCrossAttention(nn.Module):
51
  # https://github.com/MatthieuTPHR/diffusers/blob/d80b531ff8060ec1ea982b65a1b8df70f73aa67c/src/diffusers/models/attention.py#L223
52
+ def __init__(
53
+ self,
54
+ query_dim,
55
+ context_dim=None,
56
+ heads=8,
57
+ dim_head=64,
58
+ dropout=0.0,
59
+ ip_dim=0,
60
+ ip_weight=1,
61
+ ):
62
  super().__init__()
63
+
64
  inner_dim = dim_head * heads
65
  context_dim = default(context_dim, query_dim)
66
 
67
  self.heads = heads
68
  self.dim_head = dim_head
69
 
70
+ self.ip_dim = ip_dim
71
+ self.ip_weight = ip_weight
72
+
73
+ if self.ip_dim > 0:
74
+ self.to_k_ip = nn.Linear(context_dim, inner_dim, bias=False)
75
+ self.to_v_ip = nn.Linear(context_dim, inner_dim, bias=False)
76
+
77
  self.to_q = nn.Linear(query_dim, inner_dim, bias=False)
78
  self.to_k = nn.Linear(context_dim, inner_dim, bias=False)
79
  self.to_v = nn.Linear(context_dim, inner_dim, bias=False)
 
83
  )
84
  self.attention_op: Optional[Any] = None
85
 
86
+ def forward(self, x, context=None):
87
  q = self.to_q(x)
88
  context = default(context, x)
89
+
90
+ if self.ip_dim > 0:
91
+ # context dim [(b frame_num), (77 + img_token), 1024]
92
+ token_len = context.shape[1]
93
+ context_ip = context[:, -self.ip_dim :, :]
94
+ k_ip = self.to_k_ip(context_ip)
95
+ v_ip = self.to_v_ip(context_ip)
96
+ context = context[:, : (token_len - self.ip_dim), :]
97
+
98
  k = self.to_k(context)
99
  v = self.to_v(context)
100
 
 
113
  q, k, v, attn_bias=None, op=self.attention_op
114
  )
115
 
116
+ if self.ip_dim > 0:
117
+ k_ip, v_ip = map(
118
+ lambda t: t.unsqueeze(3)
119
+ .reshape(b, t.shape[1], self.heads, self.dim_head)
120
+ .permute(0, 2, 1, 3)
121
+ .reshape(b * self.heads, t.shape[1], self.dim_head)
122
+ .contiguous(),
123
+ (k_ip, v_ip),
124
+ )
125
+ # actually compute the attention, what we cannot get enough of
126
+ out_ip = xformers.ops.memory_efficient_attention(
127
+ q, k_ip, v_ip, attn_bias=None, op=self.attention_op
128
+ )
129
+ out = out + self.ip_weight * out_ip
130
+
131
  out = (
132
  out.unsqueeze(0)
133
  .reshape(b, self.heads, out.shape[1], self.dim_head)
 
137
  return self.to_out(out)
138
 
139
 
140
+ class BasicTransformerBlock3D(nn.Module):
141
+
 
 
 
 
142
  def __init__(
143
  self,
144
  dim,
145
  n_heads,
146
  d_head,
147
+ context_dim,
148
  dropout=0.0,
 
149
  gated_ff=True,
150
  checkpoint=True,
151
+ ip_dim=0,
152
+ ip_weight=1,
153
  ):
154
  super().__init__()
155
+
156
+ self.attn1 = MemoryEfficientCrossAttention(
 
 
 
157
  query_dim=dim,
158
+ context_dim=None, # self-attention
159
  heads=n_heads,
160
  dim_head=d_head,
161
  dropout=dropout,
162
+ )
 
163
  self.ff = FeedForward(dim, dropout=dropout, glu=gated_ff)
164
+ self.attn2 = MemoryEfficientCrossAttention(
165
  query_dim=dim,
166
  context_dim=context_dim,
167
  heads=n_heads,
168
  dim_head=d_head,
169
  dropout=dropout,
170
+ # ip only applies to cross-attention
171
+ ip_dim=ip_dim,
172
+ ip_weight=ip_weight,
173
+ )
174
  self.norm1 = nn.LayerNorm(dim)
175
  self.norm2 = nn.LayerNorm(dim)
176
  self.norm3 = nn.LayerNorm(dim)
177
  self.checkpoint = checkpoint
178
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  def forward(self, x, context=None, num_frames=1):
180
  return checkpoint(
181
  self._forward, (x, context, num_frames), self.parameters(), self.checkpoint
 
183
 
184
  def _forward(self, x, context=None, num_frames=1):
185
  x = rearrange(x, "(b f) l c -> b (f l) c", f=num_frames).contiguous()
186
+ x = self.attn1(self.norm1(x), context=None) + x
 
 
 
 
 
187
  x = rearrange(x, "b (f l) c -> (b f) l c", f=num_frames).contiguous()
188
  x = self.attn2(self.norm2(x), context=context) + x
189
  x = self.ff(self.norm3(x)) + x
 
191
 
192
 
193
  class SpatialTransformer3D(nn.Module):
 
194
 
195
  def __init__(
196
  self,
197
  in_channels,
198
  n_heads,
199
  d_head,
200
+ context_dim, # cross attention input dim
201
  depth=1,
202
  dropout=0.0,
203
+ ip_dim=0,
204
+ ip_weight=1,
 
205
  use_checkpoint=True,
206
  ):
207
  super().__init__()
208
+
209
  if not isinstance(context_dim, list):
210
  context_dim = [context_dim]
211
+
212
  self.in_channels = in_channels
213
+
214
  inner_dim = n_heads * d_head
215
  self.norm = nn.GroupNorm(
216
  num_groups=32, num_channels=in_channels, eps=1e-6, affine=True
217
  )
218
+ self.proj_in = nn.Linear(in_channels, inner_dim)
 
 
 
 
 
219
 
220
  self.transformer_blocks = nn.ModuleList(
221
  [
 
223
  inner_dim,
224
  n_heads,
225
  d_head,
 
226
  context_dim=context_dim[d],
227
+ dropout=dropout,
228
  checkpoint=use_checkpoint,
229
+ ip_dim=ip_dim,
230
+ ip_weight=ip_weight,
231
  )
232
  for d in range(depth)
233
  ]
234
  )
235
+
236
+ self.proj_out = zero_module(nn.Linear(in_channels, inner_dim))
237
+
 
 
 
 
238
 
239
  def forward(self, x, context=None, num_frames=1):
240
  # note: if no context is given, cross-attention defaults to self-attention
 
243
  b, c, h, w = x.shape
244
  x_in = x
245
  x = self.norm(x)
 
 
246
  x = rearrange(x, "b c h w -> b (h w) c").contiguous()
247
+ x = self.proj_in(x)
 
248
  for i, block in enumerate(self.transformer_blocks):
249
  x = block(x, context=context[i], num_frames=num_frames)
250
+ x = self.proj_out(x)
 
251
  x = rearrange(x, "b (h w) c -> b c h w", h=h, w=w).contiguous()
252
+
 
253
  return x + x_in
mvdream/models.py CHANGED
@@ -13,8 +13,10 @@ from .util import (
13
  zero_module,
14
  timestep_embedding,
15
  )
16
- from .attention import SpatialTransformer, SpatialTransformer3D
 
17
 
 
18
 
19
  class CondSequential(nn.Sequential):
20
  """
@@ -28,8 +30,6 @@ class CondSequential(nn.Sequential):
28
  x = layer(x, emb)
29
  elif isinstance(layer, SpatialTransformer3D):
30
  x = layer(x, context, num_frames=num_frames)
31
- elif isinstance(layer, SpatialTransformer):
32
- x = layer(x, context)
33
  else:
34
  x = layer(x)
35
  return x
@@ -274,6 +274,8 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
274
  disable_middle_self_attn=False,
275
  adm_in_channels=None,
276
  camera_dim=None,
 
 
277
  **kwargs,
278
  ):
279
  super().__init__()
@@ -305,9 +307,7 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
305
  "as a list/tuple (per-level) with the same length as channel_mult"
306
  )
307
  self.num_res_blocks = num_res_blocks
308
- if disable_self_attentions is not None:
309
- # should be a list of booleans, indicating whether to disable self-attention in TransformerBlocks or not
310
- assert len(disable_self_attentions) == len(channel_mult)
311
  if num_attention_blocks is not None:
312
  assert len(num_attention_blocks) == len(self.num_res_blocks)
313
  assert all(
@@ -334,6 +334,21 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
334
  self.num_heads_upsample = num_heads_upsample
335
  self.predict_codebook_ids = n_embed is not None
336
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
337
  time_embed_dim = model_channels * 4
338
  self.time_embed = nn.Sequential(
339
  nn.Linear(model_channels, time_embed_dim),
@@ -398,11 +413,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
398
  else:
399
  num_heads = ch // num_head_channels
400
  dim_head = num_head_channels
401
-
402
- if disable_self_attentions is not None:
403
- disabled_sa = disable_self_attentions[level]
404
- else:
405
- disabled_sa = False
406
 
407
  if num_attention_blocks is None or nr < num_attention_blocks[level]:
408
  layers.append(
@@ -410,10 +420,11 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
410
  ch,
411
  num_heads,
412
  dim_head,
413
- depth=transformer_depth,
414
  context_dim=context_dim,
415
- disable_self_attn=disabled_sa,
416
  use_checkpoint=use_checkpoint,
 
 
417
  )
418
  )
419
  self.input_blocks.append(CondSequential(*layers))
@@ -463,10 +474,11 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
463
  ch,
464
  num_heads,
465
  dim_head,
466
- depth=transformer_depth,
467
  context_dim=context_dim,
468
- disable_self_attn=disable_middle_self_attn,
469
  use_checkpoint=use_checkpoint,
 
 
470
  ),
471
  ResBlock(
472
  ch,
@@ -501,11 +513,6 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
501
  else:
502
  num_heads = ch // num_head_channels
503
  dim_head = num_head_channels
504
-
505
- if disable_self_attentions is not None:
506
- disabled_sa = disable_self_attentions[level]
507
- else:
508
- disabled_sa = False
509
 
510
  if num_attention_blocks is None or i < num_attention_blocks[level]:
511
  layers.append(
@@ -513,10 +520,11 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
513
  ch,
514
  num_heads,
515
  dim_head,
516
- depth=transformer_depth,
517
  context_dim=context_dim,
518
- disable_self_attn=disabled_sa,
519
  use_checkpoint=use_checkpoint,
 
 
520
  )
521
  )
522
  if level and i == self.num_res_blocks[level]:
@@ -556,9 +564,11 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
556
  x,
557
  timesteps=None,
558
  context=None,
559
- y: Optional[Tensor] = None,
560
  camera=None,
561
  num_frames=1,
 
 
562
  **kwargs,
563
  ):
564
  """
@@ -572,14 +582,14 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
572
  """
573
  assert (
574
  x.shape[0] % num_frames == 0
575
- ), "[UNet] input batch size must be dividable by num_frames!"
576
  assert (y is not None) == (
577
  self.num_classes is not None
578
  ), "must specify y if and only if the model is class-conditional"
 
579
  hs = []
580
- t_emb = timestep_embedding(
581
- timesteps, self.model_channels, repeat_only=False
582
- ).to(x.dtype)
583
 
584
  emb = self.time_embed(t_emb)
585
 
@@ -590,8 +600,13 @@ class MultiViewUNetModel(ModelMixin, ConfigMixin):
590
 
591
  # Add camera embeddings
592
  if camera is not None:
593
- assert camera.shape[0] == emb.shape[0]
594
  emb = emb + self.camera_embed(camera)
 
 
 
 
 
 
595
 
596
  h = x
597
  for module in self.input_blocks:
 
13
  zero_module,
14
  timestep_embedding,
15
  )
16
+ from .attention import SpatialTransformer3D
17
+ from .adaptor import Resampler, ImageProjModel
18
 
19
+ import kiui
20
 
21
  class CondSequential(nn.Sequential):
22
  """
 
30
  x = layer(x, emb)
31
  elif isinstance(layer, SpatialTransformer3D):
32
  x = layer(x, context, num_frames=num_frames)
 
 
33
  else:
34
  x = layer(x)
35
  return x
 
274
  disable_middle_self_attn=False,
275
  adm_in_channels=None,
276
  camera_dim=None,
277
+ ip_dim=0,
278
+ ip_weight=1.0,
279
  **kwargs,
280
  ):
281
  super().__init__()
 
307
  "as a list/tuple (per-level) with the same length as channel_mult"
308
  )
309
  self.num_res_blocks = num_res_blocks
310
+
 
 
311
  if num_attention_blocks is not None:
312
  assert len(num_attention_blocks) == len(self.num_res_blocks)
313
  assert all(
 
334
  self.num_heads_upsample = num_heads_upsample
335
  self.predict_codebook_ids = n_embed is not None
336
 
337
+ self.ip_dim = ip_dim
338
+ self.ip_weight = ip_weight
339
+
340
+ if self.ip_dim > 0:
341
+ self.image_embed = Resampler(
342
+ dim=context_dim,
343
+ depth=4,
344
+ dim_head=64,
345
+ heads=12,
346
+ num_queries=ip_dim, # num token
347
+ embedding_dim=1280,
348
+ output_dim=context_dim,
349
+ ff_mult=4,
350
+ )
351
+
352
  time_embed_dim = model_channels * 4
353
  self.time_embed = nn.Sequential(
354
  nn.Linear(model_channels, time_embed_dim),
 
413
  else:
414
  num_heads = ch // num_head_channels
415
  dim_head = num_head_channels
 
 
 
 
 
416
 
417
  if num_attention_blocks is None or nr < num_attention_blocks[level]:
418
  layers.append(
 
420
  ch,
421
  num_heads,
422
  dim_head,
 
423
  context_dim=context_dim,
424
+ depth=transformer_depth,
425
  use_checkpoint=use_checkpoint,
426
+ ip_dim=self.ip_dim,
427
+ ip_weight=self.ip_weight,
428
  )
429
  )
430
  self.input_blocks.append(CondSequential(*layers))
 
474
  ch,
475
  num_heads,
476
  dim_head,
 
477
  context_dim=context_dim,
478
+ depth=transformer_depth,
479
  use_checkpoint=use_checkpoint,
480
+ ip_dim=self.ip_dim,
481
+ ip_weight=self.ip_weight,
482
  ),
483
  ResBlock(
484
  ch,
 
513
  else:
514
  num_heads = ch // num_head_channels
515
  dim_head = num_head_channels
 
 
 
 
 
516
 
517
  if num_attention_blocks is None or i < num_attention_blocks[level]:
518
  layers.append(
 
520
  ch,
521
  num_heads,
522
  dim_head,
 
523
  context_dim=context_dim,
524
+ depth=transformer_depth,
525
  use_checkpoint=use_checkpoint,
526
+ ip_dim=self.ip_dim,
527
+ ip_weight=self.ip_weight,
528
  )
529
  )
530
  if level and i == self.num_res_blocks[level]:
 
564
  x,
565
  timesteps=None,
566
  context=None,
567
+ y=None,
568
  camera=None,
569
  num_frames=1,
570
+ ip=None,
571
+ ip_img=None,
572
  **kwargs,
573
  ):
574
  """
 
582
  """
583
  assert (
584
  x.shape[0] % num_frames == 0
585
+ ), "input batch size must be dividable by num_frames!"
586
  assert (y is not None) == (
587
  self.num_classes is not None
588
  ), "must specify y if and only if the model is class-conditional"
589
+
590
  hs = []
591
+
592
+ t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False).to(x.dtype)
 
593
 
594
  emb = self.time_embed(t_emb)
595
 
 
600
 
601
  # Add camera embeddings
602
  if camera is not None:
 
603
  emb = emb + self.camera_embed(camera)
604
+
605
+ # imagedream variant
606
+ if self.ip_dim > 0:
607
+ x[(num_frames - 1) :: num_frames, :, :, :] = ip_img
608
+ ip_emb = self.image_embed(ip)
609
+ context = torch.cat((context, ip_emb), 1)
610
 
611
  h = x
612
  for module in self.input_blocks:
mvdream/pipeline_mvdream.py CHANGED
@@ -1,8 +1,9 @@
1
  import torch
 
2
  import inspect
3
  import numpy as np
4
  from typing import Callable, List, Optional, Union
5
- from transformers import CLIPTextModel, CLIPTokenizer
6
  from diffusers import AutoencoderKL, DiffusionPipeline
7
  from diffusers.utils import (
8
  deprecate,
@@ -15,66 +16,17 @@ from diffusers.schedulers import DDIMScheduler
15
  from diffusers.utils.torch_utils import randn_tensor
16
 
17
  from .models import MultiViewUNetModel
 
18
 
19
- logger = logging.get_logger(__name__) # pylint: disable=invalid-name
20
-
21
 
22
- def create_camera_to_world_matrix(elevation, azimuth):
23
- elevation = np.radians(elevation)
24
- azimuth = np.radians(azimuth)
25
- # Convert elevation and azimuth angles to Cartesian coordinates on a unit sphere
26
- x = np.cos(elevation) * np.sin(azimuth)
27
- y = np.sin(elevation)
28
- z = np.cos(elevation) * np.cos(azimuth)
29
-
30
- # Calculate camera position, target, and up vectors
31
- camera_pos = np.array([x, y, z])
32
- target = np.array([0, 0, 0])
33
- up = np.array([0, 1, 0])
34
-
35
- # Construct view matrix
36
- forward = target - camera_pos
37
- forward /= np.linalg.norm(forward)
38
- right = np.cross(forward, up)
39
- right /= np.linalg.norm(right)
40
- new_up = np.cross(right, forward)
41
- new_up /= np.linalg.norm(new_up)
42
- cam2world = np.eye(4)
43
- cam2world[:3, :3] = np.array([right, new_up, -forward]).T
44
- cam2world[:3, 3] = camera_pos
45
- return cam2world
46
-
47
-
48
- def convert_opengl_to_blender(camera_matrix):
49
- if isinstance(camera_matrix, np.ndarray):
50
- # Construct transformation matrix to convert from OpenGL space to Blender space
51
- flip_yz = np.array([[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]])
52
- camera_matrix_blender = np.dot(flip_yz, camera_matrix)
53
- else:
54
- # Construct transformation matrix to convert from OpenGL space to Blender space
55
- flip_yz = torch.tensor(
56
- [[1, 0, 0, 0], [0, 0, -1, 0], [0, 1, 0, 0], [0, 0, 0, 1]]
57
- )
58
- if camera_matrix.ndim == 3:
59
- flip_yz = flip_yz.unsqueeze(0)
60
- camera_matrix_blender = torch.matmul(flip_yz.to(camera_matrix), camera_matrix)
61
- return camera_matrix_blender
62
 
63
 
64
- def get_camera(
65
- num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True
66
- ):
67
- angle_gap = azimuth_span / num_frames
68
- cameras = []
69
- for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
70
- camera_matrix = create_camera_to_world_matrix(elevation, azimuth)
71
- if blender_coord:
72
- camera_matrix = convert_opengl_to_blender(camera_matrix)
73
- cameras.append(camera_matrix.flatten())
74
- return torch.tensor(np.stack(cameras, 0)).float()
75
 
 
76
 
77
- class MVDreamPipeline(DiffusionPipeline):
78
  def __init__(
79
  self,
80
  vae: AutoencoderKL,
@@ -82,6 +34,9 @@ class MVDreamPipeline(DiffusionPipeline):
82
  tokenizer: CLIPTokenizer,
83
  text_encoder: CLIPTextModel,
84
  scheduler: DDIMScheduler,
 
 
 
85
  requires_safety_checker: bool = False,
86
  ):
87
  super().__init__()
@@ -123,6 +78,8 @@ class MVDreamPipeline(DiffusionPipeline):
123
  scheduler=scheduler,
124
  tokenizer=tokenizer,
125
  text_encoder=text_encoder,
 
 
126
  )
127
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
128
  self.register_to_config(requires_safety_checker=requires_safety_checker)
@@ -445,10 +402,42 @@ class MVDreamPipeline(DiffusionPipeline):
445
  latents = latents * self.scheduler.init_noise_sigma
446
  return latents
447
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
448
  @torch.no_grad()
449
  def __call__(
450
  self,
451
  prompt: str = "a car",
 
452
  height: int = 256,
453
  width: int = 256,
454
  num_inference_steps: int = 50,
@@ -457,10 +446,10 @@ class MVDreamPipeline(DiffusionPipeline):
457
  num_images_per_prompt: int = 1,
458
  eta: float = 0.0,
459
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
460
- output_type: Optional[str] = "image",
461
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
462
  callback_steps: int = 1,
463
- batch_size: int = 4,
464
  device=torch.device("cuda:0"),
465
  ):
466
  self.unet = self.unet.to(device=device)
@@ -477,7 +466,15 @@ class MVDreamPipeline(DiffusionPipeline):
477
  self.scheduler.set_timesteps(num_inference_steps, device=device)
478
  timesteps = self.scheduler.timesteps
479
 
480
- _prompt_embeds: torch.Tensor = self._encode_prompt(
 
 
 
 
 
 
 
 
481
  prompt=prompt,
482
  device=device,
483
  num_images_per_prompt=num_images_per_prompt,
@@ -487,8 +484,9 @@ class MVDreamPipeline(DiffusionPipeline):
487
  prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
488
 
489
  # Prepare latent variables
 
490
  latents: torch.Tensor = self.prepare_latents(
491
- batch_size * num_images_per_prompt,
492
  4,
493
  height,
494
  width,
@@ -498,9 +496,9 @@ class MVDreamPipeline(DiffusionPipeline):
498
  None,
499
  )
500
 
501
- camera = get_camera(batch_size).to(dtype=latents.dtype, device=device)
502
 
503
- # Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
504
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
505
 
506
  # Denoising loop
@@ -514,20 +512,21 @@ class MVDreamPipeline(DiffusionPipeline):
514
  latent_model_input, t
515
  )
516
 
 
 
 
 
 
 
 
 
 
 
 
 
 
517
  # predict the noise residual
518
- noise_pred = self.unet.forward(
519
- x=latent_model_input,
520
- timesteps=torch.tensor(
521
- [t] * 4 * multiplier,
522
- dtype=latent_model_input.dtype,
523
- device=device,
524
- ),
525
- context=torch.cat(
526
- [prompt_embeds_neg] * 4 + [prompt_embeds_pos] * 4
527
- ),
528
- num_frames=4,
529
- camera=torch.cat([camera] * multiplier),
530
- )
531
 
532
  # perform guidance
533
  if do_classifier_free_guidance:
@@ -537,7 +536,6 @@ class MVDreamPipeline(DiffusionPipeline):
537
  )
538
 
539
  # compute the previous noisy sample x_t -> x_t-1
540
- # latents = self.scheduler.step(noise_pred.to(dtype=torch.float32), t, latents.to(dtype=torch.float32)).prev_sample.to(prompt_embeds.dtype)
541
  latents: torch.Tensor = self.scheduler.step(
542
  noise_pred, t, latents, **extra_step_kwargs, return_dict=False
543
  )[0]
@@ -556,7 +554,7 @@ class MVDreamPipeline(DiffusionPipeline):
556
  elif output_type == "pil":
557
  image = self.decode_latents(latents)
558
  image = self.numpy_to_pil(image)
559
- else:
560
  image = self.decode_latents(latents)
561
 
562
  # Offload last model to CPU
 
1
  import torch
2
+ import torch.nn.functional as F
3
  import inspect
4
  import numpy as np
5
  from typing import Callable, List, Optional, Union
6
+ from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModel, CLIPImageProcessor
7
  from diffusers import AutoencoderKL, DiffusionPipeline
8
  from diffusers.utils import (
9
  deprecate,
 
16
  from diffusers.utils.torch_utils import randn_tensor
17
 
18
  from .models import MultiViewUNetModel
19
+ from .util import get_camera
20
 
21
+ import kiui
 
22
 
23
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
 
25
 
26
+ class MVDreamPipeline(DiffusionPipeline):
 
 
 
 
 
 
 
 
 
 
27
 
28
+ _optional_components = ["feature_extractor", "image_encoder"]
29
 
 
30
  def __init__(
31
  self,
32
  vae: AutoencoderKL,
 
34
  tokenizer: CLIPTokenizer,
35
  text_encoder: CLIPTextModel,
36
  scheduler: DDIMScheduler,
37
+ # imagedream variant
38
+ feature_extractor: CLIPImageProcessor,
39
+ image_encoder: CLIPVisionModel,
40
  requires_safety_checker: bool = False,
41
  ):
42
  super().__init__()
 
78
  scheduler=scheduler,
79
  tokenizer=tokenizer,
80
  text_encoder=text_encoder,
81
+ feature_extractor=feature_extractor,
82
+ image_encoder=image_encoder,
83
  )
84
  self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1)
85
  self.register_to_config(requires_safety_checker=requires_safety_checker)
 
402
  latents = latents * self.scheduler.init_noise_sigma
403
  return latents
404
 
405
+ def encode_image(self, image, device, num_images_per_prompt):
406
+ dtype = next(self.image_encoder.parameters()).dtype
407
+
408
+ image = (image * 255).astype(np.uint8)
409
+ image = self.feature_extractor(image, return_tensors="pt").pixel_values
410
+
411
+ image = image.to(device=device, dtype=dtype)
412
+
413
+ image_enc_hidden_states = self.image_encoder(image, output_hidden_states=True).hidden_states[-2]
414
+ image_enc_hidden_states = image_enc_hidden_states.repeat_interleave(num_images_per_prompt, dim=0)
415
+
416
+ # imagedream directly use zero as uncond image embeddings
417
+ uncond_image_enc_hidden_states = torch.zeros_like(image_enc_hidden_states)
418
+
419
+ return uncond_image_enc_hidden_states, image_enc_hidden_states
420
+
421
+ def encode_image_latents(self, image, device, num_images_per_prompt):
422
+
423
+ image = torch.from_numpy(image).unsqueeze(0).permute(0, 3, 1, 2) # [1, 3, H, W]
424
+ image = image.to(device=device)
425
+ image = F.interpolate(image, (256, 256), mode='bilinear', align_corners=False)
426
+ dtype = next(self.image_encoder.parameters()).dtype
427
+ image = image.to(dtype=dtype)
428
+
429
+ posterior = self.vae.encode(image).latent_dist
430
+
431
+ latents = posterior.sample() * self.vae.config.scaling_factor # [B, C, H, W]
432
+ latents = latents.repeat_interleave(num_images_per_prompt, dim=0)
433
+
434
+ return torch.zeros_like(latents), latents
435
+
436
  @torch.no_grad()
437
  def __call__(
438
  self,
439
  prompt: str = "a car",
440
+ image: Optional[np.ndarray] = None,
441
  height: int = 256,
442
  width: int = 256,
443
  num_inference_steps: int = 50,
 
446
  num_images_per_prompt: int = 1,
447
  eta: float = 0.0,
448
  generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
449
+ output_type: Optional[str] = "numpy", # pil, numpy, latents
450
  callback: Optional[Callable[[int, int, torch.FloatTensor], None]] = None,
451
  callback_steps: int = 1,
452
+ num_frames: int = 4,
453
  device=torch.device("cuda:0"),
454
  ):
455
  self.unet = self.unet.to(device=device)
 
466
  self.scheduler.set_timesteps(num_inference_steps, device=device)
467
  timesteps = self.scheduler.timesteps
468
 
469
+ # imagedream variant (TODO: debug)
470
+ if image is not None:
471
+ assert isinstance(image, np.ndarray) and image.dtype == np.float32
472
+
473
+ self.image_encoder = self.image_encoder.to(device=device)
474
+ image_embeds_neg, image_embeds_pos = self.encode_image(image, device, num_images_per_prompt)
475
+ image_latents_neg, image_latents_pos = self.encode_image_latents(image, device, num_images_per_prompt)
476
+
477
+ _prompt_embeds = self._encode_prompt(
478
  prompt=prompt,
479
  device=device,
480
  num_images_per_prompt=num_images_per_prompt,
 
484
  prompt_embeds_neg, prompt_embeds_pos = _prompt_embeds.chunk(2)
485
 
486
  # Prepare latent variables
487
+ actual_num_frames = num_frames if image is None else num_frames + 1
488
  latents: torch.Tensor = self.prepare_latents(
489
+ actual_num_frames * num_images_per_prompt,
490
  4,
491
  height,
492
  width,
 
496
  None,
497
  )
498
 
499
+ camera = get_camera(num_frames, extra_view=(actual_num_frames != num_frames)).to(dtype=latents.dtype, device=device)
500
 
501
+ # Prepare extra step kwargs.
502
  extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
503
 
504
  # Denoising loop
 
512
  latent_model_input, t
513
  )
514
 
515
+
516
+ unet_inputs = {
517
+ 'x': latent_model_input,
518
+ 'timesteps': torch.tensor([t] * actual_num_frames * multiplier, dtype=latent_model_input.dtype, device=device),
519
+ 'context': torch.cat([prompt_embeds_neg] * actual_num_frames + [prompt_embeds_pos] * actual_num_frames),
520
+ 'num_frames': actual_num_frames,
521
+ 'camera': torch.cat([camera] * multiplier),
522
+ }
523
+
524
+ if image is not None:
525
+ unet_inputs['ip'] = torch.cat([image_embeds_neg] * actual_num_frames + [image_embeds_pos] * actual_num_frames)
526
+ unet_inputs['ip_img'] = torch.cat([image_latents_neg] + [image_latents_pos]) # no repeat
527
+
528
  # predict the noise residual
529
+ noise_pred = self.unet.forward(**unet_inputs)
 
 
 
 
 
 
 
 
 
 
 
 
530
 
531
  # perform guidance
532
  if do_classifier_free_guidance:
 
536
  )
537
 
538
  # compute the previous noisy sample x_t -> x_t-1
 
539
  latents: torch.Tensor = self.scheduler.step(
540
  noise_pred, t, latents, **extra_step_kwargs, return_dict=False
541
  )[0]
 
554
  elif output_type == "pil":
555
  image = self.decode_latents(latents)
556
  image = self.numpy_to_pil(image)
557
+ else: # numpy
558
  image = self.decode_latents(latents)
559
 
560
  # Offload last model to CPU
mvdream/util.py CHANGED
@@ -1,8 +1,32 @@
1
  import math
2
  import torch
3
  import torch.nn as nn
 
4
  from einops import repeat
5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
  def checkpoint(func, inputs, params, flag):
8
  """
 
1
  import math
2
  import torch
3
  import torch.nn as nn
4
+ import numpy as np
5
  from einops import repeat
6
 
7
+ from kiui.cam import orbit_camera
8
+
9
+ def get_camera(
10
+ num_frames, elevation=15, azimuth_start=0, azimuth_span=360, blender_coord=True, extra_view=False,
11
+ ):
12
+ angle_gap = azimuth_span / num_frames
13
+ cameras = []
14
+ for azimuth in np.arange(azimuth_start, azimuth_span + azimuth_start, angle_gap):
15
+
16
+ pose = orbit_camera(-elevation, azimuth, radius=1) # kiui's elevation is negated, [4, 4]
17
+
18
+ # opengl to blender
19
+ if blender_coord:
20
+ pose[2] *= -1
21
+ pose[[1, 2]] = pose[[2, 1]]
22
+
23
+ cameras.append(pose.flatten())
24
+
25
+ if extra_view:
26
+ cameras.append(np.zeros_like(cameras[0]))
27
+
28
+ return torch.from_numpy(np.stack(cameras, axis=0)).float() # [num_frames, 16]
29
+
30
 
31
  def checkpoint(func, inputs, params, flag):
32
  """
run_imagedream.py ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import kiui
3
+ import numpy as np
4
+ import argparse
5
+ from mvdream.pipeline_mvdream import MVDreamPipeline
6
+
7
+ pipe = MVDreamPipeline.from_pretrained(
8
+ "./weights_imagedream", # local weights
9
+ # "ashawkey/mvdream-sd2.1-diffusers",
10
+ torch_dtype=torch.float16
11
+ )
12
+ pipe = pipe.to("cuda")
13
+
14
+
15
+ parser = argparse.ArgumentParser(description="ImageDream")
16
+ parser.add_argument("image", type=str, default='data/anya_rgba.png')
17
+ parser.add_argument("--prompt", type=str, default="")
18
+ args = parser.parse_args()
19
+
20
+ while True:
21
+ input_image = kiui.read_image(args.image, mode='float')
22
+ image = pipe(args.prompt, input_image)
23
+ grid = np.concatenate(
24
+ [
25
+ np.concatenate([image[0], image[2]], axis=0),
26
+ np.concatenate([image[1], image[3]], axis=0),
27
+ ],
28
+ axis=1,
29
+ )
30
+ # kiui.vis.plot_image(grid)
31
+ kiui.write_image('test_imagedream.jpg', grid)
32
+ break
main.py → run_mvdream.py RENAMED
@@ -25,4 +25,6 @@ while True:
25
  ],
26
  axis=1,
27
  )
28
- kiui.vis.plot_image(grid)
 
 
 
25
  ],
26
  axis=1,
27
  )
28
+ # kiui.vis.plot_image(grid)
29
+ kiui.write_image('test_mvdream.jpg', grid)
30
+ break