multimodalart HF staff commited on
Commit
3d9ac9f
1 Parent(s): 166b3f2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +49 -13
app.py CHANGED
@@ -24,7 +24,7 @@ pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell",
24
  torch_dtype=torch.bfloat16)
25
 
26
  pipe.transformer.to(memory_format=torch.channels_last)
27
- pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
28
  #pipe.enable_model_cpu_offload()
29
  clip_slider = CLIPSliderFlux(pipe, device=torch.device("cuda"))
30
 
@@ -35,13 +35,26 @@ controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Canny-alpha'
35
  # pipe_controlnet = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
36
  # t5_slider_controlnet = T5SliderFlux(sd_pipe=pipe_controlnet,device=torch.device("cuda"))
37
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
  @spaces.GPU(duration=200)
39
  def generate(concept_1, concept_2, scale, prompt, seed, recalc_directions, iterations, steps, interm_steps, guidance_scale,
40
  x_concept_1, x_concept_2,
41
  avg_diff_x,
42
  img2img_type = None, img = None,
43
  controlnet_scale= None, ip_adapter_scale=None,
44
-
45
  ):
46
  slider_x = [concept_1, concept_2]
47
  # check if avg diff for directions need to be re-calculated
@@ -58,7 +71,7 @@ def generate(concept_1, concept_2, scale, prompt, seed, recalc_directions, itera
58
  high_scale = scale
59
  low_scale = -1 * scale
60
  for i in range(interm_steps):
61
- cur_scale = low_scale + (high_scale - low_scale) * i / (steps - 1)
62
  image = clip_slider.generate(prompt,
63
  #guidance_scale=guidance_scale,
64
  scale=cur_scale, seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
@@ -69,15 +82,19 @@ def generate(concept_1, concept_2, scale, prompt, seed, recalc_directions, itera
69
 
70
  comma_concepts_x = f"{slider_x[1]}, {slider_x[0]}"
71
 
 
 
 
 
72
  avg_diff_x = avg_diff.cpu()
73
-
74
- return gr.update(label=comma_concepts_x, interactive=True, value=scale), x_concept_1, x_concept_2, avg_diff_x, export_to_gif(images, "clip.gif", fps=5), canvas
75
 
76
  @spaces.GPU
77
  def update_scales(x,prompt,seed, steps, interm_steps, guidance_scale,
78
  avg_diff_x,
79
  img2img_type = None, img = None,
80
- controlnet_scale= None, ip_adapter_scale=None,):
81
  print("Hola", x)
82
  avg_diff = avg_diff_x.cuda()
83
 
@@ -102,9 +119,22 @@ def update_scales(x,prompt,seed, steps, interm_steps, guidance_scale,
102
  canvas = Image.new('RGB', (256*interm_steps, 256))
103
  for i, im in enumerate(images):
104
  canvas.paste(im.resize((256,256)), (256 * i, 0))
105
- return export_to_gif(images, "clip.gif", fps=5), canvas
 
 
106
 
 
 
 
107
 
 
 
 
 
 
 
 
 
108
  def reset_recalc_directions():
109
  return True
110
 
@@ -160,6 +190,7 @@ with gr.Blocks() as demo:
160
 
161
  x_concept_1 = gr.State("")
162
  x_concept_2 = gr.State("")
 
163
  # y_concept_1 = gr.State("")
164
  # y_concept_2 = gr.State("")
165
 
@@ -181,9 +212,14 @@ with gr.Blocks() as demo:
181
  submit = gr.Button("find directions")
182
  with gr.Column():
183
  with gr.Group(elem_id="group"):
 
 
184
  #y = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
185
- output_image = gr.Image(elem_id="image_out", label="Gif")
186
- image_seq = gr.Image(label="Strip")
 
 
 
187
  # with gr.Row():
188
  # generate_butt = gr.Button("generate")
189
 
@@ -250,17 +286,17 @@ with gr.Blocks() as demo:
250
  # inputs=[slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y],
251
  # outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image])
252
  submit.click(fn=generate,
253
- inputs=[concept_1, concept_2, x, prompt, seed, recalc_directions, iterations, steps, interm_steps, guidance_scale, x_concept_1, x_concept_2, avg_diff_x],
254
- outputs=[x, x_concept_1, x_concept_2, avg_diff_x, output_image, image_seq])
255
 
256
  iterations.change(fn=reset_recalc_directions, outputs=[recalc_directions])
257
  seed.change(fn=reset_recalc_directions, outputs=[recalc_directions])
258
- x.release(fn=update_scales, inputs=[x, prompt, seed, steps, interm_steps, guidance_scale, avg_diff_x], outputs=[output_image, image_seq], trigger_mode='always_last')
259
  # generate_butt_a.click(fn=update_scales, inputs=[x_a,y_a, prompt_a, seed_a, steps_a, guidance_scale_a, avg_diff_x, avg_diff_y, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale], outputs=[output_image_a])
260
  # submit_a.click(fn=generate,
261
  # inputs=[slider_x_a, slider_y_a, prompt_a, seed_a, iterations_a, steps_a, guidance_scale_a, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale],
262
  # outputs=[x_a, y_a, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image_a])
263
-
264
 
265
  if __name__ == "__main__":
266
  demo.launch()
 
24
  torch_dtype=torch.bfloat16)
25
 
26
  pipe.transformer.to(memory_format=torch.channels_last)
27
+ #pipe.transformer = torch.compile(pipe.transformer, mode="max-autotune", fullgraph=True)
28
  #pipe.enable_model_cpu_offload()
29
  clip_slider = CLIPSliderFlux(pipe, device=torch.device("cuda"))
30
 
 
35
  # pipe_controlnet = FluxControlNetPipeline.from_pretrained(base_model, controlnet=controlnet, torch_dtype=torch.bfloat16)
36
  # t5_slider_controlnet = T5SliderFlux(sd_pipe=pipe_controlnet,device=torch.device("cuda"))
37
 
38
+ def convert_to_centered_scale(num):
39
+ if num <= 0:
40
+ raise ValueError("Input must be a positive integer")
41
+
42
+ if num % 2 == 0: # even
43
+ start = -(num // 2 - 1)
44
+ end = num // 2
45
+ else: # odd
46
+ start = -(num // 2)
47
+ end = num // 2
48
+
49
+ return tuple(range(start, end + 1))
50
+
51
  @spaces.GPU(duration=200)
52
  def generate(concept_1, concept_2, scale, prompt, seed, recalc_directions, iterations, steps, interm_steps, guidance_scale,
53
  x_concept_1, x_concept_2,
54
  avg_diff_x,
55
  img2img_type = None, img = None,
56
  controlnet_scale= None, ip_adapter_scale=None,
57
+ total_images
58
  ):
59
  slider_x = [concept_1, concept_2]
60
  # check if avg diff for directions need to be re-calculated
 
71
  high_scale = scale
72
  low_scale = -1 * scale
73
  for i in range(interm_steps):
74
+ cur_scale = low_scale + (high_scale - low_scale) * i / (interm_steps - 1)
75
  image = clip_slider.generate(prompt,
76
  #guidance_scale=guidance_scale,
77
  scale=cur_scale, seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
 
82
 
83
  comma_concepts_x = f"{slider_x[1]}, {slider_x[0]}"
84
 
85
+ scale_min = convert_to_centered_scale(interm_steps)[0]
86
+ scale_max = convert_to_centered_scale(interm_steps)[-1]
87
+
88
+ post_generation_slider_update = gr.update(minimum=scale_min, maximum=scale_max, visible=True)
89
  avg_diff_x = avg_diff.cpu()
90
+
91
+ return gr.update(label=comma_concepts_x, interactive=True, value=scale), x_concept_1, x_concept_2, avg_diff_x, export_to_gif(images, "clip.gif", fps=5), canvas, images, post_generation_slider_update
92
 
93
  @spaces.GPU
94
  def update_scales(x,prompt,seed, steps, interm_steps, guidance_scale,
95
  avg_diff_x,
96
  img2img_type = None, img = None,
97
+ controlnet_scale= None, ip_adapter_scale=None, total_images):
98
  print("Hola", x)
99
  avg_diff = avg_diff_x.cuda()
100
 
 
119
  canvas = Image.new('RGB', (256*interm_steps, 256))
120
  for i, im in enumerate(images):
121
  canvas.paste(im.resize((256,256)), (256 * i, 0))
122
+
123
+ scale_min = convert_to_centered_scale(interm_steps)[0]
124
+ scale_max = convert_to_centered_scale(interm_steps)[-1]
125
 
126
+ post_generation_slider_update = gr.update(minimum=scale_min, maximum=scale_max, visible=True)
127
+
128
+ return export_to_gif(images, "clip.gif", fps=5), canvas, images, post_generation_slider_update
129
 
130
+ def update_pre_generated_images(slider_value, total_images):
131
+ number_images = len(total_images)
132
+ if(number_images > 0):
133
+ scale_tuple = convert_to_centered_scale(number_images)
134
+ return total_images[scale_tuple.index(slider_value)]
135
+ else:
136
+ return None
137
+
138
  def reset_recalc_directions():
139
  return True
140
 
 
190
 
191
  x_concept_1 = gr.State("")
192
  x_concept_2 = gr.State("")
193
+ total_images = gr.State([])
194
  # y_concept_1 = gr.State("")
195
  # y_concept_2 = gr.State("")
196
 
 
212
  submit = gr.Button("find directions")
213
  with gr.Column():
214
  with gr.Group(elem_id="group"):
215
+ post_generation_image = gr.Image(label="Generated Images")
216
+ post_generation_slider = gr.Slider(minimum=-2, maximum=2, value=0, step=1, interactive=False)
217
  #y = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
218
+ with gr.Row():
219
+ with gr.Column(scale=4):
220
+ output_image = gr.Image(elem_id="image_out", label="Gif")
221
+ with gr.Column(scale=1):
222
+ image_seq = gr.Image(label="Strip")
223
  # with gr.Row():
224
  # generate_butt = gr.Button("generate")
225
 
 
286
  # inputs=[slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y],
287
  # outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image])
288
  submit.click(fn=generate,
289
+ inputs=[concept_1, concept_2, x, prompt, seed, recalc_directions, iterations, steps, interm_steps, guidance_scale, x_concept_1, x_concept_2, avg_diff_x, total_images],
290
+ outputs=[x, x_concept_1, x_concept_2, avg_diff_x, output_image, image_seq, total_images, post_generation_slider])
291
 
292
  iterations.change(fn=reset_recalc_directions, outputs=[recalc_directions])
293
  seed.change(fn=reset_recalc_directions, outputs=[recalc_directions])
294
+ x.release(fn=update_scales, inputs=[x, prompt, seed, steps, interm_steps, guidance_scale, avg_diff_x, total_images], outputs=[output_image, image_seq, total_images, post_generation_slider], trigger_mode='always_last')
295
  # generate_butt_a.click(fn=update_scales, inputs=[x_a,y_a, prompt_a, seed_a, steps_a, guidance_scale_a, avg_diff_x, avg_diff_y, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale], outputs=[output_image_a])
296
  # submit_a.click(fn=generate,
297
  # inputs=[slider_x_a, slider_y_a, prompt_a, seed_a, iterations_a, steps_a, guidance_scale_a, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale],
298
  # outputs=[x_a, y_a, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image_a])
299
+ post_generation_slider.release(fn=update_pre_generated_images, inputs=[post_generation_slider, total_images], outputs=[post_generation_image], trigger_mode='always_last')
300
 
301
  if __name__ == "__main__":
302
  demo.launch()