liuganghuggingface commited on
Commit
8c02a88
1 Parent(s): 89df0a5

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +8 -32
app.py CHANGED
@@ -7,31 +7,16 @@ from rdkit import Chem
7
  from rdkit.Chem import Draw
8
  from graph_decoder.diffusion_model import GraphDiT
9
 
10
- class RandomPolymerGenerator(nn.Module):
11
- def __init__(self):
12
- super().__init__()
13
- self.fc1 = nn.Linear(5, 64)
14
- self.fc2 = nn.Linear(64, 128)
15
- self.fc3 = nn.Linear(128, 256)
16
- self.fc4 = nn.Linear(256, 100) # Output size set to 100 for simplicity
17
-
18
- def forward(self, x):
19
- x = torch.relu(self.fc1(x))
20
- x = torch.relu(self.fc2(x))
21
- x = torch.relu(self.fc3(x))
22
- return torch.sigmoid(self.fc4(x))
23
-
24
- def load_graph_decoder(path='model_labeled'):
25
- model = GraphDiT(
26
  model_config_path=f"{path}/config.yaml",
27
  data_info_path=f"{path}/data.meta.json",
28
- model_dtype=torch.float32,
29
- )
30
- # model.init_model(path)
31
- # model.disable_grads()
32
- return model
33
-
34
- ATOM_SYMBOLS = ['C', 'N', 'O', 'H']
35
 
36
  def generate_random_smiles(length=10):
37
  return ''.join(random.choices(ATOM_SYMBOLS, k=length))
@@ -42,15 +27,6 @@ def generate_polymer(CH4, CO2, H2, N2, O2, guidance_scale):
42
 
43
  print('in generate_polymer')
44
  try:
45
- model = load_graph_decoder()
46
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
47
- model.to(device)
48
- properties = properties.to(device)
49
-
50
- with torch.no_grad():
51
- output = model(properties)
52
- print('output', output)
53
-
54
  # Generate a random SMILES string (this is a placeholder)
55
  generated_molecule = generate_random_smiles()
56
 
 
7
  from rdkit.Chem import Draw
8
  from graph_decoder.diffusion_model import GraphDiT
9
 
10
+ ATOM_SYMBOLS = ['C', 'N', 'O', 'H']
11
+
12
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
13
+
14
+ model = GraphDiT(
 
 
 
 
 
 
 
 
 
 
 
15
  model_config_path=f"{path}/config.yaml",
16
  data_info_path=f"{path}/data.meta.json",
17
+ model_dtype=torch.float32
18
+ )
19
+ model.to(device)
 
 
 
 
20
 
21
  def generate_random_smiles(length=10):
22
  return ''.join(random.choices(ATOM_SYMBOLS, k=length))
 
27
 
28
  print('in generate_polymer')
29
  try:
 
 
 
 
 
 
 
 
 
30
  # Generate a random SMILES string (this is a placeholder)
31
  generated_molecule = generate_random_smiles()
32