mridulk commited on
Commit
ed13e89
1 Parent(s): 6226078

updated the app

Browse files
Files changed (2) hide show
  1. app.py +64 -15
  2. phylogeny_tree.jpg +0 -0
app.py CHANGED
@@ -46,13 +46,23 @@ def masking_embed(embedding, levels=1):
46
  embedding[:, :, -replace_size:] = random_noise
47
  return embedding
48
 
 
 
 
 
 
 
 
 
 
 
49
  def generate_image(fish_name, masking_level_input,
50
  swap_fish_name, swap_level_input):
51
 
52
  fish_name = fish_name.lower()
53
 
54
- ckpt_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/checkpoints/epoch=000119.ckpt'
55
- config_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/configs/2024-03-01T23-15-36-project.yaml'
56
 
57
 
58
 
@@ -72,19 +82,19 @@ def generate_image(fish_name, masking_level_input,
72
  if value == class_name:
73
  return key
74
 
75
- config = OmegaConf.load(config_path) # TODO: Optionally download from same location as ckpt and chnage this logic
76
- model = load_model_from_config(config, ckpt_path) # TODO: check path
77
 
78
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
79
- model = model.to(device)
80
 
81
  if opt.plms:
82
  sampler = PLMSSampler(model)
83
  else:
84
  sampler = DDIMSampler(model)
85
 
86
- os.makedirs(opt.outdir, exist_ok=True)
87
- outpath = opt.outdir
88
 
89
  prompt = opt.prompt
90
  all_images = []
@@ -96,12 +106,13 @@ def generate_image(fish_name, masking_level_input,
96
 
97
  class_to_node_dict = {key.lower(): value for key, value in class_to_node_dict.items()}
98
 
99
- sample_path = os.path.join(outpath, opt.output_dir_name)
100
- os.makedirs(sample_path, exist_ok=True)
101
- base_count = len(os.listdir(sample_path))
102
 
103
  prompt = class_to_node_dict[fish_name]
104
  if swap_fish_name:
 
105
  swap_level = int(swap_level_input.split(" ")[-1]) - 1
106
  swap_fish = class_to_node_dict[swap_fish_name]
107
 
@@ -243,12 +254,26 @@ if __name__ == "__main__":
243
  )
244
  opt = parser.parse_args()
245
 
 
 
 
 
 
 
 
 
246
 
247
  def setup_interface():
248
  with gr.Blocks() as demo:
 
 
 
 
 
 
249
  with gr.Row():
250
  with gr.Column():
251
- gr.Markdown("### Generate Images Based on Prompts")
252
  gr.Markdown("Enter a prompt to generate an image:")
253
  prompt_input = gr.Textbox(label="Species Name")
254
  gr.Markdown("Trait Masking")
@@ -261,13 +286,37 @@ if __name__ == "__main__":
261
  swap_fish_name = gr.Textbox(label="Species Name to swap trait with:")
262
  swap_level_input = gr.Dropdown(label="Level of swapping", choices=["Level 3", "Level 2"], value="Level 3")
263
  submit_button = gr.Button("Generate")
264
- gr.Markdown("### Phylogeny Tree")
265
  architecture_image = "phylogeny_tree.jpg" # Update this with the actual path
266
  gr.Image(value=architecture_image, label="Phylogeny Tree")
267
 
268
  with gr.Column():
269
- gr.Markdown("### Generated Image")
270
- output_image = gr.Image(label="Generated Image")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
271
 
272
  # Display an image of the architecture
273
 
 
46
  embedding[:, :, -replace_size:] = random_noise
47
  return embedding
48
 
49
+
50
+ # LOAD MODEL GLOBALLY
51
+ ckpt_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/checkpoints/epoch=000119.ckpt'
52
+ config_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/configs/2024-03-01T23-15-36-project.yaml'
53
+ config = OmegaConf.load(config_path) # TODO: Optionally download from same location as ckpt and chnage this logic
54
+ model = load_model_from_config(config, ckpt_path) # TODO: check path
55
+
56
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
57
+ model = model.to(device)
58
+
59
  def generate_image(fish_name, masking_level_input,
60
  swap_fish_name, swap_level_input):
61
 
62
  fish_name = fish_name.lower()
63
 
64
+ # ckpt_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/checkpoints/epoch=000119.ckpt'
65
+ # config_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/configs/2024-03-01T23-15-36-project.yaml'
66
 
67
 
68
 
 
82
  if value == class_name:
83
  return key
84
 
85
+ # config = OmegaConf.load(config_path) # TODO: Optionally download from same location as ckpt and chnage this logic
86
+ # model = load_model_from_config(config, ckpt_path) # TODO: check path
87
 
88
+ # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
89
+ # model = model.to(device)
90
 
91
  if opt.plms:
92
  sampler = PLMSSampler(model)
93
  else:
94
  sampler = DDIMSampler(model)
95
 
96
+ # os.makedirs(opt.outdir, exist_ok=True)
97
+ # outpath = opt.outdir
98
 
99
  prompt = opt.prompt
100
  all_images = []
 
106
 
107
  class_to_node_dict = {key.lower(): value for key, value in class_to_node_dict.items()}
108
 
109
+ # sample_path = os.path.join(outpath, opt.output_dir_name)
110
+ # os.makedirs(sample_path, exist_ok=True)
111
+ # base_count = len(os.listdir(sample_path))
112
 
113
  prompt = class_to_node_dict[fish_name]
114
  if swap_fish_name:
115
+ swap_fish_name = swap_fish_name.lower()
116
  swap_level = int(swap_level_input.split(" ")[-1]) - 1
117
  swap_fish = class_to_node_dict[swap_fish_name]
118
 
 
254
  )
