liuganghuggingface's picture
Upload app.py with huggingface_hub
1037049 verified
raw
history blame
4.93 kB
import spaces
import gradio as gr
import torch
import torch.nn as nn
import random
from rdkit import Chem
from rdkit.Chem import Draw
from graph_decoder.diffusion_model import GraphDiT
ATOM_SYMBOLS = ['C', 'N', 'O', 'H']
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
path = 'model_labeled'
model = GraphDiT(
model_config_path=f"{path}/config.yaml",
data_info_path=f"{path}/data.meta.json",
model_dtype=torch.float32
)
model.to(device)
def generate_random_smiles(length=10):
return ''.join(random.choices(ATOM_SYMBOLS, k=length))
@spaces.GPU
def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
properties = torch.tensor([CH4, CO2, H2, N2, O2], dtype=torch.float32).unsqueeze(0)
print('in generate_polymer')
try:
# Generate a random SMILES string (this is a placeholder)
generated_molecule = generate_random_smiles()
# model.generate(properties, device)
mol = Chem.MolFromSmiles(generated_molecule)
if mol is not None:
standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
img = Draw.MolToImage(mol)
return standardized_smiles, img
except Exception as e:
print(f"Error in generation: {e}")
return "Generation failed", None
# Create the Gradio interface
with gr.Blocks(title="Simplified Polymer Design") as iface:
gr.Markdown("## Polymer Design with Random Neural Network")
with gr.Row():
CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")
generate_btn = gr.Button("Generate Polymer")
with gr.Row():
result_smiles = gr.Textbox(label="Generated SMILES")
result_image = gr.Image(label="Molecule Visualization", type="pil")
generate_btn.click(
generate_polymer,
inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
outputs=[result_smiles, result_image]
)
if __name__ == "__main__":
iface.launch()
# import spaces
# import gradio as gr
# import torch
# from rdkit import Chem
# from rdkit.Chem import Draw
# # from graph_decoder.diffusion_model import GraphDiT
# # Load the model
# def load_graph_decoder(path='model_labeled'):
# model = GraphDiT(
# model_config_path=f"{path}/config.yaml",
# data_info_path=f"{path}/data.meta.json",
# model_dtype=torch.float32,
# )
# model.init_model(path)
# model.disable_grads()
# return model
# # model = load_graph_decoder()
# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# @spaces.GPU
# def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
# properties = [CH4, CO2, H2, N2, O2]
# try:
# model = load_graph_decoder()
# device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# model.to(device)
# print('enter function')
# generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale)
# if generated_molecule is not None:
# mol = Chem.MolFromSmiles(generated_molecule)
# if mol is not None:
# standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
# img = Draw.MolToImage(mol)
# return standardized_smiles, img
# except Exception as e:
# print(f"Error in generation: {e}")
# return "Generation failed", None
# # Create the Gradio interface
# with gr.Blocks(title="Simplified Polymer Design") as iface:
# gr.Markdown("## Polymer Design with GraphDiT")
# with gr.Row():
# CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
# CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
# H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
# N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
# O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
# guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")
# generate_btn = gr.Button("Generate Polymer")
# with gr.Row():
# result_smiles = gr.Textbox(label="Generated SMILES")
# result_image = gr.Image(label="Molecule Visualization", type="pil")
# generate_btn.click(
# generate_polymer,
# inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
# outputs=[result_smiles, result_image]
# )
# if __name__ == "__main__":
# iface.launch()