|
from abc import ABC |
|
from dataclasses import dataclass |
|
from typing import List, Union |
|
import numpy as np |
|
from rdkit import Chem |
|
from rdkit.Chem.BRICS import BRICSDecompose |
|
from rdkit.Chem.Recap import RecapDecompose |
|
|
|
import random |
|
|
|
|
|
@dataclass |
|
class Fragment: |
|
smiles: Union[str, None] |
|
tokens: Union[List[int], None] |
|
|
|
|
|
class BaseFragmentCreator(ABC): |
|
""" |
|
Is the base class for all fragment creator and does nothing to the smiles |
|
""" |
|
|
|
def __init__(self) -> None: |
|
pass |
|
|
|
def create_fragment(self, frag: Fragment) -> Fragment: |
|
return "" |
|
|
|
|
|
|
|
class RandomSubsliceFragmentCreator(BaseFragmentCreator): |
|
def __init__(self, max_fragment_size=50) -> None: |
|
super().__init__() |
|
self.max_fragment_size = max_fragment_size |
|
|
|
def create_fragment(self, frag: Fragment) -> Fragment: |
|
""" |
|
Creates the random sub slice fragments from the tokens |
|
""" |
|
tokens = frag.tokens |
|
|
|
startIdx = np.random.randint(0, len(tokens) - 1) |
|
|
|
endIdx = np.random.randint( |
|
startIdx + 1, min(len(tokens), startIdx + self.max_fragment_size) |
|
) |
|
return Fragment(smiles=None, tokens=tokens[startIdx:endIdx]) |
|
|
|
|
|
class BricksFragmentCreator(BaseFragmentCreator): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def create_fragment(self, frag: Fragment) -> Fragment: |
|
""" |
|
Creates the Bricks fragments and takes one randomly |
|
""" |
|
smiles = frag.smiles |
|
m = Chem.MolFromSmiles(smiles) |
|
if m is None: |
|
return "" |
|
|
|
res = list(BRICSDecompose(m, minFragmentSize=3)) |
|
|
|
return random.choice(res) |
|
|
|
|
|
class RecapFragmentCreator(BaseFragmentCreator): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def create_fragment(self, frag: Fragment) -> Fragment: |
|
""" |
|
Creates the Recap fragments and takes one randomly |
|
""" |
|
smiles = frag.smiles |
|
m = Chem.MolFromSmiles(smiles) |
|
if m is None: |
|
return "" |
|
|
|
res = RecapDecompose(m, minFragmentSize=3).GetAllChildren() |
|
|
|
return random.choice(res) |
|
|
|
|
|
class MolFragsFragmentCreator(BaseFragmentCreator): |
|
def __init__(self) -> None: |
|
super().__init__() |
|
|
|
def create_fragment(self, frag: Fragment) -> Fragment: |
|
""" |
|
Creates the Bricks fragments and takes one randomly |
|
""" |
|
smiles = frag.smiles |
|
m = Chem.MolFromSmiles(smiles) |
|
if m is None: |
|
return "" |
|
|
|
res = list(Chem.rdmolops.GetMolFrags(m, asMols=True)) |
|
res = [Chem.MolToSmiles(m) for m in res] |
|
|
|
return random.choice(res) |
|
|
|
|
|
def fragment_creator_factory(key: Union[str, None]): |
|
if key is None: |
|
return None |
|
|
|
if key == "mol_frags": |
|
return MolFragsFragmentCreator() |
|
elif key == "recap": |
|
return RecapFragmentCreator() |
|
elif key == "bricks": |
|
return BricksFragmentCreator() |
|
elif key == "rss": |
|
return RandomSubsliceFragmentCreator() |
|
else: |
|
raise ValueError(f"Do not have factory for the given key: {key}") |
|
|
|
|
|
if __name__ == "__main__": |
|
from tokenizer import SmilesTokenizer |
|
|
|
tokenizer = SmilesTokenizer() |
|
|
|
creator = BricksFragmentCreator() |
|
|
|
|
|
|
|
|
|
frag = creator.create_fragment("CC(=O)NC1=CC=C(C=C1)O") |
|
|
|
print(frag) |
|
tokens = tokenizer.encode(frag) |
|
print(tokens) |
|
print([tokenizer._convert_id_to_token(t) for t in tokens]) |
|
|