File size: 3,793 Bytes
963134f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
import pandas as pd
import os

import torch
from transformers import RobertaTokenizerFast, RobertaForMaskedLM, DataCollatorWithPadding

import datasets
from datasets import disable_caching
disable_caching()

DEVICE = 'cuda:0'                                   # model device
ENCODER_MODEL_NAME = "entropy/roberta_zinc_480m"    # encoder name
ENCODER_BATCH_SIZE = 1024                           # batch size for computing embeddings

TOKENIZER_MAX_LEN = 256                             # max_length param on tokenizer
TOKENIZATION_NUM_PROC = 32                          # number of processes for tokenization

'''
Data source is expected to be a CSV file with a column of SMILES strings 
denoted by `SMILES_COLUMN`. The CSV is processed in chunks of size `PROCESS_CHUNKSIZE`. 

Processed chunks are saved to `SAVE_PATH` with the format `SAVE_PATH/processed_shard_{i}.hf`
'''

DATASET_CSV_FILENAME = None                         # path to data csv
PROCESS_CHUNKSIZE = 1000000                         # how many rows to process/save for each dataset shard
SMILES_COLUMN = 'smiles'                            # csv column holding smiles strings
MAX_CHUNKS = None                                   # total number of chunks to process (if None, all chunks are processed)
MAX_SMILES_LENGTH = 90                              # max smiles string length (exclusive)
MIN_SMILES_LENGTH = 5                               # min smiles string length (exclusive)
FILTER_NUM_PROC = 32                                # number of processes for filtering
SAVE_PATH = None                                    # directory to save data shards to

assert DATASET_CSV_FILENAME is not None, "must specify dataset filename"
assert SAVE_PATH is not None, "must specify save path"


def tokenization(example):
    return tokenizer(example[SMILES_COLUMN], add_special_tokens=True, 
                     truncation=True, max_length=TOKENIZER_MAX_LEN)

def embed(inputs):
    inputs = {k:inputs[k] for k in ['input_ids', 'attention_mask']}
    inputs = collator(inputs)
    inputs = {k:v.to(DEVICE) for k,v in inputs.items()}
    
    with torch.no_grad():
        outputs = model(**inputs, output_hidden_states=True)
        full_embeddings = outputs[-1][-1]
        mask = inputs['attention_mask']
        
        mean_embeddings = ((full_embeddings * mask.unsqueeze(-1)).sum(1) / mask.sum(-1).unsqueeze(-1))
        
    return {'encoder_hidden_states' : mean_embeddings}

def length_filter_smiles(example):
    min_check = (len(example[SMILES_COLUMN])>MIN_SMILES_LENGTH) if (MIN_SMILES_LENGTH is not None) else True
    max_check = (len(example[SMILES_COLUMN])<MAX_SMILES_LENGTH) if (MIN_SMILES_LENGTH is not None) else True
    type_check = type(example[SMILES_COLUMN])==str
    filter_pass = all([min_check, max_check, type_check])
    return filter_pass


tokenizer = RobertaTokenizerFast.from_pretrained(ENCODER_MODEL_NAME, max_len=TOKENIZER_MAX_LEN)
collator = DataCollatorWithPadding(tokenizer, padding=True, return_tensors='pt')

model = RobertaForMaskedLM.from_pretrained(ENCODER_MODEL_NAME)
model.to(DEVICE)
model.eval()

df_iter = pd.read_csv(DATASET_CSV_FILENAME, chunksize=PROCESS_CHUNKSIZE, usecols=[SMILES_COLUMN])

for i, df in enumerate(df_iter):
    print(f'processing dataset chunk {i}')
    
    dataset = datasets.Dataset.from_pandas(df)
    
    dataset = dataset.filter(lambda example: length_filter_smiles(example), num_proc=FILTER_NUM_PROC)
    
    dataset = dataset.map(tokenization, batched=True, num_proc=TOKENIZATION_NUM_PROC)
    
    dataset = dataset.map(embed, batched=True, batch_size=ENCODER_BATCH_SIZE)
    
    dataset.save_to_disk(f'{SAVE_PATH}/processed_shard_{i}.hf')
    
    if (MAX_CHUNKS is not None) and (i >= MAX_CHUNKS-1):
        break

print('finished data processing')