Wauplin HF staff commited on
Commit
9169c11
1 Parent(s): 77d4381

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -4
app.py CHANGED
@@ -14,6 +14,7 @@ import copy
14
  import gc
15
  import pickle
16
  import spaces
 
17
 
18
  lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
19
 
@@ -93,7 +94,7 @@ div#share-btn-container > div {flex-direction: row;background: black;align-items
93
  original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
94
 
95
  @spaces.GPU
96
- def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, seed=-1):
97
 
98
  repo_id_1 = shuffled_items[0]['repo']
99
  repo_id_2 = shuffled_items[1]['repo']
@@ -134,6 +135,18 @@ def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lor
134
  seed = random.randint(0, 2147483647)
135
  generator = torch.Generator(device="cuda").manual_seed(seed)
136
  image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768, generator=generator).images[0]
 
 
 
 
 
 
 
 
 
 
 
 
137
  return image, gr.update(visible=True), seed, gr.update(visible=True, interactive=True), gr.update(visible=False), gr.update(visible=True, interactive=True), gr.update(visible=False)
138
 
139
  def get_description(item):
@@ -176,7 +189,7 @@ def save_preferences(lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, g
176
  def hide_post_gen_info():
177
  return gr.update(visible=False)
178
 
179
- with gr.Blocks(css=css) as demo:
180
  shuffled_items = gr.State()
181
  title = gr.HTML(
182
  '''<h1>LoRA Roulette 🎰</h1>
@@ -245,5 +258,11 @@ with gr.Blocks(css=css) as demo:
245
  thumbs_down.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("down"), last_used_seed], outputs=[thumbs_down, thumbs_down_clicked, thumbs_up])
246
  share_button.click(None, [], [], _js=share_js)
247
 
248
- demo.queue(concurrency_count=2)
249
- demo.launch()
 
 
 
 
 
 
 
14
  import gc
15
  import pickle
16
  import spaces
17
+ import user_history
18
 
19
  lora_list = hf_hub_download(repo_id="multimodalart/LoraTheExplorer", filename="sdxl_loras.json", repo_type="space")
20
 
 
94
  original_pipe = DiffusionPipeline.from_pretrained("stabilityai/stable-diffusion-xl-base-1.0", torch_dtype=torch.float16)
95
 
96
  @spaces.GPU
97
+ def merge_and_run(prompt, negative_prompt, shuffled_items, lora_1_scale=0.5, lora_2_scale=0.5, seed=-1, profile: gr.OAuthProfile | None=None):
98
 
99
  repo_id_1 = shuffled_items[0]['repo']
100
  repo_id_2 = shuffled_items[1]['repo']
 
135
  seed = random.randint(0, 2147483647)
136
  generator = torch.Generator(device="cuda").manual_seed(seed)
137
  image = pipe(prompt=prompt, negative_prompt=negative_prompt, num_inference_steps=20, width=768, height=768, generator=generator).images[0]
138
+
139
+ # save generated images (if logged in)
140
+ user_history.save_image(label=prompt, image=image, profile=profile, metadata={
141
+ "prompt": prompt,
142
+ "negative_prompt": negative_prompt,
143
+ "lora_1_repo_id": shuffled_items[0]['repo'],
144
+ "lora_2_repo_id": shuffled_items[1]['repo'],
145
+ "lora_1_scale": lora_1_scale,
146
+ "lora_2_scale": lora_2_scale,
147
+ "seed": seed,
148
+ })
149
+
150
  return image, gr.update(visible=True), seed, gr.update(visible=True, interactive=True), gr.update(visible=False), gr.update(visible=True, interactive=True), gr.update(visible=False)
151
 
152
  def get_description(item):
 
189
  def hide_post_gen_info():
190
  return gr.update(visible=False)
191
 
192
+ with gr.Blocks() as demo:
193
  shuffled_items = gr.State()
194
  title = gr.HTML(
195
  '''<h1>LoRA Roulette 🎰</h1>
 
258
  thumbs_down.click(save_preferences, inputs=[lora_1_id, lora_1_scale, lora_2_id, lora_2_scale, prompt, output_image, gr.State("down"), last_used_seed], outputs=[thumbs_down, thumbs_down_clicked, thumbs_up])
259
  share_button.click(None, [], [], _js=share_js)
260
 
261
+ with gr.Blocks(css=css) as demo_with_history:
262
+ with gr.Tab("Lora Roulette"):
263
+ demo.render()
264
+ with gr.Tab("Past generations"):
265
+ user_history.render()
266
+
267
+ demo_with_history.queue(concurrency_count=2)
268
+ demo_with_history.launch()