liuganghuggingface commited on
Commit
6a4bbaf
1 Parent(s): 02d6370

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +29 -348
app.py CHANGED
@@ -1,383 +1,64 @@
1
- import spaces
2
  import gradio as gr
3
-
4
  import torch
5
- import numpy as np
6
- import pandas as pd
7
- import random
8
- import io
9
- import imageio
10
- import os
11
- import tempfile
12
- import atexit
13
- import glob
14
- import csv
15
- from datetime import datetime
16
- import json
17
-
18
  from rdkit import Chem
19
  from rdkit.Chem import Draw
20
-
21
- from evaluator import Evaluator
22
- # from loader import load_graph_decoder
23
-
24
- ### load model start
25
  from graph_decoder.diffusion_model import GraphDiT
26
- def count_parameters(model):
27
- r"""
28
- Returns the number of trainable parameters and number of all parameters in the model.
29
- """
30
- trainable_params, all_param = 0, 0
31
- for param in model.parameters():
32
- num_params = param.numel()
33
- all_param += num_params
34
- if param.requires_grad:
35
- trainable_params += num_params
36
-
37
- return trainable_params, all_param
38
 
 
39
  def load_graph_decoder(path='model_labeled'):
40
- model_config_path = f"{path}/config.yaml"
41
- data_info_path = f"{path}/data.meta.json"
42
-
43
  model = GraphDiT(
44
- model_config_path=model_config_path,
45
- data_info_path=data_info_path,
46
  model_dtype=torch.float32,
47
  )
48
  model.init_model(path)
49
  model.disable_grads()
50
-
51
- trainable_params, all_param = count_parameters(model)
52
- param_stats = "Loaded Graph DiT from {} trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format(
53
- path, trainable_params, all_param, 100 * trainable_params / all_param
54
- )
55
- print(param_stats)
56
  return model
57
- ### load model end
58
-
59
- # Load the CSV data
60
- known_labels = pd.read_csv('data/known_labels.csv')
61
- knwon_smiles = pd.read_csv('data/known_polymers.csv')
62
-
63
- all_properties = ['CH4', 'CO2', 'H2', 'N2', 'O2']
64
-
65
- # Initialize evaluators
66
- evaluators = {prop: Evaluator(f'evaluators/{prop}.joblib', prop) for prop in all_properties}
67
-
68
- # Get min and max values for each property
69
- property_ranges = {prop: (known_labels[prop].min(), known_labels[prop].max()) for prop in all_properties}
70
-
71
- # Create a temporary directory for GIFs
72
- temp_dir = tempfile.mkdtemp(prefix="polymer_gifs_")
73
-
74
- def cleanup_temp_files():
75
- """Clean up temporary GIF files on exit."""
76
- for file in glob.glob(os.path.join(temp_dir, "*.gif")):
77
- try:
78
- os.remove(file)
79
- except Exception as e:
80
- print(f"Error deleting {file}: {e}")
81
- try:
82
- os.rmdir(temp_dir)
83
- except Exception as e:
84
- print(f"Error deleting temporary directory {temp_dir}: {e}")
85
 
86
- # Register the cleanup function to be called on exit
87
- atexit.register(cleanup_temp_files)
88
 
89
- def random_properties():
90
- return known_labels[all_properties].sample(1).values.tolist()[0]
91
-
92
- def load_model(model_choice):
93
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
94
- model = load_graph_decoder(path=model_choice)
95
- return (model, device)
96
-
97
- # Create a flagged folder if it doesn't exist
98
- flagged_folder = "flagged"
99
- os.makedirs(flagged_folder, exist_ok=True)
100
-
101
- def save_interesting_log(smiles, properties, suggested_properties):
102
- """Save interesting polymer data to a CSV file."""
103
- log_file = os.path.join(flagged_folder, "log.csv")
104
- file_exists = os.path.isfile(log_file)
105
-
106
- with open(log_file, 'a', newline='') as csvfile:
107
- fieldnames = ['timestamp', 'smiles'] + all_properties + [f'suggested_{prop}' for prop in all_properties]
108
- writer = csv.DictWriter(csvfile, fieldnames=fieldnames)
109
-
110
- if not file_exists:
111
- writer.writeheader()
112
-
113
- log_data = {
114
- 'timestamp': datetime.now().strftime("%Y-%m-%d %H:%M:%S"),
115
- 'smiles': smiles,
116
- **{prop: value for prop, value in zip(all_properties, properties)},
117
- **{f'suggested_{prop}': value for prop, value in suggested_properties.items()}
118
- }
119
- writer.writerow(log_data)
120
-
121
- @spaces.GPU(duration=75)
122
- def generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
123
- print('in generate_graph')
124
- model, device = model_state
125
-
126
  properties = [CH4, CO2, H2, N2, O2]
