import numpy as np import pandas as pd import streamlit as st from PIL import Image import sys import io import os import glob import json import zipfile from tqdm import tqdm from itertools import chain import torch from torch.utils.data import DataLoader sys.path.insert(0, os.path.abspath("src/")) from clip.clip import _transform from training.datasets import CellPainting from clip.model import convert_weights, CLIPGeneral from rdkit import Chem from rdkit.Chem import Draw from rdkit.Chem import AllChem from rdkit.Chem import DataStructs basepath = os.path.dirname(__file__) datapath = os.path.join(basepath, "data") CLOOME_PATH = "/home/ana/gitrepos/hti-cloob" MODEL_PATH = os.path.join(datapath, "epoch_55.pt") npzs = os.path.join(datapath, "npzs") molecule_features = os.path.join(datapath, "all_molecule_cellpainting_features.pkl") mol_index_file = os.path.join(datapath, "cellpainting-unique-molecule.csv") image_features = os.path.join(datapath, "subset_image_cellpainting_features.pkl") images_arr = os.path.join(datapath, "subset_npzs_dict_.npz") img_index_file = os.path.join(datapath, "cellpainting-all-imgpermol.csv") imgname = "I1" device = "cuda" if torch.cuda.is_available() else "cpu" model_type = "RN50" image_resolution = 520 ######### CLOOME FUNCTIONS ######### def convert_models_to_fp32(model): for p in model.parameters(): p.data = p.data.float() if p.grad: p.grad.data = p.grad.data.float() def load(model_path, device, model, image_resolution): state_dict = torch.load(model_path, map_location=device) state_dict = state_dict["state_dict"] model_config_file = f"{model.replace('/', '-')}.json" print('Loading model from', model_config_file) assert os.path.exists(model_config_file) with open(model_config_file, 'r') as f: model_info = json.load(f) model = CLIPGeneral(**model_info) convert_weights(model) convert_models_to_fp32(model) if str(device) == "cpu": model.float() print(device) new_state_dict = {k[len('module.'):]: v for k,v in state_dict.items()} model.load_state_dict(new_state_dict) model.to(device) model.eval() return model def get_features(dataset, model, device): all_image_features = [] all_text_features = [] all_ids = [] print(f"get_features {device}") print(len(dataset)) with torch.no_grad(): for batch in tqdm(DataLoader(dataset, num_workers=1, batch_size=64)): if type(batch) is dict: imgs = batch text_features = None mols = None elif type(batch) is torch.Tensor: mols = batch imgs = None else: imgs, mols = batch if mols is not None: text_features = model.encode_text(mols.to(device)) text_features = text_features / text_features.norm(dim=-1, keepdim=True) all_text_features.append(text_features) molecules_exist = True if imgs is not None: images = imgs["input"] ids = imgs["ID"] img_features = model.encode_image(images.to(device)) img_features = img_features / img_features.norm(dim=-1, keepdim=True) all_image_features.append(img_features) all_ids.append(ids) all_ids = list(chain.from_iterable(all_ids)) if imgs is not None and mols is not None: return torch.cat(all_image_features), torch.cat(all_text_features), all_ids elif imgs is not None: return torch.cat(all_image_features), all_ids elif mols is not None: return torch.cat(all_text_features), all_ids return def read_array(file): t = torch.load(file) features = t["mol_features"] ids = t["mol_ids"] return features, ids def main(df, model_path, model, img_path=None, mol_path=None, image_resolution=None): # Load the model device = "cuda" if torch.cuda.is_available() else "cpu" print(torch.cuda.device_count()) model = load(model_path, device, model, image_resolution) preprocess_val = _transform(image_resolution, image_resolution, is_train=False, normalize="dataset", preprocess="downsize") # Load the dataset val = CellPainting(df, img_path, mol_path, transforms = preprocess_val) # Calculate the image features print("getting_features") result = get_features(val, model, device) if len(result) > 2: val_img_features, val_text_features, val_ids = result return val_img_features, val_text_features, val_ids else: val_img_features, val_ids = result return val_img_features, val_ids def img_to_numpy(file): img = Image.open(file) arr = np.array(img) return arr def illumination_threshold(arr, perc=0.0028): """ Return threshold value to not display a percentage of highest pixels""" perc = perc/100 h = arr.shape[0] w = arr.shape[1] # find n pixels to delete total_pixels = h * w n_pixels = total_pixels * perc n_pixels = int(np.around(n_pixels)) # find indexes of highest pixels flat_inds = np.argpartition(arr, -n_pixels, axis=None)[-n_pixels:] inds = np.array(np.unravel_index(flat_inds, arr.shape)).T max_values = [arr[i, j] for i, j in inds] threshold = min(max_values) return threshold def process_image(arr): threshold = illumination_threshold(arr) scaled_img = sixteen_to_eight_bit(arr, threshold) return scaled_img def sixteen_to_eight_bit(arr, display_max, display_min=0): threshold_image = ((arr.astype(float) - display_min) * (arr > display_min)) scaled_image = (threshold_image * (256. / (display_max - display_min))) scaled_image[scaled_image > 255] = 255 scaled_image = scaled_image.astype(np.uint8) return scaled_image def process_image(arr): threshold = illumination_threshold(arr) scaled_img = sixteen_to_eight_bit(arr, threshold) return scaled_img def process_sample(imglst, channels, filenames, outdir, outfile): sample = np.zeros((520, 696, 5), dtype=np.uint8) filenames_dict, channels_dict = {}, {} for i, (img, channel, fname) in enumerate(zip(imglst, channels, filenames)): print(channel) arr = img_to_numpy(img) arr = process_image(arr) sample[:,:,i] = arr channels_dict[i] = channel filenames_dict[channel] = fname sample_dict = dict(sample=sample, channels=channels_dict, filenames=filenames_dict) outfile = outfile + ".npz" outpath = os.path.join(outdir, outfile) np.savez(outpath, sample=sample, channels=channels, filenames=filenames) return sample_dict, outpath def display_cellpainting(sample): arr = sample["sample"] r = arr[:, :, 0].astype(np.float32) g = arr[:, :, 3].astype(np.float32) b = arr[:, :, 4].astype(np.float32) rgb_arr = np.dstack((r, g, b)) im = Image.fromarray(rgb_arr.astype("uint8")) im_rgb = im.convert("RGB") return im_rgb def morgan_from_smiles(smiles, radius=3, nbits=1024, chiral=True): mol = Chem.MolFromSmiles(smiles) fp = AllChem.GetMorganFingerprintAsBitVect(mol, radius=3, nBits=nbits, useChirality=chiral) arr = np.zeros((0,), dtype=np.int8) DataStructs.ConvertToNumpyArray(fp,arr) return arr def save_hdf(fps, index, outfile_hdf): ids = [i for i in range(len(fps))] columns = [str(i) for i in range(fps[0].shape[0])] df = pd.DataFrame(fps, index=ids, columns=columns) df.to_hdf(outfile_hdf, key="df", mode="w") return outfile_hdf def create_index(outdir, ids, filename): filepath = os.path.join(outdir, filename) if type(ids) is str: values = [ids] else: values = ids data = {"SAMPLE_KEY": values} print(data) df = pd.DataFrame(data) df.to_csv(filepath) return filepath def draw_molecules(smiles_lst): mols = [Chem.MolFromSmiles(s) for s in smiles_lst] mol_imgs = [Chem.Draw.MolToImage(m) for m in mols] return mol_imgs def reshape_image(arr): c, h, w = arr.shape reshaped_image = np.empty((h, w, c)) reshaped_image[:,:,0] = arr[0] reshaped_image[:,:,1] = arr[1] reshaped_image[:,:,2] = arr[2] reshaped_pil = Image.fromarray(reshaped_image.astype("uint8")) return reshaped_pil # missing functions: save morgan to to_hdf, create index, load features, calculate similarities ##### STREAMLIT FUNCTIONS ###### st.title('CLOOME: Contrastive Learning for Molecule Representation with Microscopy Images and Chemical Structures') def main_page(): st.markdown( """ Contrastive learning for self-supervised representation learning has brought a strong improvement to many application areas, such as computer vision and natural language processing. With the availability of large collections of unlabeled data in vision and language, contrastive learning of language and image representations has shown impressive results. The contrastive learning methods CLIP and CLOOB have demonstrated that the learned representations are highly transferable to a large set of diverse tasks when trained on multi-modal data from two different domains. In drug discovery, similar large, multi-modal datasets comprising both cell-based microscopy images and chemical structures of molecules are available. However, contrastive learning has not yet been used for this type of multi-modal data, although transferable representations could be a remedy for the time-consuming and cost-expensive label acquisition in this domain. In this work, we present a contrastive learning method for image-based and structure-based representations of small molecules for drug discovery. Our method, Contrastive Leave One Out boost for Molecule Encoders (CLOOME), is based on CLOOB and comprises an encoder for microscopy data, an encoder for chemical structures and a contrastive learning objective. On the benchmark dataset ”Cell Painting”, we demonstrate the ability of our method to learn transferable representations by performing linear probing for activity prediction tasks. Additionally, we show that the representations could also be useful for bioisosteric replacement tasks. """ ) def molecules_from_image(): ## TODO: Check if expander can be automatically collapsed exp = st.expander("Upload a microscopy image") with exp: channels = ['Mito', 'ERSyto', 'ERSytoBleed', 'Ph_golgi', 'Hoechst'] imglst, filenames = [], [] for c in channels: file_obj = st.file_uploader(f'Choose a TIF image for {c}:', ".tif") if file_obj is not None: imglst.append(file_obj) filenames.append(file_obj.name) if imglst: if not os.path.isdir(npzs): os.mkdir(npzs) sample_dict, imgpath = process_sample(imglst, channels, filenames, npzs, imgname) print(imglst) i = display_cellpainting(sample_dict) st.image(i) uploaded_file = st.file_uploader("Choose a molecule file to retrieve from (optional)") if imglst: if uploaded_file is not None: molecule_df = pd.read_csv(uploaded_file) smiles = molecule_df["SMILES"].tolist() morgan = [morgan_from_smiles(s) for s in smiles] molnames = [f"M{i}" for i in range(len(morgan))] mol_index_fname = "mol_index.csv" mol_index = create_index(datapath, molnames, mol_index_fname) molpath = os.path.join(datapath, "mols.hdf") fps_fname = save_hdf(morgan, molnames, molpath) mol_imgs = draw_molecules(smiles) mol_features, mol_ids = main(mol_index, MODEL_PATH, model_type, mol_path=molpath, image_resolution=image_resolution) predefined_features = False else: mol_index = pd.read_csv(mol_index_file) mol_features_torch = torch.load(molecule_features, map_location=device) mol_features = mol_features_torch["mol_features"] mol_ids = mol_features_torch["mol_ids"] print(len(mol_ids)) predefined_features = True img_index_fname = "img_index.csv" img_index = create_index(datapath, imgname, img_index_fname) img_features, img_ids = main(img_index, MODEL_PATH, model_type, img_path=npzs, image_resolution=image_resolution) print(img_features.shape) print(mol_features.shape) logits = img_features @ mol_features.T mol_probs = (30.0 * logits).softmax(dim=-1) top_probs, top_labels = mol_probs.cpu().topk(5, dim=-1) # Delete this if want to allow retrieval for multiple images top_probs = torch.flatten(top_probs) top_labels = torch.flatten(top_labels) print(top_probs.shape) print(top_labels.shape) if predefined_features: mol_index.set_index(["SAMPLE_KEY"], inplace=True) top_ids = [mol_ids[i] for i in top_labels] smiles = mol_index.loc[top_ids]["SMILES"].tolist() mol_imgs = draw_molecules(smiles) with st.container(): #st.write("Ranking of most similar molecules") columns = st.columns(len(top_probs)) for i, col in enumerate(columns): if predefined_features: image_id = i else: image_id = top_labels[i] index = i+1 col.image(mol_imgs[image_id], width=140, caption=index) print(mol_probs.sum(dim=-1)) print((top_probs, top_labels)) def images_from_molecule(): smiles = st.text_input("Enter a SMILES string", value="CC(=O)OC1=CC=CC=C1C(=O)O", placeholder="CC(=O)OC1=CC=CC=C1C(=O)O") if smiles: smiles = [smiles] morgan = [morgan_from_smiles(s) for s in smiles] molnames = [f"M{i}" for i in range(len(morgan))] mol_index_fname = "mol_index.csv" mol_index = create_index(datapath, molnames, mol_index_fname) molpath = os.path.join(datapath, "mols.hdf") fps_fname = save_hdf(morgan, molnames, molpath) mol_imgs = draw_molecules(smiles) mol_features, mol_ids = main(mol_index, MODEL_PATH, model_type, mol_path=molpath, image_resolution=image_resolution) col1, col2, col3 = st.columns(3) with col1: st.write("") with col2: st.image(mol_imgs, width = 140) with col3: st.write("") img_features_torch = torch.load(image_features, map_location=device) img_features = img_features_torch["img_features"] img_ids = img_features_torch["img_ids"] logits = mol_features @ img_features.T img_probs = (30.0 * logits).softmax(dim=-1) top_probs, top_labels = img_probs.cpu().topk(5, dim=-1) top_probs = torch.flatten(top_probs) top_labels = torch.flatten(top_labels) img_index = pd.read_csv(img_index_file) img_index.set_index(["SAMPLE_KEY"], inplace=True) top_ids = [img_ids[i] for i in top_labels] images_dict = np.load(images_arr, allow_pickle = True) with st.container(): columns = st.columns(len(top_probs)) for i, col in enumerate(columns): id = top_ids[i] id = f"{id}.npz" image = images_dict[id] ## TODO: generalize and functionalize im = reshape_image(image) index = i+1 col.image(im, caption=index) page_names_to_funcs = { "-": main_page, "Molecules from a microscopy image": molecules_from_image, "Microscopy images from a molecule": images_from_molecule, } selected_page = st.sidebar.selectbox("What would you like to retrieve?", page_names_to_funcs.keys()) page_names_to_funcs[selected_page]()