linoyts HF staff commited on
Commit
d5a8945
1 Parent(s): e10cbeb

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +15 -11
app.py CHANGED
@@ -10,30 +10,31 @@ flash_pipe.scheduler = EulerDiscreteScheduler.from_config(flash_pipe.scheduler.c
10
  clip_slider = CLIPSliderXL(flash_pipe, device=torch.device("cuda"), iterations=50)
11
 
12
  @spaces.GPU
13
- def generate(slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2):
14
 
15
  # check if avg diff for directions need to be re-calculated
16
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
17
  clip_slider.avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1])
18
  x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
19
- print("clip_slider.avg_diff[0]", clip_slider.avg_diff[0])
20
- print("clip_slider.avg_diff[1]", clip_slider.avg_diff[1])
21
  if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
22
  clip_slider.avg_diff_2nd = clip_slider.find_latent_direction(slider_y[0], slider_y[1])
23
  y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
24
 
 
25
  comma_concepts_x = ', '.join(slider_x)
26
  comma_concepts_y = ', '.join(slider_y)
27
 
28
- image = clip_slider.generate(prompt, scale=0, scale_2nd=0, num_inference_steps=8)
 
29
 
30
- return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, image
31
 
32
- def update_x(x,y,prompt):
33
  image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
34
  return image
35
 
36
- def update_y(x,y,prompt):
37
  image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
38
  return image
39
 
@@ -67,6 +68,9 @@ with gr.Blocks(css=css) as demo:
67
  x_concept_2 = gr.State("")
68
  y_concept_1 = gr.State("")
69
  y_concept_2 = gr.State("")
 
 
 
70
 
71
  with gr.Row():
72
  with gr.Column():
@@ -80,10 +84,10 @@ with gr.Blocks(css=css) as demo:
80
  output_image = gr.Image(elem_id="image_out")
81
 
82
  submit.click(fn=generate,
83
- inputs=[slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2],
84
- outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, output_image])
85
- x.change(fn=update_x, inputs=[x,y, prompt], outputs=[output_image])
86
- y.change(fn=update_y, inputs=[x,y, prompt], outputs=[output_image])
87
 
88
  if __name__ == "__main__":
89
  demo.launch()
 
10
  clip_slider = CLIPSliderXL(flash_pipe, device=torch.device("cuda"), iterations=50)
11
 
12
  @spaces.GPU
13
+ def generate(slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2, , avg_diff_x, avg_diff_y):
14
 
15
  # check if avg diff for directions need to be re-calculated
16
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
17
  clip_slider.avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1])
18
  x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
19
+
 
20
  if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
21
  clip_slider.avg_diff_2nd = clip_slider.find_latent_direction(slider_y[0], slider_y[1])
22
  y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
23
 
24
+ image = clip_slider.generate(prompt, scale=0, scale_2nd=0, num_inference_steps=8)
25
  comma_concepts_x = ', '.join(slider_x)
26
  comma_concepts_y = ', '.join(slider_y)
27
 
28
+ avg_diff_x = clip_slider.avg_diff.cpu()
29
+ avg_diff_y = clip_slider.avg_diff_2nd.cpu()
30
 
31
+ return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, image
32
 
33
+ def update_x(x,y,prompt, avg_diff_x, avg_diff_y):
34
  image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
35
  return image
36
 
37
+ def update_y(x,y,prompt, avg_diff_x, avg_diff_y):
38
  image = clip_slider.generate(prompt, scale=x, scale_2nd=y, num_inference_steps=8)
39
  return image
40
 
 
68
  x_concept_2 = gr.State("")
69
  y_concept_1 = gr.State("")
70
  y_concept_2 = gr.State("")
71
+
72
+ avg_diff_x = gr.State()
73
+ avg_diff_y = gr.State()
74
 
75
  with gr.Row():
76
  with gr.Column():
 
84
  output_image = gr.Image(elem_id="image_out")
85
 
86
  submit.click(fn=generate,
87
+ inputs=[slider_x, slider_y, prompt, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y],
88
+ outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image])
89
+ x.change(fn=update_x, inputs=[x,y, prompt, avg_diff_x, avg_diff_y], outputs=[output_image])
90
+ y.change(fn=update_y, inputs=[x,y, prompt, avg_diff_x, avg_diff_y], outputs=[output_image])
91
 
92
  if __name__ == "__main__":
93
  demo.launch()