TillCyrill
scripts from other repo
0da959e
raw
history blame
6.22 kB
import numpy as np
import scipy.spatial as ss
import torch
import torch.nn.functional as F
from torch_geometric.utils import to_undirected
from torch_sparse import coalesce
atom_mapping = {0:'H', 1:'C', 2:'N', 3:'O', 4:'F', 5:'P', 6:'S', 7:'CL', 8:'BR', 9:'I', 10: 'UNK'}
residue_mapping = {0:'ALA', 1:'ARG', 2:'ASN', 3:'ASP', 4:'CYS', 5:'CYX', 6:'GLN', 7:'GLU', 8:'GLY', 9:'HIE', 10:'ILE', 11:'LEU', 12:'LYS', 13:'MET', 14:'PHE', 15:'PRO', 16:'SER', 17:'THR', 18:'TRP', 19:'TYR', 20:'VAL', 21:'UNK'}
ligand_atoms_mapping = {8: 0, 16: 1, 6: 2, 7: 3, 1: 4, 15: 5, 17: 6, 9: 7, 53: 8, 35: 9, 5: 10, 33: 11, 26: 12, 14: 13, 34: 14, 44: 15, 12: 16, 23: 17, 77: 18, 27: 19, 52: 20, 30: 21, 4: 22, 45: 23}
def prot_df_to_graph(item, df, edge_dist_cutoff, feat_col='element'):
r"""
Converts protein in dataframe representation to a graph compatible with Pytorch-Geometric, where each node is an atom.
:param df: Protein structure in dataframe format.
:type df: pandas.DataFrame
:param node_col: Column of dataframe to find node feature values. For example, for atoms use ``feat_col="element"`` and for residues use ``feat_col="resname"``
:type node_col: str, optional
:param allowable_feats: List containing all possible values of node type, to be converted into 1-hot node features.
Any elements in ``feat_col`` that are not found in ``allowable_feats`` will be added to an appended "unknown" bin (see :func:`atom3d.util.graph.one_of_k_encoding_unk`).
:type allowable_feats: list, optional
:param edge_dist_cutoff: Maximum distance cutoff (in Angstroms) to define an edge between two atoms, defaults to 4.5.
:type edge_dist_cutoff: float, optional
:return: tuple containing
- node_feats (torch.FloatTensor): Features for each node, one-hot encoded by values in ``allowable_feats``.
- edges (torch.LongTensor): Edges in COO format
- edge_weights (torch.LongTensor): Edge weights, defined as a function of distance between atoms given by :math:`w_{i,j} = \frac{1}{d(i,j)}`, where :math:`d(i, j)` is the Euclidean distance between node :math:`i` and node :math:`j`.
- node_pos (torch.FloatTensor): x-y-z coordinates of each node
:rtype: Tuple
"""
allowable_feats = atom_mapping
try :
node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())
kd_tree = ss.KDTree(node_pos)
edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff))
edges = torch.LongTensor(edge_tuples).t().contiguous()
edges = to_undirected(edges)
except:
print(f"Problem with PDB Id is {item['id']}")
node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices(e-1, allowable_feats) for e in df[feat_col]])
edge_weights = torch.FloatTensor(
[1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edges.t()]).view(-1)
return node_feats, edges, edge_weights, node_pos
def mol_df_to_graph_for_qm(df, bonds=None, allowable_atoms=None, edge_dist_cutoff=4.5, onehot_edges=True):
"""
Converts molecule in dataframe to a graph compatible with Pytorch-Geometric
:param df: Molecule structure in dataframe format
:type mol: pandas.DataFrame
:param bonds: Molecule structure in dataframe format
:type bonds: pandas.DataFrame
:param allowable_atoms: List containing allowable atom types
:type allowable_atoms: list[str], optional
:return: Tuple containing \n
- node_feats (torch.FloatTensor): Features for each node, one-hot encoded by atom type in ``allowable_atoms``.
- edge_index (torch.LongTensor): Edges from chemical bond graph in COO format.
- edge_feats (torch.FloatTensor): Edge features given by bond type. Single = 1.0, Double = 2.0, Triple = 3.0, Aromatic = 1.5.
- node_pos (torch.FloatTensor): x-y-z coordinates of each node.
"""
if allowable_atoms is None:
allowable_atoms = ligand_atoms_mapping
node_pos = torch.FloatTensor(df[['x', 'y', 'z']].to_numpy())
if bonds is not None:
N = df.shape[0]
bond_mapping = {1.0: 0, 2.0: 1, 3.0: 2, 1.5: 3}
bond_data = torch.FloatTensor(bonds)
edge_tuples = torch.cat((bond_data[:, :2], torch.flip(bond_data[:, :2], dims=(1,))), dim=0)
edge_index = edge_tuples.t().long().contiguous()
if onehot_edges:
bond_idx = list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist())) + list(map(lambda x: bond_mapping[x], bond_data[:,-1].tolist()))
edge_attr = F.one_hot(torch.tensor(bond_idx), num_classes=4).to(torch.float)
edge_index, edge_attr = coalesce(edge_index, edge_attr, N, N)
else:
edge_attr = torch.cat((torch.FloatTensor(bond_data[:,-1]).view(-1), torch.FloatTensor(bond_data[:,-1]).view(-1)), dim=0)
else:
kd_tree = ss.KDTree(node_pos)
edge_tuples = list(kd_tree.query_pairs(edge_dist_cutoff))
edge_index = torch.LongTensor(edge_tuples).t().contiguous()
edge_index = to_undirected(edge_index)
edge_attr = torch.FloatTensor([1.0 / (np.linalg.norm(node_pos[i] - node_pos[j]) + 1e-5) for i, j in edge_index.t()]).view(-1)
edge_attr = edge_attr.unsqueeze(1)
node_feats = torch.FloatTensor([one_of_k_encoding_unk_indices_qm(e, allowable_atoms) for e in df['element']])
return node_feats, edge_index, edge_attr, node_pos
def one_of_k_encoding_unk_indices(x, allowable_set):
"""Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element."""
one_hot_encoding = [0] * len(allowable_set)
if x in allowable_set:
one_hot_encoding[x] = 1
else:
one_hot_encoding[-1] = 1
return one_hot_encoding
def one_of_k_encoding_unk_indices_qm(x, allowable_set):
"""Converts input to 1-hot encoding given a set of allowable values. Additionally maps inputs not in the allowable set to the last element."""
one_hot_encoding = [0] * (len(allowable_set)+1)
if x in allowable_set:
one_hot_encoding[allowable_set[x]] = 1
else:
one_hot_encoding[-1] = 1
return one_hot_encoding