liuganghuggingface commited on
Commit
26a7403
1 Parent(s): 6cc3c63

Upload app.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. app.py +10 -16
app.py CHANGED
@@ -6,24 +6,18 @@ import random
6
  from rdkit import Chem
7
  from rdkit.Chem import Draw
8
 
9
- class RandomPolymerGenerator(nn.Module):
10
- def __init__(self):
11
- super().__init__()
12
- self.fc1 = nn.Linear(5, 64)
13
- self.fc2 = nn.Linear(64, 128)
14
- self.fc3 = nn.Linear(128, 256)
15
- self.fc4 = nn.Linear(256, 100) # Output size set to 100 for simplicity
16
-
17
- def forward(self, x):
18
- x = torch.relu(self.fc1(x))
19
- x = torch.relu(self.fc2(x))
20
- x = torch.relu(self.fc3(x))
21
- return torch.sigmoid(self.fc4(x))
22
-
23
- def load_graph_decoder():
24
- model = RandomPolymerGenerator()
25
  return model
26
 
 
27
  ATOM_SYMBOLS = ['C', 'N', 'O', 'H']
28
 
29
  def generate_random_smiles(length=10):
 
6
  from rdkit import Chem
7
  from rdkit.Chem import Draw
8
 
9
+ from graph_decoder.diffusion_model import GraphDiT
10
+ def load_graph_decoder(path='model_labeled'):
11
+ model = GraphDiT(
12
+ model_config_path=f"{path}/config.yaml",
13
+ data_info_path=f"{path}/data.meta.json",
14
+ model_dtype=torch.float32,
15
+ )
16
+ model.init_model(path)
17
+ model.disable_grads()
 
 
 
 
 
 
 
18
  return model
19
 
20
+
21
  ATOM_SYMBOLS = ['C', 'N', 'O', 'H']
22
 
23
  def generate_random_smiles(length=10):