127
 
128
- def is_nan_like(x):
129
- return x == 0 or x == '' or (isinstance(x, float) and np.isnan(x))
130
-
131
- properties = [None if is_nan_like(prop) else prop for prop in properties]
132
-
133
- nan_message = "The following gas properties were treated as NaN: "
134
- nan_gases = [gas for gas, prop in zip(all_properties, properties) if prop is None]
135
- nan_message += ", ".join(nan_gases) if nan_gases else "None"
136
-
137
- num_nodes = None if num_nodes == 0 else num_nodes
138
-
139
- for _ in range(repeating_time):
140
- # try:
141
  model.to(device)
142
- generated_molecule, img_list = model.generate(properties, device=device, guide_scale=guidance_scale, num_nodes=num_nodes, number_chain_steps=num_chain_steps)
143
-
144
- # Create GIF if img_list is available
145
- gif_path = None
146
- if img_list and len(img_list) > 0:
147
- imgs = [np.array(pil_img) for pil_img in img_list]
148
- imgs.extend([imgs[-1]] * 10)
149
- gif_path = os.path.join(temp_dir, f"polymer_gen_{random.randint(0, 999999)}.gif")
150
- imageio.mimsave(gif_path, imgs, format='GIF', fps=fps, loop=0)
151
 
152
  if generated_molecule is not None:
153
  mol = Chem.MolFromSmiles(generated_molecule)
154
  if mol is not None:
155
  standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
156
- is_novel = standardized_smiles not in knwon_smiles['SMILES'].values
157
- novelty_status = "Novel (Not in Labeled Set)" if is_novel else "Not Novel (Exists in Labeled Set)"
158
  img = Draw.MolToImage(mol)
159
-
160
- # Evaluate the generated molecule
161
- suggested_properties = {}
162
- for prop, evaluator in evaluators.items():
163
- suggested_properties[prop] = evaluator([standardized_smiles])[0]
164
-
165
- suggested_properties_text = "\n".join([f"**Suggested {prop}:** {value:.2f}" for prop, value in suggested_properties.items()])
166
-
167
- return (
168
- f"**Generated polymer SMILES:** `{standardized_smiles}`\n\n"
169
- f"**{nan_message}**\n\n"
170
- f"**{novelty_status}**\n\n"
171
- f"**Suggested Properties:**\n{suggested_properties_text}",
172
- img,
173
- gif_path,
174
- properties, # Add this
175
- suggested_properties # Add this
176
- )
177
- else:
178
- return (
179
- f"**Generation failed:** Could not generate a valid molecule.\n\n**{nan_message}**",
180
- None,
181
- gif_path,
182
- properties,
183
- None,
184
- )
185
- # except Exception as e:
186
- # print(f"Error in generation: {e}")
187
- # continue
188
 