255
  opt = parser.parse_args()
256
 
257
+ title = "🎞️ Phylo Diffusion - Generating Fish Images Tool"
258
+ description = "Write the Species name to generate an image for.\n For Trait Masking: Specify the Level information as well"
259
+
260
+
261
+ def load_example(prompt, level, option, components):
262
+ components['prompt_input'].value = prompt
263
+ components['masking_level_input'].value = level
264
+ # components['option'].value = option
265
 
266
  def setup_interface():
267
  with gr.Blocks() as demo:
268
+
269
+ gr.Markdown("# Phylo Diffusion - Generating Fish Images Tool")
270
+ gr.Markdown("### Write the Species name to generate a fish image")
271
+ gr.Markdown("### 1. Trait Masking: Specify the Level information as well")
272
+ gr.Markdown("### 2. Trait Swapping: Specify the species name to swap trait with at also at what level")
273
+
274
  with gr.Row():
275
  with gr.Column():
276
+ gr.Markdown("## Generate Images Based on Prompts")
277
  gr.Markdown("Enter a prompt to generate an image:")
278
  prompt_input = gr.Textbox(label="Species Name")
279
  gr.Markdown("Trait Masking")
 
286
  swap_fish_name = gr.Textbox(label="Species Name to swap trait with:")
287
  swap_level_input = gr.Dropdown(label="Level of swapping", choices=["Level 3", "Level 2"], value="Level 3")
288
  submit_button = gr.Button("Generate")
289
+ gr.Markdown("## Phylogeny Tree")
290
  architecture_image = "phylogeny_tree.jpg" # Update this with the actual path
291
  gr.Image(value=architecture_image, label="Phylogeny Tree")
292
 
293
  with gr.Column():
294
+
295
+ gr.Markdown("## Generated Image")
296
+ output_image = gr.Image(label="Generated Image", width=512, height=512)
297
+
298
+
299
+ # Place to put example buttons
300
+ gr.Markdown("## Select an example:")
301
+ examples = [
302
+ ("Gambusia Affinis", "None", "", "Level 3"),
303
+ ("Lepomis Auritus", "None", "", "Level 3"),
304
+ ("Lepomis Auritus", "Level 3", "", "Level 3"),
305
+ ("Noturus nocturnus", "None", "Notropis dorsalis", "Level 2")]
306
+
307
+ for text, level, swap_text, swap_level in examples:
308
+ if level == "None" and swap_text == "":
309
+ button = gr.Button(f"Species: {text}")
310
+ elif level != "None":
311
+ button = gr.Button(f"Species: {text} | Masking: {level}")
312
+ elif swap_text != "":
313
+ button = gr.Button(f"Species: {text} | Swapping with {swap_text} at {swap_level} ")
314
+ button.click(
315
+ fn=lambda text=text, level=level, swap_text=swap_text, swap_level=swap_level: (text, level, swap_text, swap_level),
316
+ inputs=[],
317
+ outputs=[prompt_input, masking_level_input, swap_fish_name, swap_level_input]
318
+ )
319
+
320
 
321
  # Display an image of the architecture
322
 
phylogeny_tree.jpg CHANGED