Spaces:
Running
on
Zero
Running
on
Zero
import os | |
from rdkit import Chem | |
from rdkit.Chem import Draw, AllChem | |
from rdkit.Geometry import Point3D | |
from rdkit import RDLogger | |
import numpy as np | |
import rdkit.Chem | |
class MolecularVisualization: | |
def __init__(self, atom_decoder): | |
self.atom_decoder = atom_decoder | |
def mol_from_graphs(self, node_list, adjacency_matrix): | |
""" | |
Convert graphs to rdkit molecules | |
node_list: the nodes of a batch of nodes (bs x n) | |
adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) | |
""" | |
# dictionary to map integer value to the char of atom | |
atom_decoder = self.atom_decoder | |
# create empty editable mol object | |
mol = Chem.RWMol() | |
# add atoms to mol and keep track of index | |
node_to_idx = {} | |
for i in range(len(node_list)): | |
if node_list[i] == -1: | |
continue | |
a = Chem.Atom(atom_decoder[int(node_list[i])]) | |
molIdx = mol.AddAtom(a) | |
node_to_idx[i] = molIdx | |
for ix, row in enumerate(adjacency_matrix): | |
for iy, bond in enumerate(row): | |
# only traverse half the symmetric matrix | |
if iy <= ix: | |
continue | |
if bond == 1: | |
bond_type = Chem.rdchem.BondType.SINGLE | |
elif bond == 2: | |
bond_type = Chem.rdchem.BondType.DOUBLE | |
elif bond == 3: | |
bond_type = Chem.rdchem.BondType.TRIPLE | |
elif bond == 4: | |
bond_type = Chem.rdchem.BondType.AROMATIC | |
else: | |
continue | |
mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type) | |
try: | |
mol = mol.GetMol() | |
except rdkit.Chem.KekulizeException: | |
print("Can't kekulize molecule") | |
mol = None | |
return mol | |
def visualize_chain(self, nodes_list, adjacency_matrix): | |
RDLogger.DisableLog('rdApp.*') | |
# convert graphs to the rdkit molecules | |
mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] | |
# find the coordinates of atoms in the final molecule | |
final_molecule = mols[-1] | |
AllChem.Compute2DCoords(final_molecule) | |
coords = [] | |
for i, atom in enumerate(final_molecule.GetAtoms()): | |
positions = final_molecule.GetConformer().GetAtomPosition(i) | |
coords.append((positions.x, positions.y, positions.z)) | |
# align all the molecules | |
for i, mol in enumerate(mols): | |
AllChem.Compute2DCoords(mol) | |
conf = mol.GetConformer() | |
for j, atom in enumerate(mol.GetAtoms()): | |
x, y, z = coords[j] | |
conf.SetAtomPosition(j, Point3D(x, y, z)) | |
# create list of molecule images | |
mol_images = [] | |
for frame, mol in enumerate(mols): | |
img = Draw.MolToImage(mol, size=(300, 300), legend=f"Frame {frame}") | |
mol_images.append(img) | |
return mol_images |