189
- return f"**Generation failed:** Could not generate a valid molecule after {repeating_time} attempts.\n\n**{nan_message}**", None, None
190
-
191
- def set_random_properties():
192
- return random_properties()
193
-
194
- # Create a mapping of internal names to display names
195
- model_name_mapping = {
196
- "model_all": "Graph DiT (trained on labeled + unlabeled)",
197
- "model_labeled": "Graph DiT (trained on labeled)"
198
- }
199
-
200
- def numpy_to_python(obj):
201
- if isinstance(obj, np.integer):
202
- return int(obj)
203
- elif isinstance(obj, np.floating):
204
- return float(obj)
205
- elif isinstance(obj, np.ndarray):
206
- return obj.tolist()
207
- elif isinstance(obj, list):
208
- return [numpy_to_python(item) for item in obj]
209
- elif isinstance(obj, dict):
210
- return {k: numpy_to_python(v) for k, v in obj.items()}
211
- else:
212
- return obj
213
-
214
- def on_generate(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps):
215
- result = generate_graph(CH4, CO2, H2, N2, O2, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps)
216
- # Check if the generation was successful
217
- if result[0].startswith("**Generated polymer SMILES:**"):
218
- smiles = result[0].split("**Generated polymer SMILES:** `")[1].split("`")[0]
219
- properties = json.dumps(numpy_to_python(result[3]))
220
- suggested_properties = json.dumps(numpy_to_python(result[4]))
221
- # Return the result with an enabled feedback button
222
- return [*result[:3], smiles, properties, suggested_properties, gr.Button(interactive=True)]
223
- else:
224
- # Return the result with a disabled feedback button
225
- return [*result[:3], "", "[]", "[]", gr.Button(interactive=False)]
226
-
227
- def process_feedback(checkbox_value, smiles, properties, suggested_properties):
228
- if checkbox_value:
229
- # Check if properties and suggested_properties are already Python objects
230
- if isinstance(properties, str):
231
- properties = json.loads(properties)
232
- if isinstance(suggested_properties, str):
233
- suggested_properties = json.loads(suggested_properties)
234
-
235
- save_interesting_log(smiles, properties, suggested_properties)
236
- return gr.Textbox(value="Thank you for your feedback! This polymer has been saved to our interesting polymers log.", visible=True)
237
- else:
238
- return gr.Textbox(value="Thank you for your feedback!", visible=True)
239
-
240
- # ADD THIS FUNCTION
241
- def reset_feedback_button():
242
- return gr.Button(interactive=False)
243
-
244
- # Create the Gradio interface using Blocks
245
- with gr.Blocks(title="Polymer Design with GraphDiT") as iface:
246
- # Navigation Bar
247
- with gr.Row(elem_id="navbar"):
248
- gr.Markdown("""
249
- <div style="text-align: center;">
250
- <h1>🔗🔬 Polymer Design with GraphDiT</h1>
251
- <div style="display: flex; gap: 20px; justify-content: center; align-items: center; margin-top: 10px;">
252
- <a href="https://github.com/liugangcode/Graph-DiT" target="_blank" style="display: flex; align-items: center; gap: 5px; text-decoration: none; color: inherit;">
253
- <img src="https://img.icons8.com/ios-glyphs/30/000000/github.png" alt="GitHub" />
254
- <span>View Code</span>
255
- </a>
256
- <a href="https://arxiv.org/abs/2401.13858" target="_blank" style="text-decoration: none; color: inherit;">
257
- 📄 View Paper
258
- </a>
259
- </div>
260
- </div>
261
- """)
262
-
263
- # Main Description
264
- gr.Markdown("""
265
- ## Introduction
266
-
267
- Input the desired gas barrier properties for CH₄, CO₂, H₂, N₂, and O₂ to generate novel polymer structures. The results are visualized as molecular graphs and represented by SMILES strings if they are successfully generated. Note: Gas barrier values set to 0 will be treated as `NaN` (unconditionally). If the generation fails, please retry or increase the number of repetition attempts.
268
- """)
269
-
270
- # Model Selection
271
- model_choice = gr.Radio(
272
- choices=list(model_name_mapping.values()),
273
- label="Model Zoo",
274
- # value="Graph DiT (trained on labeled + unlabeled)"
275
- value="Graph DiT (trained on labeled)"
276
- )
277
-
278
- # Model Description Accordion
279
- with gr.Accordion("🔍 Model Description", open=False):
280
- gr.Markdown("""
281
- ### GraphDiT: Graph Diffusion Transformer
282
-
283
- GraphDiT is a graph diffusion model designed for targeted molecular generation. It employs a conditional diffusion process to iteratively refine molecular structures based on user-specified properties.
284
-
285
- We have collected a labeled polymer database for gas permeability from [Membrane Database](https://research.csiro.au/virtualscreening/membrane-database-polymer-gas-separation-membranes/). Additionally, we utilize unlabeled polymer structures from [PolyInfo](https://polymer.nims.go.jp/).
286
-
287
- The gas permeability ranges from 0 to over ten thousand, with only hundreds of labeled data points, making this task particularly challenging.
288
-
289
- We are actively working on improving the model. We welcome any feedback regarding model usage or suggestions for improvement.
290
-
291
- #### Currently, we have two variants of Graph DiT:
292
- - **Graph DiT (trained on labeled + unlabeled)**: This model uses both labeled and unlabeled data for training, potentially leading to more diverse/novel polymer generation.
293
- - **Graph DiT (trained on labeled)**: This model is trained exclusively on labeled data, which may result in higher validity but potentially less diverse/novel outputs.
294
- """)
295
-
296
- # Citation Accordion
297
- with gr.Accordion("📄 Citation", open=False):
298
- gr.Markdown("""
299
- If you use this model or interface useful, please cite the following paper:
300
- ```bibtex
301
- @article{graphdit2024,
302
- title={Graph Diffusion Transformers for Multi-Conditional Molecular Generation},
303
- author={Liu, Gang and Xu, Jiaxin and Luo, Tengfei and Jiang, Meng},
304
- journal={NeurIPS},
305
- year={2024},
306
- }
307
- ```
308
- """)
309
-
310
- model_state = gr.State(lambda: load_model("model_labeled"))
311
-
312
- with gr.Row():
313
- CH4_input = gr.Slider(0, property_ranges['CH4'][1], value=2.5, label=f"CH₄ (Barrier) [0-{property_ranges['CH4'][1]:.1f}]")
314
- CO2_input = gr.Slider(0, property_ranges['CO2'][1], value=15.4, label=f"CO₂ (Barrier) [0-{property_ranges['CO2'][1]:.1f}]")
315
- H2_input = gr.Slider(0, property_ranges['H2'][1], value=21.0, label=f"H₂ (Barrier) [0-{property_ranges['H2'][1]:.1f}]")
316
- N2_input = gr.Slider(0, property_ranges['N2'][1], value=1.5, label=f"N₂ (Barrier) [0-{property_ranges['N2'][1]:.1f}]")
317
- O2_input = gr.Slider(0, property_ranges['O2'][1], value=2.8, label=f"O₂ (Barrier) [0-{property_ranges['O2'][1]:.1f}]")
318
 
 
 
 
 
