File size: 5,191 Bytes
7afa2e2
e9af19b
 
6cc3c63
 
750cfac
 
6cc3c63
26a7403
 
 
 
 
 
 
 
 
e9af19b
491cb22
26a7403
6cc3c63
 
 
 
491cb22
7afa2e2
6a4bbaf
6cc3c63
e9af19b
6cc3c63
6a4bbaf
4fd362d
 
e9af19b
6cc3c63
e9af19b
6cc3c63
 
 
 
 
 
 
 
 
 
 
 
6a4bbaf
 
e9af19b
6a4bbaf
e9af19b
6a4bbaf
 
6cc3c63
6a4bbaf
e9af19b
6a4bbaf
 
 
 
 
 
e9af19b
6a4bbaf
e9af19b
 
6a4bbaf
 
491cb22
e9af19b
6a4bbaf
 
 
e9af19b
 
491cb22
6cc3c63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
import spaces
import gradio as gr
import torch
import torch.nn as nn
import random
from rdkit import Chem
from rdkit.Chem import Draw

from graph_decoder.diffusion_model import GraphDiT
def load_graph_decoder(path='model_labeled'):
    model = GraphDiT(
        model_config_path=f"{path}/config.yaml",
        data_info_path=f"{path}/data.meta.json",
        model_dtype=torch.float32,
    )
    model.init_model(path)
    model.disable_grads()
    return model


ATOM_SYMBOLS = ['C', 'N', 'O', 'H']

def generate_random_smiles(length=10):
    return ''.join(random.choices(ATOM_SYMBOLS, k=length))

@spaces.GPU
def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
    properties = torch.tensor([CH4, CO2, H2, N2, O2], dtype=torch.float32).unsqueeze(0)
    
    print('in generate_polymer')
    try:
        model = load_graph_decoder()
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        model.to(device)
        properties = properties.to(device)
        
        with torch.no_grad():
            output = model(properties)
        print('output', output)
        
        # Generate a random SMILES string (this is a placeholder)
        generated_molecule = generate_random_smiles()
        
        mol = Chem.MolFromSmiles(generated_molecule)
        if mol is not None:
            standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
            img = Draw.MolToImage(mol)
            return standardized_smiles, img
    except Exception as e:
        print(f"Error in generation: {e}")
    
    return "Generation failed", None

# Create the Gradio interface
with gr.Blocks(title="Simplified Polymer Design") as iface:
    gr.Markdown("## Polymer Design with Random Neural Network")
    
    with gr.Row():
        CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
        CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
        H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
        N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
        O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
        guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")

    generate_btn = gr.Button("Generate Polymer")

    with gr.Row():
        result_smiles = gr.Textbox(label="Generated SMILES")
        result_image = gr.Image(label="Molecule Visualization", type="pil")

    generate_btn.click(
        generate_polymer,
        inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
        outputs=[result_smiles, result_image]
    )

if __name__ == "__main__":
    iface.launch()

# import spaces
# import gradio as gr
# import torch
# from rdkit import Chem
# from rdkit.Chem import Draw
# # from graph_decoder.diffusion_model import GraphDiT

# # Load the model
# def load_graph_decoder(path='model_labeled'):
#     model = GraphDiT(
#         model_config_path=f"{path}/config.yaml",
#         data_info_path=f"{path}/data.meta.json",
#         model_dtype=torch.float32,
#     )
#     model.init_model(path)
#     model.disable_grads()
#     return model

# # model = load_graph_decoder()
# # device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

# @spaces.GPU
# def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
#     properties = [CH4, CO2, H2, N2, O2]
    
#     try:
#         model = load_graph_decoder()
#         device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
#         model.to(device)
#         print('enter function')
#         generated_molecule, _ = model.generate(properties, device=device, guide_scale=guidance_scale)
        
#         if generated_molecule is not None:
#             mol = Chem.MolFromSmiles(generated_molecule)
#             if mol is not None:
#                 standardized_smiles = Chem.MolToSmiles(mol, isomericSmiles=True)
#                 img = Draw.MolToImage(mol)
#                 return standardized_smiles, img
#     except Exception as e:
#         print(f"Error in generation: {e}")
    
#     return "Generation failed", None

# # Create the Gradio interface
# with gr.Blocks(title="Simplified Polymer Design") as iface:
#     gr.Markdown("## Polymer Design with GraphDiT")
    
#     with gr.Row():
#         CH4_input = gr.Slider(0, 100, value=2.5, label="CH₄ (Barrier)")
#         CO2_input = gr.Slider(0, 100, value=15.4, label="CO₂ (Barrier)")
#         H2_input = gr.Slider(0, 100, value=21.0, label="H₂ (Barrier)")
#         N2_input = gr.Slider(0, 100, value=1.5, label="N₂ (Barrier)")
#         O2_input = gr.Slider(0, 100, value=2.8, label="O₂ (Barrier)")
#         guidance_scale = gr.Slider(1, 3, value=2, label="Guidance Scale")

#     generate_btn = gr.Button("Generate Polymer")

#     with gr.Row():
#         result_smiles = gr.Textbox(label="Generated SMILES")
#         result_image = gr.Image(label="Molecule Visualization", type="pil")

#     generate_btn.click(
#         generate_polymer,
#         inputs=[CH4_input, CO2_input, H2_input, N2_input, O2_input, guidance_scale],
#         outputs=[result_smiles, result_image]
#     )

# if __name__ == "__main__":
#     iface.launch()