liuganghuggingface's picture
Update graph_decoder/transformer.py
7f99386 verified
raw
history blame contribute delete
No virus
6.28 kB
import torch
import torch.nn as nn
from .layers import Attention, MLP
from .conditions import TimestepEmbedder, ConditionEmbedder
from .diffusion_utils import PlaceHolder
def modulate(x, shift, scale):
return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1)
class Transformer(nn.Module):
def __init__(
self,
max_n_nodes,
hidden_size=384,
depth=12,
num_heads=16,
mlp_ratio=4.0,
drop_condition=0.1,
Xdim=118,
Edim=5,
ydim=5,
):
super().__init__()
self.num_heads = num_heads
self.ydim = ydim
self.x_embedder = nn.Sequential(
nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False),
nn.LayerNorm(hidden_size)
)
self.t_embedder = TimestepEmbedder(hidden_size)
self.y_embedder = ConditionEmbedder(ydim, hidden_size, drop_condition)
self.blocks = nn.ModuleList(
[
Block(hidden_size, num_heads, mlp_ratio=mlp_ratio)
for _ in range(depth)
]
)
self.output_layer = OutputLayer(
max_n_nodes=max_n_nodes,
hidden_size=hidden_size,
atom_type=Xdim,
bond_type=Edim,
mlp_ratio=mlp_ratio,
num_heads=num_heads,
)
self.initialize_weights()
def initialize_weights(self):
# Initialize transformer layers:
def _basic_init(module):
if isinstance(module, nn.Linear):
torch.nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.constant_(module.bias, 0)
def _constant_init(module, i):
if isinstance(module, nn.Linear):
nn.init.constant_(module.weight, i)
if module.bias is not None:
nn.init.constant_(module.bias, i)
self.apply(_basic_init)
for block in self.blocks:
_constant_init(block.adaLN_modulation[0], 0)
_constant_init(self.output_layer.adaLN_modulation[0], 0)
def disable_grads(self):
"""
Disable gradients for all parameters in the model.
"""
for param in self.parameters():
param.requires_grad = False
def print_trainable_parameters(self):
print("Trainable parameters:")
for name, param in self.named_parameters():
if param.requires_grad:
print(f"{name}: {param.size()}")
# Calculate and print the total number of trainable parameters
total_params = sum(p.numel() for p in self.parameters() if p.requires_grad)
print(f"\nTotal trainable parameters: {total_params}")
def forward(self, X_in, E_in, node_mask, y_in, t, unconditioned):
bs, n, _ = X_in.size()
X = torch.cat([X_in, E_in.reshape(bs, n, -1)], dim=-1)
X = self.x_embedder(X)
c1 = self.t_embedder(t)
c2 = self.y_embedder(y_in, self.training, unconditioned)
c = c1 + c2
for i, block in enumerate(self.blocks):
X = block(X, c, node_mask)
# X: B * N * dx, E: B * N * N * de
X, E = self.output_layer(X, X_in, E_in, c, t, node_mask)
return PlaceHolder(X=X, E=E, y=None).mask(node_mask)
class Block(nn.Module):
def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs):
super().__init__()
self.attn_norm = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=False)
self.mlp_norm = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=False)
self.attn = Attention(
hidden_size, num_heads=num_heads, qkv_bias=False, qk_norm=True, **block_kwargs
)
self.mlp = MLP(
in_features=hidden_size,
hidden_features=int(hidden_size * mlp_ratio),
)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, 6 * hidden_size, bias=True),
nn.Softsign()
)
def forward(self, x, c, node_mask):
(
shift_msa,
scale_msa,
gate_msa,
shift_mlp,
scale_mlp,
gate_mlp,
) = self.adaLN_modulation(c).chunk(6, dim=1)
x = x + gate_msa.unsqueeze(1) * modulate(self.attn_norm(self.attn(x, node_mask=node_mask)), shift_msa, scale_msa)
x = x + gate_mlp.unsqueeze(1) * modulate(self.mlp_norm(self.mlp(x)), shift_mlp, scale_mlp)
return x
class OutputLayer(nn.Module):
def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None):
super().__init__()
self.atom_type = atom_type
self.bond_type = bond_type
final_size = atom_type + max_n_nodes * bond_type
self.xedecoder = MLP(in_features=hidden_size,
out_features=final_size, drop=0)
self.norm_final = nn.LayerNorm(final_size, eps=1e-05, elementwise_affine=False)
self.adaLN_modulation = nn.Sequential(
nn.Linear(hidden_size, hidden_size, bias=True),
nn.SiLU(),
nn.Linear(hidden_size, 2 * final_size, bias=True)
)
def forward(self, x, x_in, e_in, c, t, node_mask):
x_all = self.xedecoder(x)
B, N, D = x_all.size()
shift, scale = self.adaLN_modulation(c).chunk(2, dim=1)
x_all = modulate(self.norm_final(x_all), shift, scale)
atom_out = x_all[:, :, :self.atom_type]
atom_out = x_in + atom_out
bond_out = x_all[:, :, self.atom_type:].reshape(B, N, N, self.bond_type)
bond_out = e_in + bond_out
##### standardize adj_out
edge_mask = (~node_mask)[:, :, None] & (~node_mask)[:, None, :]
diag_mask = (
torch.eye(N, dtype=torch.bool)
.unsqueeze(0)
.expand(B, -1, -1)
.type_as(edge_mask)
)
bond_out.masked_fill_(edge_mask[:, :, :, None], 0)
bond_out.masked_fill_(diag_mask[:, :, :, None], 0)
bond_out = 1 / 2 * (bond_out + torch.transpose(bond_out, 1, 2))
return atom_out, bond_out