linoyts HF staff commited on
Commit
54507b8
β€’
1 Parent(s): 6f3bf64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -7
app.py CHANGED
@@ -626,11 +626,10 @@ class LEditsPPPipelineStableDiffusionXL(
626
  else:
627
  # "2" because SDXL always indexes from the penultimate layer.
628
  edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
629
-
630
- print("SHALOM???")
631
- if avg_diff is not None and avg_diff_2 is not None:
632
- #scale=3
633
- print("SHALOM")
634
  normed_prompt_embeds = edit_concepts_embeds / edit_concepts_embeds.norm(dim=-1, keepdim=True)
635
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
636
  if i == 0:
@@ -639,14 +638,26 @@ class LEditsPPPipelineStableDiffusionXL(
639
  standard_weights = torch.ones_like(weights)
640
 
641
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
642
- edit_concepts_embeds = edit_concepts_embeds + (weights * avg_diff[None, :].repeat(1,tokenizer.model_max_length, 1) * scale)
 
 
 
 
 
 
643
  else:
644
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
645
 
646
  standard_weights = torch.ones_like(weights)
647
 
648
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
649
- edit_concepts_embeds = edit_concepts_embeds + (weights * avg_diff_2[None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
 
 
 
 
 
 
650
 
651
  edit_prompt_embeds_list.append(edit_concepts_embeds)
652
  i+=1
 
626
  else:
627
  # "2" because SDXL always indexes from the penultimate layer.
628
  edit_concepts_embeds = edit_concepts_embeds.hidden_states[-(clip_skip + 2)]
629
+
630
+
631
+ if avg_diff is not None:
632
+
 
633
  normed_prompt_embeds = edit_concepts_embeds / edit_concepts_embeds.norm(dim=-1, keepdim=True)
634
  sims = normed_prompt_embeds[0] @ normed_prompt_embeds[0].T
635
  if i == 0:
 
638
  standard_weights = torch.ones_like(weights)
639
 
640
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
641
+ edit_concepts_embeds = edit_concepts_embeds + (
642
+ weights * avg_diff[0][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
643
+
644
+ if avg_diff_2nd is not None:
645
+ edit_concepts_embeds += (weights * avg_diff_2nd[0][None, :].repeat(1,
646
+ self.pipe.tokenizer.model_max_length,
647
+ 1) * scale_2nd)
648
  else:
649
  weights = sims[toks.argmax(), :][None, :, None].repeat(1, 1, 1280)
650
 
651
  standard_weights = torch.ones_like(weights)
652
 
653
  weights = standard_weights + (weights - standard_weights) * correlation_weight_factor
654
+ edit_concepts_embeds = edit_concepts_embeds + (
655
+ weights * avg_diff[1][None, :].repeat(1, tokenizer.model_max_length, 1) * scale)
656
+ if avg_diff_2nd is not None:
657
+ edit_concepts_embeds += (weights * avg_diff_2nd[1][None, :].repeat(1,
658
+ self.pipe.tokenizer_2.model_max_length,
659
+ 1) * scale_2nd)
660
+
661
 
662
  edit_prompt_embeds_list.append(edit_concepts_embeds)
663
  i+=1