liuganghuggingface commited on
Commit
78b3cac
1 Parent(s): 8957853

Upload graph_decoder/transformer.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. graph_decoder/transformer.py +180 -0
graph_decoder/transformer.py ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from .layers import Attention, MLP
4
+ from .conditions import TimestepEmbedder, ConditionEmbedder
5
+ from .diffusion_utils import PlaceHolder
6
+
7
+ def modulate(x, shift, scale):
8
+ return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
9
+
10
+ class Transformer(nn.Module):
11
+ def __init__(
12
+ self,
13
+ max_n_nodes,
14
+ hidden_size=384,
15
+ depth=12,
16
+ num_heads=16,
17
+ mlp_ratio=4.0,
18
+ drop_condition=0.1,
19
+ Xdim=118,
20
+ Edim=5,
21
+ ydim=5,
22
+ ):
23
+ super().__init__()
24
+ self.num_heads = num_heads
25
+ self.ydim = ydim
26
+ self.x_embedder = nn.Sequential(
27
+ nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False),
28
+ nn.LayerNorm(hidden_size)
29
+ )
30
+
31
+ self.t_embedder = TimestepEmbedder(hidden_size)
32
+ self.y_embedder = ConditionEmbedder(ydim, hidden_size, drop_condition)
33
+
34
+ self.blocks = nn.ModuleList(
35
+ [
36
+ Block(hidden_size, num_heads, mlp_ratio=mlp_ratio)
37
+ for _ in range(depth)
38
+ ]
39
+ )
40
+ self.output_layer = OutputLayer(
41
+ max_n_nodes=max_n_nodes,
42
+ hidden_size=hidden_size,
43
+ atom_type=Xdim,
44
+ bond_type=Edim,
45
+ mlp_ratio=mlp_ratio,
46
+ num_heads=num_heads,
47
+ )
48
+ self.initialize_weights()
49
+
50
+ def initialize_weights(self):
51
+ # Initialize transformer layers:
52
+ def _basic_init(module):
53
+ if isinstance(module, nn.Linear):
54
+ torch.nn.init.xavier_uniform_(module.weight)
55
+ if module.bias is not None:
56
+ nn.init.constant_(module.bias, 0)
57
+
58
+ def _constant_init(module, i):
59
+ if isinstance(module, nn.Linear):
60
+ nn.init.constant_(module.weight, i)
61
+ if module.bias is not None:
62
+ nn.init.constant_(module.bias, i)
63
+
64
+ self.apply(_basic_init)
65
+
66
+ for block in self.blocks:
67
+ _constant_init(block.adaLN_modulation[0], 0)
68
+ _constant_init(self.output_layer.adaLN_modulation[0], 0)
69
+
70
+ def disable_grads(self):
71
+ """
72
+ Disable gradients for all parameters in the model.
73
+ """
74
+ for param in self.parameters():
75
+ param.requires_grad = False
76
+
77
+ def print_trainable_parameters(self):
78
+ print("Trainable parameters:")
79
+ for name, param in self.named_parameters():
80
+ if param.requires_grad:
81
+ print(f"{name}: {param.size()}")
82
+
83
+ # Calculate and print the total number of trainable parameters
84
+ total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
85
+ print(f"\nTotal trainable parameters: {total_params}")
86
+
87
+ def forward(self, X_in, E_in, node_mask, y_in, t, unconditioned):
88
+ bs, n, _ = X_in.size()
89
+ X = torch.cat([X_in, E_in.reshape(bs, n, -1)], dim=-1)
90
+ X = self.x_embedder(X)
91
+
92
+ c1 = self.t_embedder(t)
93
+ c2 = self.y_embedder(y_in, self.training, unconditioned)
94
+ c = c1 + c2
95
+
96
+ for i, block in enumerate(self.blocks):
97
+ X = block(X, c, node_mask)
98
+
99
+ # X: B * N * dx, E: B * N * N * de
100
+ X, E = self.output_layer(X, X_in, E_in, c, t, node_mask)
101
+ return PlaceHolder(X=X, E=E, y=None).mask(node_mask)
102
+
103
+ class Block(nn.Module):
104
+ def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
105
+ super().__init__()
106
+ self.attn_norm = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=False)
107
+ self.mlp_norm = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=False)
108
+
109
+ self.attn = Attention(
110
+ hidden_size, num_heads=num_heads, qkv_bias=False, qk_norm=True, **block_kwargs
111
+ )
112
+
113
+ self.mlp = MLP(
114
+ in_features=hidden_size,
115
+ hidden_features=int(hidden_size * mlp_ratio),
116
+ )
117
+
118
+ self.adaLN_modulation = nn.Sequential(
119
+ nn.Linear(hidden_size, hidden_size, bias=True),
120
+ nn.SiLU(),
121
+ nn.Linear(hidden_size, 6 * hidden_size, bias=True),
122
+ nn.Softsign()
123
+ )
124
+
125
+ def forward(self, x, c, node_mask):
126
+ (
127
+ shift_msa,
128
+ scale_msa,
129
+ gate_msa,
130
+ shift_mlp,
131
+ scale_mlp,
132
+ gate_mlp,
133
+ ) = self.adaLN_modulation(c).chunk(6, dim=1)
134
+
135
+ x = x + gate_msa.unsqueeze(1) * modulate(self.attn_norm(self.attn(x, node_mask=node_mask)), shift_msa, scale_msa)
136
+ x = x + gate_mlp.unsqueeze(1) * modulate(self.mlp_norm(self.mlp(x)), shift_mlp, scale_mlp)
137
+
138
+ return x
139
+
140
+ class OutputLayer(nn.Module):
141
+ def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None):
142
+ super().__init__()
143
+ self.atom_type = atom_type
144
+ self.bond_type = bond_type
145
+ final_size = atom_type + max_n_nodes * bond_type
146
+ self.xedecoder = MLP(in_features=hidden_size,
147
+ out_features=final_size, drop=0)
148
+
149
+ self.norm_final = nn.LayerNorm(final_size, eps=1e-05, elementwise_affine=False)
150
+ self.adaLN_modulation = nn.Sequential(
151
+ nn.Linear(hidden_size, hidden_size, bias=True),
152
+ nn.SiLU(),
153
+ nn.Linear(hidden_size, 2 * final_size, bias=True)
154
+ )
155
+
156
+ def forward(self, x, x_in, e_in, c, t, node_mask):
157
+ x_all = self.xedecoder(x)
158
+ B, N, D = x_all.size()
159
+ shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
160
+ x_all = modulate(self.norm_final(x_all), shift, scale)
161
+
162
+ atom_out = x_all[:, :, :self.atom_type]
163
+ atom_out = x_in + atom_out
164
+
165
+ bond_out = x_all[:, :, self.atom_type:].reshape(B, N, N, self.bond_type)
166
+ bond_out = e_in + bond_out
167
+
168
+ ##### standardize adj_out
169
+ edge_mask = (~node_mask)[:, :, None] & (~node_mask)[:, None, :]
170
+ diag_mask = (
171
+ torch.eye(N, dtype=torch.bool)
172
+ .unsqueeze(0)
173
+ .expand(B, -1, -1)
174
+ .type_as(edge_mask)
175
+ )
176
+ bond_out.masked_fill_(edge_mask[:, :, :, None], 0)
177
+ bond_out.masked_fill_(diag_mask[:, :, :, None], 0)
178
+ bond_out = 1 / 2 * (bond_out + torch.transpose(bond_out, 1, 2))
179
+
180
+ return atom_out, bond_out