bluestarburst commited on
Commit
83f7703
1 Parent(s): 96eabf2

Upload folder using huggingface_hub

Browse files
animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc CHANGED
Binary files a/animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc and b/animatediff/utils/__pycache__/convert_from_ckpt.cpython-310.pyc differ
 
animatediff/utils/convert_from_ckpt.py CHANGED
@@ -198,20 +198,21 @@ def assign_to_checkpoint(
198
  new_path = new_path.replace(replacement["old"], replacement["new"])
199
 
200
  # proj_attn.weight has to be converted from conv 1D to linear
201
- if "proj_attn.weight" in new_path:
202
- checkpoint[new_path] = old_checkpoint[path["old"]][:, :, 0]
 
203
  else:
204
  checkpoint[new_path] = old_checkpoint[path["old"]]
205
 
206
 
207
  def conv_attn_to_linear(checkpoint):
208
  keys = list(checkpoint.keys())
209
- attn_keys = ["query.weight", "key.weight", "value.weight"]
210
  for key in keys:
211
  if ".".join(key.split(".")[-2:]) in attn_keys:
212
  if checkpoint[key].ndim > 2:
213
  checkpoint[key] = checkpoint[key][:, :, 0, 0]
214
- elif "proj_attn.weight" in key:
215
  if checkpoint[key].ndim > 2:
216
  checkpoint[key] = checkpoint[key][:, :, 0]
217
 
@@ -632,7 +633,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
632
  oldKey = {"old": "key", "new": "to_k"}
633
  oldQuery = {"old": "query", "new": "to_q"}
634
  oldValue = {"old": "value", "new": "to_v"}
635
- oldOut = {"old": "proj_attn", "new": "to_out"}
636
  assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
637
  conv_attn_to_linear(new_checkpoint)
638
 
@@ -669,7 +670,7 @@ def convert_ldm_vae_checkpoint(checkpoint, config):
669
  oldKey = {"old": "key", "new": "to_k"}
670
  oldQuery = {"old": "query", "new": "to_q"}
671
  oldValue = {"old": "value", "new": "to_v"}
672
- oldOut = {"old": "proj_attn", "new": "to_out"}
673
  assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
674
  conv_attn_to_linear(new_checkpoint)
675
  return new_checkpoint
 
198
  new_path = new_path.replace(replacement["old"], replacement["new"])
199
 
200
  # proj_attn.weight has to be converted from conv 1D to linear
201
+ if "to_out.0.weight" in new_path and "decoder" in new_path:
202
+ # turn [512, 512, 1] into [512, 512]
203
+ checkpoint[new_path] = old_checkpoint[path["old"]].squeeze(-1)
204
  else:
205
  checkpoint[new_path] = old_checkpoint[path["old"]]
206
 
207
 
208
  def conv_attn_to_linear(checkpoint):
209
  keys = list(checkpoint.keys())
210
+ attn_keys = ["to_q.weight", "to_k.weight", "to_v.weight"]
211
  for key in keys:
212
  if ".".join(key.split(".")[-2:]) in attn_keys:
213
  if checkpoint[key].ndim > 2:
214
  checkpoint[key] = checkpoint[key][:, :, 0, 0]
215
+ elif "to_out.0.weight" in key:
216
  if checkpoint[key].ndim > 2:
217
  checkpoint[key] = checkpoint[key][:, :, 0]
218
 
 
633
  oldKey = {"old": "key", "new": "to_k"}
634
  oldQuery = {"old": "query", "new": "to_q"}
635
  oldValue = {"old": "value", "new": "to_v"}
636
+ oldOut = {"old": "proj_attn", "new": "to_out.0"}
637
  assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
638
  conv_attn_to_linear(new_checkpoint)
639
 
 
670
  oldKey = {"old": "key", "new": "to_k"}
671
  oldQuery = {"old": "query", "new": "to_q"}
672
  oldValue = {"old": "value", "new": "to_v"}
673
+ oldOut = {"old": "proj_attn", "new": "to_out.0"}
674
  assign_to_checkpoint(paths, new_checkpoint, vae_state_dict, additional_replacements=[meta_path, oldKey, oldQuery, oldValue, oldOut], config=config)
675
  conv_attn_to_linear(new_checkpoint)
676
  return new_checkpoint