File size: 2,321 Bytes
7afa2e2
e9af19b
 
750cfac
 
26a9b58
e9af19b
6a4bbaf
e9af19b
26a9b58
 
 
 
 
 
 
 
e9af19b
491cb22
6a4bbaf
 
491cb22
7afa2e2
6a4bbaf
e9af19b
 
6a4bbaf
e9af19b
26a9b58
6a4bbaf
e9af19b
 
 
 
 
 
6a4bbaf
 
 
e9af19b
6a4bbaf
e9af19b
6a4bbaf
 
 
 
e9af19b
6a4bbaf
 
 
 
 
 
e9af19b
6a4bbaf
e9af19b
 
6a4bbaf
 
491cb22
e9af19b
6a4bbaf
 
 
e9af19b
 
491cb22
6a4bbaf
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
import spaces
import gradio as gr
import torch
from rdkit import Chem
from rdkit.Chem import Draw
from graph_decoder.diffusion_model import GraphDiT

# Load the model
def load_graph_decoder(path='model_labeled'):
    model = GraphDiT(
        model_config_path=f"{path}/config.yaml",
        data_info_path=f"{path}/data.meta.json",
        model_dtype=torch.float32,
    )
    model.init_model(path)
    model.disable_grads()
    # model = None
    return model

model = load_graph_decoder()
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

@spaces.GPU
def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
    properties = [CH4, CO2, H2, N2, O2]
    
    try:
        model.to(device)
        print('enter function')
        generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale)
        
        if generated_molecule is not None:
            mol = Chem.MolFromSmiles(generated_molecule)
            if mol is not None:
                standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
                img = Draw.MolToImage(mol)
                return standardized_smiles, img
    except Exception as e:
        print(f"Error in generation: {e}")
    
    return "Generation failed", None

# Create the Gradio interface
with gr.Blocks(title="Simplified Polymer Design") as iface:
    gr.Markdown("## Polymer Design with GraphDiT")
    
    with gr.Row():
        CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
        CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
        H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
        N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
        O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
        guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")

    generate_btn = gr.Button("Generate Polymer")

    with gr.Row():
        result_smiles = gr.Textbox(label="Generated SMILES")
        result_image = gr.Image(label="Molecule Visualization", type="pil")

    generate_btn.click(
        generate_polymer,
        inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
        outputs=[result_smiles, result_image]
    )

if __name__ == "__main__":
    iface.launch()