liuganghuggingface's picture
Upload app.py with huggingface_hub
6ad4174 verified
raw
history blame
2.34 kB
import spaces
import gradio as gr
import torch
from rdkit import Chem
from rdkit.Chem import Draw
# from graph_decoder.diffusion_model import GraphDiT
# Load the model
def load_graph_decoder(path='model_labeled'):
# model = GraphDiT(
# model_config_path=f"{path}/config.yaml",
# data_info_path=f"{path}/data.meta.json",
# model_dtype=torch.float32,
# )
# model.init_model(path)
# model.disable_grads()
model = None
return model
model = load_graph_decoder()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
@spaces.GPU
def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
properties = [CH4, CO2, H2, N2, O2]
try:
print('enter generate polymer')
model.to(device)
generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale)
if generated_molecule is not None:
mol = Chem.MolFromSmiles(generated_molecule)
if mol is not None:
standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
img = Draw.MolToImage(mol)
return standardized_smiles, img
except Exception as e:
print(f"Error in generation: {e}")
return "Generation failed", None
# Create the Gradio interface
with gr.Blocks(title="Simplified Polymer Design") as iface:
gr.Markdown("## Polymer Design with GraphDiT")
with gr.Row():
CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")
generate_btn = gr.Button("Generate Polymer")
with gr.Row():
result_smiles = gr.Textbox(label="Generated SMILES")
result_image = gr.Image(label="Molecule Visualization", type="pil")
generate_btn.click(
generate_polymer,
inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
outputs=[result_smiles, result_image]
)
if __name__ == "__main__":
iface.launch()