|
import argparse |
|
import json |
|
import os |
|
import pickle |
|
import random |
|
from functools import partial |
|
|
|
import pandas as pd |
|
import numpy as np |
|
import requests |
|
import torch |
|
import torch.distributed as dist |
|
from tqdm import tqdm |
|
import multiprocessing |
|
from multiprocessing import Pool |
|
from fragment_creator import BaseFragmentCreator, BricksFragmentCreator, Fragment |
|
from tokenizer import SmilesTokenizer |
|
from torch.utils.data.distributed import DistributedSampler |
|
from rdkit import Chem |
|
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles |
|
from tqdm.contrib.concurrent import process_map, thread_map |
|
from typing import List |
|
import swifter |
|
|
|
DATA_CACHE_DIR = "data" |
|
|
|
|
|
def _tokenize_smiles( |
|
smi: List[str], |
|
tokenizer: SmilesTokenizer = None, |
|
max_smiles_len=256, |
|
log_output=True, |
|
): |
|
|
|
tokens = tokenizer.encode(smi) |
|
if len(tokens) > max_smiles_len: |
|
if log_output: |
|
print(f"Removing to long {smi} with smiles len of {len(tokens)} ") |
|
return None |
|
|
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _tokenize_scaffolds(smi: str, tokenizer=None, max_smiles_len=256, log_output=True): |
|
|
|
|
|
smi = MurckoScaffoldSmiles(smi) |
|
tokens = tokenizer.encode(smi) |
|
tokens = tokens[1:-1] |
|
if len(tokens) > max_smiles_len: |
|
if log_output: |
|
print(f"Removing to long {smi} with smiles len of {len(tokens)} ") |
|
return None |
|
|
|
return tokens |
|
|
|
|
|
|
|
|
|
|
|
|
|
def pad_batch(src, pad_idx): |
|
max_len = max([len(d) for d in src]) |
|
|
|
padded_src = np.ones([len(src), max_len]) * pad_idx |
|
|
|
for i, j in enumerate(src): |
|
padded_src[i][0 : len(j)] = j |
|
|
|
|
|
|
|
padded_src = padded_src.T |
|
return padded_src |
|
|
|
|
|
def pretokenize( |
|
data_file=os.path.join( |
|
DATA_CACHE_DIR, "FULL_combined_zinc_pubchemqc_qm9_pc9_reddb_chembl.parquet" |
|
), |
|
tokenizer=SmilesTokenizer(), |
|
limit=None, |
|
context=["logp", "sascore", "mol_weight"], |
|
out_name: str = "processed_dataset", |
|
remove_nan_context_rows: bool = False, |
|
): |
|
df = pd.read_parquet(data_file) |
|
|
|
if limit is not None: |
|
|
|
df = df.sample(n=limit) |
|
|
|
|
|
else: |
|
|
|
df = df.sample(frac=1.0) |
|
|
|
cpu_count = ( |
|
multiprocessing.cpu_count() |
|
) |
|
print(f"Running on {cpu_count} CPUs ") |
|
|
|
tqdm.pandas() |
|
|
|
df["scaffolds"] = df["smiles"].progress_map(lambda s: None if "." in s else s) |
|
df["smiles"] = df["scaffolds"].copy() |
|
orig_len = len(df) |
|
if context is not None: |
|
if df.get("origin") is not None: |
|
origins = df["origin"].unique() |
|
origin_dics = {} |
|
for i, o in enumerate(origins): |
|
df.loc[df["origin"] == o, "origin"] = i |
|
origin_dics[o] = i |
|
df["origin"] = df["origin"].astype(float) |
|
with open( |
|
os.path.join( |
|
DATA_CACHE_DIR, os.path.basename(data_file) + "_origins.json" |
|
), |
|
"w", |
|
) as f: |
|
json.dump(origin_dics, f) |
|
|
|
mask = ( |
|
~df["smiles"].isna() |
|
& ( |
|
(~df[context].isna()).all(axis=1) |
|
if remove_nan_context_rows |
|
else np.ones(len(df["smiles"]), dtype=bool) |
|
) |
|
& ~df["scaffolds"].isna() |
|
) |
|
else: |
|
mask = ~df["smiles"].isna() |
|
error_count = np.count_nonzero(~mask) |
|
df = df[mask] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
df["tokens"] = df["smiles"].swifter.apply( |
|
partial(_tokenize_smiles, tokenizer=tokenizer, log_output=False) |
|
) |
|
df["scaffolds"] = df["tokens"].copy() |
|
|
|
mask = ~df["tokens"].isna() & ~df["scaffolds"].isna() |
|
df = df[mask] |
|
error_count += np.count_nonzero(~mask) |
|
|
|
|
|
df = df.sample(frac=1).reset_index(drop=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if context is not None: |
|
context_list = df[context].to_numpy() |
|
context_dict = {k: context_list[:, i] for i, k in enumerate(context)} |
|
else: |
|
context_dict = {} |
|
|
|
print(f"Error count: {error_count} / {orig_len} = {error_count/orig_len}") |
|
|
|
cache_path = os.path.join(os.path.dirname(__file__), ".cache") |
|
os.makedirs(cache_path, exist_ok=True) |
|
out_path = os.path.join(cache_path, f"{out_name}_{limit}.pkl") |
|
with open(out_path, "wb") as f: |
|
pickle.dump( |
|
{ |
|
"tokens": df["tokens"].tolist(), |
|
"smiles": df["smiles"].tolist(), |
|
"scaf": df["scaffolds"].tolist(), |
|
**context_dict, |
|
}, |
|
f, |
|
) |
|
print(f"Saved to {out_path}") |
|
print("Done.") |
|
|
|
|
|
class PretokDataset(torch.utils.data.Dataset): |
|
"""Loads pretokenized example from disk and returns them as PyTorch tensors.""" |
|
|
|
def __init__(self, split, pad_token_id, dataset="processed_dataset.pkl"): |
|
super().__init__() |
|
self.split = split |
|
self.dataset = dataset |
|
self.pad_token_id = pad_token_id |
|
cache_path = os.path.join(os.path.dirname(__file__), ".cache") |
|
with open(os.path.join(cache_path, self.dataset), "rb") as f: |
|
self.data_dict = pickle.load(f) |
|
|
|
|
|
split_ix = int(len(self.data_dict["tokens"]) * 0.9) |
|
if self.split == "train": |
|
self.data_dict = {k: self.data_dict[k][:split_ix] for k in self.data_dict} |
|
elif self.split == "val": |
|
self.data_dict = {k: self.data_dict[k][split_ix:] for k in self.data_dict} |
|
else: |
|
raise RuntimeError(f"Could not find split for: self.split={self.split}") |
|
|
|
def __len__(self): |
|
return len(self.data_dict["tokens"]) |
|
|
|
def __getitem__(self, idx): |
|
m = self.data_dict |
|
|
|
start = idx |
|
end = idx + 1 |
|
|
|
|
|
padded_tokens = pad_batch(m["tokens"][start:end], self.pad_token_id) |
|
chunk = torch.from_numpy((padded_tokens).astype(np.int64)) |
|
|
|
padded_scaffolds = torch.from_numpy( |
|
pad_batch(m["scaf"][start:end], self.pad_token_id).astype(np.int64) |
|
) |
|
|
|
item = { |
|
"seq": chunk, |
|
"scaf": padded_scaffolds, |
|
"smiles": m["smiles"][start:end], |
|
**{ |
|
k: torch.tensor(m[k][start:end], dtype=torch.float32) |
|
for k in m |
|
if k != "scaf" and k != "tokens" and k != "smiles" |
|
}, |
|
} |
|
|
|
return item |
|
|
|
|
|
def padding_collate_fn( |
|
data, tokenizer: SmilesTokenizer, fragment_creator: BaseFragmentCreator |
|
): |
|
|
|
pad_idx = tokenizer.pad_token_id |
|
|
|
src = [d["seq"] for d in data] |
|
|
|
max_len = max([len(d) for d in src]) |
|
padded_src = np.ones([len(src), max_len]) * pad_idx |
|
for i, j in enumerate(src): |
|
padded_src[i][0 : len(j)] = j.ravel() |
|
|
|
if fragment_creator is None: |
|
smiles_context = [d["scaf"] for d in data] |
|
else: |
|
|
|
smiles_context = [] |
|
for d in data: |
|
s = d["smiles"][0] |
|
tokens = d["seq"] |
|
frag = fragment_creator.create_fragment(Fragment(smiles=s, tokens=tokens)) |
|
if frag.tokens is not None: |
|
smiles_context.append(frag.tokens) |
|
else: |
|
smiles_context.append( |
|
torch.tensor( |
|
tokenizer.encode(frag.smiles)[1:-1], |
|
dtype=torch.long, |
|
device=tokens.device, |
|
) |
|
) |
|
|
|
max_len_ctx = max([len(d) for d in smiles_context]) |
|
padded_smiles_context = np.ones([len(smiles_context), max_len_ctx]) * pad_idx |
|
for i, j in enumerate(smiles_context): |
|
padded_smiles_context[i][0 : len(j)] = j.ravel() |
|
|
|
|
|
padded_src = padded_src.T |
|
|
|
original_context_keys = [ |
|
k for k in data[0].keys() if k != "seq" and k != "scaf" and k != "smiles" |
|
] |
|
context_out_dict = {k: [] for k in original_context_keys} |
|
|
|
for k in original_context_keys: |
|
val_list = [] |
|
for d in data: |
|
val_list.append(d[k]) |
|
|
|
context_out_dict[k] = torch.concat(val_list, dim=0) |
|
|
|
return { |
|
"src": torch.tensor(padded_src, dtype=torch.long), |
|
"fragment": torch.tensor(padded_smiles_context.T, dtype=torch.long), |
|
"context": context_out_dict, |
|
} |
|
|
|
|
|
class SmilesTask: |
|
@staticmethod |
|
def iter_batches( |
|
split, |
|
batch_size, |
|
device, |
|
context_keys: List[str], |
|
num_workers=0, |
|
dataset="processed_dataset.pkl", |
|
fragment_creator: BaseFragmentCreator = BricksFragmentCreator(), |
|
): |
|
tokenizer = SmilesTokenizer() |
|
ds = PretokDataset(split, tokenizer.pad_token_id, dataset=dataset) |
|
is_ddp = int(os.environ.get("RANK", -1)) != -1 |
|
dl = torch.utils.data.DataLoader( |
|
ds, |
|
batch_size=batch_size, |
|
pin_memory=True, |
|
num_workers=num_workers, |
|
shuffle=False, |
|
sampler=DistributedSampler(ds) if is_ddp else None, |
|
collate_fn=lambda batch: padding_collate_fn( |
|
batch, tokenizer, fragment_creator |
|
), |
|
) |
|
|
|
for data in dl: |
|
data["src"] = data["src"].to(device, non_blocking=True) |
|
data["tgt"] = data["src"].to(device, non_blocking=True) |
|
|
|
data["src"] = data["src"][:-1, :].T |
|
data["tgt"] = data["tgt"][1:, :].T |
|
|
|
data["fragment"] = ( |
|
data["fragment"].to(device, non_blocking=True).T |
|
) |
|
keys = list(data["context"].keys()) |
|
for d in keys: |
|
if d not in context_keys: |
|
del data["context"][d] |
|
else: |
|
data["context"][d] = data["context"][d].to( |
|
device, non_blocking=True |
|
) |
|
|
|
yield data |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
pretokenize( |
|
data_file=os.path.join( |
|
DATA_CACHE_DIR, |
|
"OrganiX13.parquet", |
|
), |
|
limit=None, |
|
context=["logp", "sascore", "mol_weight"], |
|
out_name="processed_dataset", |
|
remove_nan_context_rows=False, |
|
) |
|
|
|
|