mridulk commited on
Commit
148daee
1 Parent(s): ed13e89

updated monkeys app and fish app

Browse files
Files changed (3) hide show
  1. app.py +4 -18
  2. monkeys_app.py +345 -0
  3. monkeys_masking_6levels.jpg +0 -0
app.py CHANGED
@@ -61,11 +61,6 @@ def generate_image(fish_name, masking_level_input,
61
 
62
  fish_name = fish_name.lower()
63
 
64
- # ckpt_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/checkpoints/epoch=000119.ckpt'
65
- # config_path = '/globalscratch/mridul/ldm/final_runs_eccv/fishes/2024-03-01T23-15-36_HLE_days3/configs/2024-03-01T23-15-36-project.yaml'
66
-
67
-
68
-
69
 
70
  label_to_class_mapping = {0: 'Alosa-chrysochloris', 1: 'Carassius-auratus', 2: 'Cyprinus-carpio', 3: 'Esox-americanus',
71
  4: 'Gambusia-affinis', 5: 'Lepisosteus-osseus', 6: 'Lepisosteus-platostomus', 7: 'Lepomis-auritus', 8: 'Lepomis-cyanellus',
@@ -82,19 +77,13 @@ def generate_image(fish_name, masking_level_input,
82
  if value == class_name:
83
  return key
84
 
85
- # config = OmegaConf.load(config_path) # TODO: Optionally download from same location as ckpt and chnage this logic
86
- # model = load_model_from_config(config, ckpt_path) # TODO: check path
87
-
88
- # device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
89
- # model = model.to(device)
90
 
91
  if opt.plms:
92
  sampler = PLMSSampler(model)
93
  else:
94
  sampler = DDIMSampler(model)
95
 
96
- # os.makedirs(opt.outdir, exist_ok=True)
97
- # outpath = opt.outdir
98
 
99
  prompt = opt.prompt
100
  all_images = []
@@ -105,12 +94,11 @@ def generate_image(fish_name, masking_level_input,
105
  class_to_node_dict = pickle.load(pickle_file)
106
 
107
  class_to_node_dict = {key.lower(): value for key, value in class_to_node_dict.items()}
108
-
109
- # sample_path = os.path.join(outpath, opt.output_dir_name)
110
- # os.makedirs(sample_path, exist_ok=True)
111
- # base_count = len(os.listdir(sample_path))
112
 
113
  prompt = class_to_node_dict[fish_name]
 
 
114
  if swap_fish_name:
115
  swap_fish_name = swap_fish_name.lower()
116
  swap_level = int(swap_level_input.split(" ")[-1]) - 1
@@ -118,8 +106,6 @@ def generate_image(fish_name, masking_level_input,
118
 
119
  swap_fish_split = swap_fish[0].split(',')
120
  fish_name_split = prompt[0].split(',')
121
- # print(swap_fish_split, fish_name_split)
122
- # print(swap_level)
123
  fish_name_split[swap_level] = swap_fish_split[swap_level]
124
 
125
  prompt = [','.join(fish_name_split)]
 
61
 
62
  fish_name = fish_name.lower()
63
 
 
 
 
 
 
64
 
65
  label_to_class_mapping = {0: 'Alosa-chrysochloris', 1: 'Carassius-auratus', 2: 'Cyprinus-carpio', 3: 'Esox-americanus',
66
  4: 'Gambusia-affinis', 5: 'Lepisosteus-osseus', 6: 'Lepisosteus-platostomus', 7: 'Lepomis-auritus', 8: 'Lepomis-cyanellus',
 
77
  if value == class_name:
78
  return key
79
 
 
 
 
 
 
80
 
81
  if opt.plms:
82
  sampler = PLMSSampler(model)
83
  else:
84
  sampler = DDIMSampler(model)
85
 
86
+
 
87
 
88
  prompt = opt.prompt
89
  all_images = []
 
94
  class_to_node_dict = pickle.load(pickle_file)
95
 
96
  class_to_node_dict = {key.lower(): value for key, value in class_to_node_dict.items()}
97
+
 
 
 
98
 
99
  prompt = class_to_node_dict[fish_name]
100
+
101
+ ### Trait Swapping
102
  if swap_fish_name:
103
  swap_fish_name = swap_fish_name.lower()
104
  swap_level = int(swap_level_input.split(" ")[-1]) - 1
 
106
 
107
  swap_fish_split = swap_fish[0].split(',')
108
  fish_name_split = prompt[0].split(',')
 
 
109
  fish_name_split[swap_level] = swap_fish_split[swap_level]
110
 
111
  prompt = [','.join(fish_name_split)]
monkeys_app.py ADDED
@@ -0,0 +1,345 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import gradio as gr
3
+
4
+
5
+ import argparse, os, sys, glob
6
+ import torch
7
+ import pickle
8
+ import numpy as np
9
+ from omegaconf import OmegaConf
10
+ from PIL import Image
11
+ from tqdm import tqdm, trange
12
+ from einops import rearrange
13
+ from torchvision.utils import make_grid
14
+
15
+ from ldm.util import instantiate_from_config
16
+ from ldm.models.diffusion.ddim import DDIMSampler
17
+ from ldm.models.diffusion.plms import PLMSSampler
18
+
19
+
20
+ def load_model_from_config(config, ckpt, verbose=False):
21
+ print(f"Loading model from {ckpt}")
22
+ # pl_sd = torch.load(ckpt, map_location="cpu")
23
+ pl_sd = torch.load(ckpt)#, map_location="cpu")
24
+ sd = pl_sd["state_dict"]
25
+ model = instantiate_from_config(config.model)
26
+ m, u = model.load_state_dict(sd, strict=False)
27
+ if len(m) > 0 and verbose:
28
+ print("missing keys:")
29
+ print(m)
30
+ if len(u) > 0 and verbose:
31
+ print("unexpected keys:")
32
+ print(u)
33
+
34
+ model.cuda()
35
+ model.eval()
36
+ return model
37
+
38
+
39
+ def masking_embed(embedding, levels=1):
40
+ """
41
+ size of embedding - nx1xd, n: number of samples, d - 512
42
+ replacing the last 128*levels from the embedding
43
+ """
44
+ replace_size = 128*levels
45
+ random_noise = torch.randn(embedding.shape[0], embedding.shape[1], replace_size)
46
+ embedding[:, :, -replace_size:] = random_noise
47
+ return embedding
48
+
49
+
50
+ # LOAD MODEL GLOBALLY
51
+ config_path = '/globalscratch/mridul/ldm/monkeys/2024-04-03T19-49-17_HLE_days1_lr1e-6_6levels/configs/2024-04-03T19-49-17-project.yaml'
52
+ ckpt_path = '/globalscratch/mridul/ldm/monkeys/2024-04-03T19-49-17_HLE_days1_lr1e-6_6levels/checkpoints/epoch=000335.ckpt'
53
+ config = OmegaConf.load(config_path) # TODO: Optionally download from same location as ckpt and chnage this logic
54
+ model = load_model_from_config(config, ckpt_path) # TODO: check path
55
+
56
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
57
+ model = model.to(device)
58
+
59
+ # def generate_image(fish_name, masking_level_input,
60
+ # swap_fish_name, swap_level_input):
61
+
62
+ def generate_image(fish_name,
63
+ swap_fish_name, swap_level_input):
64
+
65
+ fish_name = fish_name.lower()
66
+
67
+
68
+ label_to_class_mapping = {0: 'Alosa-chrysochloris', 1: 'Carassius-auratus', 2: 'Cyprinus-carpio', 3: 'Esox-americanus',
69
+ 4: 'Gambusia-affinis', 5: 'Lepisosteus-osseus', 6: 'Lepisosteus-platostomus', 7: 'Lepomis-auritus', 8: 'Lepomis-cyanellus',
70
+ 9: 'Lepomis-gibbosus', 10: 'Lepomis-gulosus', 11: 'Lepomis-humilis', 12: 'Lepomis-macrochirus', 13: 'Lepomis-megalotis',
71
+ 14: 'Lepomis-microlophus', 15: 'Morone-chrysops', 16: 'Morone-mississippiensis', 17: 'Notropis-atherinoides',
72
+ 18: 'Notropis-blennius', 19: 'Notropis-boops', 20: 'Notropis-buccatus', 21: 'Notropis-buchanani', 22: 'Notropis-dorsalis',
73
+ 23: 'Notropis-hudsonius', 24: 'Notropis-leuciodus', 25: 'Notropis-nubilus', 26: 'Notropis-percobromus',
74
+ 27: 'Notropis-stramineus', 28: 'Notropis-telescopus', 29: 'Notropis-texanus', 30: 'Notropis-volucellus',
75
+ 31: 'Notropis-wickliffi', 32: 'Noturus-exilis', 33: 'Noturus-flavus', 34: 'Noturus-gyrinus', 35: 'Noturus-miurus',
76
+ 36: 'Noturus-nocturnus', 37: 'Phenacobius-mirabilis'}
77
+
78
+ def get_label_from_class(class_name):
79
+ for key, value in label_to_class_mapping.items():
80
+ if value == class_name:
81
+ return key
82
+
83
+
84
+ if opt.plms:
85
+ sampler = PLMSSampler(model)
86
+ else:
87
+ sampler = DDIMSampler(model)
88
+
89
+
90
+
91
+ prompt = opt.prompt
92
+ all_images = []
93
+ labels = []
94
+
95
+ # class_to_node = '/fastscratch/mridul/fishes/class_to_ancestral_label.pkl'
96
+ class_to_node = '/projects/ml4science/mridul/data/monkeys_dataset/monkey_HLE_labels_6levels.pkl'
97
+ with open(class_to_node, 'rb') as pickle_file:
98
+ class_to_node_dict = pickle.load(pickle_file)
99
+
100
+ class_to_node_dict = {key.lower(): value for key, value in class_to_node_dict.items()}
101
+
102
+
103
+ prompt = class_to_node_dict[fish_name]
104
+
105
+ ### Trait Swapping
106
+ if swap_fish_name:
107
+ swap_fish_name = swap_fish_name.lower()
108
+ swap_level = int(swap_level_input.split(" ")[-1]) - 1
109
+ swap_fish = class_to_node_dict[swap_fish_name]
110
+
111
+ swap_fish_split = swap_fish[0].split(',')
112
+ fish_name_split = prompt[0].split(',')
113
+ fish_name_split[swap_level] = swap_fish_split[swap_level]
114
+
115
+ prompt = [','.join(fish_name_split)]
116
+
117
+ all_samples=list()
118
+ with torch.no_grad():
119
+ with model.ema_scope():
120
+ uc = None
121
+ for n in trange(opt.n_iter, desc="Sampling"):
122
+
123
+ all_prompts = opt.n_samples * (prompt)
124
+ all_prompts = [tuple(all_prompts)]
125
+ c = model.get_learned_conditioning({'class_to_node': all_prompts})
126
+ # if masking_level_input != "None":
127
+ # masked_level = int(masking_level_input.split(" ")[-1])
128
+ # masked_level = 4-masked_level
129
+ # c = masking_embed(c, levels=masked_level)
130
+ shape = [3, 64, 64]
131
+ samples_ddim, _ = sampler.sample(S=opt.ddim_steps,
132
+ conditioning=c,
133
+ batch_size=opt.n_samples,
134
+ shape=shape,
135
+ verbose=False,
136
+ unconditional_guidance_scale=opt.scale,
137
+ unconditional_conditioning=uc,
138
+ eta=opt.ddim_eta)
139
+
140
+ x_samples_ddim = model.decode_first_stage(samples_ddim)
141
+ x_samples_ddim = torch.clamp((x_samples_ddim+1.0)/2.0, min=0.0, max=1.0)
142
+
143
+ all_samples.append(x_samples_ddim)
144
+
145
+ ###### to make grid
146
+ # additionally, save as grid
147
+ grid = torch.stack(all_samples, 0)
148
+ grid = rearrange(grid, 'n b c h w -> (n b) c h w')
149
+ grid = make_grid(grid, nrow=opt.n_samples)
150
+
151
+ # to image
152
+ grid = 255. * rearrange(grid, 'c h w -> h w c').cpu().numpy()
153
+ final_image = Image.fromarray(grid.astype(np.uint8))
154
+ # final_image.save(os.path.join(sample_path, f'{class_name.replace(" ", "-")}.png'))
155
+
156
+ return final_image
157
+
158
+
159
+ if __name__ == "__main__":
160
+ parser = argparse.ArgumentParser()
161
+
162
+ parser.add_argument(
163
+ "--prompt",
164
+ type=str,
165
+ nargs="?",
166
+ default="a painting of a virus monster playing guitar",
167
+ help="the prompt to render"
168
+ )
169
+
170
+ parser.add_argument(
171
+ "--outdir",
172
+ type=str,
173
+ nargs="?",
174
+ help="dir to write results to",
175
+ default="outputs/txt2img-samples"
176
+ )
177
+ parser.add_argument(
178
+ "--ddim_steps",
179
+ type=int,
180
+ default=200,
181
+ help="number of ddim sampling steps",
182
+ )
183
+
184
+ parser.add_argument(
185
+ "--plms",
186
+ action='store_true',
187
+ help="use plms sampling",
188
+ )
189
+
190
+ parser.add_argument(
191
+ "--ddim_eta",
192
+ type=float,
193
+ default=1.0,
194
+ help="ddim eta (eta=0.0 corresponds to deterministic sampling",
195
+ )
196
+ parser.add_argument(
197
+ "--n_iter",
198
+ type=int,
199
+ default=1,
200
+ help="sample this often",
201
+ )
202
+
203
+ parser.add_argument(
204
+ "--H",
205
+ type=int,
206
+ default=256,
207
+ help="image height, in pixel space",
208
+ )
209
+
210
+ parser.add_argument(
211
+ "--W",
212
+ type=int,
213
+ default=256,
214
+ help="image width, in pixel space",
215
+ )
216
+
217
+ parser.add_argument(
218
+ "--n_samples",
219
+ type=int,
220
+ default=1,
221
+ help="how many samples to produce for the given prompt",
222
+ )
223
+
224
+ parser.add_argument(
225
+ "--output_dir_name",
226
+ type=str,
227
+ default='default_file',
228
+ help="name of folder",
229
+ )
230
+
231
+ parser.add_argument(
232
+ "--postfix",
233
+ type=str,
234
+ default='',
235
+ help="name of folder",
236
+ )
237
+
238
+ parser.add_argument(
239
+ "--scale",
240
+ type=float,
241
+ # default=5.0,
242
+ default=1.0,
243
+ help="unconditional guidance scale: eps = eps(x, empty) + scale * (eps(x, cond) - eps(x, empty))",
244
+ )
245
+ opt = parser.parse_args()
246
+
247
+ title = "🎞️ Phylo Diffusion - Generating Monkey Images Tool"
248
+ description = "Write the Species name to generate an image for.\n For Trait Masking: Specify the Level information as well"
249
+
250
+
251
+ def load_example(prompt, level, option, components):
252
+ components['prompt_input'].value = prompt
253
+ components['masking_level_input'].value = level
254
+ # components['option'].value = option
255
+
256
+ def setup_interface():
257
+ with gr.Blocks() as demo:
258
+
259
+ gr.Markdown("# Phylo Diffusion - Generating Fish Images Tool")
260
+ gr.Markdown("### Write the Species name to generate a fish image")
261
+ gr.Markdown("### 1. Trait Masking: Specify the Level information as well")
262
+ gr.Markdown("### 2. Trait Swapping: Specify the species name to swap trait with at also at what level")
263
+
264
+ with gr.Row():
265
+ with gr.Column():
266
+ gr.Markdown("## Generate Images Based on Prompts")
267
+ gr.Markdown("Enter a prompt to generate an image:")
268
+ prompt_input = gr.Textbox(label="Species Name")
269
+ # gr.Markdown("Trait Masking")
270
+ # with gr.Row():
271
+ # masking_level_input = gr.Dropdown(label="Select Ancestral Level", choices=["None", "Level 3", "Level 2"], value="None")
272
+ # masking_level_input = "None"
273
+
274
+ gr.Markdown("Trait Swapping")
275
+ with gr.Row():
276
+ swap_fish_name = gr.Textbox(label="Species Name to swap trait with:")
277
+ swap_level_input = gr.Dropdown(label="Level of swapping", choices=["Level 5","Level 4","Level 3", "Level 2"], value="Level 5")
278
+ submit_button = gr.Button("Generate")
279
+ gr.Markdown("## Phylogeny Tree")
280
+ architecture_image = "monkeys_masking_6levels.jpg" # Update this with the actual path
281
+ gr.Image(value=architecture_image, label="Phylogeny Tree")
282
+
283
+ with gr.Column():
284
+
285
+ gr.Markdown("## Generated Image")
286
+ output_image = gr.Image(label="Generated Image", width=256, height=256)
287
+
288
+
289
+ # # Place to put example buttons
290
+ # gr.Markdown("## Select an example:")
291
+ # examples = [
292
+ # ("Gambusia Affinis", "None", "", "Level 3"),
293
+ # ("Lepomis Auritus", "None", "", "Level 3"),
294
+ # ("Lepomis Auritus", "Level 3", "", "Level 3"),
295
+ # ("Noturus nocturnus", "None", "Notropis dorsalis", "Level 2")]
296
+
297
+ # for text, level, swap_text, swap_level in examples:
298
+ # if level == "None" and swap_text == "":
299
+ # button = gr.Button(f"Species: {text}")
300
+ # elif level != "None":
301
+ # button = gr.Button(f"Species: {text} | Masking: {level}")
302
+ # elif swap_text != "":
303
+ # button = gr.Button(f"Species: {text} | Swapping with {swap_text} at {swap_level} ")
304
+ # button.click(
305
+ # fn=lambda text=text, level=level, swap_text=swap_text, swap_level=swap_level: (text, level, swap_text, swap_level),
306
+ # inputs=[],
307
+ # outputs=[prompt_input, masking_level_input, swap_fish_name, swap_level_input]
308
+ # )
309
+
310
+
311
+ # Display an image of the architecture
312
+
313
+ # submit_button.click(
314
+ # fn=generate_image,
315
+ # inputs=[prompt_input,
316
+ # swap_fish_name, swap_level_input],
317
+ # outputs=output_image
318
+ # )
319
+
320
+
321
+
322
+ submit_button.click(
323
+ fn=generate_image,
324
+ inputs=[prompt_input,
325
+ # masking_level_input,
326
+ swap_fish_name, swap_level_input],
327
+ outputs=output_image
328
+ )
329
+
330
+ return demo
331
+
332
+ # # Launch the interface
333
+ # iface = setup_interface()
334
+
335
+ # iface = gr.Interface(
336
+ # fn=generate_image,
337
+ # inputs=gr.Textbox(label="Prompt"),
338
+ # outputs=[
339
+ # gr.Image(label="Generated Image"),
340
+ # ]
341
+ # )
342
+
343
+ iface = setup_interface()
344
+
345
+ iface.launch(share=True)
monkeys_masking_6levels.jpg ADDED