linoyts HF staff commited on
Commit
1eb5467
β€’
1 Parent(s): dc736bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -91
app.py CHANGED
@@ -1,6 +1,6 @@
1
  import gradio as gr
2
  import spaces
3
- from clip_slider_pipeline import T5SliderFlux
4
  from diffusers import FluxPipeline
5
  import torch
6
  import time
@@ -22,7 +22,7 @@ def process_controlnet_img(image):
22
  pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell",
23
  torch_dtype=torch.bfloat16)
24
  #pipe.enable_model_cpu_offload()
25
- t5_slider = T5SliderFlux(pipe, device=torch.device("cuda"))
26
 
27
  base_model = 'black-forest-labs/FLUX.1-schnell'
28
  controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Canny-alpha'
@@ -32,10 +32,9 @@ controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Canny-alpha'
32
 
33
 
34
  @spaces.GPU(duration=200)
35
- def generate(slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale,
36
- x_concept_1, x_concept_2, y_concept_1, y_concept_2,
37
  avg_diff_x,
38
- avg_diff_y,correlation,
39
  img2img_type = None, img = None,
40
  controlnet_scale= None, ip_adapter_scale=None,
41
 
@@ -47,51 +46,43 @@ def generate(slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale
47
  print("x_concept_1", x_concept_1, "x_concept_2", x_concept_2)
48
 
49
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
50
- avg_diff = t5_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations).to(torch.float16)
51
  x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
52
-
53
-
54
- if not sorted(slider_y) == sorted([y_concept_1, y_concept_2]):
55
- avg_diff_2nd = t5_slider.find_latent_direction(slider_y[0], slider_y[1], num_iterations=iterations).to(torch.float16)
56
- y_concept_1, y_concept_2 = slider_y[0], slider_y[1]
57
- end_time = time.time()
58
  print(f"direction time: {end_time - start_time:.2f} ms")
59
 
60
  start_time = time.time()
61
 
62
  if img2img_type=="controlnet canny" and img is not None:
63
  control_img = process_controlnet_img(img)
64
- image = t5_slider_controlnet.generate(prompt, correlation_weight_factor=correlation, guidance_scale=guidance_scale, image=control_img, controlnet_conditioning_scale =controlnet_scale, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=avg_diff, avg_diff_2nd=avg_diff_2nd)
65
  elif img2img_type=="ip adapter" and img is not None:
66
- image = t5_slider.generate(prompt, guidance_scale=guidance_scale, correlation_weight_factor=correlation, ip_adapter_image=img, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=avg_diff, avg_diff_2nd=avg_diff_2nd)
67
  else: # text to image
68
- image = t5_slider.generate(prompt, guidance_scale=guidance_scale, correlation_weight_factor=correlation, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=avg_diff, avg_diff_2nd=avg_diff_2nd)
69
 
70
  end_time = time.time()
71
  print(f"generation time: {end_time - start_time:.2f} ms")
72
 
73
  comma_concepts_x = ', '.join(slider_x)
74
- comma_concepts_y = ', '.join(slider_y)
75
 
76
  avg_diff_x = avg_diff.cpu()
77
- avg_diff_y = avg_diff_2nd.cpu()
78
 
79
- return gr.update(label=comma_concepts_x, interactive=True),gr.update(label=comma_concepts_y, interactive=True), x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, image
80
 
81
  @spaces.GPU
82
- def update_scales(x,y,prompt,seed, steps, guidance_scale,
83
- avg_diff_x, avg_diff_y,
84
  img2img_type = None, img = None,
85
  controlnet_scale= None, ip_adapter_scale=None,):
86
  avg_diff = avg_diff_x.cuda()
87
- avg_diff_2nd = avg_diff_y.cuda()
88
  if img2img_type=="controlnet canny" and img is not None:
89
  control_img = process_controlnet_img(img)
90
- image = t5_slider_controlnet.generate(prompt, guidance_scale=guidance_scale, image=control_img, controlnet_conditioning_scale =controlnet_scale, scale=x, scale_2nd=y, seed=seed, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
91
  elif img2img_type=="ip adapter" and img is not None:
92
- image = t5_slider.generate(prompt, guidance_scale=guidance_scale, ip_adapter_image=img, scale=x, scale_2nd=y, seed=seed, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
93
  else:
94
- image = t5_slider.generate(prompt, guidance_scale=guidance_scale, scale=x, scale_2nd=y, seed=seed, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
95
  return image
96
 
97
 
@@ -103,7 +94,7 @@ def update_x(x,y,prompt,seed, steps,
103
  img = None):
104
  avg_diff = avg_diff_x.cuda()
105
  avg_diff_2nd = avg_diff_y.cuda()
106
- image = t5_slider.generate(prompt, scale=x, scale_2nd=y, seed=seed, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
107
  return image
108
 
109
  @spaces.GPU
@@ -113,7 +104,7 @@ def update_y(x,y,prompt,seed, steps,
113
  img = None):
114
  avg_diff = avg_diff_x.cuda()
115
  avg_diff_2nd = avg_diff_y.cuda()
116
- image = t5_slider.generate(prompt, scale=x, scale_2nd=y, seed=seed, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
117
  return image
118
 
119
 
@@ -146,29 +137,29 @@ with gr.Blocks(css=css) as demo:
146
 
147
  x_concept_1 = gr.State("")
148
  x_concept_2 = gr.State("")
149
- y_concept_1 = gr.State("")
150
- y_concept_2 = gr.State("")
151
 
152
  avg_diff_x = gr.State()
153
- avg_diff_y = gr.State()
154
 
155
  with gr.Tab("text2image"):
156
  with gr.Row():
157
  with gr.Column():
158
- slider_x = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
159
- slider_y = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
160
  prompt = gr.Textbox(label="Prompt")
161
  submit = gr.Button("find directions")
162
  with gr.Column():
163
  with gr.Group(elem_id="group"):
164
- x = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="x", interactive=False)
165
- y = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
166
  output_image = gr.Image(elem_id="image_out")
167
  with gr.Row():
168
  generate_butt = gr.Button("generate")
169
 
170
  with gr.Accordion(label="advanced options", open=False):
171
- iterations = gr.Slider(label = "num iterations", minimum=0, value=200, maximum=400)
172
  steps = gr.Slider(label = "num inference steps", minimum=1, value=4, maximum=10)
173
  guidance_scale = gr.Slider(
174
  label="Guidance scale",
@@ -177,69 +168,72 @@ with gr.Blocks(css=css) as demo:
177
  step=0.1,
178
  value=5,
179
  )
180
- correlation = gr.Slider(
181
- label="correlation",
182
- minimum=0.1,
183
- maximum=1.0,
184
- step=0.05,
185
- value=0.6,
186
- )
187
  seed = gr.Slider(minimum=0, maximum=np.iinfo(np.int32).max, label="Seed", interactive=True, randomize=True)
188
 
189
 
190
- with gr.Tab(label="image2image"):
191
- with gr.Row():
192
- with gr.Column():
193
- image = gr.ImageEditor(type="pil", image_mode="L", crop_size=(512, 512))
194
- slider_x_a = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
195
- slider_y_a = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
196
- img2img_type = gr.Radio(["controlnet canny", "ip adapter"], label="", info="", visible=False, value="controlnet canny")
197
- prompt_a = gr.Textbox(label="Prompt")
198
- submit_a = gr.Button("Submit")
199
- with gr.Column():
200
- with gr.Group(elem_id="group"):
201
- x_a = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="x", interactive=False)
202
- y_a = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
203
- output_image_a = gr.Image(elem_id="image_out")
204
- with gr.Row():
205
- generate_butt_a = gr.Button("generate")
206
 
207
- with gr.Accordion(label="advanced options", open=False):
208
- iterations_a = gr.Slider(label = "num iterations", minimum=0, value=200, maximum=300)
209
- steps_a = gr.Slider(label = "num inference steps", minimum=1, value=8, maximum=30)
210
- guidance_scale_a = gr.Slider(
211
- label="Guidance scale",
212
- minimum=0.1,
213
- maximum=10.0,
214
- step=0.1,
215
- value=5,
216
- )
217
- controlnet_conditioning_scale = gr.Slider(
218
- label="controlnet conditioning scale",
219
- minimum=0.5,
220
- maximum=5.0,
221
- step=0.1,
222
- value=0.7,
223
- )
224
- ip_adapter_scale = gr.Slider(
225
- label="ip adapter scale",
226
- minimum=0.5,
227
- maximum=5.0,
228
- step=0.1,
229
- value=0.8,
230
- visible=False
231
- )
232
- seed_a = gr.Slider(minimum=0, maximum=np.iinfo(np.int32).max, label="Seed", interactive=True, randomize=True)
233
 
 
 
 
234
  submit.click(fn=generate,
235
- inputs=[slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y,correlation],
236
- outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image])
237
-
238
- generate_butt.click(fn=update_scales, inputs=[x,y, prompt, seed, steps, guidance_scale, avg_diff_x, avg_diff_y], outputs=[output_image])
239
- generate_butt_a.click(fn=update_scales, inputs=[x_a,y_a, prompt_a, seed_a, steps_a, guidance_scale_a, avg_diff_x, avg_diff_y, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale], outputs=[output_image_a])
240
- submit_a.click(fn=generate,
241
- inputs=[slider_x_a, slider_y_a, prompt_a, seed_a, iterations_a, steps_a, guidance_scale_a, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, correlation, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale],
242
- outputs=[x_a, y_a, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image_a])
243
 
244
 
245
  if __name__ == "__main__":
 
1
  import gradio as gr
2
  import spaces
3
+ from clip_slider_pipeline import CLIPSliderFlux
4
  from diffusers import FluxPipeline
5
  import torch
6
  import time
 
22
  pipe = FluxPipeline.from_pretrained("black-forest-labs/FLUX.1-schnell",
23
  torch_dtype=torch.bfloat16)
24
  #pipe.enable_model_cpu_offload()
25
+ clip_slider = CLIPSliderFlux(pipe, device=torch.device("cuda"))
26
 
27
  base_model = 'black-forest-labs/FLUX.1-schnell'
28
  controlnet_model = 'InstantX/FLUX.1-dev-Controlnet-Canny-alpha'
 
32
 
33
 
34
  @spaces.GPU(duration=200)
35
+ def generate(slider_x, prompt, seed, iterations, steps, guidance_scale,
36
+ x_concept_1, x_concept_2,
37
  avg_diff_x,
 
38
  img2img_type = None, img = None,
39
  controlnet_scale= None, ip_adapter_scale=None,
40
 
 
46
  print("x_concept_1", x_concept_1, "x_concept_2", x_concept_2)
47
 
48
  if not sorted(slider_x) == sorted([x_concept_1, x_concept_2]):
49
+ avg_diff = clip_slider.find_latent_direction(slider_x[0], slider_x[1], num_iterations=iterations).to(torch.float16)
50
  x_concept_1, x_concept_2 = slider_x[0], slider_x[1]
51
+
 
 
 
 
 
52
  print(f"direction time: {end_time - start_time:.2f} ms")
53
 
54
  start_time = time.time()
55
 
56
  if img2img_type=="controlnet canny" and img is not None:
57
  control_img = process_controlnet_img(img)
58
+ image = clip_slider.generate(prompt, guidance_scale=guidance_scale, image=control_img, controlnet_conditioning_scale =controlnet_scale, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=avg_diff, avg_diff_2nd=avg_diff_2nd)
59
  elif img2img_type=="ip adapter" and img is not None:
60
+ image = clip_slider.generate(prompt, guidance_scale=guidance_scale, ip_adapter_image=img, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=avg_diff, avg_diff_2nd=avg_diff_2nd)
61
  else: # text to image
62
+ image = clip_slider.generate(prompt, guidance_scale=guidance_scale, scale=0, scale_2nd=0, seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
63
 
64
  end_time = time.time()
65
  print(f"generation time: {end_time - start_time:.2f} ms")
66
 
67
  comma_concepts_x = ', '.join(slider_x)
 
68
 
69
  avg_diff_x = avg_diff.cpu()
 
70
 
71
+ return gr.update(label=comma_concepts_x, interactive=True), x_concept_1, x_concept_2, avg_diff_x, image
72
 
73
  @spaces.GPU
74
+ def update_scales(x,prompt,seed, steps, guidance_scale,
75
+ avg_diff_x,
76
  img2img_type = None, img = None,
77
  controlnet_scale= None, ip_adapter_scale=None,):
78
  avg_diff = avg_diff_x.cuda()
 
79
  if img2img_type=="controlnet canny" and img is not None:
80
  control_img = process_controlnet_img(img)
81
+ image = t5_slider_controlnet.generate(prompt, guidance_scale=guidance_scale, image=control_img, controlnet_conditioning_scale =controlnet_scale, scale=x, seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
82
  elif img2img_type=="ip adapter" and img is not None:
83
+ image = clip_slider.generate(prompt, guidance_scale=guidance_scale, ip_adapter_image=img, scale=x,seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
84
  else:
85
+ image = clip_slider.generate(prompt, guidance_scale=guidance_scale, scale=x, seed=seed, num_inference_steps=steps, avg_diff=avg_diff)
86
  return image
87
 
88
 
 
94
  img = None):
95
  avg_diff = avg_diff_x.cuda()
96
  avg_diff_2nd = avg_diff_y.cuda()
97
+ image = clip_slider.generate(prompt, scale=x, scale_2nd=y, seed=seed, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
98
  return image
99
 
100
  @spaces.GPU
 
104
  img = None):
105
  avg_diff = avg_diff_x.cuda()
106
  avg_diff_2nd = avg_diff_y.cuda()
107
+ image = clip_slider.generate(prompt, scale=x, scale_2nd=y, seed=seed, num_inference_steps=steps, avg_diff=avg_diff,avg_diff_2nd=avg_diff_2nd)
108
  return image
109
 
110
 
 
137
 
138
  x_concept_1 = gr.State("")
139
  x_concept_2 = gr.State("")
140
+ # y_concept_1 = gr.State("")
141
+ # y_concept_2 = gr.State("")
142
 
143
  avg_diff_x = gr.State()
144
+ #avg_diff_y = gr.State()
145
 
146
  with gr.Tab("text2image"):
147
  with gr.Row():
148
  with gr.Column():
149
+ slider_x = gr.Dropdown(label="Slider concept range", allow_custom_value=True, multiselect=True, max_choices=2)
150
+ #slider_y = gr.Dropdown(label="Slider Y concept range", allow_custom_value=True, multiselect=True, max_choices=2)
151
  prompt = gr.Textbox(label="Prompt")
152
  submit = gr.Button("find directions")
153
  with gr.Column():
154
  with gr.Group(elem_id="group"):
155
+ x = gr.Slider(minimum=-4, value=0, maximum=4, elem_id="x", interactive=False)
156
+ #y = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
157
  output_image = gr.Image(elem_id="image_out")
158
  with gr.Row():
159
  generate_butt = gr.Button("generate")
160
 
161
  with gr.Accordion(label="advanced options", open=False):
162
+ iterations = gr.Slider(label = "num iterations", minimum=0, value=300, maximum=400)
163
  steps = gr.Slider(label = "num inference steps", minimum=1, value=4, maximum=10)
164
  guidance_scale = gr.Slider(
165
  label="Guidance scale",
 
168
  step=0.1,
169
  value=5,
170
  )
171
+ # correlation = gr.Slider(
172
+ # label="correlation",
173
+ # minimum=0.1,
174
+ # maximum=1.0,
175
+ # step=0.05,
176
+ # value=0.6,
177
+ # )
178
  seed = gr.Slider(minimum=0, maximum=np.iinfo(np.int32).max, label="Seed", interactive=True, randomize=True)
179
 
180
 
181
+ # with gr.Tab(label="image2image"):
182
+ # with gr.Row():
183
+ # with gr.Column():
184
+ # image = gr.ImageEditor(type="pil", image_mode="L", crop_size=(512, 512))
185
+ # slider_x_a = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
186
+ # slider_y_a = gr.Dropdown(label="Slider X concept range", allow_custom_value=True, multiselect=True, max_choices=2)
187
+ # img2img_type = gr.Radio(["controlnet canny", "ip adapter"], label="", info="", visible=False, value="controlnet canny")
188
+ # prompt_a = gr.Textbox(label="Prompt")
189
+ # submit_a = gr.Button("Submit")
190
+ # with gr.Column():
191
+ # with gr.Group(elem_id="group"):
192
+ # x_a = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="x", interactive=False)
193
+ # y_a = gr.Slider(minimum=-10, value=0, maximum=10, elem_id="y", interactive=False)
194
+ # output_image_a = gr.Image(elem_id="image_out")
195
+ # with gr.Row():
196
+ # generate_butt_a = gr.Button("generate")
197
 
198
+ # with gr.Accordion(label="advanced options", open=False):
199
+ # iterations_a = gr.Slider(label = "num iterations", minimum=0, value=200, maximum=300)
200
+ # steps_a = gr.Slider(label = "num inference steps", minimum=1, value=8, maximum=30)
201
+ # guidance_scale_a = gr.Slider(
202
+ # label="Guidance scale",
203
+ # minimum=0.1,
204
+ # maximum=10.0,
205
+ # step=0.1,
206
+ # value=5,
207
+ # )
208
+ # controlnet_conditioning_scale = gr.Slider(
209
+ # label="controlnet conditioning scale",
210
+ # minimum=0.5,
211
+ # maximum=5.0,
212
+ # step=0.1,
213
+ # value=0.7,
214
+ # )
215
+ # ip_adapter_scale = gr.Slider(
216
+ # label="ip adapter scale",
217
+ # minimum=0.5,
218
+ # maximum=5.0,
219
+ # step=0.1,
220
+ # value=0.8,
221
+ # visible=False
222
+ # )
223
+ # seed_a = gr.Slider(minimum=0, maximum=np.iinfo(np.int32).max, label="Seed", interactive=True, randomize=True)
224
 
225
+ # submit.click(fn=generate,
226
+ # inputs=[slider_x, slider_y, prompt, seed, iterations, steps, guidance_scale, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y],
227
+ # outputs=[x, y, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image])
228
  submit.click(fn=generate,
229
+ inputs=[slider_x, prompt, seed, iterations, steps, guidance_scale, x_concept_1, x_concept_2, avg_diff_x],
230
+ outputs=[x, x_concept_1, x_concept_2, avg_diff_x, output_image])
231
+
232
+ generate_butt.click(fn=update_scales, inputs=[x, prompt, seed, steps, guidance_scale, avg_diff_x], outputs=[output_image])
233
+ # generate_butt_a.click(fn=update_scales, inputs=[x_a,y_a, prompt_a, seed_a, steps_a, guidance_scale_a, avg_diff_x, avg_diff_y, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale], outputs=[output_image_a])
234
+ # submit_a.click(fn=generate,
235
+ # inputs=[slider_x_a, slider_y_a, prompt_a, seed_a, iterations_a, steps_a, guidance_scale_a, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, img2img_type, image, controlnet_conditioning_scale, ip_adapter_scale],
236
+ # outputs=[x_a, y_a, x_concept_1, x_concept_2, y_concept_1, y_concept_2, avg_diff_x, avg_diff_y, output_image_a])
237
 
238
 
239
  if __name__ == "__main__":