{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Demonstrator" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Load the model" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:sample:Compiling the model...\n" ] } ], "source": [ "import rdkit\n", "from rdkit import Chem\n", "import rdkit.rdBase as rkrb\n", "import rdkit.RDLogger as rkl\n", "import os\n", "import torch \n", "import logging\n", "import numpy as np\n", "from plot_utils import check_metrics\n", "from sample import Sampler\n", "import pandas as pd\n", "\n", "device = 'cuda' if torch.cuda.is_available() else 'cpu'\n", "\n", "if \"cuda\" in device:\n", " # dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'\n", " dtype = \"float16\" if torch.cuda.is_available() else \"float32\"\n", "else:\n", " dtype = \"float32\"\n", "\n", "logger = rkl.logger()\n", "logger.setLevel(rkl.ERROR)\n", "rkrb.DisableLog(\"rdApp.error\")\n", "\n", "torch.set_num_threads(8)\n", "logging.basicConfig(level=logging.INFO)\n", "logger = logging.getLogger(__name__)\n", "\n", "sampler = Sampler(\n", " load_path=os.path.join(\n", " os.getcwd(), \"out\", \"llama2-M-Full-RSS.pt\"\n", " ),\n", " device=device,\n", " seed=1234,\n", " dtype=dtype,\n", " compile=True,\n", ")\n", "\n", " \n", "num_samples = 100\n", "df_comp = pd.read_parquet(os.path.join(os.getcwd(),\"data\",\"OrganiX13.parquet\"))\n", "df_comp = df_comp.sample(n=2_500_000)\n", "comp_context_dict = {c: df_comp[c].to_numpy() for c in [\"logp\", \"sascore\", \"mol_weight\"]} \n", "comp_smiles = df_comp[\"smiles\"]\n", "\n" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "INFO:root:Wrote file /home/ndobberstein/Projekte/llama2-molgen/chemiscope_gen.json\n" ] } ], "source": [ "from typing import List, Dict\n", "import json\n", "from rdkit.Chem import AllChem\n", "\n", "@torch.no_grad()\n", "def convert_to_chemiscope(smiles_list : List[str], context_dict : Dict[str, List[float]]):\n", " # For more details on the file format: https://chemiscope.org/docs/tutorial/input-reference.html\n", "\n", " structures = []\n", " remove_list = []\n", " for i,smi in enumerate(smiles_list):\n", " mol = Chem.MolFromSmiles(smi)\n", " if mol is None:\n", " logging.info(f\"Mol invalid: {smi} ! Skipping...\")\n", " remove_list.append(i)\n", " continue\n", "\n", " res = AllChem.EmbedMolecule(mol,randomSeed=0xf00d, maxAttempts=20)\n", " # res = AllChem.Compute2DCoords(mol)\n", "\n", " if res != 0:\n", " logging.info(f\"Could not calculate coordinates for {smi}! Skipping..\")\n", " remove_list.append(i)\n", " continue\n", " \n", "\n", " conf = list(mol.GetConformers())[0]\n", " x,y,z = [],[],[]\n", " symbols = []\n", " for atom, coords in zip(mol.GetAtoms(), conf.GetPositions()):\n", " symbols.append(atom.GetSymbol())\n", " x.append(coords[0])\n", " y.append(coords[1])\n", " z.append(coords[2])\n", " \n", " structures.append({\n", " \"size\": len(x),\n", " \"names\": symbols,\n", " \"x\": x,\n", " \"y\": y,\n", " \"z\" : z\n", " })\n", "\n", "\n", "\n", " properties = {}\n", " \n", " for c in context_dict:\n", " properties[c] = {\n", " \"target\": \"structure\",\n", " \"values\": [v for i, v in enumerate(context_dict[c]) if i not in remove_list]\n", " }\n", " \n", "\n", "\n", " \n", " data = {\n", " \"meta\": {\n", " # // the name of the dataset\n", " \"name\": \"Test Dataset\",\n", " # // description of the dataset, OPTIONAL\n", " \"description\": \"This contains data from generated molecules\",\n", " # // authors of the dataset, OPTIONAL\n", " \"authors\": [\"Niklas Dobberstein, niklas.dobberstein@scai.fraunhofer.de\"],\n", " # // references for the dataset, OPTIONAL\n", " \"references\": [\n", " \"\",\n", " ],\n", " \n", " },\n", " \"properties\": properties,\n", " \"structures\": structures\n", " }\n", " \n", " out_path = os.path.join(os.getcwd(), \"chemiscope_gen.json\")\n", " with open(out_path, \"w\") as f:\n", " json.dump(data, f)\n", "\n", " logging.info(f\"Wrote file {out_path}\")\n", "\n", "convert_to_chemiscope([\n", " \"CC=O\",\n", " \"s1ccnc1\"\n", "], {\"logp\": [1.0,2.0], \"sascore\": [1.5,-2.0]})" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "8b28a4e692de4bb48fde10a88d9727ba", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(Checkbox(value=False, description='logp'), Checkbox(value=False, description='sascore'), Checkb…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "62331a62f2bf4d08a3a202ad277c6d92", "version_major": 2, "version_minor": 0 }, "text/plain": [ "HBox(children=(FloatSlider(value=0.0, description='logp:', max=7.0, min=-4.0, step=0.5), FloatSlider(value=2.0…" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "2d498af39f4046b0a5bb92080361dfec", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Text(value='', description='Context SMI:')" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ed8a755253444e9c83dc27c5f830588b", "version_major": 2, "version_minor": 0 }, "text/plain": [ "FloatSlider(value=0.8, description='Temperature:', max=2.0)" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "139e7d1e40984101800e2cbb740280b0", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Button(description='Generate', style=ButtonStyle())" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "4d119a3b477243ac916478a6ec2a55c7", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "dfce28d4f6a3414c838e6542ffb43fc6", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Output()" ] }, "metadata": {}, "output_type": "display_data" } ], "source": [ "import ipywidgets as widgets\n", "from IPython.display import display, clear_output, HTML\n", "import numpy as np\n", "import torch\n", "import matplotlib.pyplot as plt\n", "from rdkit import Chem\n", "from rdkit.Chem import Draw\n", "import logging\n", "from plot_utils import calc_context_from_smiles\n", "\n", "# Define the context_cols options and create checkboxes for them\n", "context_cols_options = [\"logp\", \"sascore\", \"mol_weight\"]\n", "context_cols_checkboxes = [widgets.Checkbox(description=col, value=False) for col in context_cols_options]\n", "\n", "# Create a text input for context_smi\n", "context_smi_input = widgets.Text(description=\"Context SMI:\", value=\"\")\n", "\n", "# Create sliders for temperature and context_cols values\n", "temperature_slider = widgets.FloatSlider(description=\"Temperature:\", min=0, max=2.0, step=0.1, value=0.8)\n", "\n", "logp_slider = widgets.FloatSlider(description=\"logp:\", min=-4, max=7, step=0.5, value=0.0)\n", "sascore_slider = widgets.FloatSlider(description=\"sascore:\", min=1, max=10, step=0.5, value=2.0)\n", "mol_weight_slider = widgets.FloatSlider(description=\"mol_weight:\", min=0.5, max=10, step=0.5, value=3.0)\n", "\n", "# Create a button to generate the code and display SMILES\n", "generate_button = widgets.Button(description=\"Generate\")\n", "\n", "# Create an output widget for displaying generated information\n", "output = widgets.Output()\n", "\n", "# Create an output widget for displaying the RDKit molecules\n", "molecule_output = widgets.Output()\n", "\n", "@torch.no_grad()\n", "def generate_code(_):\n", " with output:\n", " clear_output(wait=False)\n", " # logging.info(\"Parameters used in generation:\")\n", " \n", " # Get the selected context_cols\n", " selected_context_cols = [col for col, checkbox in zip(context_cols_options, context_cols_checkboxes) if checkbox.value]\n", " # logging.info(f\"Context Cols: {selected_context_cols}\")\n", " \n", " # Get the values of context_smi and temperature from the sliders\n", " context_smi = context_smi_input.value.strip()\n", " temperature = temperature_slider.value\n", " # logging.info(f\"Context Smiles: {context_smi}\")\n", " # logging.info(f\"Temperature: {temperature}\")\n", " \n", " # Get the values of logp, sascore, and mol_weight from the sliders\n", " context_dict = {} if len(selected_context_cols) != 0 else None\n", " for c in selected_context_cols:\n", " if c == \"logp\":\n", " val = logp_slider.value\n", " elif c == \"sascore\":\n", " val = sascore_slider.value\n", " else:\n", " val = mol_weight_slider.value\n", " val = round(val, 2)\n", " context_dict[c] = val*torch.ones((num_samples,),device=device,dtype=torch.float)\n", " # logging.info(f\"{c}: {val}\")\n", " \n", " # Generate SMILES using the provided context\n", " smiles, context = sampler.generate(\n", " context_cols=context_dict,\n", " context_smi=context_smi,\n", " start_smiles=None,\n", " num_samples=num_samples,\n", " max_new_tokens=256,\n", " temperature=temperature,\n", " top_k=25,\n", " total_gen_steps=int(np.ceil(num_samples / 1000)),\n", " return_context=True\n", " )\n", " \n", " with open(os.path.join(os.getcwd(), \"gen_smiles.txt\"), \"w\") as f:\n", " for s in smiles:\n", " f.write(f\"{s}\\n\")\n", " # Display SMILES as RDKit molecules\n", " display_molecules(smiles, context)\n", "\n", "\n", "\n", "def display_molecules(smiles_list, context_dict):\n", " with molecule_output:\n", " clear_output(wait=False)\n", " molecules = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]\n", " \n", " # Convert RDKit molecules to images and store them in a list\n", " images = [Draw.MolToImage(mol) for mol in molecules]\n", " \n", " # Create a subplot grid to display the images\n", " num_images = len(images)\n", " num_cols = 5 # Number of columns in the grid\n", " num_rows = (num_images + num_cols - 1) // num_cols # Calculate the number of rows\n", " \n", " fig, axes = plt.subplots(num_rows, num_cols, figsize=(25, 25))\n", " fig.subplots_adjust(hspace=0.5)\n", " calculated_context = {c:[] for c in context_dict}\n", " for i, ax in enumerate(axes.flat):\n", " if i < num_images:\n", " ax.imshow(images[i])\n", " for j, c in enumerate(context_dict):\n", " smiles = smiles_list[i]\n", " smi_con = round(calc_context_from_smiles([smiles], c)[0],2)\n", " calculated_context[c].append(smi_con)\n", " ax.text(0.5, -0.1 * j , f\"{c}: {context_dict[c][i]} vs {smi_con}\", transform=ax.transAxes, fontsize=10, ha='center')\n", " \n", " ax.axis('off')\n", " else:\n", " fig.delaxes(ax) # Remove empty subplots if there are more rows than images\n", " \n", "\n", " if len(context_dict) >= 2:\n", " convert_to_chemiscope(smiles_list, calculated_context)\n", "\n", " plt.savefig(\"gen_mols.png\")\n", " plt.show()\n", "\n", "# Attach the generate_code function to the button's click event\n", "generate_button.on_click(generate_code)\n", "\n", "# Display the widgets\n", "display(widgets.HBox(context_cols_checkboxes))\n", "display(widgets.HBox((logp_slider, sascore_slider, mol_weight_slider)))\n", "\n", "display(context_smi_input)\n", "display(temperature_slider)\n", "display(generate_button)\n", "display(output)\n", "display(molecule_output)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { "model_id": "ea96e00e0ea8448d97906ec965f04788", "version_major": 2, "version_minor": 0 }, "text/plain": [ "Batch: 0%| | 0/1 [00:00