Ana Sanchez commited on
Commit
4f08713
1 Parent(s): 364f895
Files changed (1) hide show
  1. cloome.py +498 -0
cloome.py ADDED
@@ -0,0 +1,498 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import pandas as pd
3
+ import streamlit as st
4
+ from PIL import Image
5
+
6
+ import sys
7
+ import io
8
+ import os
9
+ import glob
10
+ import json
11
+ import zipfile
12
+ from tqdm import tqdm
13
+ from itertools import chain
14
+
15
+ import torch
16
+ from torch.utils.data import DataLoader
17
+ from torch.utils.tensorboard import SummaryWriter
18
+
19
+ import clip.clip as clip
20
+ from clip.clip import _transform
21
+ from training.datasets import CellPainting
22
+ from clip.model import convert_weights, CLIPGeneral
23
+
24
+ from rdkit import Chem
25
+ from rdkit.Chem import Draw
26
+ from rdkit.Chem import AllChem
27
+ from rdkit.Chem import DataStructs
28
+
29
+
30
+
31
+
32
+ basepath = os.path.dirname(__file__)
33
+
34
+ MODEL_PATH = os.path.join(basepath, "epoch_55.pt")
35
+ CLOOME_PATH = "/home/ana/gitrepos/hti-cloob"
36
+ npzs = os.path.join(basepath, "npzs")
37
+ imgname = "I1"
38
+ molecule_features = "all_molecule_cellpainting_features.pkl"
39
+ image_features = "subset_image_cellpainting_features.pkl"
40
+ images_arr = "subset_npzs_dict_200.npz"
41
+
42
+ device = "cuda" if torch.cuda.is_available() else "cpu"
43
+ model_type = "RN50"
44
+ image_resolution = 520
45
+
46
+ ######### CLOOME FUNCTIONS #########
47
+ def convert_models_to_fp32(model):
48
+ for p in model.parameters():
49
+ p.data = p.data.float()
50
+ if p.grad:
51
+ p.grad.data = p.grad.data.float()
52
+
53
+
54
+ def load(model_path, device, model, image_resolution):
55
+ state_dict = torch.load(model_path, map_location="cpu")
56
+ state_dict = state_dict["state_dict"]
57
+
58
+ model_config_file = f"{model.replace('/', '-')}.json"
59
+ print('Loading model from', model_config_file)
60
+ assert os.path.exists(model_config_file)
61
+ with open(model_config_file, 'r') as f:
62
+ model_info = json.load(f)
63
+ model = CLIPGeneral(**model_info)
64
+ convert_weights(model)
65
+ convert_models_to_fp32(model)
66
+
67
+ if str(device) == "cpu":
68
+ model.float()
69
+ print(device)
70
+
71
+ new_state_dict = {k[len('module.'):]: v for k,v in state_dict.items()}
72
+
73
+ model.load_state_dict(new_state_dict)
74
+ model.to(device)
75
+ model.eval()
76
+
77
+ return model
78
+
79
+
80
+ def get_features(dataset, model, device):
81
+ all_image_features = []
82
+ all_text_features = []
83
+ all_ids = []
84
+
85
+ print(f"get_features {device}")
86
+ print(len(dataset))
87
+
88
+ with torch.no_grad():
89
+ for batch in tqdm(DataLoader(dataset, num_workers=1, batch_size=64)):
90
+ if type(batch) is dict:
91
+ imgs = batch
92
+ text_features = None
93
+ mols = None
94
+ elif type(batch) is torch.Tensor:
95
+ mols = batch
96
+ imgs = None
97
+ else:
98
+ imgs, mols = batch
99
+
100
+ if mols is not None:
101
+ text_features = model.encode_text(mols.to(device))
102
+ text_features = text_features / text_features.norm(dim=-1, keepdim=True)
103
+ all_text_features.append(text_features)
104
+ molecules_exist = True
105
+
106
+ if imgs is not None:
107
+ images = imgs["input"]
108
+ ids = imgs["ID"]
109
+
110
+ img_features = model.encode_image(images.to(device))
111
+ img_features = img_features / img_features.norm(dim=-1, keepdim=True)
112
+ all_image_features.append(img_features)
113
+
114
+ all_ids.append(ids)
115
+
116
+
117
+ all_ids = list(chain.from_iterable(all_ids))
118
+
119
+ if imgs is not None and mols is not None:
120
+ return torch.cat(all_image_features), torch.cat(all_text_features), all_ids
121
+ elif imgs is not None:
122
+ return torch.cat(all_image_features), all_ids
123
+ elif mols is not None:
124
+ return torch.cat(all_text_features), all_ids
125
+ return
126
+
127
+
128
+ def read_array(file):
129
+ t = torch.load(file)
130
+ features = t["mol_features"]
131
+ ids = t["mol_ids"]
132
+ return features, ids
133
+
134
+
135
+ def main(df, model_path, model, img_path=None, mol_path=None, image_resolution=None):
136
+ # Load the model
137
+ device = "cuda" if torch.cuda.is_available() else "cpu"
138
+ print(torch.cuda.device_count())
139
+
140
+ model = load(model_path, device, model, image_resolution)
141
+
142
+ preprocess_val = _transform(image_resolution, image_resolution, is_train=False, normalize="dataset", preprocess="downsize")
143
+
144
+ # Load the dataset
145
+ val = CellPainting(df,
146
+ img_path,
147
+ mol_path,
148
+ transforms = preprocess_val)
149
+
150
+ # Calculate the image features
151
+ print("getting_features")
152
+ result = get_features(val, model, device)
153
+
154
+ if len(result) > 2:
155
+ val_img_features, val_text_features, val_ids = result
156
+ return val_img_features, val_text_features, val_ids
157
+ else:
158
+ val_img_features, val_ids = result
159
+ return val_img_features, val_ids
160
+
161
+ #val_img_features, val_ids = get_features(val, model, device)
162
+
163
+ #return val_img_features, val_text_features, val_ids
164
+
165
+ def img_to_numpy(file):
166
+ img = Image.open(file)
167
+ arr = np.array(img)
168
+ return arr
169
+
170
+
171
+ def illumination_threshold(arr, perc=0.0028):
172
+ """ Return threshold value to not display a percentage of highest pixels"""
173
+
174
+ perc = perc/100
175
+
176
+ h = arr.shape[0]
177
+ w = arr.shape[1]
178
+
179
+ # find n pixels to delete
180
+ total_pixels = h * w
181
+ n_pixels = total_pixels * perc
182
+ n_pixels = int(np.around(n_pixels))
183
+
184
+ # find indexes of highest pixels
185
+ flat_inds = np.argpartition(arr, -n_pixels, axis=None)[-n_pixels:]
186
+ inds = np.array(np.unravel_index(flat_inds, arr.shape)).T
187
+
188
+ max_values = [arr[i, j] for i, j in inds]
189
+
190
+ threshold = min(max_values)
191
+
192
+ return threshold
193
+
194
+
195
+ def process_image(arr):
196
+ threshold = illumination_threshold(arr)
197
+ scaled_img = sixteen_to_eight_bit(arr, threshold)
198
+ return scaled_img
199
+
200
+
201
+ def sixteen_to_eight_bit(arr, display_max, display_min=0):
202
+ threshold_image = ((arr.astype(float) - display_min) * (arr > display_min))
203
+
204
+ scaled_image = (threshold_image * (256. / (display_max - display_min)))
205
+ scaled_image[scaled_image > 255] = 255
206
+
207
+ scaled_image = scaled_image.astype(np.uint8)
208
+
209
+ return scaled_image
210
+
211
+
212
+ def process_image(arr):
213
+ threshold = illumination_threshold(arr)
214
+ scaled_img = sixteen_to_eight_bit(arr, threshold)
215
+ return scaled_img
216
+
217
+
218
+ def process_sample(imglst, channels, filenames, outdir, outfile):
219
+ sample = np.zeros((520, 696, 5), dtype=np.uint8)
220
+
221
+ filenames_dict, channels_dict = {}, {}
222
+
223
+ for i, (img, channel, fname) in enumerate(zip(imglst, channels, filenames)):
224
+ print(channel)
225
+ arr = img_to_numpy(img)
226
+ arr = process_image(arr)
227
+
228
+ sample[:,:,i] = arr
229
+
230
+ channels_dict[i] = channel
231
+ filenames_dict[channel] = fname
232
+
233
+ sample_dict = dict(sample=sample,
234
+ channels=channels_dict,
235
+ filenames=filenames_dict)
236
+
237
+ outfile = outfile + ".npz"
238
+ outpath = os.path.join(outdir, outfile)
239
+
240
+ np.savez(outpath, sample=sample, channels=channels, filenames=filenames)
241
+
242
+ return sample_dict, outpath
243
+
244
+
245
+ def display_cellpainting(sample):
246
+ arr = sample["sample"]
247
+ r = arr[:, :, 0].astype(np.float32)
248
+ g = arr[:, :, 3].astype(np.float32)
249
+ b = arr[:, :, 4].astype(np.float32)
250
+
251
+ rgb_arr = np.dstack((r, g, b))
252
+
253
+ im = Image.fromarray(rgb_arr.astype("uint8"))
254
+ im_rgb = im.convert("RGB")
255
+ return im_rgb
256
+
257
+
258
+ def morgan_from_smiles(smiles, radius=3, nbits=1024, chiral=True):
259
+ mol = Chem.MolFromSmiles(smiles)
260
+ fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=3, nBits=nbits, useChirality=chiral)
261
+ arr = np.zeros((0,), dtype=np.int8)
262
+ DataStructs.ConvertToNumpyArray(fp,arr)
263
+ return arr
264
+
265
+
266
+ def save_hdf(fps, index, outfile_hdf):
267
+ ids = [i for i in range(len(fps))]
268
+ columns = [str(i) for i in range(fps[0].shape[0])]
269
+ df = pd.DataFrame(fps, index=ids, columns=columns)
270
+ df.to_hdf(outfile_hdf, key="df", mode="w")
271
+ return outfile_hdf
272
+
273
+
274
+ def create_index(outdir, ids, filename):
275
+ filepath = os.path.join(outdir, filename)
276
+ if type(ids) is str:
277
+ values = [ids]
278
+ else:
279
+ values = ids
280
+ data = {"SAMPLE_KEY": values}
281
+ print(data)
282
+ df = pd.DataFrame(data)
283
+ df.to_csv(filepath)
284
+ return filepath
285
+
286
+
287
+ def draw_molecules(smiles_lst):
288
+ mols = [Chem.MolFromSmiles(s) for s in smiles_lst]
289
+ mol_imgs = [Chem.Draw.MolToImage(m) for m in mols]
290
+ return mol_imgs
291
+
292
+
293
+ def reshape_image(arr):
294
+ c, h, w = arr.shape
295
+ reshaped_image = np.empty((h, w, c))
296
+
297
+ reshaped_image[:,:,0] = arr[0]
298
+ reshaped_image[:,:,1] = arr[1]
299
+ reshaped_image[:,:,2] = arr[2]
300
+
301
+ reshaped_pil = Image.fromarray(reshaped_image.astype("uint8"))
302
+
303
+ return reshaped_pil
304
+
305
+
306
+ # missing functions: save morgan to to_hdf, create index, load features, calculate similarities
307
+
308
+
309
+ #model = load(MODEL_PATH, device, model_type, image_resolution)
310
+
311
+ ##### STREAMLIT FUNCTIONS ######
312
+ st.title('CLOOME: Contrastive Learning for Molecule Representation with Microscopy Images and Chemical Structures')
313
+
314
+
315
+ def main_page():
316
+ st.markdown(
317
+ """
318
+ Contrastive learning for self-supervised representation learning has brought a
319
+ strong improvement to many application areas, such as computer vision and natural
320
+ language processing. With the availability of large collections of unlabeled data in
321
+ vision and language, contrastive learning of language and image representations
322
+ has shown impressive results. The contrastive learning methods CLIP and CLOOB
323
+ have demonstrated that the learned representations are highly transferable to a
324
+ large set of diverse tasks when trained on multi-modal data from two different
325
+ domains. In drug discovery, similar large, multi-modal datasets comprising both
326
+ cell-based microscopy images and chemical structures of molecules are available.
327
+
328
+ However, contrastive learning has not yet been used for this type of multi-modal data,
329
+ although transferable representations could be a remedy for the
330
+ time-consuming and cost-expensive label acquisition in this domain. In this work,
331
+ we present a contrastive learning method for image-based and structure-based
332
+ representations of small molecules for drug discovery.
333
+
334
+ Our method, Contrastive Leave One Out boost for Molecule Encoders (CLOOME), is based on CLOOB
335
+ and comprises an encoder for microscopy data, an encoder for chemical structures
336
+ and a contrastive learning objective. On the benchmark dataset ”Cell Painting”,
337
+ we demonstrate the ability of our method to learn transferable representations by
338
+ performing linear probing for activity prediction tasks. Additionally, we show that
339
+ the representations could also be useful for bioisosteric replacement tasks.
340
+ """
341
+ )
342
+
343
+
344
+ def molecules_from_image():
345
+ ## TODO: Check if expander can be automatically collapsed
346
+ exp = st.expander("Upload a microscopy image")
347
+ with exp:
348
+ channels = ['Mito', 'ERSyto', 'ERSytoBleed', 'Ph_golgi', 'Hoechst']
349
+ imglst, filenames = [], []
350
+
351
+ for c in channels:
352
+ file_obj = st.file_uploader(f'Choose a TIF image for {c}:', ".tif")
353
+ if file_obj is not None:
354
+ imglst.append(file_obj)
355
+ filenames.append(file_obj.name)
356
+
357
+
358
+ if imglst:
359
+ if not os.path.isdir(npzs):
360
+ os.mkdir(npzs)
361
+
362
+ sample_dict, imgpath = process_sample(imglst, channels, filenames, npzs, imgname)
363
+ print(imglst)
364
+
365
+
366
+ i = display_cellpainting(sample_dict)
367
+ st.image(i)
368
+
369
+ uploaded_file = st.file_uploader("Choose a molecule file to retrieve from (optional)")
370
+
371
+ if imglst:
372
+ if uploaded_file is not None:
373
+ molecule_df = pd.read_csv(uploaded_file)
374
+ smiles = molecule_df["SMILES"].tolist()
375
+ morgan = [morgan_from_smiles(s) for s in smiles]
376
+ molnames = [f"M{i}" for i in range(len(morgan))]
377
+ mol_index_fname = "mol_index.csv"
378
+ mol_index = create_index(basepath, molnames, mol_index_fname)
379
+ molpath = os.path.join(basepath, "mols.hdf")
380
+ fps_fname = save_hdf(morgan, molnames, molpath)
381
+ mol_imgs = draw_molecules(smiles)
382
+ mol_features, mol_ids = main(mol_index, MODEL_PATH, model_type, mol_path=molpath, image_resolution=image_resolution)
383
+ predefined_features = False
384
+ else:
385
+ mol_index = pd.read_csv("cellpainting-unique-molecule.csv")
386
+ mol_features_torch = torch.load("all_molecule_cellpainting_features.pkl")
387
+ mol_features = mol_features_torch["mol_features"]
388
+ mol_ids = mol_features_torch["mol_ids"]
389
+ print(len(mol_ids))
390
+ predefined_features = True
391
+
392
+ img_index_fname = "img_index.csv"
393
+ img_index = create_index(basepath, imgname, img_index_fname)
394
+ img_features, img_ids = main(img_index, MODEL_PATH, model_type, img_path=npzs, image_resolution=image_resolution)
395
+
396
+ print(img_features.shape)
397
+ print(mol_features.shape)
398
+
399
+ logits = img_features @ mol_features.T
400
+ mol_probs = (30.0 * logits).softmax(dim=-1)
401
+ top_probs, top_labels = mol_probs.cpu().topk(5, dim=-1)
402
+
403
+ # Delete this if want to allow retrieval for multiple images
404
+ top_probs = torch.flatten(top_probs)
405
+ top_labels = torch.flatten(top_labels)
406
+
407
+ print(top_probs.shape)
408
+ print(top_labels.shape)
409
+
410
+ if predefined_features:
411
+ mol_index.set_index(["SAMPLE_KEY"], inplace=True)
412
+ top_ids = [mol_ids[i] for i in top_labels]
413
+ smiles = mol_index.loc[top_ids]["SMILES"].tolist()
414
+ mol_imgs = draw_molecules(smiles)
415
+
416
+ with st.container():
417
+ #st.write("Ranking of most similar molecules")
418
+ columns = st.columns(len(top_probs))
419
+ for i, col in enumerate(columns):
420
+ if predefined_features:
421
+ image_id = i
422
+ else:
423
+ image_id = top_labels[i]
424
+ index = i+1
425
+ col.image(mol_imgs[image_id], width=140, caption=index)
426
+
427
+ print(mol_probs.sum(dim=-1))
428
+ print((top_probs, top_labels))
429
+
430
+ def images_from_molecule():
431
+ smiles = st.text_input("Enter a SMILES string", value="CC(=O)OC1=CC=CC=C1C(=O)O", placeholder="CC(=O)OC1=CC=CC=C1C(=O)O")
432
+ if smiles:
433
+ smiles = [smiles]
434
+ morgan = [morgan_from_smiles(s) for s in smiles]
435
+ molnames = [f"M{i}" for i in range(len(morgan))]
436
+ mol_index_fname = "mol_index.csv"
437
+ mol_index = create_index(basepath, molnames, mol_index_fname)
438
+ molpath = os.path.join(basepath, "mols.hdf")
439
+ fps_fname = save_hdf(morgan, molnames, molpath)
440
+ mol_imgs = draw_molecules(smiles)
441
+
442
+ mol_features, mol_ids = main(mol_index, MODEL_PATH, model_type, mol_path=molpath, image_resolution=image_resolution)
443
+
444
+ col1, col2, col3 = st.columns(3)
445
+
446
+ with col1:
447
+ st.write("")
448
+
449
+ with col2:
450
+ st.image(mol_imgs, width = 140)
451
+
452
+ with col3:
453
+ st.write("")
454
+
455
+
456
+ img_features_torch = torch.load(image_features)
457
+ img_features = img_features_torch["img_features"]
458
+ img_ids = img_features_torch["img_ids"]
459
+
460
+ logits = mol_features @ img_features.T
461
+ img_probs = (30.0 * logits).softmax(dim=-1)
462
+ top_probs, top_labels = img_probs.cpu().topk(5, dim=-1)
463
+
464
+ top_probs = torch.flatten(top_probs)
465
+ top_labels = torch.flatten(top_labels)
466
+
467
+ img_index = pd.read_csv("cellpainting-all-imgpermol.csv")
468
+ img_index.set_index(["SAMPLE_KEY"], inplace=True)
469
+ top_ids = [img_ids[i] for i in top_labels]
470
+
471
+ images_dict = np.load(images_arr, allow_pickle = True)
472
+
473
+ with st.container():
474
+ columns = st.columns(len(top_probs))
475
+ for i, col in enumerate(columns):
476
+ id = top_ids[i]
477
+ id = f"{id}.npz"
478
+ image = images_dict[id]
479
+
480
+ ## TODO: generalize and functionalize
481
+ im = reshape_image(image)
482
+
483
+ index = i+1
484
+ col.image(im, caption=index)
485
+
486
+
487
+ page_names_to_funcs = {
488
+ "-": main_page,
489
+ "Molecules from a microscopy image": molecules_from_image,
490
+ "Microscopy images from a molecule": images_from_molecule,
491
+ }
492
+
493
+
494
+ selected_page = st.sidebar.selectbox("What would you like to retrieve?", page_names_to_funcs.keys())
495
+ page_names_to_funcs[selected_page]()
496
+
497
+ # print(img_features.shape)
498
+ # print(img_ids)