import os from rdkit import Chem from rdkit.Chem import Draw, AllChem from rdkit.Geometry import Point3D from rdkit import RDLogger import numpy as np import rdkit.Chem class MolecularVisualization: def __init__(self, atom_decoder): self.atom_decoder = atom_decoder def mol_from_graphs(self, node_list, adjacency_matrix): """ Convert graphs to rdkit molecules node_list: the nodes of a batch of nodes (bs x n) adjacency_matrix: the adjacency_matrix of the molecule (bs x n x n) """ # dictionary to map integer value to the char of atom atom_decoder = self.atom_decoder # create empty editable mol object mol = Chem.RWMol() # add atoms to mol and keep track of index node_to_idx = {} for i in range(len(node_list)): if node_list[i] == -1: continue a = Chem.Atom(atom_decoder[int(node_list[i])]) molIdx = mol.AddAtom(a) node_to_idx[i] = molIdx for ix, row in enumerate(adjacency_matrix): for iy, bond in enumerate(row): # only traverse half the symmetric matrix if iy <= ix: continue if bond == 1: bond_type = Chem.rdchem.BondType.SINGLE elif bond == 2: bond_type = Chem.rdchem.BondType.DOUBLE elif bond == 3: bond_type = Chem.rdchem.BondType.TRIPLE elif bond == 4: bond_type = Chem.rdchem.BondType.AROMATIC else: continue mol.AddBond(node_to_idx[ix], node_to_idx[iy], bond_type) try: mol = mol.GetMol() except rdkit.Chem.KekulizeException: print("Can't kekulize molecule") mol = None return mol def visualize_chain(self, nodes_list, adjacency_matrix): RDLogger.DisableLog('rdApp.*') # convert graphs to the rdkit molecules mols = [self.mol_from_graphs(nodes_list[i], adjacency_matrix[i]) for i in range(nodes_list.shape[0])] # find the coordinates of atoms in the final molecule final_molecule = mols[-1] AllChem.Compute2DCoords(final_molecule) coords = [] for i, atom in enumerate(final_molecule.GetAtoms()): positions = final_molecule.GetConformer().GetAtomPosition(i) coords.append((positions.x, positions.y, positions.z)) # align all the molecules for i, mol in enumerate(mols): AllChem.Compute2DCoords(mol) conf = mol.GetConformer() for j, atom in enumerate(mol.GetAtoms()): x, y, z = coords[j] conf.SetAtomPosition(j, Point3D(x, y, z)) # create list of molecule images mol_images = [] for frame, mol in enumerate(mols): img = Draw.MolToImage(mol, size=(300, 300), legend=f"Frame {frame}") mol_images.append(img) return mol_images