319
  with gr.Row():
320
- guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale from Properties")
321
- num_nodes = gr.Slider(0, 50, step=1, value=0, label="Number of Nodes (0 for Random, Larger Graphs Take More Time)")
322
- repeating_time = gr.Slider(1, 10, step=1, value=3, label="Repetition Until Success")
323
- num_chain_steps = gr.Slider(0, 499, step=1, value=50, label="Number of Diffusion Steps to Visualize (Larger Numbers Take More Time)")
324
- fps = gr.Slider(0.25, 10, step=0.25, value=5, label="Frames Per Second")
 
325
 
326
- with gr.Row():
327
- random_btn = gr.Button("🔀 Randomize Properties (from Labeled Data)")
328
- generate_btn = gr.Button("🚀 Generate Polymer")
329
 
330
  with gr.Row():
331
- result_text = gr.Textbox(label="📝 Generation Result")
332
- result_image = gr.Image(label="Final Molecule Visualization", type="pil")
333
- result_gif = gr.Image(label="Generation Process Visualization", type="filepath", format="gif")
334
-
335
- with gr.Row() as feedback_row:
336
- feedback_btn = gr.Button("🌟 I think this polymer is interesting!", visible=True, interactive=False)
337
- feedback_result = gr.Textbox(label="Feedback Result", visible=False)
338
-
339
- # Add model switching functionality
340
- def switch_model(choice):
341
- # Convert display name back to internal name
342
- internal_name = next(key for key, value in model_name_mapping.items() if value == choice)
343
- return load_model(internal_name)
344
-
345
- model_choice.change(switch_model, inputs=[model_choice], outputs=[model_state])
346
-
347
- # Hidden components to store generation data
348
- hidden_smiles = gr.Textbox(visible=False)
349
- hidden_properties = gr.JSON(visible=False)
350
- hidden_suggested_properties = gr.JSON(visible=False)
351
-
352
- # Set up event handlers
353
- random_btn.click(
354
- set_random_properties,
355
- outputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input]
356
- )
357
 
