liuganghuggingface's picture
Upload graph_decoder/visualize_utils.py with huggingface_hub
58dad72 verified
raw
history blame contribute delete
No virus
3.08 kB
import os
from rdkit import Chem
from rdkit.Chem import Draw, AllChem
from rdkit.Geometry import Point3D
from rdkit import RDLogger
import numpy as np
import rdkit.Chem
class MolecularVisualization:
def __init__(self, atom_decoder):
self.atom_decoder = atom_decoder
def mol_from_graphs(self, node_list, adjacency_matrix):
"""
Convert graphs to rdkit molecules
node_list: the nodes of a batch of nodes (bs x n)
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n)
"""
# dictionary to map integer value to the char of atom
atom_decoder = self.atom_decoder
# create empty editable mol object
mol = Chem.RWMol()
# add atoms to mol and keep track of index
node_to_idx = {}
for i in range(len(node_list)):
if node_list[i] == -1:
continue
a = Chem.Atom(atom_decoder[int(node_list[i])])
molIdx = mol.AddAtom(a)
node_to_idx[i] = molIdx
for ix, row in enumerate(adjacency_matrix):
for iy, bond in enumerate(row):
# only traverse half the symmetric matrix
if iy <= ix:
continue
if bond == 1:
bond_type = Chem.rdchem.BondType.SINGLE
elif bond == 2:
bond_type = Chem.rdchem.BondType.DOUBLE
elif bond == 3:
bond_type = Chem.rdchem.BondType.TRIPLE
elif bond == 4:
bond_type = Chem.rdchem.BondType.AROMATIC
else:
continue
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type)
try:
mol = mol.GetMol()
except rdkit.Chem.KekulizeException:
print("Can't kekulize molecule")
mol = None
return mol
def visualize_chain(self, nodes_list, adjacency_matrix):
RDLogger.DisableLog('rdApp.*')
# convert graphs to the rdkit molecules
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])]
# find the coordinates of atoms in the final molecule
final_molecule = mols[-1]
AllChem.Compute2DCoords(final_molecule)
coords = []
for i, atom in enumerate(final_molecule.GetAtoms()):
positions = final_molecule.GetConformer().GetAtomPosition(i)
coords.append((positions.x, positions.y, positions.z))
# align all the molecules
for i, mol in enumerate(mols):
AllChem.Compute2DCoords(mol)
conf = mol.GetConformer()
for j, atom in enumerate(mol.GetAtoms()):
x, y, z = coords[j]
conf.SetAtomPosition(j, Point3D(x, y, z))
# create list of molecule images
mol_images = []
for frame, mol in enumerate(mols):
img = Draw.MolToImage(mol, size=(300, 300), legend=f"Frame {frame}")
mol_images.append(img)
return mol_images