358
  generate_btn.click(
359
- on_generate,
360
- inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale, num_nodes, repeating_time, model_state, num_chain_steps, fps],
361
- outputs=[result_text, result_image, result_gif, hidden_smiles, hidden_properties, hidden_suggested_properties, feedback_btn]
362
- )
363
-
364
- feedback_btn.click(
365
- process_feedback,
366
- inputs=[gr.Checkbox(value=True, visible=False), hidden_smiles, hidden_properties, hidden_suggested_properties],
367
- outputs=[feedback_result]
368
- ).then(
369
- lambda: gr.Button(interactive=False),
370
- outputs=[feedback_btn]
371
  )
372
-
373
- CH4_input.change(reset_feedback_button, outputs=[feedback_btn])
374
- CO2_input.change(reset_feedback_button, outputs=[feedback_btn])
375
- H2_input.change(reset_feedback_button, outputs=[feedback_btn])
376
- N2_input.change(reset_feedback_button, outputs=[feedback_btn])
377
- O2_input.change(reset_feedback_button, outputs=[feedback_btn])
378
- random_btn.click(reset_feedback_button, outputs=[feedback_btn])
379
 
380
- # Launch the interface
381
  if __name__ == "__main__":
382
- # iface.launch(share=True)
383
- iface.launch(share=False)
 
 
1
  import gradio as gr
 
2
  import torch
 
 
 
 
 
 
 
 
 
 
 
 
 
3
  from rdkit import Chem
4
  from rdkit.Chem import Draw
 
 
 
 
 
5
  from graph_decoder.diffusion_model import GraphDiT
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ # Load the model
8
  def load_graph_decoder(path='model_labeled'):
 
 
 
9
  model = GraphDiT(
10
+ model_config_path=f"{path}/config.yaml",
11
+ data_info_path=f"{path}/data.meta.json",
12
  model_dtype=torch.float32,
13
  )
14
  model.init_model(path)
15
  model.disable_grads()
 
 
 
 
 
 
16
  return model
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
 
18
+ model = load_graph_decoder()
19
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
20
 
21
+ def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
  properties = [CH4, CO2, H2, N2, O2]
23
 
24
+ try:
 
 
 
 
 
 
 
 
 
 
 
 
25
  model.to(device)
26
+ generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale)
 
 
 
 
 
 
 
 
27
 
28
  if generated_molecule is not None:
29
  mol = Chem.MolFromSmiles(generated_molecule)
30
  if mol is not None:
31
  standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
 
 
32
  img = Draw.MolToImage(mol)
33
+ return standardized_smiles, img
34
+ except Exception as e:
35
+ print(f"Error in generation: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
 
37
+ return "Generation failed", None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
38
 
39
+ # Create the Gradio interface
40
+ with gr.Blocks(title="Simplified Polymer Design") as iface:
41
+ gr.Markdown("## Polymer Design with GraphDiT")
42
+
43
  with gr.Row():
44
+ CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
45
+ CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
46
+ H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
47
+ N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
48
+ O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
49
+ guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")
50
 
51
+ generate_btn = gr.Button("Generate Polymer")
 
 
52
 
53
  with gr.Row():
54
+ result_smiles = gr.Textbox(label="Generated SMILES")
55
+ result_image = gr.Image(label="Molecule Visualization", type="pil")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
56
 
57
  generate_btn.click(
58
+ generate_polymer,
59
+ inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
60
+ outputs=[result_smiles, result_image]
 
 
 
 
 
 
 
 
 
61
  )
 
 
 
 
 
 
 
62
 
 
63
  if __name__ == "__main__":
64
+ iface.launch()