Add LlaMol codes
Browse files- .gitignore +183 -0
- LICENSE +437 -0
- README.md +162 -5
- assets/llamol.png +0 -0
- config/config.yaml +2 -0
- config/train/llama2-Debug.yaml +47 -0
- config/train/llama2-DebugGPU.yaml +46 -0
- config/train/llama2-M-Full-BRICKS.yaml +46 -0
- config/train/llama2-M-Full-RSS.yaml +46 -0
- config/train/llama2-M-Full.yaml +46 -0
- data/Full_PC9_GAP.parquet +3 -0
- data/RedDB_Full.parquet +3 -0
- data/chembl_log_sascore.parquet +3 -0
- data/combine_all.py +164 -0
- data/opv/prepare_opv.py +265 -0
- data/pubchemqc2020_energy.parquet +3 -0
- data/pubchemqc_energy.parquet +3 -0
- data/qm9_zinc250k_cep/convert_to_parquet.py +41 -0
- data/qm9_zinc250k_cep/qm9_zinc250_cep.parquet +3 -0
- data/vocab.txt +612 -0
- data/zinc/convert_to_parquet.py +67 -0
- data/zinc/zinc_complete/download_zinc.sh +300 -0
- data/zinc/zinc_complete/run_download.py +21 -0
- demonstrator.ipynb +521 -0
- fragment_creator.py +136 -0
- generate_paper_graphs.sh +19 -0
- get_fragment_table.sh +42 -0
- model.py +787 -0
- out/llama2-M-Full-RSS.pt +3 -0
- plot_utils.py +513 -0
- preprocess_dataset.py +370 -0
- requirements.txt +8 -0
- sample.py +616 -0
- tokenizer.py +404 -0
- torch2-env.yaml +29 -0
- train.py +101 -0
- trainLLamaMol.sh +19 -0
- trainLLamaMolDDPSingleNode.sh +28 -0
- trainer.py +513 -0
.gitignore
ADDED
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.out
|
2 |
+
debug
|
3 |
+
debug-gpu
|
4 |
+
outputs
|
5 |
+
chemiscope_gen.json
|
6 |
+
gen_smiles.txt
|
7 |
+
__pycache__
|
8 |
+
*.png
|
9 |
+
*.csv
|
10 |
+
*.json
|
11 |
+
# Byte-compiled / optimized / DLL files
|
12 |
+
__pycache__/
|
13 |
+
*.py[cod]
|
14 |
+
*$py.class
|
15 |
+
data/opv/download
|
16 |
+
data/opv/opv.parquet
|
17 |
+
data/qm9_zinc250k_cep/zinc_properties.csv
|
18 |
+
data/qm9_zinc250k_cep/qm9_zinc250k_cep.parquet
|
19 |
+
data/zinc/zinc_complete/*/*.txt
|
20 |
+
!data/zinc/zinc_complete/download_zinc.sh
|
21 |
+
!data/zinc/zinc_complete/run_download.py
|
22 |
+
data/zinc/zinc_processed
|
23 |
+
data/zinc/zinc_processed.parquet
|
24 |
+
data/zinc/zinc_full.parquet
|
25 |
+
data/OrganiX13.parquet
|
26 |
+
.cache
|
27 |
+
out/plots
|
28 |
+
# C extensions
|
29 |
+
*.so
|
30 |
+
|
31 |
+
# Distribution / packaging
|
32 |
+
.Python
|
33 |
+
build/
|
34 |
+
develop-eggs/
|
35 |
+
dist/
|
36 |
+
downloads/
|
37 |
+
eggs/
|
38 |
+
.eggs/
|
39 |
+
lib/
|
40 |
+
lib64/
|
41 |
+
parts/
|
42 |
+
sdist/
|
43 |
+
var/
|
44 |
+
wheels/
|
45 |
+
share/python-wheels/
|
46 |
+
*.egg-info/
|
47 |
+
.installed.cfg
|
48 |
+
*.egg
|
49 |
+
MANIFEST
|
50 |
+
|
51 |
+
# PyInstaller
|
52 |
+
# Usually these files are written by a python script from a template
|
53 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
54 |
+
*.manifest
|
55 |
+
*.spec
|
56 |
+
|
57 |
+
# Installer logs
|
58 |
+
pip-log.txt
|
59 |
+
pip-delete-this-directory.txt
|
60 |
+
|
61 |
+
# Unit test / coverage reports
|
62 |
+
htmlcov/
|
63 |
+
.tox/
|
64 |
+
.nox/
|
65 |
+
.coverage
|
66 |
+
.coverage.*
|
67 |
+
.cache
|
68 |
+
nosetests.xml
|
69 |
+
coverage.xml
|
70 |
+
*.cover
|
71 |
+
*.py,cover
|
72 |
+
.hypothesis/
|
73 |
+
.pytest_cache/
|
74 |
+
cover/
|
75 |
+
|
76 |
+
# Translations
|
77 |
+
*.mo
|
78 |
+
*.pot
|
79 |
+
|
80 |
+
# Django stuff:
|
81 |
+
*.log
|
82 |
+
local_settings.py
|
83 |
+
db.sqlite3
|
84 |
+
db.sqlite3-journal
|
85 |
+
|
86 |
+
# Flask stuff:
|
87 |
+
instance/
|
88 |
+
.webassets-cache
|
89 |
+
|
90 |
+
# Scrapy stuff:
|
91 |
+
.scrapy
|
92 |
+
|
93 |
+
# Sphinx documentation
|
94 |
+
docs/_build/
|
95 |
+
|
96 |
+
# PyBuilder
|
97 |
+
.pybuilder/
|
98 |
+
target/
|
99 |
+
|
100 |
+
# Jupyter Notebook
|
101 |
+
.ipynb_checkpoints
|
102 |
+
|
103 |
+
# IPython
|
104 |
+
profile_default/
|
105 |
+
ipython_config.py
|
106 |
+
|
107 |
+
# pyenv
|
108 |
+
# For a library or package, you might want to ignore these files since the code is
|
109 |
+
# intended to run in multiple environments; otherwise, check them in:
|
110 |
+
# .python-version
|
111 |
+
|
112 |
+
# pipenv
|
113 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
114 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
115 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
116 |
+
# install all needed dependencies.
|
117 |
+
#Pipfile.lock
|
118 |
+
|
119 |
+
# poetry
|
120 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
121 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
122 |
+
# commonly ignored for libraries.
|
123 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
124 |
+
#poetry.lock
|
125 |
+
|
126 |
+
# pdm
|
127 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
128 |
+
#pdm.lock
|
129 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
130 |
+
# in version control.
|
131 |
+
# https://pdm.fming.dev/#use-with-ide
|
132 |
+
.pdm.toml
|
133 |
+
|
134 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
135 |
+
__pypackages__/
|
136 |
+
|
137 |
+
# Celery stuff
|
138 |
+
celerybeat-schedule
|
139 |
+
celerybeat.pid
|
140 |
+
|
141 |
+
# SageMath parsed files
|
142 |
+
*.sage.py
|
143 |
+
|
144 |
+
# Environments
|
145 |
+
.env
|
146 |
+
.venv
|
147 |
+
env/
|
148 |
+
venv/
|
149 |
+
ENV/
|
150 |
+
env.bak/
|
151 |
+
venv.bak/
|
152 |
+
|
153 |
+
# Spyder project settings
|
154 |
+
.spyderproject
|
155 |
+
.spyproject
|
156 |
+
|
157 |
+
# Rope project settings
|
158 |
+
.ropeproject
|
159 |
+
|
160 |
+
# mkdocs documentation
|
161 |
+
/site
|
162 |
+
|
163 |
+
# mypy
|
164 |
+
.mypy_cache/
|
165 |
+
.dmypy.json
|
166 |
+
dmypy.json
|
167 |
+
|
168 |
+
# Pyre type checker
|
169 |
+
.pyre/
|
170 |
+
|
171 |
+
# pytype static type analyzer
|
172 |
+
.pytype/
|
173 |
+
|
174 |
+
# Cython debug symbols
|
175 |
+
cython_debug/
|
176 |
+
|
177 |
+
# PyCharm
|
178 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
179 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
180 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
181 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
182 |
+
#.idea/
|
183 |
+
!assets/*.png
|
LICENSE
ADDED
@@ -0,0 +1,437 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Attribution-NonCommercial-ShareAlike 4.0 International
|
2 |
+
|
3 |
+
=======================================================================
|
4 |
+
|
5 |
+
Creative Commons Corporation ("Creative Commons") is not a law firm and
|
6 |
+
does not provide legal services or legal advice. Distribution of
|
7 |
+
Creative Commons public licenses does not create a lawyer-client or
|
8 |
+
other relationship. Creative Commons makes its licenses and related
|
9 |
+
information available on an "as-is" basis. Creative Commons gives no
|
10 |
+
warranties regarding its licenses, any material licensed under their
|
11 |
+
terms and conditions, or any related information. Creative Commons
|
12 |
+
disclaims all liability for damages resulting from their use to the
|
13 |
+
fullest extent possible.
|
14 |
+
|
15 |
+
Using Creative Commons Public Licenses
|
16 |
+
|
17 |
+
Creative Commons public licenses provide a standard set of terms and
|
18 |
+
conditions that creators and other rights holders may use to share
|
19 |
+
original works of authorship and other material subject to copyright
|
20 |
+
and certain other rights specified in the public license below. The
|
21 |
+
following considerations are for informational purposes only, are not
|
22 |
+
exhaustive, and do not form part of our licenses.
|
23 |
+
|
24 |
+
Considerations for licensors: Our public licenses are
|
25 |
+
intended for use by those authorized to give the public
|
26 |
+
permission to use material in ways otherwise restricted by
|
27 |
+
copyright and certain other rights. Our licenses are
|
28 |
+
irrevocable. Licensors should read and understand the terms
|
29 |
+
and conditions of the license they choose before applying it.
|
30 |
+
Licensors should also secure all rights necessary before
|
31 |
+
applying our licenses so that the public can reuse the
|
32 |
+
material as expected. Licensors should clearly mark any
|
33 |
+
material not subject to the license. This includes other CC-
|
34 |
+
licensed material, or material used under an exception or
|
35 |
+
limitation to copyright. More considerations for licensors:
|
36 |
+
wiki.creativecommons.org/Considerations_for_licensors
|
37 |
+
|
38 |
+
Considerations for the public: By using one of our public
|
39 |
+
licenses, a licensor grants the public permission to use the
|
40 |
+
licensed material under specified terms and conditions. If
|
41 |
+
the licensor's permission is not necessary for any reason--for
|
42 |
+
example, because of any applicable exception or limitation to
|
43 |
+
copyright--then that use is not regulated by the license. Our
|
44 |
+
licenses grant only permissions under copyright and certain
|
45 |
+
other rights that a licensor has authority to grant. Use of
|
46 |
+
the licensed material may still be restricted for other
|
47 |
+
reasons, including because others have copyright or other
|
48 |
+
rights in the material. A licensor may make special requests,
|
49 |
+
such as asking that all changes be marked or described.
|
50 |
+
Although not required by our licenses, you are encouraged to
|
51 |
+
respect those requests where reasonable. More considerations
|
52 |
+
for the public:
|
53 |
+
wiki.creativecommons.org/Considerations_for_licensees
|
54 |
+
|
55 |
+
=======================================================================
|
56 |
+
|
57 |
+
Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International
|
58 |
+
Public License
|
59 |
+
|
60 |
+
By exercising the Licensed Rights (defined below), You accept and agree
|
61 |
+
to be bound by the terms and conditions of this Creative Commons
|
62 |
+
Attribution-NonCommercial-ShareAlike 4.0 International Public License
|
63 |
+
("Public License"). To the extent this Public License may be
|
64 |
+
interpreted as a contract, You are granted the Licensed Rights in
|
65 |
+
consideration of Your acceptance of these terms and conditions, and the
|
66 |
+
Licensor grants You such rights in consideration of benefits the
|
67 |
+
Licensor receives from making the Licensed Material available under
|
68 |
+
these terms and conditions.
|
69 |
+
|
70 |
+
|
71 |
+
Section 1 -- Definitions.
|
72 |
+
|
73 |
+
a. Adapted Material means material subject to Copyright and Similar
|
74 |
+
Rights that is derived from or based upon the Licensed Material
|
75 |
+
and in which the Licensed Material is translated, altered,
|
76 |
+
arranged, transformed, or otherwise modified in a manner requiring
|
77 |
+
permission under the Copyright and Similar Rights held by the
|
78 |
+
Licensor. For purposes of this Public License, where the Licensed
|
79 |
+
Material is a musical work, performance, or sound recording,
|
80 |
+
Adapted Material is always produced where the Licensed Material is
|
81 |
+
synched in timed relation with a moving image.
|
82 |
+
|
83 |
+
b. Adapter's License means the license You apply to Your Copyright
|
84 |
+
and Similar Rights in Your contributions to Adapted Material in
|
85 |
+
accordance with the terms and conditions of this Public License.
|
86 |
+
|
87 |
+
c. BY-NC-SA Compatible License means a license listed at
|
88 |
+
creativecommons.org/compatiblelicenses, approved by Creative
|
89 |
+
Commons as essentially the equivalent of this Public License.
|
90 |
+
|
91 |
+
d. Copyright and Similar Rights means copyright and/or similar rights
|
92 |
+
closely related to copyright including, without limitation,
|
93 |
+
performance, broadcast, sound recording, and Sui Generis Database
|
94 |
+
Rights, without regard to how the rights are labeled or
|
95 |
+
categorized. For purposes of this Public License, the rights
|
96 |
+
specified in Section 2(b)(1)-(2) are not Copyright and Similar
|
97 |
+
Rights.
|
98 |
+
|
99 |
+
e. Effective Technological Measures means those measures that, in the
|
100 |
+
absence of proper authority, may not be circumvented under laws
|
101 |
+
fulfilling obligations under Article 11 of the WIPO Copyright
|
102 |
+
Treaty adopted on December 20, 1996, and/or similar international
|
103 |
+
agreements.
|
104 |
+
|
105 |
+
f. Exceptions and Limitations means fair use, fair dealing, and/or
|
106 |
+
any other exception or limitation to Copyright and Similar Rights
|
107 |
+
that applies to Your use of the Licensed Material.
|
108 |
+
|
109 |
+
g. License Elements means the license attributes listed in the name
|
110 |
+
of a Creative Commons Public License. The License Elements of this
|
111 |
+
Public License are Attribution, NonCommercial, and ShareAlike.
|
112 |
+
|
113 |
+
h. Licensed Material means the artistic or literary work, database,
|
114 |
+
or other material to which the Licensor applied this Public
|
115 |
+
License.
|
116 |
+
|
117 |
+
i. Licensed Rights means the rights granted to You subject to the
|
118 |
+
terms and conditions of this Public License, which are limited to
|
119 |
+
all Copyright and Similar Rights that apply to Your use of the
|
120 |
+
Licensed Material and that the Licensor has authority to license.
|
121 |
+
|
122 |
+
j. Licensor means the individual(s) or entity(ies) granting rights
|
123 |
+
under this Public License.
|
124 |
+
|
125 |
+
k. NonCommercial means not primarily intended for or directed towards
|
126 |
+
commercial advantage or monetary compensation. For purposes of
|
127 |
+
this Public License, the exchange of the Licensed Material for
|
128 |
+
other material subject to Copyright and Similar Rights by digital
|
129 |
+
file-sharing or similar means is NonCommercial provided there is
|
130 |
+
no payment of monetary compensation in connection with the
|
131 |
+
exchange.
|
132 |
+
|
133 |
+
l. Share means to provide material to the public by any means or
|
134 |
+
process that requires permission under the Licensed Rights, such
|
135 |
+
as reproduction, public display, public performance, distribution,
|
136 |
+
dissemination, communication, or importation, and to make material
|
137 |
+
available to the public including in ways that members of the
|
138 |
+
public may access the material from a place and at a time
|
139 |
+
individually chosen by them.
|
140 |
+
|
141 |
+
m. Sui Generis Database Rights means rights other than copyright
|
142 |
+
resulting from Directive 96/9/EC of the European Parliament and of
|
143 |
+
the Council of 11 March 1996 on the legal protection of databases,
|
144 |
+
as amended and/or succeeded, as well as other essentially
|
145 |
+
equivalent rights anywhere in the world.
|
146 |
+
|
147 |
+
n. You means the individual or entity exercising the Licensed Rights
|
148 |
+
under this Public License. Your has a corresponding meaning.
|
149 |
+
|
150 |
+
|
151 |
+
Section 2 -- Scope.
|
152 |
+
|
153 |
+
a. License grant.
|
154 |
+
|
155 |
+
1. Subject to the terms and conditions of this Public License,
|
156 |
+
the Licensor hereby grants You a worldwide, royalty-free,
|
157 |
+
non-sublicensable, non-exclusive, irrevocable license to
|
158 |
+
exercise the Licensed Rights in the Licensed Material to:
|
159 |
+
|
160 |
+
a. reproduce and Share the Licensed Material, in whole or
|
161 |
+
in part, for NonCommercial purposes only; and
|
162 |
+
|
163 |
+
b. produce, reproduce, and Share Adapted Material for
|
164 |
+
NonCommercial purposes only.
|
165 |
+
|
166 |
+
2. Exceptions and Limitations. For the avoidance of doubt, where
|
167 |
+
Exceptions and Limitations apply to Your use, this Public
|
168 |
+
License does not apply, and You do not need to comply with
|
169 |
+
its terms and conditions.
|
170 |
+
|
171 |
+
3. Term. The term of this Public License is specified in Section
|
172 |
+
6(a).
|
173 |
+
|
174 |
+
4. Media and formats; technical modifications allowed. The
|
175 |
+
Licensor authorizes You to exercise the Licensed Rights in
|
176 |
+
all media and formats whether now known or hereafter created,
|
177 |
+
and to make technical modifications necessary to do so. The
|
178 |
+
Licensor waives and/or agrees not to assert any right or
|
179 |
+
authority to forbid You from making technical modifications
|
180 |
+
necessary to exercise the Licensed Rights, including
|
181 |
+
technical modifications necessary to circumvent Effective
|
182 |
+
Technological Measures. For purposes of this Public License,
|
183 |
+
simply making modifications authorized by this Section 2(a)
|
184 |
+
(4) never produces Adapted Material.
|
185 |
+
|
186 |
+
5. Downstream recipients.
|
187 |
+
|
188 |
+
a. Offer from the Licensor -- Licensed Material. Every
|
189 |
+
recipient of the Licensed Material automatically
|
190 |
+
receives an offer from the Licensor to exercise the
|
191 |
+
Licensed Rights under the terms and conditions of this
|
192 |
+
Public License.
|
193 |
+
|
194 |
+
b. Additional offer from the Licensor -- Adapted Material.
|
195 |
+
Every recipient of Adapted Material from You
|
196 |
+
automatically receives an offer from the Licensor to
|
197 |
+
exercise the Licensed Rights in the Adapted Material
|
198 |
+
under the conditions of the Adapter's License You apply.
|
199 |
+
|
200 |
+
c. No downstream restrictions. You may not offer or impose
|
201 |
+
any additional or different terms or conditions on, or
|
202 |
+
apply any Effective Technological Measures to, the
|
203 |
+
Licensed Material if doing so restricts exercise of the
|
204 |
+
Licensed Rights by any recipient of the Licensed
|
205 |
+
Material.
|
206 |
+
|
207 |
+
6. No endorsement. Nothing in this Public License constitutes or
|
208 |
+
may be construed as permission to assert or imply that You
|
209 |
+
are, or that Your use of the Licensed Material is, connected
|
210 |
+
with, or sponsored, endorsed, or granted official status by,
|
211 |
+
the Licensor or others designated to receive attribution as
|
212 |
+
provided in Section 3(a)(1)(A)(i).
|
213 |
+
|
214 |
+
b. Other rights.
|
215 |
+
|
216 |
+
1. Moral rights, such as the right of integrity, are not
|
217 |
+
licensed under this Public License, nor are publicity,
|
218 |
+
privacy, and/or other similar personality rights; however, to
|
219 |
+
the extent possible, the Licensor waives and/or agrees not to
|
220 |
+
assert any such rights held by the Licensor to the limited
|
221 |
+
extent necessary to allow You to exercise the Licensed
|
222 |
+
Rights, but not otherwise.
|
223 |
+
|
224 |
+
2. Patent and trademark rights are not licensed under this
|
225 |
+
Public License.
|
226 |
+
|
227 |
+
3. To the extent possible, the Licensor waives any right to
|
228 |
+
collect royalties from You for the exercise of the Licensed
|
229 |
+
Rights, whether directly or through a collecting society
|
230 |
+
under any voluntary or waivable statutory or compulsory
|
231 |
+
licensing scheme. In all other cases the Licensor expressly
|
232 |
+
reserves any right to collect such royalties, including when
|
233 |
+
the Licensed Material is used other than for NonCommercial
|
234 |
+
purposes.
|
235 |
+
|
236 |
+
|
237 |
+
Section 3 -- License Conditions.
|
238 |
+
|
239 |
+
Your exercise of the Licensed Rights is expressly made subject to the
|
240 |
+
following conditions.
|
241 |
+
|
242 |
+
a. Attribution.
|
243 |
+
|
244 |
+
1. If You Share the Licensed Material (including in modified
|
245 |
+
form), You must:
|
246 |
+
|
247 |
+
a. retain the following if it is supplied by the Licensor
|
248 |
+
with the Licensed Material:
|
249 |
+
|
250 |
+
i. identification of the creator(s) of the Licensed
|
251 |
+
Material and any others designated to receive
|
252 |
+
attribution, in any reasonable manner requested by
|
253 |
+
the Licensor (including by pseudonym if
|
254 |
+
designated);
|
255 |
+
|
256 |
+
ii. a copyright notice;
|
257 |
+
|
258 |
+
iii. a notice that refers to this Public License;
|
259 |
+
|
260 |
+
iv. a notice that refers to the disclaimer of
|
261 |
+
warranties;
|
262 |
+
|
263 |
+
v. a URI or hyperlink to the Licensed Material to the
|
264 |
+
extent reasonably practicable;
|
265 |
+
|
266 |
+
b. indicate if You modified the Licensed Material and
|
267 |
+
retain an indication of any previous modifications; and
|
268 |
+
|
269 |
+
c. indicate the Licensed Material is licensed under this
|
270 |
+
Public License, and include the text of, or the URI or
|
271 |
+
hyperlink to, this Public License.
|
272 |
+
|
273 |
+
2. You may satisfy the conditions in Section 3(a)(1) in any
|
274 |
+
reasonable manner based on the medium, means, and context in
|
275 |
+
which You Share the Licensed Material. For example, it may be
|
276 |
+
reasonable to satisfy the conditions by providing a URI or
|
277 |
+
hyperlink to a resource that includes the required
|
278 |
+
information.
|
279 |
+
3. If requested by the Licensor, You must remove any of the
|
280 |
+
information required by Section 3(a)(1)(A) to the extent
|
281 |
+
reasonably practicable.
|
282 |
+
|
283 |
+
b. ShareAlike.
|
284 |
+
|
285 |
+
In addition to the conditions in Section 3(a), if You Share
|
286 |
+
Adapted Material You produce, the following conditions also apply.
|
287 |
+
|
288 |
+
1. The Adapter's License You apply must be a Creative Commons
|
289 |
+
license with the same License Elements, this version or
|
290 |
+
later, or a BY-NC-SA Compatible License.
|
291 |
+
|
292 |
+
2. You must include the text of, or the URI or hyperlink to, the
|
293 |
+
Adapter's License You apply. You may satisfy this condition
|
294 |
+
in any reasonable manner based on the medium, means, and
|
295 |
+
context in which You Share Adapted Material.
|
296 |
+
|
297 |
+
3. You may not offer or impose any additional or different terms
|
298 |
+
or conditions on, or apply any Effective Technological
|
299 |
+
Measures to, Adapted Material that restrict exercise of the
|
300 |
+
rights granted under the Adapter's License You apply.
|
301 |
+
|
302 |
+
|
303 |
+
Section 4 -- Sui Generis Database Rights.
|
304 |
+
|
305 |
+
Where the Licensed Rights include Sui Generis Database Rights that
|
306 |
+
apply to Your use of the Licensed Material:
|
307 |
+
|
308 |
+
a. for the avoidance of doubt, Section 2(a)(1) grants You the right
|
309 |
+
to extract, reuse, reproduce, and Share all or a substantial
|
310 |
+
portion of the contents of the database for NonCommercial purposes
|
311 |
+
only;
|
312 |
+
|
313 |
+
b. if You include all or a substantial portion of the database
|
314 |
+
contents in a database in which You have Sui Generis Database
|
315 |
+
Rights, then the database in which You have Sui Generis Database
|
316 |
+
Rights (but not its individual contents) is Adapted Material,
|
317 |
+
including for purposes of Section 3(b); and
|
318 |
+
|
319 |
+
c. You must comply with the conditions in Section 3(a) if You Share
|
320 |
+
all or a substantial portion of the contents of the database.
|
321 |
+
|
322 |
+
For the avoidance of doubt, this Section 4 supplements and does not
|
323 |
+
replace Your obligations under this Public License where the Licensed
|
324 |
+
Rights include other Copyright and Similar Rights.
|
325 |
+
|
326 |
+
|
327 |
+
Section 5 -- Disclaimer of Warranties and Limitation of Liability.
|
328 |
+
|
329 |
+
a. UNLESS OTHERWISE SEPARATELY UNDERTAKEN BY THE LICENSOR, TO THE
|
330 |
+
EXTENT POSSIBLE, THE LICENSOR OFFERS THE LICENSED MATERIAL AS-IS
|
331 |
+
AND AS-AVAILABLE, AND MAKES NO REPRESENTATIONS OR WARRANTIES OF
|
332 |
+
ANY KIND CONCERNING THE LICENSED MATERIAL, WHETHER EXPRESS,
|
333 |
+
IMPLIED, STATUTORY, OR OTHER. THIS INCLUDES, WITHOUT LIMITATION,
|
334 |
+
WARRANTIES OF TITLE, MERCHANTABILITY, FITNESS FOR A PARTICULAR
|
335 |
+
PURPOSE, NON-INFRINGEMENT, ABSENCE OF LATENT OR OTHER DEFECTS,
|
336 |
+
ACCURACY, OR THE PRESENCE OR ABSENCE OF ERRORS, WHETHER OR NOT
|
337 |
+
KNOWN OR DISCOVERABLE. WHERE DISCLAIMERS OF WARRANTIES ARE NOT
|
338 |
+
ALLOWED IN FULL OR IN PART, THIS DISCLAIMER MAY NOT APPLY TO YOU.
|
339 |
+
|
340 |
+
b. TO THE EXTENT POSSIBLE, IN NO EVENT WILL THE LICENSOR BE LIABLE
|
341 |
+
TO YOU ON ANY LEGAL THEORY (INCLUDING, WITHOUT LIMITATION,
|
342 |
+
NEGLIGENCE) OR OTHERWISE FOR ANY DIRECT, SPECIAL, INDIRECT,
|
343 |
+
INCIDENTAL, CONSEQUENTIAL, PUNITIVE, EXEMPLARY, OR OTHER LOSSES,
|
344 |
+
COSTS, EXPENSES, OR DAMAGES ARISING OUT OF THIS PUBLIC LICENSE OR
|
345 |
+
USE OF THE LICENSED MATERIAL, EVEN IF THE LICENSOR HAS BEEN
|
346 |
+
ADVISED OF THE POSSIBILITY OF SUCH LOSSES, COSTS, EXPENSES, OR
|
347 |
+
DAMAGES. WHERE A LIMITATION OF LIABILITY IS NOT ALLOWED IN FULL OR
|
348 |
+
IN PART, THIS LIMITATION MAY NOT APPLY TO YOU.
|
349 |
+
|
350 |
+
c. The disclaimer of warranties and limitation of liability provided
|
351 |
+
above shall be interpreted in a manner that, to the extent
|
352 |
+
possible, most closely approximates an absolute disclaimer and
|
353 |
+
waiver of all liability.
|
354 |
+
|
355 |
+
|
356 |
+
Section 6 -- Term and Termination.
|
357 |
+
|
358 |
+
a. This Public License applies for the term of the Copyright and
|
359 |
+
Similar Rights licensed here. However, if You fail to comply with
|
360 |
+
this Public License, then Your rights under this Public License
|
361 |
+
terminate automatically.
|
362 |
+
|
363 |
+
b. Where Your right to use the Licensed Material has terminated under
|
364 |
+
Section 6(a), it reinstates:
|
365 |
+
|
366 |
+
1. automatically as of the date the violation is cured, provided
|
367 |
+
it is cured within 30 days of Your discovery of the
|
368 |
+
violation; or
|
369 |
+
|
370 |
+
2. upon express reinstatement by the Licensor.
|
371 |
+
|
372 |
+
For the avoidance of doubt, this Section 6(b) does not affect any
|
373 |
+
right the Licensor may have to seek remedies for Your violations
|
374 |
+
of this Public License.
|
375 |
+
|
376 |
+
c. For the avoidance of doubt, the Licensor may also offer the
|
377 |
+
Licensed Material under separate terms or conditions or stop
|
378 |
+
distributing the Licensed Material at any time; however, doing so
|
379 |
+
will not terminate this Public License.
|
380 |
+
|
381 |
+
d. Sections 1, 5, 6, 7, and 8 survive termination of this Public
|
382 |
+
License.
|
383 |
+
|
384 |
+
|
385 |
+
Section 7 -- Other Terms and Conditions.
|
386 |
+
|
387 |
+
a. The Licensor shall not be bound by any additional or different
|
388 |
+
terms or conditions communicated by You unless expressly agreed.
|
389 |
+
|
390 |
+
b. Any arrangements, understandings, or agreements regarding the
|
391 |
+
Licensed Material not stated herein are separate from and
|
392 |
+
independent of the terms and conditions of this Public License.
|
393 |
+
|
394 |
+
|
395 |
+
Section 8 -- Interpretation.
|
396 |
+
|
397 |
+
a. For the avoidance of doubt, this Public License does not, and
|
398 |
+
shall not be interpreted to, reduce, limit, restrict, or impose
|
399 |
+
conditions on any use of the Licensed Material that could lawfully
|
400 |
+
be made without permission under this Public License.
|
401 |
+
|
402 |
+
b. To the extent possible, if any provision of this Public License is
|
403 |
+
deemed unenforceable, it shall be automatically reformed to the
|
404 |
+
minimum extent necessary to make it enforceable. If the provision
|
405 |
+
cannot be reformed, it shall be severed from this Public License
|
406 |
+
without affecting the enforceability of the remaining terms and
|
407 |
+
conditions.
|
408 |
+
|
409 |
+
c. No term or condition of this Public License will be waived and no
|
410 |
+
failure to comply consented to unless expressly agreed to by the
|
411 |
+
Licensor.
|
412 |
+
|
413 |
+
d. Nothing in this Public License constitutes or may be interpreted
|
414 |
+
as a limitation upon, or waiver of, any privileges and immunities
|
415 |
+
that apply to the Licensor or You, including from the legal
|
416 |
+
processes of any jurisdiction or authority.
|
417 |
+
|
418 |
+
=======================================================================
|
419 |
+
|
420 |
+
Creative Commons is not a party to its public
|
421 |
+
licenses. Notwithstanding, Creative Commons may elect to apply one of
|
422 |
+
its public licenses to material it publishes and in those instances
|
423 |
+
will be considered the “Licensor.” The text of the Creative Commons
|
424 |
+
public licenses is dedicated to the public domain under the CC0 Public
|
425 |
+
Domain Dedication. Except for the limited purpose of indicating that
|
426 |
+
material is shared under a Creative Commons public license or as
|
427 |
+
otherwise permitted by the Creative Commons policies published at
|
428 |
+
creativecommons.org/policies, Creative Commons does not authorize the
|
429 |
+
use of the trademark "Creative Commons" or any other trademark or logo
|
430 |
+
of Creative Commons without its prior written consent including,
|
431 |
+
without limitation, in connection with any unauthorized modifications
|
432 |
+
to any of its public licenses or any other arrangements,
|
433 |
+
understandings, or agreements concerning use of licensed material. For
|
434 |
+
the avoidance of doubt, this paragraph does not form part of the
|
435 |
+
public licenses.
|
436 |
+
|
437 |
+
Creative Commons may be contacted at creativecommons.org.
|
README.md
CHANGED
@@ -1,5 +1,162 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Llamol
|
2 |
+
|
3 |
+
<p align="center">
|
4 |
+
<img src="assets/llamol.png" width="300" height="300" alt="LLamol">
|
5 |
+
</p>
|
6 |
+
|
7 |
+
This is the official repository for the paper ["LLamol: A Dynamic Multi-Conditional Generative Transformer for De Novo Molecular Design"](https://arxiv.org/abs/2311.14407).
|
8 |
+
In this repository are the weights for LLamol (`out/llama2-M-Full-RSS.pt`) and the dataset OrganiX13.
|
9 |
+
|
10 |
+
Image made with [Hotspot.ai](https://hotpot.ai/art-generator)
|
11 |
+
## Installation
|
12 |
+
Install using Mamba to be fast: https://mamba.readthedocs.io/en/latest/micromamba-installation.html
|
13 |
+
|
14 |
+
|
15 |
+
```bash
|
16 |
+
$ "${SHELL}" <(curl -L micro.mamba.pm/install.sh)
|
17 |
+
$ micromamba env create -f torch2-env.yaml
|
18 |
+
$ micromamba activate torch2-llamol
|
19 |
+
$ python sample.py
|
20 |
+
```
|
21 |
+
# Download and preprocess the OrganiX13 dataset:
|
22 |
+
If you want to train with the full 13 Million dataset do the following steps. These are *not* necessary if you just want to use the model for inference:
|
23 |
+
1. Download and preprocess the OPV dataset by running `/data/opv/prepare_opv.py`
|
24 |
+
2. Download and preprocess the ZINC dataset by running `/data/zinc/zinc_complete/run_download.py` followed by `/data/zinc/convert_to_parquet.py`
|
25 |
+
(we recommend at least 16GB RAM for this)
|
26 |
+
3. Download and preprocess the ZINC dataset by running `/data/qm9_zinc250k_cep/convert_to_parquet.py`
|
27 |
+
|
28 |
+
4. Run `data/combine_all.py` to combine the dataset to `data/OrganiX13.parquet` (this can take a while, especially on the zinc dataset. In total it took ~2 hours when using my Laptop, which has 16 GB ram and an Intel i7 10th Gen)
|
29 |
+
5. Run `preprocess_dataset.py` which should create the file `.cache/processed_dataset_None.pkl`
|
30 |
+
|
31 |
+
Now you can use that in the training of the model by specifing the file under the `processed_dataset_ckpt` of the training .yaml files.
|
32 |
+
|
33 |
+
|
34 |
+
|
35 |
+
# Interactive Demo
|
36 |
+
|
37 |
+
After installation you can play around with the model using the `demonstrator.ipynb` file. Just run all and scroll down to the last cell.
|
38 |
+
After a short time there should be a UI where you can play around with the model.
|
39 |
+
|
40 |
+
|
41 |
+
## Training
|
42 |
+
|
43 |
+
First the env needs to be activated so:
|
44 |
+
```bash
|
45 |
+
$ conda activate torch2-llamol # When installed with conda instead of micromamba
|
46 |
+
OR
|
47 |
+
$ micromamba activate torch2-llamol
|
48 |
+
``````
|
49 |
+
|
50 |
+
To train locally you can run:
|
51 |
+
```bash
|
52 |
+
# To set the config that you want to train with
|
53 |
+
$ python train.py train=llama2-M-Full-RSS
|
54 |
+
```
|
55 |
+
|
56 |
+
Parameters can also be overriden by using the following, for example:
|
57 |
+
```bash
|
58 |
+
$ python train.py train=llama2-M-Full-RSS train.model.dim=1024
|
59 |
+
```
|
60 |
+
For more information look at [Hydra](https://hydra.cc/docs/1.3/intro/)
|
61 |
+
|
62 |
+
To start a job on a SLURM cluster use the following script:
|
63 |
+
```bash
|
64 |
+
$ sbatch trainLLamaMol.sh
|
65 |
+
``````
|
66 |
+
|
67 |
+
## Training Multi-GPU on 1 Node with multiple GPUS (nproc_per_node)
|
68 |
+
```bash
|
69 |
+
torchrun --standalone --max_restarts=3 --nnodes=1 --nproc_per_node=2 --rdzv-backend=c10d --rdzv-endpoint="$localhost:12345" train.py train=llama2-M-Full-RSS > "train_runs/run_MultiGPU.out"
|
70 |
+
```
|
71 |
+
## Training Multi-GPU on 1 Node with multiple GPUS on a Cluster
|
72 |
+
Currently there is only one script to train with DDP. To change the number of GPUS in that script you have to change the bash script itself.
|
73 |
+
TODO: Make it more dynamic, with allowing console commands to change the number of GPUS etc.
|
74 |
+
```bash
|
75 |
+
sbatch trainLLamaMolDDPSingleNode.sh
|
76 |
+
```
|
77 |
+
|
78 |
+
## Sampling
|
79 |
+
Sampling can be changed by the OPTIONAL parameters as shown below.
|
80 |
+
```bash
|
81 |
+
$ python sample.py --help
|
82 |
+
|
83 |
+
$ python sample.py --num_samples 2000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --seed 4312 --context_cols logp sascore mol_weight --temperature 0.8
|
84 |
+
```
|
85 |
+
|
86 |
+
|
87 |
+
## Using own dataset
|
88 |
+
|
89 |
+
Use the `preprocess_dataset.py` file to tokenize the dataset. The dataset should be either in the parquet or csv format.
|
90 |
+
The SMILES used for training should be in the `smiles` column in the dataset. All conditions, should be given to the pretokenize function.
|
91 |
+
After the preprocessing is done a file should be stored in the .cache directory with the name `processed_dataset_{limit}.pkl`.
|
92 |
+
You could also rename this file to not overwrite it every time you run the preprocessing.
|
93 |
+
|
94 |
+
The `.cache/processed_dataset_{limit}.pkl` can then be set in the `config/train/llama2-M-Full-RSS.yaml file` to change the training with the new dataset in the `processed_dataset_ckpt` field in the yaml file.
|
95 |
+
|
96 |
+
# Training methods
|
97 |
+
|
98 |
+
The training method we used and described in the paper is here called RSS for "Random Smiles Sampling" which was the method then described in the "Stochastic Context Learning" as taking a random subsequence from the current SMILES while training and feeding that into the model as a token sequence condition. So the model we used in the paper was the `out/llama2-M-Full-RSS.pt`.
|
99 |
+
|
100 |
+
We also tried other approached for including the token sequence.
|
101 |
+
One was using murcko scaffolds as they were used in the MolGPT paper, but this approach did not yield great results for our purposes.
|
102 |
+
The other was using BRICKS decomposition, which also did not yield very good results.
|
103 |
+
|
104 |
+
The different methods are implemented in the `fragment_creator.py` file.
|
105 |
+
Each of the models were trained with their respective configurations in the `config/train` folder.
|
106 |
+
|
107 |
+
# Thanks
|
108 |
+
|
109 |
+
|
110 |
+
- [Karpathy](https://github.com/karpathy/llama2.c) for the implementation of the Llama 2 architecture and training code
|
111 |
+
|
112 |
+
- [DeepChem](https://github.com/deepchem/deepchem) for the SmilesTokenizer
|
113 |
+
|
114 |
+
- [TorchDrug](https://github.com/DeepGraphLearning/torchdrug/) for the downloads scripts for the OPV and CEP datasets
|
115 |
+
|
116 |
+
- Zinc 15 dataset (Teague Sterling and John J. Irwin. ZINC 15 – ligand discovery for everyone. Journal of Chemical Information
|
117 |
+
and Modeling, 55(11):2324–2337, November 2015.)
|
118 |
+
|
119 |
+
- QM9 dataset (
|
120 |
+
Raghunathan Ramakrishnan, Pavlo O. Dral, Matthias Rupp, and O. Anatole von Lilienfeld. Quantum chemistry
|
121 |
+
structures and properties of 134 kilo molecules. Scientific Data, 1(1), aug 2014.)
|
122 |
+
|
123 |
+
- PC9 dataset (Marta Glavatskikh, Jules Leguy, Gilles Hunault, Thomas Cauchy, and Benoit Da Mota. Dataset’s chemical
|
124 |
+
diversity limits the generalizability of machine learning predictions. Journal of Cheminformatics, 11(1), nov 2019)
|
125 |
+
|
126 |
+
- ZINC 250k (Rafael Gó mez-Bombarelli, Jennifer N. Wei, David Duvenaud, José Miguel Hernández-Lobato, Benjamín
|
127 |
+
Sánchez-Lengeling, Dennis Sheberla, Jorge Aguilera-Iparraguirre, Timothy D. Hirzel, Ryan P. Adams, and Alán
|
128 |
+
Aspuru-Guzik. Automatic chemical design using a data-driven continuous representation of molecules. ACS
|
129 |
+
Central Science, 4(2):268–276, jan 2018.)
|
130 |
+
|
131 |
+
- RedDB (Elif Sorkun, Qi Zhang, Abhishek Khetan, Murat Cihan Sorkun, and Süleyman Er. RedDB, a computational
|
132 |
+
database of electroactive molecules for aqueous redox flow batteries. Scientific Data, 9(1), nov 2022.)
|
133 |
+
|
134 |
+
- OPV (Peter C. St. John, Caleb Phillips, Travis W. Kemper, A. Nolan Wilson, Yanfei Guan, Michael F. Crowley, Mark R.
|
135 |
+
Nimlos, and Ross E. Larsen. Message-passing neural networks for high-throughput polymer screening. The
|
136 |
+
Journal of Chemical Physics, 150(23):234111, jun 2019.)
|
137 |
+
|
138 |
+
- PubchemQC 2020 (Maho Nakata, Tomomi Shimazaki, Masatomo Hashimoto, and Toshiyuki Maeda. PubChemQC PM6: Data sets
|
139 |
+
of 221 million molecules with optimized molecular geometries and electronic properties. Journal of Chemical
|
140 |
+
Information and Modeling, 60(12):5891–5899, oct 2020.)
|
141 |
+
|
142 |
+
- PubchemQC 2017 (Maho Nakata and Tomomi Shimazaki. PubChemQC project: A large-scale first-principles electronic structure
|
143 |
+
database for data-driven chemistry. Journal of Chemical Information and Modeling, 57(6):1300–1308, may 2017.)
|
144 |
+
|
145 |
+
- CEP (Johannes Hachmann, Roberto Olivares-Amaya, Sule Atahan-Evrenk, Carlos Amador-Bedolla, Roel S. Sánchez-
|
146 |
+
Carrera, Aryeh Gold-Parker, Leslie Vogt, Anna M. Brockway, and Alán Aspuru-Guzik. The Harvard clean energy
|
147 |
+
project: Large-scale computational screening and design of organic photovoltaics on the world community grid.
|
148 |
+
The Journal of Physical Chemistry Letters, 2(17):2241–2251, aug 2011.) subset ( David Duvenaud, Dougal Maclaurin, Jorge Aguilera-Iparraguirre, Rafael Gómez-Bombarelli, Timothy Hirzel,
|
149 |
+
Alán Aspuru-Guzik, and Ryan P. Adams. Convolutional networks on graphs for learning molecular fingerprints,
|
150 |
+
2015.)
|
151 |
+
- ChEMBL (James Blackshaw, Anna Gaulton, A. Patrícia Bento, Marleen De Veij, David Mendez Lopez, Nicolas Bosc, Juan
|
152 |
+
Felipe Mosquera Morales, María Paula Margariños, Andrew Leach, Emma Manners, Barbara Zdrazil, Harris
|
153 |
+
Ioannidis, Fiona Hunter, Eloy Félix, and Ricardo Arcila Toro. CHEMBL database release 31, September 2009.)
|
154 |
+
|
155 |
+
# Funding disclaimer
|
156 |
+
|
157 |
+
This project has received funding from the European Union’s Horizon 2020 research and innovation programme under Grant Agreement no. 875489.
|
158 |
+
|
159 |
+
This website reflects only the author’s view. The funding agency is not responsible for any use made of the information it contains.
|
160 |
+
|
161 |
+
# License
|
162 |
+
<p xmlns:cc="http://creativecommons.org/ns#" xmlns:dct="http://purl.org/dc/terms/"><span property="dct:title">LLamol is licensed under <a href="http://creativecommons.org/licenses/by-nc-sa/4.0/?ref=chooser-v1" target="_blank" rel="license noopener noreferrer" style="display:inline-block;">CC BY-NC-SA 4.0<img style="height:22px!important;margin-left:3px;vertical-align:text-bottom;" src="https://mirrors.creativecommons.org/presskit/icons/cc.svg?ref=chooser-v1"><img style="height:22px!important;margin-left:3px;vertical-align:text-bottom;" src="https://mirrors.creativecommons.org/presskit/icons/by.svg?ref=chooser-v1"><img style="height:22px!important;margin-left:3px;vertical-align:text-bottom;" src="https://mirrors.creativecommons.org/presskit/icons/nc.svg?ref=chooser-v1"><img style="height:22px!important;margin-left:3px;vertical-align:text-bottom;" src="https://mirrors.creativecommons.org/presskit/icons/sa.svg?ref=chooser-v1"></a></p>
|
assets/llamol.png
ADDED
config/config.yaml
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
defaults:
|
2 |
+
- train: "llama2-Debug"
|
config/train/llama2-Debug.yaml
ADDED
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
io:
|
2 |
+
# I/O
|
3 |
+
out_dir : "debug"
|
4 |
+
eval_interval : 10
|
5 |
+
log_interval : 10
|
6 |
+
eval_iters : 5
|
7 |
+
eval_only : false # if True, script exits right after the first eval
|
8 |
+
always_save_checkpoint : true # if True, always save a checkpoint after each eval
|
9 |
+
init_from : "scratch" # 'scratch' or 'resume'
|
10 |
+
resume_when_snapshot_available: false
|
11 |
+
|
12 |
+
loader:
|
13 |
+
batch_size : 4 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
14 |
+
max_seq_len : 768
|
15 |
+
dataset : "smiles"
|
16 |
+
processed_dataset_ckpt : "processed_dataset_500000.pkl"
|
17 |
+
fragment_creator : "rss"
|
18 |
+
|
19 |
+
model:
|
20 |
+
dim : 32
|
21 |
+
n_layers : 1
|
22 |
+
n_heads : 1
|
23 |
+
multiple_of : 16
|
24 |
+
dropout : 0.1
|
25 |
+
|
26 |
+
context:
|
27 |
+
context_keys: ["logp", "sascore", "mol_weight"]
|
28 |
+
context_dims : [1,1,1]
|
29 |
+
|
30 |
+
optimizer:
|
31 |
+
gradient_accumulation_steps : 4 # used to simulate larger batch sizes
|
32 |
+
learning_rate : 1e-4 # max learning rate
|
33 |
+
max_iters : 20 # total number of training iterations
|
34 |
+
weight_decay : 1e-1
|
35 |
+
beta1 : 0.9
|
36 |
+
beta2 : 0.95
|
37 |
+
grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
|
38 |
+
# learning rate decay settings
|
39 |
+
decay_lr : true # whether to decay the learning rate
|
40 |
+
warmup_iters : 10 # how many steps to warm up for
|
41 |
+
lr_decay_iters : 100 # should be ~= max_iters per Chinchilla
|
42 |
+
min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
43 |
+
|
44 |
+
dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
|
45 |
+
compile : false # Use torch.compile, but in my test this is really slow
|
46 |
+
label : "llama2-Debug"
|
47 |
+
profile : false
|
config/train/llama2-DebugGPU.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
io:
|
2 |
+
# I/O
|
3 |
+
out_dir : "debug-gpu"
|
4 |
+
eval_interval : 10
|
5 |
+
log_interval : 10
|
6 |
+
eval_iters : 5
|
7 |
+
eval_only : false # if True, script exits right after the first eval
|
8 |
+
always_save_checkpoint : true # if True, always save a checkpoint after each eval
|
9 |
+
init_from : "scratch" # 'scratch' or 'resume'
|
10 |
+
resume_when_snapshot_available: false
|
11 |
+
|
12 |
+
loader:
|
13 |
+
batch_size : 256 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
14 |
+
max_seq_len : 256
|
15 |
+
dataset : "smiles"
|
16 |
+
processed_dataset_ckpt : "processed_dataset_500000.pkl"
|
17 |
+
|
18 |
+
model:
|
19 |
+
dim : 256
|
20 |
+
n_layers : 8
|
21 |
+
n_heads : 8
|
22 |
+
multiple_of : 128
|
23 |
+
dropout : 0.1
|
24 |
+
|
25 |
+
context:
|
26 |
+
context_keys: ["logp", "sascore", "mol_weight"]
|
27 |
+
context_dims : [1,1,1]
|
28 |
+
|
29 |
+
optimizer:
|
30 |
+
gradient_accumulation_steps : 4 # used to simulate larger batch sizes
|
31 |
+
learning_rate : 1e-4 # max learning rate
|
32 |
+
max_iters : 25 # total number of training iterations
|
33 |
+
weight_decay : 1e-1
|
34 |
+
beta1 : 0.9
|
35 |
+
beta2 : 0.95
|
36 |
+
grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
|
37 |
+
# learning rate decay settings
|
38 |
+
decay_lr : true # whether to decay the learning rate
|
39 |
+
warmup_iters : 10 # how many steps to warm up for
|
40 |
+
lr_decay_iters : 100 # should be ~= max_iters per Chinchilla
|
41 |
+
min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
42 |
+
|
43 |
+
dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
|
44 |
+
compile : false # Use torch.compile, but in my test this is really slow
|
45 |
+
label : "llama2-Debug"
|
46 |
+
profile: true # Profile the run
|
config/train/llama2-M-Full-BRICKS.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
io:
|
2 |
+
# I/O
|
3 |
+
out_dir : "out"
|
4 |
+
eval_interval : 500
|
5 |
+
log_interval : 10
|
6 |
+
eval_iters : 10
|
7 |
+
eval_only : false # if True, script exits right after the first eval
|
8 |
+
always_save_checkpoint : false # if True, always save a checkpoint after each eval
|
9 |
+
init_from : "scratch" # 'scratch' or 'resume'
|
10 |
+
resume_when_snapshot_available: true
|
11 |
+
|
12 |
+
loader:
|
13 |
+
batch_size : 384 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
14 |
+
max_seq_len : 768
|
15 |
+
dataset : "smiles"
|
16 |
+
processed_dataset_ckpt : "processed_dataset_None.pkl"
|
17 |
+
fragment_creator : "bricks"
|
18 |
+
|
19 |
+
model:
|
20 |
+
dim : 256
|
21 |
+
n_layers : 8
|
22 |
+
n_heads : 8
|
23 |
+
multiple_of : 128
|
24 |
+
dropout : 0.1
|
25 |
+
|
26 |
+
context:
|
27 |
+
context_keys: ["logp", "sascore", "mol_weight"]
|
28 |
+
context_dims : [1,1,1]
|
29 |
+
|
30 |
+
optimizer:
|
31 |
+
gradient_accumulation_steps : 4 # used to simulate larger batch sizes
|
32 |
+
learning_rate : 1e-4 # max learning rate
|
33 |
+
max_iters : 100000 # total number of training iterations
|
34 |
+
weight_decay : 1e-1
|
35 |
+
beta1 : 0.9
|
36 |
+
beta2 : 0.95
|
37 |
+
grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
|
38 |
+
# learning rate decay settings
|
39 |
+
decay_lr : true # whether to decay the learning rate
|
40 |
+
warmup_iters : 1000 # how many steps to warm up for
|
41 |
+
lr_decay_iters : 100000 # should be ~= max_iters per Chinchilla
|
42 |
+
min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
43 |
+
|
44 |
+
dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
|
45 |
+
compile : false # Use torch.compile, but in my test this is really slow
|
46 |
+
label : "llama2-M-Full-BRICKS"
|
config/train/llama2-M-Full-RSS.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
io:
|
2 |
+
# I/O
|
3 |
+
out_dir : "out"
|
4 |
+
eval_interval : 500
|
5 |
+
log_interval : 10
|
6 |
+
eval_iters : 10
|
7 |
+
eval_only : false # if True, script exits right after the first eval
|
8 |
+
always_save_checkpoint : false # if True, always save a checkpoint after each eval
|
9 |
+
init_from : "scratch" # 'scratch' or 'resume'
|
10 |
+
resume_when_snapshot_available: true # resume the training always, when the `snapshot_` is available in the out/ folder
|
11 |
+
|
12 |
+
loader:
|
13 |
+
batch_size : 256 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
14 |
+
max_seq_len : 256 # the maximum sequence length we want to use in the training data.
|
15 |
+
dataset : "smiles"
|
16 |
+
processed_dataset_ckpt : "processed_dataset_None.pkl"
|
17 |
+
fragment_creator : "rss" # the method we want to use to train with the token_sequence
|
18 |
+
|
19 |
+
model:
|
20 |
+
dim : 384
|
21 |
+
n_layers : 8
|
22 |
+
n_heads : 8
|
23 |
+
multiple_of : 128
|
24 |
+
dropout : 0.1
|
25 |
+
|
26 |
+
context:
|
27 |
+
context_keys: ["logp", "sascore", "mol_weight"]
|
28 |
+
context_dims : [1,1,1]
|
29 |
+
|
30 |
+
optimizer:
|
31 |
+
gradient_accumulation_steps : 4 # used to simulate larger batch sizes
|
32 |
+
learning_rate : 1e-4 # max learning rate
|
33 |
+
max_iters : 100000 # total number of training iterations
|
34 |
+
weight_decay : 1e-1
|
35 |
+
beta1 : 0.9
|
36 |
+
beta2 : 0.95
|
37 |
+
grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
|
38 |
+
# learning rate decay settings
|
39 |
+
decay_lr : true # whether to decay the learning rate
|
40 |
+
warmup_iters : 1000 # how many steps to warm up for
|
41 |
+
lr_decay_iters : 100000 # should be ~= max_iters per Chinchilla
|
42 |
+
min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
43 |
+
|
44 |
+
dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
|
45 |
+
compile : false # Use torch.compile, but in my test this is really slow
|
46 |
+
label : "llama2-M-Full-RSS" # the name of the output file / model
|
config/train/llama2-M-Full.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
io:
|
2 |
+
# I/O
|
3 |
+
out_dir : "out"
|
4 |
+
eval_interval : 500
|
5 |
+
log_interval : 10
|
6 |
+
eval_iters : 10
|
7 |
+
eval_only : false # if True, script exits right after the first eval
|
8 |
+
always_save_checkpoint : false # if True, always save a checkpoint after each eval
|
9 |
+
init_from : "scratch" # 'scratch' or 'resume'
|
10 |
+
resume_when_snapshot_available: true
|
11 |
+
|
12 |
+
loader:
|
13 |
+
batch_size : 384 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
14 |
+
max_seq_len : 768
|
15 |
+
dataset : "smiles"
|
16 |
+
processed_dataset_ckpt : "processed_dataset_None.pkl"
|
17 |
+
fragment_creator : null
|
18 |
+
|
19 |
+
model:
|
20 |
+
dim : 256
|
21 |
+
n_layers : 8
|
22 |
+
n_heads : 8
|
23 |
+
multiple_of : 128
|
24 |
+
dropout : 0.1
|
25 |
+
|
26 |
+
context:
|
27 |
+
context_keys: ["logp", "sascore", "mol_weight"]
|
28 |
+
context_dims : [1,1,1]
|
29 |
+
|
30 |
+
optimizer:
|
31 |
+
gradient_accumulation_steps : 4 # used to simulate larger batch sizes
|
32 |
+
learning_rate : 1e-4 # max learning rate
|
33 |
+
max_iters : 100000 # total number of training iterations
|
34 |
+
weight_decay : 1e-1
|
35 |
+
beta1 : 0.9
|
36 |
+
beta2 : 0.95
|
37 |
+
grad_clip : 1.0 # clip gradients at this value, or disable if == 0.0
|
38 |
+
# learning rate decay settings
|
39 |
+
decay_lr : true # whether to decay the learning rate
|
40 |
+
warmup_iters : 1000 # how many steps to warm up for
|
41 |
+
lr_decay_iters : 100000 # should be ~= max_iters per Chinchilla
|
42 |
+
min_lr : 0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
43 |
+
|
44 |
+
dtype : "float16" # Use float16 for training, could also be changed to float32 or bfloat16
|
45 |
+
compile : false # Use torch.compile, but in my test this is really slow
|
46 |
+
label : "llama2-M-Full"
|
data/Full_PC9_GAP.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:0e1c1932284e5987ff997675b3f8ad2a8763c4dc864315e78a774841fb6b6791
|
3 |
+
size 38893336
|
data/RedDB_Full.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:543e98ba1b622a2a949a3818d047daa478658d2d91923a291907a2d9c8c886bd
|
3 |
+
size 1024066
|
data/chembl_log_sascore.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:30d04f6f1f01caec6164d85b23ba1282dfe63ec1b245e4c358aa216831c32ee8
|
3 |
+
size 99582099
|
data/combine_all.py
ADDED
@@ -0,0 +1,164 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import os
|
4 |
+
from rdkit import Chem
|
5 |
+
from rdkit.Chem import Descriptors
|
6 |
+
import multiprocessing
|
7 |
+
|
8 |
+
from rdkit import Chem
|
9 |
+
from rdkit.Chem import RDConfig
|
10 |
+
import os
|
11 |
+
import sys
|
12 |
+
sys.path.append(os.path.join(RDConfig.RDContribDir, 'SA_Score'))
|
13 |
+
# now you can import sascore!
|
14 |
+
import sascorer
|
15 |
+
|
16 |
+
np.random.seed(42)
|
17 |
+
|
18 |
+
def calcLogPIfMol(smi):
|
19 |
+
m = Chem.MolFromSmiles(smi)
|
20 |
+
if m is not None:
|
21 |
+
return Descriptors.MolLogP(m)
|
22 |
+
else:
|
23 |
+
return None
|
24 |
+
|
25 |
+
def calcMol(smi):
|
26 |
+
return Chem.MolFromSmiles(smi)
|
27 |
+
|
28 |
+
def calcMolWeight(smi):
|
29 |
+
mol = Chem.MolFromSmiles(smi)
|
30 |
+
return Descriptors.ExactMolWt(mol)
|
31 |
+
|
32 |
+
def calcSascore(smi):
|
33 |
+
mol = Chem.MolFromSmiles(smi)
|
34 |
+
|
35 |
+
return sascorer.calculateScore(mol)
|
36 |
+
|
37 |
+
def calculateValues(smi: pd.Series):
|
38 |
+
|
39 |
+
|
40 |
+
with multiprocessing.Pool(8) as pool:
|
41 |
+
print("Starting logps")
|
42 |
+
logps = pool.map(calcLogPIfMol, smi)
|
43 |
+
print("Done logps")
|
44 |
+
valid_mols = ~pd.isna(logps)
|
45 |
+
logps = pd.Series(logps)[valid_mols]
|
46 |
+
smi = pd.Series(smi)[valid_mols]
|
47 |
+
logps.reset_index(drop=True,inplace=True)
|
48 |
+
smi.reset_index(drop=True,inplace=True)
|
49 |
+
print("Starting mol weights")
|
50 |
+
mol_weights = pool.map(calcMolWeight, smi)
|
51 |
+
print("Done mol weights")
|
52 |
+
print("Starting sascores")
|
53 |
+
sascores = pool.map(calcSascore, smi)
|
54 |
+
print("Done sascores")
|
55 |
+
|
56 |
+
return smi, logps, mol_weights,sascores
|
57 |
+
|
58 |
+
def calculateProperties(df):
|
59 |
+
|
60 |
+
smi, logps, mol_weights,sascores = calculateValues(df["smiles"])
|
61 |
+
out_df = pd.DataFrame({"smiles": smi, "logp":logps, "mol_weight":mol_weights, "sascore":sascores })
|
62 |
+
|
63 |
+
return out_df
|
64 |
+
|
65 |
+
if __name__ == "__main__":
|
66 |
+
|
67 |
+
cwd = os.path.dirname(__file__)
|
68 |
+
|
69 |
+
print("df_pc9")
|
70 |
+
df_pc9 = pd.read_parquet(os.path.join(cwd, "Full_PC9_GAP.parquet"))
|
71 |
+
df_pc9 = calculateProperties(df_pc9)
|
72 |
+
|
73 |
+
|
74 |
+
print("df_zinc_full")
|
75 |
+
|
76 |
+
df_zinc_full = pd.read_parquet(
|
77 |
+
os.path.join(cwd, "zinc", "zinc_processed.parquet")
|
78 |
+
)
|
79 |
+
df_zinc_full = df_zinc_full.sample(n=5_000_000)
|
80 |
+
df_zinc_full = calculateProperties(df_zinc_full)
|
81 |
+
|
82 |
+
|
83 |
+
print("df_zinc_qm9")
|
84 |
+
df_zinc_qm9 = pd.read_parquet(os.path.join(cwd,"qm9_zinc250k_cep", "qm9_zinc250_cep.parquet"))
|
85 |
+
df_zinc_qm9 = calculateProperties(df_zinc_qm9)
|
86 |
+
|
87 |
+
print("df_opv")
|
88 |
+
df_opv = pd.read_parquet(os.path.join(cwd,"opv", "opv.parquet"))
|
89 |
+
df_opv = calculateProperties(df_opv)
|
90 |
+
|
91 |
+
|
92 |
+
print("df_reddb")
|
93 |
+
# Source: https://dataverse.harvard.edu/dataset.xhtml?persistentId=doi:10.7910/DVN/F3QFSQ
|
94 |
+
df_reddb = pd.read_parquet(os.path.join(cwd,"RedDB_Full.parquet"))
|
95 |
+
df_reddb = calculateProperties(df_reddb)
|
96 |
+
|
97 |
+
print("df_chembl")
|
98 |
+
df_chembl = pd.read_parquet(
|
99 |
+
os.path.join(cwd, "chembl_log_sascore.parquet")
|
100 |
+
)
|
101 |
+
df_chembl = calculateProperties(df_chembl)
|
102 |
+
|
103 |
+
|
104 |
+
print("df_pubchemqc_2017")
|
105 |
+
df_pubchemqc_2017 = pd.read_parquet(
|
106 |
+
os.path.join(cwd, "pubchemqc_energy.parquet")
|
107 |
+
)
|
108 |
+
df_pubchemqc_2017 = calculateProperties(df_pubchemqc_2017)
|
109 |
+
|
110 |
+
|
111 |
+
print("df_pubchemqc_2020")
|
112 |
+
|
113 |
+
df_pubchemqc_2020 = pd.read_parquet(
|
114 |
+
os.path.join(cwd, "pubchemqc2020_energy.parquet")
|
115 |
+
)
|
116 |
+
df_pubchemqc_2020 = calculateProperties(df_pubchemqc_2020)
|
117 |
+
|
118 |
+
|
119 |
+
|
120 |
+
df_list = [
|
121 |
+
df_zinc_qm9,
|
122 |
+
df_opv,
|
123 |
+
df_pubchemqc_2017,
|
124 |
+
df_pubchemqc_2020,
|
125 |
+
df_zinc_full,
|
126 |
+
df_reddb,
|
127 |
+
df_pc9,
|
128 |
+
df_chembl,
|
129 |
+
]
|
130 |
+
|
131 |
+
print(f"ZINC QM9 {len(df_zinc_qm9)}")
|
132 |
+
print(f"df_opv {len(df_opv)}")
|
133 |
+
print(f"df_pubchemqc_2017 {len(df_pubchemqc_2017)}")
|
134 |
+
print(f"df_pubchemqc_2020 {len(df_pubchemqc_2020)}")
|
135 |
+
print(f"df_zinc_full {len(df_zinc_full)}")
|
136 |
+
print(f"df_reddb {len(df_reddb)}")
|
137 |
+
print(f"df_pc9 {len(df_pc9)}")
|
138 |
+
print(f"df_chembl {len(df_chembl)}")
|
139 |
+
|
140 |
+
|
141 |
+
|
142 |
+
|
143 |
+
|
144 |
+
all_columns = [
|
145 |
+
"smiles",
|
146 |
+
"logp",
|
147 |
+
"sascore",
|
148 |
+
"mol_weight"
|
149 |
+
] # set([*df_zinc_qm9.columns.tolist(),*df_pubchemqc_2017.columns.tolist(),*df_pubchemqc_2020.columns.tolist(),*df_zinc_full.columns.tolist()] )
|
150 |
+
print("concatenting")
|
151 |
+
df = pd.concat(
|
152 |
+
df_list, axis=0, ignore_index=True
|
153 |
+
) # pd.DataFrame(columns=all_columns)
|
154 |
+
df = df[all_columns] # .fillna(0)
|
155 |
+
# df = df.sample(n=7_500_000)
|
156 |
+
df.reset_index(drop=True, inplace=True)
|
157 |
+
df["mol_weight"] = df["mol_weight"] / 100.0
|
158 |
+
|
159 |
+
print(df.head())
|
160 |
+
print("saving")
|
161 |
+
print("Combined len:", len(df))
|
162 |
+
df.to_parquet(
|
163 |
+
os.path.join(cwd, "OrganiX13.parquet")
|
164 |
+
)
|
data/opv/prepare_opv.py
ADDED
@@ -0,0 +1,265 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
import os
|
3 |
+
import struct
|
4 |
+
import logging
|
5 |
+
from tqdm import tqdm
|
6 |
+
import csv
|
7 |
+
from collections import defaultdict
|
8 |
+
import pandas as pd
|
9 |
+
|
10 |
+
logger = logging.getLogger(__name__)
|
11 |
+
|
12 |
+
# Taken from here https://torchdrug.ai/docs/_modules/torchdrug/utils/file.html#download
|
13 |
+
def download(url, path, save_file=None, md5=None):
|
14 |
+
"""
|
15 |
+
Download a file from the specified url.
|
16 |
+
Skip the downloading step if there exists a file satisfying the given MD5.
|
17 |
+
|
18 |
+
Parameters:
|
19 |
+
url (str): URL to download
|
20 |
+
path (str): path to store the downloaded file
|
21 |
+
save_file (str, optional): name of save file. If not specified, infer the file name from the URL.
|
22 |
+
md5 (str, optional): MD5 of the file
|
23 |
+
"""
|
24 |
+
from six.moves.urllib.request import urlretrieve
|
25 |
+
|
26 |
+
if save_file is None:
|
27 |
+
save_file = os.path.basename(url)
|
28 |
+
if "?" in save_file:
|
29 |
+
save_file = save_file[:save_file.find("?")]
|
30 |
+
save_file = os.path.join(path, save_file)
|
31 |
+
|
32 |
+
if not os.path.exists(save_file) or compute_md5(save_file) != md5:
|
33 |
+
logger.info("Downloading %s to %s" % (url, save_file))
|
34 |
+
urlretrieve(url, save_file)
|
35 |
+
return save_file
|
36 |
+
|
37 |
+
|
38 |
+
|
39 |
+
def smart_open(file_name, mode="rb"):
|
40 |
+
"""
|
41 |
+
Open a regular file or a zipped file.
|
42 |
+
|
43 |
+
This function can be used as drop-in replacement of the builtin function `open()`.
|
44 |
+
|
45 |
+
Parameters:
|
46 |
+
file_name (str): file name
|
47 |
+
mode (str, optional): open mode for the file stream
|
48 |
+
"""
|
49 |
+
import bz2
|
50 |
+
import gzip
|
51 |
+
|
52 |
+
extension = os.path.splitext(file_name)[1]
|
53 |
+
if extension == '.bz2':
|
54 |
+
return bz2.BZ2File(file_name, mode)
|
55 |
+
elif extension == '.gz':
|
56 |
+
return gzip.GzipFile(file_name, mode)
|
57 |
+
else:
|
58 |
+
return open(file_name, mode)
|
59 |
+
|
60 |
+
|
61 |
+
def extract(zip_file, member=None):
|
62 |
+
"""
|
63 |
+
Extract files from a zip file. Currently, ``zip``, ``gz``, ``tar.gz``, ``tar`` file types are supported.
|
64 |
+
|
65 |
+
Parameters:
|
66 |
+
zip_file (str): file name
|
67 |
+
member (str, optional): extract specific member from the zip file.
|
68 |
+
If not specified, extract all members.
|
69 |
+
"""
|
70 |
+
import gzip
|
71 |
+
import shutil
|
72 |
+
import zipfile
|
73 |
+
import tarfile
|
74 |
+
|
75 |
+
zip_name, extension = os.path.splitext(zip_file)
|
76 |
+
if zip_name.endswith(".tar"):
|
77 |
+
extension = ".tar" + extension
|
78 |
+
zip_name = zip_name[:-4]
|
79 |
+
save_path = os.path.dirname(zip_file)
|
80 |
+
|
81 |
+
if extension == ".gz":
|
82 |
+
member = os.path.basename(zip_name)
|
83 |
+
members = [member]
|
84 |
+
save_files = [os.path.join(save_path, member)]
|
85 |
+
for _member, save_file in zip(members, save_files):
|
86 |
+
with open(zip_file, "rb") as fin:
|
87 |
+
fin.seek(-4, 2)
|
88 |
+
file_size = struct.unpack("<I", fin.read())[0]
|
89 |
+
with gzip.open(zip_file, "rb") as fin:
|
90 |
+
if not os.path.exists(save_file) or file_size != os.path.getsize(save_file):
|
91 |
+
logger.info("Extracting %s to %s" % (zip_file, save_file))
|
92 |
+
with open(save_file, "wb") as fout:
|
93 |
+
shutil.copyfileobj(fin, fout)
|
94 |
+
elif extension in [".tar.gz", ".tgz", ".tar"]:
|
95 |
+
tar = tarfile.open(zip_file, "r")
|
96 |
+
if member is not None:
|
97 |
+
members = [member]
|
98 |
+
save_files = [os.path.join(save_path, os.path.basename(member))]
|
99 |
+
logger.info("Extracting %s from %s to %s" % (member, zip_file, save_files[0]))
|
100 |
+
else:
|
101 |
+
members = tar.getnames()
|
102 |
+
save_files = [os.path.join(save_path, _member) for _member in members]
|
103 |
+
logger.info("Extracting %s to %s" % (zip_file, save_path))
|
104 |
+
for _member, save_file in zip(members, save_files):
|
105 |
+
if tar.getmember(_member).isdir():
|
106 |
+
os.makedirs(save_file, exist_ok=True)
|
107 |
+
continue
|
108 |
+
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
109 |
+
if not os.path.exists(save_file) or tar.getmember(_member).size != os.path.getsize(save_file):
|
110 |
+
with tar.extractfile(_member) as fin, open(save_file, "wb") as fout:
|
111 |
+
shutil.copyfileobj(fin, fout)
|
112 |
+
elif extension == ".zip":
|
113 |
+
zipped = zipfile.ZipFile(zip_file)
|
114 |
+
if member is not None:
|
115 |
+
members = [member]
|
116 |
+
save_files = [os.path.join(save_path, os.path.basename(member))]
|
117 |
+
logger.info("Extracting %s from %s to %s" % (member, zip_file, save_files[0]))
|
118 |
+
else:
|
119 |
+
members = zipped.namelist()
|
120 |
+
save_files = [os.path.join(save_path, _member) for _member in members]
|
121 |
+
logger.info("Extracting %s to %s" % (zip_file, save_path))
|
122 |
+
for _member, save_file in zip(members, save_files):
|
123 |
+
if zipped.getinfo(_member).is_dir():
|
124 |
+
os.makedirs(save_file, exist_ok=True)
|
125 |
+
continue
|
126 |
+
os.makedirs(os.path.dirname(save_file), exist_ok=True)
|
127 |
+
if not os.path.exists(save_file) or zipped.getinfo(_member).file_size != os.path.getsize(save_file):
|
128 |
+
with zipped.open(_member, "r") as fin, open(save_file, "wb") as fout:
|
129 |
+
shutil.copyfileobj(fin, fout)
|
130 |
+
else:
|
131 |
+
raise ValueError("Unknown file extension `%s`" % extension)
|
132 |
+
|
133 |
+
if len(save_files) == 1:
|
134 |
+
return save_files[0]
|
135 |
+
else:
|
136 |
+
return save_path
|
137 |
+
|
138 |
+
|
139 |
+
|
140 |
+
def compute_md5(file_name, chunk_size=65536):
|
141 |
+
"""
|
142 |
+
Compute MD5 of the file.
|
143 |
+
|
144 |
+
Parameters:
|
145 |
+
file_name (str): file name
|
146 |
+
chunk_size (int, optional): chunk size for reading large files
|
147 |
+
"""
|
148 |
+
import hashlib
|
149 |
+
|
150 |
+
md5 = hashlib.md5()
|
151 |
+
with open(file_name, "rb") as fin:
|
152 |
+
chunk = fin.read(chunk_size)
|
153 |
+
while chunk:
|
154 |
+
md5.update(chunk)
|
155 |
+
chunk = fin.read(chunk_size)
|
156 |
+
return md5.hexdigest()
|
157 |
+
|
158 |
+
|
159 |
+
|
160 |
+
def get_line_count(file_name, chunk_size=8192*1024):
|
161 |
+
"""
|
162 |
+
Get the number of lines in a file.
|
163 |
+
|
164 |
+
Parameters:
|
165 |
+
file_name (str): file name
|
166 |
+
chunk_size (int, optional): chunk size for reading large files
|
167 |
+
"""
|
168 |
+
count = 0
|
169 |
+
with open(file_name, "rb") as fin:
|
170 |
+
chunk = fin.read(chunk_size)
|
171 |
+
while chunk:
|
172 |
+
count += chunk.count(b"\n")
|
173 |
+
chunk = fin.read(chunk_size)
|
174 |
+
return count
|
175 |
+
|
176 |
+
|
177 |
+
class OPV:
|
178 |
+
"""
|
179 |
+
Quantum mechanical calculations on organic photovoltaic candidate molecules.
|
180 |
+
|
181 |
+
Statistics:
|
182 |
+
- #Molecule: 94,576
|
183 |
+
- #Regression task: 8
|
184 |
+
|
185 |
+
Parameters:
|
186 |
+
path (str): path to store the dataset
|
187 |
+
verbose (int, optional): output verbose level
|
188 |
+
**kwargs
|
189 |
+
"""
|
190 |
+
|
191 |
+
train_url = "https://cscdata.nrel.gov/api/datasets/ad5d2c9a-af0a-4d72-b943-1e433d5750d6/download/" \
|
192 |
+
"b69cf9a5-e7e0-405b-88cb-40df8007242e"
|
193 |
+
valid_url = "https://cscdata.nrel.gov/api/datasets/ad5d2c9a-af0a-4d72-b943-1e433d5750d6/download/" \
|
194 |
+
"1c8e7379-3071-4360-ba8e-0c6481c33d2c"
|
195 |
+
test_url = "https://cscdata.nrel.gov/api/datasets/ad5d2c9a-af0a-4d72-b943-1e433d5750d6/download/" \
|
196 |
+
"4ef40592-0080-4f00-9bb7-34b25f94962a"
|
197 |
+
train_md5 = "16e439b7411ea0a8d3a56ba4802b61b1"
|
198 |
+
valid_md5 = "3aa2ac62015932ca84661feb5d29adda"
|
199 |
+
test_md5 = "bad072224f0755478f0729476ca99a33"
|
200 |
+
target_fields = ["gap", "homo", "lumo", "spectral_overlap", "gap_extrapolated", "homo_extrapolated",
|
201 |
+
"lumo_extrapolated", "optical_lumo_extrapolated"]
|
202 |
+
|
203 |
+
def read_csv(self, csv_file, smiles_field="smiles", target_fields=None, verbose=0):
|
204 |
+
if target_fields is not None:
|
205 |
+
target_fields = set(target_fields)
|
206 |
+
|
207 |
+
with open(csv_file, "r") as fin:
|
208 |
+
reader = csv.reader(fin)
|
209 |
+
if verbose:
|
210 |
+
reader = iter(tqdm(reader, "Loading %s" % csv_file, get_line_count(csv_file)))
|
211 |
+
fields = next(reader)
|
212 |
+
smiles = []
|
213 |
+
targets = defaultdict(list)
|
214 |
+
for i, values in enumerate(reader):
|
215 |
+
if not any(values):
|
216 |
+
continue
|
217 |
+
if smiles_field is None:
|
218 |
+
smiles.append("")
|
219 |
+
for field, value in zip(fields, values):
|
220 |
+
if field == smiles_field:
|
221 |
+
smiles.append(value)
|
222 |
+
elif target_fields is None or field in target_fields:
|
223 |
+
pass
|
224 |
+
# value = eval(value)
|
225 |
+
# if value == "":
|
226 |
+
# value = math.nan
|
227 |
+
# targets[field].append(value)
|
228 |
+
|
229 |
+
return smiles, targets
|
230 |
+
|
231 |
+
def __init__(self, path, verbose=1, **kwargs):
|
232 |
+
path = os.path.expanduser(path)
|
233 |
+
if not os.path.exists(path):
|
234 |
+
os.makedirs(path)
|
235 |
+
self.path = path
|
236 |
+
|
237 |
+
train_zip_file = download(self.train_url, path, save_file="mol_train.csv.gz", md5=self.train_md5)
|
238 |
+
valid_zip_file = download(self.valid_url, path, save_file="mol_valid.csv.gz", md5=self.valid_md5)
|
239 |
+
test_zip_file = download(self.test_url, path, save_file="mol_test.csv.gz", md5=self.test_md5)
|
240 |
+
train_file = extract(train_zip_file)
|
241 |
+
valid_file = extract(valid_zip_file)
|
242 |
+
test_file = extract(test_zip_file)
|
243 |
+
|
244 |
+
train_smiles, train_targets = self.read_csv(train_file, smiles_field="smile", target_fields=self.target_fields)
|
245 |
+
valid_smiles, valid_targets = self.read_csv(valid_file, smiles_field="smile", target_fields=self.target_fields)
|
246 |
+
test_smiles, test_targets = self.read_csv(test_file, smiles_field="smile", target_fields=self.target_fields)
|
247 |
+
self.num_train = len(train_smiles)
|
248 |
+
self.num_valid = len(valid_smiles)
|
249 |
+
self.num_test = len(test_smiles)
|
250 |
+
|
251 |
+
smiles = train_smiles + valid_smiles + test_smiles
|
252 |
+
targets = {k: train_targets[k] + valid_targets[k] + test_targets[k] for k in train_targets}
|
253 |
+
|
254 |
+
# self.load_smiles(smiles, targets, verbose=verbose, **kwargs)
|
255 |
+
print(smiles[:10])
|
256 |
+
df_out = pd.DataFrame({"smiles": smiles})
|
257 |
+
df_out.to_parquet(os.path.join(os.path.dirname(__file__), "opv.parquet"))
|
258 |
+
|
259 |
+
|
260 |
+
if __name__ == "__main__":
|
261 |
+
logging.basicConfig(level=logging.INFO)
|
262 |
+
cwd = os.path.join(os.path.dirname(__file__), "download")
|
263 |
+
os.makedirs(cwd,exist_ok=True)
|
264 |
+
d = OPV(cwd)
|
265 |
+
|
data/pubchemqc2020_energy.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:4d5ef9f419a48be52b1fe6332eb08d77df0b6ff7ec34f8c99c06e63fa232abf1
|
3 |
+
size 39165769
|
data/pubchemqc_energy.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:b5ca78b6f81f04ddcc2ed6e031d86f0a2f1e38d6c4001bfd93a28005b7168cf8
|
3 |
+
size 89749991
|
data/qm9_zinc250k_cep/convert_to_parquet.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import requests
|
3 |
+
import hashlib
|
4 |
+
import os
|
5 |
+
# Download and read zinc_properties file
|
6 |
+
zinc_url = "https://raw.githubusercontent.com/aspuru-guzik-group/chemical_vae/master/models/zinc_properties/250k_rndm_zinc_drugs_clean_3.csv"
|
7 |
+
zinc_md5 = "b59078b2b04c6e9431280e3dc42048d5"
|
8 |
+
zinc_filename = "zinc_properties.csv"
|
9 |
+
|
10 |
+
response = requests.get(zinc_url)
|
11 |
+
downloaded_data = response.content
|
12 |
+
|
13 |
+
downloaded_md5 = hashlib.md5(downloaded_data).hexdigest()
|
14 |
+
if zinc_md5 == downloaded_md5:
|
15 |
+
with open(zinc_filename, 'wb') as f:
|
16 |
+
f.write(downloaded_data)
|
17 |
+
print(f"File '{zinc_filename}' downloaded and saved.")
|
18 |
+
else:
|
19 |
+
raise ValueError("MD5 checksum does not match")
|
20 |
+
|
21 |
+
zinc_df = pd.read_csv(zinc_filename)
|
22 |
+
zinc_df = zinc_df[["smiles"]]
|
23 |
+
|
24 |
+
cwd = os.path.dirname(__file__)
|
25 |
+
|
26 |
+
qm9_filename = os.path.join(cwd,"QM9IsoFull.csv")
|
27 |
+
cep_filename = os.path.join(cwd,"cep-processed.csv")
|
28 |
+
|
29 |
+
qm9_df = pd.read_csv(qm9_filename, sep="|")
|
30 |
+
qm9_df = qm9_df[["smiles"]]
|
31 |
+
|
32 |
+
cep_df = pd.read_csv(cep_filename)
|
33 |
+
cep_df = cep_df[["smiles"]]
|
34 |
+
|
35 |
+
# Combine the dataframes into one large dataframe
|
36 |
+
combined_df = pd.concat([zinc_df, qm9_df, cep_df], axis=0)
|
37 |
+
|
38 |
+
# Save the combined dataframe to a Parquet file
|
39 |
+
output_filename = "qm9_zinc250_cep.parquet"
|
40 |
+
combined_df.to_parquet(output_filename, index=False)
|
41 |
+
print(f"Combined dataframe saved to '{output_filename}' as Parquet file.")
|
data/qm9_zinc250k_cep/qm9_zinc250_cep.parquet
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:3003c48cff3793646f07692b85745786d4d9b103323b3b59b3ae5b23af071d3a
|
3 |
+
size 7580076
|
data/vocab.txt
ADDED
@@ -0,0 +1,612 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[PAD]
|
2 |
+
[unused1]
|
3 |
+
[unused2]
|
4 |
+
[unused3]
|
5 |
+
[unused4]
|
6 |
+
[unused5]
|
7 |
+
[unused6]
|
8 |
+
[unused7]
|
9 |
+
[unused8]
|
10 |
+
[unused9]
|
11 |
+
[unused10]
|
12 |
+
[UNK]
|
13 |
+
[CLS]
|
14 |
+
[SEP]
|
15 |
+
[MASK]
|
16 |
+
c
|
17 |
+
C
|
18 |
+
(
|
19 |
+
)
|
20 |
+
O
|
21 |
+
1
|
22 |
+
2
|
23 |
+
=
|
24 |
+
N
|
25 |
+
.
|
26 |
+
n
|
27 |
+
3
|
28 |
+
F
|
29 |
+
Cl
|
30 |
+
>>
|
31 |
+
~
|
32 |
+
-
|
33 |
+
4
|
34 |
+
[C@H]
|
35 |
+
S
|
36 |
+
[C@@H]
|
37 |
+
[O-]
|
38 |
+
Br
|
39 |
+
#
|
40 |
+
/
|
41 |
+
[nH]
|
42 |
+
[N+]
|
43 |
+
s
|
44 |
+
5
|
45 |
+
o
|
46 |
+
P
|
47 |
+
[Na+]
|
48 |
+
[Si]
|
49 |
+
I
|
50 |
+
[Na]
|
51 |
+
[Pd]
|
52 |
+
[K+]
|
53 |
+
[K]
|
54 |
+
[P]
|
55 |
+
B
|
56 |
+
[C@]
|
57 |
+
[C@@]
|
58 |
+
[Cl-]
|
59 |
+
6
|
60 |
+
[OH-]
|
61 |
+
\
|
62 |
+
[N-]
|
63 |
+
[Li]
|
64 |
+
[H]
|
65 |
+
[2H]
|
66 |
+
[NH4+]
|
67 |
+
[c-]
|
68 |
+
[P-]
|
69 |
+
[Cs+]
|
70 |
+
[Li+]
|
71 |
+
[Cs]
|
72 |
+
[NaH]
|
73 |
+
[H-]
|
74 |
+
[O+]
|
75 |
+
[BH4-]
|
76 |
+
[Cu]
|
77 |
+
7
|
78 |
+
[Mg]
|
79 |
+
[Fe+2]
|
80 |
+
[n+]
|
81 |
+
[Sn]
|
82 |
+
[BH-]
|
83 |
+
[Pd+2]
|
84 |
+
[CH]
|
85 |
+
[I-]
|
86 |
+
[Br-]
|
87 |
+
[C-]
|
88 |
+
[Zn]
|
89 |
+
[B-]
|
90 |
+
[F-]
|
91 |
+
[Al]
|
92 |
+
[P+]
|
93 |
+
[BH3-]
|
94 |
+
[Fe]
|
95 |
+
[C]
|
96 |
+
[AlH4]
|
97 |
+
[Ni]
|
98 |
+
[SiH]
|
99 |
+
8
|
100 |
+
[Cu+2]
|
101 |
+
[Mn]
|
102 |
+
[AlH]
|
103 |
+
[nH+]
|
104 |
+
[AlH4-]
|
105 |
+
[O-2]
|
106 |
+
[Cr]
|
107 |
+
[Mg+2]
|
108 |
+
[NH3+]
|
109 |
+
[S@]
|
110 |
+
[Pt]
|
111 |
+
[Al+3]
|
112 |
+
[S@@]
|
113 |
+
[S-]
|
114 |
+
[Ti]
|
115 |
+
[Zn+2]
|
116 |
+
[PH]
|
117 |
+
[NH2+]
|
118 |
+
[Ru]
|
119 |
+
[Ag+]
|
120 |
+
[S+]
|
121 |
+
[I+3]
|
122 |
+
[NH+]
|
123 |
+
[Ca+2]
|
124 |
+
[Ag]
|
125 |
+
9
|
126 |
+
[Os]
|
127 |
+
[Se]
|
128 |
+
[SiH2]
|
129 |
+
[Ca]
|
130 |
+
[Ti+4]
|
131 |
+
[Ac]
|
132 |
+
[Cu+]
|
133 |
+
[S]
|
134 |
+
[Rh]
|
135 |
+
[Cl+3]
|
136 |
+
[cH-]
|
137 |
+
[Zn+]
|
138 |
+
[O]
|
139 |
+
[Cl+]
|
140 |
+
[SH]
|
141 |
+
[H+]
|
142 |
+
[Pd+]
|
143 |
+
[se]
|
144 |
+
[PH+]
|
145 |
+
[I]
|
146 |
+
[Pt+2]
|
147 |
+
[C+]
|
148 |
+
[Mg+]
|
149 |
+
[Hg]
|
150 |
+
[W]
|
151 |
+
[SnH]
|
152 |
+
[SiH3]
|
153 |
+
[Fe+3]
|
154 |
+
[NH]
|
155 |
+
[Mo]
|
156 |
+
[CH2+]
|
157 |
+
%10
|
158 |
+
[CH2-]
|
159 |
+
[CH2]
|
160 |
+
[n-]
|
161 |
+
[Ce+4]
|
162 |
+
[NH-]
|
163 |
+
[Co]
|
164 |
+
[I+]
|
165 |
+
[PH2]
|
166 |
+
[Pt+4]
|
167 |
+
[Ce]
|
168 |
+
[B]
|
169 |
+
[Sn+2]
|
170 |
+
[Ba+2]
|
171 |
+
%11
|
172 |
+
[Fe-3]
|
173 |
+
[18F]
|
174 |
+
[SH-]
|
175 |
+
[Pb+2]
|
176 |
+
[Os-2]
|
177 |
+
[Zr+4]
|
178 |
+
[N]
|
179 |
+
[Ir]
|
180 |
+
[Bi]
|
181 |
+
[Ni+2]
|
182 |
+
[P@]
|
183 |
+
[Co+2]
|
184 |
+
[s+]
|
185 |
+
[As]
|
186 |
+
[P+3]
|
187 |
+
[Hg+2]
|
188 |
+
[Yb+3]
|
189 |
+
[CH-]
|
190 |
+
[Zr+2]
|
191 |
+
[Mn+2]
|
192 |
+
[CH+]
|
193 |
+
[In]
|
194 |
+
[KH]
|
195 |
+
[Ce+3]
|
196 |
+
[Zr]
|
197 |
+
[AlH2-]
|
198 |
+
[OH2+]
|
199 |
+
[Ti+3]
|
200 |
+
[Rh+2]
|
201 |
+
[Sb]
|
202 |
+
[S-2]
|
203 |
+
%12
|
204 |
+
[P@@]
|
205 |
+
[Si@H]
|
206 |
+
[Mn+4]
|
207 |
+
p
|
208 |
+
[Ba]
|
209 |
+
[NH2-]
|
210 |
+
[Ge]
|
211 |
+
[Pb+4]
|
212 |
+
[Cr+3]
|
213 |
+
[Au]
|
214 |
+
[LiH]
|
215 |
+
[Sc+3]
|
216 |
+
[o+]
|
217 |
+
[Rh-3]
|
218 |
+
%13
|
219 |
+
[Br]
|
220 |
+
[Sb-]
|
221 |
+
[S@+]
|
222 |
+
[I+2]
|
223 |
+
[Ar]
|
224 |
+
[V]
|
225 |
+
[Cu-]
|
226 |
+
[Al-]
|
227 |
+
[Te]
|
228 |
+
[13c]
|
229 |
+
[13C]
|
230 |
+
[Cl]
|
231 |
+
[PH4+]
|
232 |
+
[SiH4]
|
233 |
+
[te]
|
234 |
+
[CH3-]
|
235 |
+
[S@@+]
|
236 |
+
[Rh+3]
|
237 |
+
[SH+]
|
238 |
+
[Bi+3]
|
239 |
+
[Br+2]
|
240 |
+
[La]
|
241 |
+
[La+3]
|
242 |
+
[Pt-2]
|
243 |
+
[N@@]
|
244 |
+
[PH3+]
|
245 |
+
[N@]
|
246 |
+
[Si+4]
|
247 |
+
[Sr+2]
|
248 |
+
[Al+]
|
249 |
+
[Pb]
|
250 |
+
[SeH]
|
251 |
+
[Si-]
|
252 |
+
[V+5]
|
253 |
+
[Y+3]
|
254 |
+
[Re]
|
255 |
+
[Ru+]
|
256 |
+
[Sm]
|
257 |
+
*
|
258 |
+
[3H]
|
259 |
+
[NH2]
|
260 |
+
[Ag-]
|
261 |
+
[13CH3]
|
262 |
+
[OH+]
|
263 |
+
[Ru+3]
|
264 |
+
[OH]
|
265 |
+
[Gd+3]
|
266 |
+
[13CH2]
|
267 |
+
[In+3]
|
268 |
+
[Si@@]
|
269 |
+
[Si@]
|
270 |
+
[Ti+2]
|
271 |
+
[Sn+]
|
272 |
+
[Cl+2]
|
273 |
+
[AlH-]
|
274 |
+
[Pd-2]
|
275 |
+
[SnH3]
|
276 |
+
[B+3]
|
277 |
+
[Cu-2]
|
278 |
+
[Nd+3]
|
279 |
+
[Pb+3]
|
280 |
+
[13cH]
|
281 |
+
[Fe-4]
|
282 |
+
[Ga]
|
283 |
+
[Sn+4]
|
284 |
+
[Hg+]
|
285 |
+
[11CH3]
|
286 |
+
[Hf]
|
287 |
+
[Pr]
|
288 |
+
[Y]
|
289 |
+
[S+2]
|
290 |
+
[Cd]
|
291 |
+
[Cr+6]
|
292 |
+
[Zr+3]
|
293 |
+
[Rh+]
|
294 |
+
[CH3]
|
295 |
+
[N-3]
|
296 |
+
[Hf+2]
|
297 |
+
[Th]
|
298 |
+
[Sb+3]
|
299 |
+
%14
|
300 |
+
[Cr+2]
|
301 |
+
[Ru+2]
|
302 |
+
[Hf+4]
|
303 |
+
[14C]
|
304 |
+
[Ta]
|
305 |
+
[Tl+]
|
306 |
+
[B+]
|
307 |
+
[Os+4]
|
308 |
+
[PdH2]
|
309 |
+
[Pd-]
|
310 |
+
[Cd+2]
|
311 |
+
[Co+3]
|
312 |
+
[S+4]
|
313 |
+
[Nb+5]
|
314 |
+
[123I]
|
315 |
+
[c+]
|
316 |
+
[Rb+]
|
317 |
+
[V+2]
|
318 |
+
[CH3+]
|
319 |
+
[Ag+2]
|
320 |
+
[cH+]
|
321 |
+
[Mn+3]
|
322 |
+
[Se-]
|
323 |
+
[As-]
|
324 |
+
[Eu+3]
|
325 |
+
[SH2]
|
326 |
+
[Sm+3]
|
327 |
+
[IH+]
|
328 |
+
%15
|
329 |
+
[OH3+]
|
330 |
+
[PH3]
|
331 |
+
[IH2+]
|
332 |
+
[SH2+]
|
333 |
+
[Ir+3]
|
334 |
+
[AlH3]
|
335 |
+
[Sc]
|
336 |
+
[Yb]
|
337 |
+
[15NH2]
|
338 |
+
[Lu]
|
339 |
+
[sH+]
|
340 |
+
[Gd]
|
341 |
+
[18F-]
|
342 |
+
[SH3+]
|
343 |
+
[SnH4]
|
344 |
+
[TeH]
|
345 |
+
[Si@@H]
|
346 |
+
[Ga+3]
|
347 |
+
[CaH2]
|
348 |
+
[Tl]
|
349 |
+
[Ta+5]
|
350 |
+
[GeH]
|
351 |
+
[Br+]
|
352 |
+
[Sr]
|
353 |
+
[Tl+3]
|
354 |
+
[Sm+2]
|
355 |
+
[PH5]
|
356 |
+
%16
|
357 |
+
[N@@+]
|
358 |
+
[Au+3]
|
359 |
+
[C-4]
|
360 |
+
[Nd]
|
361 |
+
[Ti+]
|
362 |
+
[IH]
|
363 |
+
[N@+]
|
364 |
+
[125I]
|
365 |
+
[Eu]
|
366 |
+
[Sn+3]
|
367 |
+
[Nb]
|
368 |
+
[Er+3]
|
369 |
+
[123I-]
|
370 |
+
[14c]
|
371 |
+
%17
|
372 |
+
[SnH2]
|
373 |
+
[YH]
|
374 |
+
[Sb+5]
|
375 |
+
[Pr+3]
|
376 |
+
[Ir+]
|
377 |
+
[N+3]
|
378 |
+
[AlH2]
|
379 |
+
[19F]
|
380 |
+
%18
|
381 |
+
[Tb]
|
382 |
+
[14CH]
|
383 |
+
[Mo+4]
|
384 |
+
[Si+]
|
385 |
+
[BH]
|
386 |
+
[Be]
|
387 |
+
[Rb]
|
388 |
+
[pH]
|
389 |
+
%19
|
390 |
+
%20
|
391 |
+
[Xe]
|
392 |
+
[Ir-]
|
393 |
+
[Be+2]
|
394 |
+
[C+4]
|
395 |
+
[RuH2]
|
396 |
+
[15NH]
|
397 |
+
[U+2]
|
398 |
+
[Au-]
|
399 |
+
%21
|
400 |
+
%22
|
401 |
+
[Au+]
|
402 |
+
[15n]
|
403 |
+
[Al+2]
|
404 |
+
[Tb+3]
|
405 |
+
[15N]
|
406 |
+
[V+3]
|
407 |
+
[W+6]
|
408 |
+
[14CH3]
|
409 |
+
[Cr+4]
|
410 |
+
[ClH+]
|
411 |
+
b
|
412 |
+
[Ti+6]
|
413 |
+
[Nd+]
|
414 |
+
[Zr+]
|
415 |
+
[PH2+]
|
416 |
+
[Fm]
|
417 |
+
[N@H+]
|
418 |
+
[RuH]
|
419 |
+
[Dy+3]
|
420 |
+
%23
|
421 |
+
[Hf+3]
|
422 |
+
[W+4]
|
423 |
+
[11C]
|
424 |
+
[13CH]
|
425 |
+
[Er]
|
426 |
+
[124I]
|
427 |
+
[LaH]
|
428 |
+
[F]
|
429 |
+
[siH]
|
430 |
+
[Ga+]
|
431 |
+
[Cm]
|
432 |
+
[GeH3]
|
433 |
+
[IH-]
|
434 |
+
[U+6]
|
435 |
+
[SeH+]
|
436 |
+
[32P]
|
437 |
+
[SeH-]
|
438 |
+
[Pt-]
|
439 |
+
[Ir+2]
|
440 |
+
[se+]
|
441 |
+
[U]
|
442 |
+
[F+]
|
443 |
+
[BH2]
|
444 |
+
[As+]
|
445 |
+
[Cf]
|
446 |
+
[ClH2+]
|
447 |
+
[Ni+]
|
448 |
+
[TeH3]
|
449 |
+
[SbH2]
|
450 |
+
[Ag+3]
|
451 |
+
%24
|
452 |
+
[18O]
|
453 |
+
[PH4]
|
454 |
+
[Os+2]
|
455 |
+
[Na-]
|
456 |
+
[Sb+2]
|
457 |
+
[V+4]
|
458 |
+
[Ho+3]
|
459 |
+
[68Ga]
|
460 |
+
[PH-]
|
461 |
+
[Bi+2]
|
462 |
+
[Ce+2]
|
463 |
+
[Pd+3]
|
464 |
+
[99Tc]
|
465 |
+
[13C@@H]
|
466 |
+
[Fe+6]
|
467 |
+
[c]
|
468 |
+
[GeH2]
|
469 |
+
[10B]
|
470 |
+
[Cu+3]
|
471 |
+
[Mo+2]
|
472 |
+
[Cr+]
|
473 |
+
[Pd+4]
|
474 |
+
[Dy]
|
475 |
+
[AsH]
|
476 |
+
[Ba+]
|
477 |
+
[SeH2]
|
478 |
+
[In+]
|
479 |
+
[TeH2]
|
480 |
+
[BrH+]
|
481 |
+
[14cH]
|
482 |
+
[W+]
|
483 |
+
[13C@H]
|
484 |
+
[AsH2]
|
485 |
+
[In+2]
|
486 |
+
[N+2]
|
487 |
+
[N@@H+]
|
488 |
+
[SbH]
|
489 |
+
[60Co]
|
490 |
+
[AsH4+]
|
491 |
+
[AsH3]
|
492 |
+
[18OH]
|
493 |
+
[Ru-2]
|
494 |
+
[Na-2]
|
495 |
+
[CuH2]
|
496 |
+
[31P]
|
497 |
+
[Ti+5]
|
498 |
+
[35S]
|
499 |
+
[P@@H]
|
500 |
+
[ArH]
|
501 |
+
[Co+]
|
502 |
+
[Zr-2]
|
503 |
+
[BH2-]
|
504 |
+
[131I]
|
505 |
+
[SH5]
|
506 |
+
[VH]
|
507 |
+
[B+2]
|
508 |
+
[Yb+2]
|
509 |
+
[14C@H]
|
510 |
+
[211At]
|
511 |
+
[NH3+2]
|
512 |
+
[IrH]
|
513 |
+
[IrH2]
|
514 |
+
[Rh-]
|
515 |
+
[Cr-]
|
516 |
+
[Sb+]
|
517 |
+
[Ni+3]
|
518 |
+
[TaH3]
|
519 |
+
[Tl+2]
|
520 |
+
[64Cu]
|
521 |
+
[Tc]
|
522 |
+
[Cd+]
|
523 |
+
[1H]
|
524 |
+
[15nH]
|
525 |
+
[AlH2+]
|
526 |
+
[FH+2]
|
527 |
+
[BiH3]
|
528 |
+
[Ru-]
|
529 |
+
[Mo+6]
|
530 |
+
[AsH+]
|
531 |
+
[BaH2]
|
532 |
+
[BaH]
|
533 |
+
[Fe+4]
|
534 |
+
[229Th]
|
535 |
+
[Th+4]
|
536 |
+
[As+3]
|
537 |
+
[NH+3]
|
538 |
+
[P@H]
|
539 |
+
[Li-]
|
540 |
+
[7NaH]
|
541 |
+
[Bi+]
|
542 |
+
[PtH+2]
|
543 |
+
[p-]
|
544 |
+
[Re+5]
|
545 |
+
[NiH]
|
546 |
+
[Ni-]
|
547 |
+
[Xe+]
|
548 |
+
[Ca+]
|
549 |
+
[11c]
|
550 |
+
[Rh+4]
|
551 |
+
[AcH]
|
552 |
+
[HeH]
|
553 |
+
[Sc+2]
|
554 |
+
[Mn+]
|
555 |
+
[UH]
|
556 |
+
[14CH2]
|
557 |
+
[SiH4+]
|
558 |
+
[18OH2]
|
559 |
+
[Ac-]
|
560 |
+
[Re+4]
|
561 |
+
[118Sn]
|
562 |
+
[153Sm]
|
563 |
+
[P+2]
|
564 |
+
[9CH]
|
565 |
+
[9CH3]
|
566 |
+
[Y-]
|
567 |
+
[NiH2]
|
568 |
+
[Si+2]
|
569 |
+
[Mn+6]
|
570 |
+
[ZrH2]
|
571 |
+
[C-2]
|
572 |
+
[Bi+5]
|
573 |
+
[24NaH]
|
574 |
+
[Fr]
|
575 |
+
[15CH]
|
576 |
+
[Se+]
|
577 |
+
[At]
|
578 |
+
[P-3]
|
579 |
+
[124I-]
|
580 |
+
[CuH2-]
|
581 |
+
[Nb+4]
|
582 |
+
[Nb+3]
|
583 |
+
[MgH]
|
584 |
+
[Ir+4]
|
585 |
+
[67Ga+3]
|
586 |
+
[67Ga]
|
587 |
+
[13N]
|
588 |
+
[15OH2]
|
589 |
+
[2NH]
|
590 |
+
[Ho]
|
591 |
+
[Cn]
|
592 |
+
[0*]
|
593 |
+
[1*]
|
594 |
+
[2*]
|
595 |
+
[3*]
|
596 |
+
[4*]
|
597 |
+
[5*]
|
598 |
+
[6*]
|
599 |
+
[7*]
|
600 |
+
[8*]
|
601 |
+
[9*]
|
602 |
+
[10*]
|
603 |
+
[11*]
|
604 |
+
[12*]
|
605 |
+
[13*]
|
606 |
+
[14*]
|
607 |
+
[15*]
|
608 |
+
[16*]
|
609 |
+
[17*]
|
610 |
+
[18*]
|
611 |
+
[19*]
|
612 |
+
[20*]
|
data/zinc/convert_to_parquet.py
ADDED
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import os.path as osp
|
3 |
+
import os
|
4 |
+
from tqdm import tqdm
|
5 |
+
import dask.dataframe as dd
|
6 |
+
import pandas as pd
|
7 |
+
import pyarrow as pa
|
8 |
+
import shutil
|
9 |
+
cwd = osp.abspath(osp.dirname(__file__))
|
10 |
+
zinc_path = os.path.join(cwd, "zinc_complete")
|
11 |
+
alls_dirs = [
|
12 |
+
osp.join(zinc_path, f)
|
13 |
+
for f in os.listdir(zinc_path)
|
14 |
+
if osp.isdir(osp.join(zinc_path, f))
|
15 |
+
]
|
16 |
+
|
17 |
+
|
18 |
+
print("Number of dirs: ", len(alls_dirs))
|
19 |
+
all_dfs = []
|
20 |
+
for d in alls_dirs:
|
21 |
+
print(f"Read: {d }")
|
22 |
+
df = dd.read_csv(
|
23 |
+
os.path.join(cwd, "zinc_complete", f"{d}/*.txt"),
|
24 |
+
sep="\t",
|
25 |
+
usecols=["smiles"],
|
26 |
+
)
|
27 |
+
all_dfs.append(df)
|
28 |
+
|
29 |
+
concatenated_df = dd.concat(all_dfs)
|
30 |
+
# res = df["logp"].map_partitions(lambda d, bins: pd.cut(d, bins), 25).compute()
|
31 |
+
# print(res)
|
32 |
+
|
33 |
+
print("Writing")
|
34 |
+
# print(df)
|
35 |
+
# name_function = lambda x: f"zincfull-{x}.parquet"
|
36 |
+
concatenated_df = concatenated_df.repartition(npartitions=1)
|
37 |
+
concatenated_df = concatenated_df.reset_index(drop=True)
|
38 |
+
concatenated_df.to_parquet(
|
39 |
+
os.path.join(cwd, "zinc_processed"),
|
40 |
+
)
|
41 |
+
print("Done Writing")
|
42 |
+
print(len(concatenated_df))
|
43 |
+
shutil.copy(
|
44 |
+
os.path.join(cwd, "zinc_processed", "part.0.parquet"),
|
45 |
+
os.path.join(cwd, "zinc_processed.parquet")
|
46 |
+
)
|
47 |
+
|
48 |
+
# df = None
|
49 |
+
# for d in tqdm(alls_dirs):
|
50 |
+
# if df is not None:
|
51 |
+
# print(len(df))
|
52 |
+
# files = [osp.join(d,f) for f in os.listdir(d)]
|
53 |
+
# for f in files:
|
54 |
+
# try:
|
55 |
+
# df_extra = pd.read_csv(f,sep="\t")
|
56 |
+
# except Exception as e:
|
57 |
+
# print(f"Got error {f}: {e}")
|
58 |
+
# continue
|
59 |
+
# # print(df)
|
60 |
+
# if df is None:
|
61 |
+
# df = df_extra
|
62 |
+
|
63 |
+
# else:
|
64 |
+
# df = df.append(df_extra)
|
65 |
+
|
66 |
+
|
67 |
+
# df.to_parquet(osp.join(cwd, "zinc_combined.parquet"))
|
data/zinc/zinc_complete/download_zinc.sh
ADDED
@@ -0,0 +1,300 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AAAA.txt -O AA/AAAA.txt
|
2 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AAAB.txt -O AA/AAAB.txt
|
3 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AAAC.txt -O AA/AAAC.txt
|
4 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AAAD.txt -O AA/AAAD.txt
|
5 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AABA.txt -O AA/AABA.txt
|
6 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AABB.txt -O AA/AABB.txt
|
7 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AABC.txt -O AA/AABC.txt
|
8 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AABD.txt -O AA/AABD.txt
|
9 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AACA.txt -O AA/AACA.txt
|
10 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AACB.txt -O AA/AACB.txt
|
11 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AACC.txt -O AA/AACC.txt
|
12 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AACD.txt -O AA/AACD.txt
|
13 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AAEA.txt -O AA/AAEA.txt
|
14 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AAEB.txt -O AA/AAEB.txt
|
15 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AAEC.txt -O AA/AAEC.txt
|
16 |
+
mkdir -pv AA && wget http://files.docking.org/2D/AA/AAED.txt -O AA/AAED.txt
|
17 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BAAA.txt -O BA/BAAA.txt
|
18 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BAAB.txt -O BA/BAAB.txt
|
19 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BAAC.txt -O BA/BAAC.txt
|
20 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BAAD.txt -O BA/BAAD.txt
|
21 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BABA.txt -O BA/BABA.txt
|
22 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BABB.txt -O BA/BABB.txt
|
23 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BABC.txt -O BA/BABC.txt
|
24 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BABD.txt -O BA/BABD.txt
|
25 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BACA.txt -O BA/BACA.txt
|
26 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BACB.txt -O BA/BACB.txt
|
27 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BACC.txt -O BA/BACC.txt
|
28 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BACD.txt -O BA/BACD.txt
|
29 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BAEA.txt -O BA/BAEA.txt
|
30 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BAEB.txt -O BA/BAEB.txt
|
31 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BAEC.txt -O BA/BAEC.txt
|
32 |
+
mkdir -pv BA && wget http://files.docking.org/2D/BA/BAED.txt -O BA/BAED.txt
|
33 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CAAA.txt -O CA/CAAA.txt
|
34 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CAAB.txt -O CA/CAAB.txt
|
35 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CAAC.txt -O CA/CAAC.txt
|
36 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CAAD.txt -O CA/CAAD.txt
|
37 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CABA.txt -O CA/CABA.txt
|
38 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CABB.txt -O CA/CABB.txt
|
39 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CABC.txt -O CA/CABC.txt
|
40 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CABD.txt -O CA/CABD.txt
|
41 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CACA.txt -O CA/CACA.txt
|
42 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CACB.txt -O CA/CACB.txt
|
43 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CACC.txt -O CA/CACC.txt
|
44 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CACD.txt -O CA/CACD.txt
|
45 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CAEA.txt -O CA/CAEA.txt
|
46 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CAEB.txt -O CA/CAEB.txt
|
47 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CAEC.txt -O CA/CAEC.txt
|
48 |
+
mkdir -pv CA && wget http://files.docking.org/2D/CA/CAED.txt -O CA/CAED.txt
|
49 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DAAA.txt -O DA/DAAA.txt
|
50 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DAAB.txt -O DA/DAAB.txt
|
51 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DAAC.txt -O DA/DAAC.txt
|
52 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DAAD.txt -O DA/DAAD.txt
|
53 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DABA.txt -O DA/DABA.txt
|
54 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DABB.txt -O DA/DABB.txt
|
55 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DABC.txt -O DA/DABC.txt
|
56 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DABD.txt -O DA/DABD.txt
|
57 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DACA.txt -O DA/DACA.txt
|
58 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DACB.txt -O DA/DACB.txt
|
59 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DACC.txt -O DA/DACC.txt
|
60 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DACD.txt -O DA/DACD.txt
|
61 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DAEA.txt -O DA/DAEA.txt
|
62 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DAEB.txt -O DA/DAEB.txt
|
63 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DAEC.txt -O DA/DAEC.txt
|
64 |
+
mkdir -pv DA && wget http://files.docking.org/2D/DA/DAED.txt -O DA/DAED.txt
|
65 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EAAA.txt -O EA/EAAA.txt
|
66 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EAAB.txt -O EA/EAAB.txt
|
67 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EAAC.txt -O EA/EAAC.txt
|
68 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EAAD.txt -O EA/EAAD.txt
|
69 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EABA.txt -O EA/EABA.txt
|
70 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EABB.txt -O EA/EABB.txt
|
71 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EABC.txt -O EA/EABC.txt
|
72 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EABD.txt -O EA/EABD.txt
|
73 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EACA.txt -O EA/EACA.txt
|
74 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EACB.txt -O EA/EACB.txt
|
75 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EACC.txt -O EA/EACC.txt
|
76 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EACD.txt -O EA/EACD.txt
|
77 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EAEA.txt -O EA/EAEA.txt
|
78 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EAEB.txt -O EA/EAEB.txt
|
79 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EAEC.txt -O EA/EAEC.txt
|
80 |
+
mkdir -pv EA && wget http://files.docking.org/2D/EA/EAED.txt -O EA/EAED.txt
|
81 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FAAA.txt -O FA/FAAA.txt
|
82 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FAAB.txt -O FA/FAAB.txt
|
83 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FAAC.txt -O FA/FAAC.txt
|
84 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FAAD.txt -O FA/FAAD.txt
|
85 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FABA.txt -O FA/FABA.txt
|
86 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FABB.txt -O FA/FABB.txt
|
87 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FABC.txt -O FA/FABC.txt
|
88 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FABD.txt -O FA/FABD.txt
|
89 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FACA.txt -O FA/FACA.txt
|
90 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FACB.txt -O FA/FACB.txt
|
91 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FACC.txt -O FA/FACC.txt
|
92 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FACD.txt -O FA/FACD.txt
|
93 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FAEA.txt -O FA/FAEA.txt
|
94 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FAEB.txt -O FA/FAEB.txt
|
95 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FAEC.txt -O FA/FAEC.txt
|
96 |
+
mkdir -pv FA && wget http://files.docking.org/2D/FA/FAED.txt -O FA/FAED.txt
|
97 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GAAA.txt -O GA/GAAA.txt
|
98 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GAAB.txt -O GA/GAAB.txt
|
99 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GAAC.txt -O GA/GAAC.txt
|
100 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GAAD.txt -O GA/GAAD.txt
|
101 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABAA.txt -O AB/ABAA.txt
|
102 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABAB.txt -O AB/ABAB.txt
|
103 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABAC.txt -O AB/ABAC.txt
|
104 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABAD.txt -O AB/ABAD.txt
|
105 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABBA.txt -O AB/ABBA.txt
|
106 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABBB.txt -O AB/ABBB.txt
|
107 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABBC.txt -O AB/ABBC.txt
|
108 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABBD.txt -O AB/ABBD.txt
|
109 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABCA.txt -O AB/ABCA.txt
|
110 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABCB.txt -O AB/ABCB.txt
|
111 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABCC.txt -O AB/ABCC.txt
|
112 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABCD.txt -O AB/ABCD.txt
|
113 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABEA.txt -O AB/ABEA.txt
|
114 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABEB.txt -O AB/ABEB.txt
|
115 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABEC.txt -O AB/ABEC.txt
|
116 |
+
mkdir -pv AB && wget http://files.docking.org/2D/AB/ABED.txt -O AB/ABED.txt
|
117 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBAA.txt -O BB/BBAA.txt
|
118 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBAB.txt -O BB/BBAB.txt
|
119 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBAC.txt -O BB/BBAC.txt
|
120 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBAD.txt -O BB/BBAD.txt
|
121 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBBA.txt -O BB/BBBA.txt
|
122 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBBB.txt -O BB/BBBB.txt
|
123 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBBC.txt -O BB/BBBC.txt
|
124 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBBD.txt -O BB/BBBD.txt
|
125 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GABA.txt -O GA/GABA.txt
|
126 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GABB.txt -O GA/GABB.txt
|
127 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GABC.txt -O GA/GABC.txt
|
128 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GABD.txt -O GA/GABD.txt
|
129 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GACA.txt -O GA/GACA.txt
|
130 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GACB.txt -O GA/GACB.txt
|
131 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GACC.txt -O GA/GACC.txt
|
132 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GACD.txt -O GA/GACD.txt
|
133 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GAEA.txt -O GA/GAEA.txt
|
134 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GAEB.txt -O GA/GAEB.txt
|
135 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GAEC.txt -O GA/GAEC.txt
|
136 |
+
mkdir -pv GA && wget http://files.docking.org/2D/GA/GAED.txt -O GA/GAED.txt
|
137 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HAAA.txt -O HA/HAAA.txt
|
138 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HAAB.txt -O HA/HAAB.txt
|
139 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HAAC.txt -O HA/HAAC.txt
|
140 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HAAD.txt -O HA/HAAD.txt
|
141 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HABA.txt -O HA/HABA.txt
|
142 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HABB.txt -O HA/HABB.txt
|
143 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HABC.txt -O HA/HABC.txt
|
144 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HABD.txt -O HA/HABD.txt
|
145 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HACA.txt -O HA/HACA.txt
|
146 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HACB.txt -O HA/HACB.txt
|
147 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HACC.txt -O HA/HACC.txt
|
148 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HACD.txt -O HA/HACD.txt
|
149 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HAEA.txt -O HA/HAEA.txt
|
150 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HAEB.txt -O HA/HAEB.txt
|
151 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HAEC.txt -O HA/HAEC.txt
|
152 |
+
mkdir -pv HA && wget http://files.docking.org/2D/HA/HAED.txt -O HA/HAED.txt
|
153 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IAAA.txt -O IA/IAAA.txt
|
154 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IAAB.txt -O IA/IAAB.txt
|
155 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IAAC.txt -O IA/IAAC.txt
|
156 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IAAD.txt -O IA/IAAD.txt
|
157 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IABA.txt -O IA/IABA.txt
|
158 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IABB.txt -O IA/IABB.txt
|
159 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IABC.txt -O IA/IABC.txt
|
160 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IABD.txt -O IA/IABD.txt
|
161 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IACA.txt -O IA/IACA.txt
|
162 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IACB.txt -O IA/IACB.txt
|
163 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IACC.txt -O IA/IACC.txt
|
164 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IACD.txt -O IA/IACD.txt
|
165 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IAEA.txt -O IA/IAEA.txt
|
166 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IAEB.txt -O IA/IAEB.txt
|
167 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IAEC.txt -O IA/IAEC.txt
|
168 |
+
mkdir -pv IA && wget http://files.docking.org/2D/IA/IAED.txt -O IA/IAED.txt
|
169 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JAAA.txt -O JA/JAAA.txt
|
170 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JAAB.txt -O JA/JAAB.txt
|
171 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JAAC.txt -O JA/JAAC.txt
|
172 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JAAD.txt -O JA/JAAD.txt
|
173 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JABA.txt -O JA/JABA.txt
|
174 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JABB.txt -O JA/JABB.txt
|
175 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JABC.txt -O JA/JABC.txt
|
176 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JABD.txt -O JA/JABD.txt
|
177 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JACA.txt -O JA/JACA.txt
|
178 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JACB.txt -O JA/JACB.txt
|
179 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JACC.txt -O JA/JACC.txt
|
180 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JACD.txt -O JA/JACD.txt
|
181 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JAEA.txt -O JA/JAEA.txt
|
182 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JAEB.txt -O JA/JAEB.txt
|
183 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JAEC.txt -O JA/JAEC.txt
|
184 |
+
mkdir -pv JA && wget http://files.docking.org/2D/JA/JAED.txt -O JA/JAED.txt
|
185 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KAAA.txt -O KA/KAAA.txt
|
186 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KAAB.txt -O KA/KAAB.txt
|
187 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KAAC.txt -O KA/KAAC.txt
|
188 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KAAD.txt -O KA/KAAD.txt
|
189 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KABA.txt -O KA/KABA.txt
|
190 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KABB.txt -O KA/KABB.txt
|
191 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KABC.txt -O KA/KABC.txt
|
192 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KABD.txt -O KA/KABD.txt
|
193 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KACA.txt -O KA/KACA.txt
|
194 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KACB.txt -O KA/KACB.txt
|
195 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KACC.txt -O KA/KACC.txt
|
196 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KACD.txt -O KA/KACD.txt
|
197 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KAEA.txt -O KA/KAEA.txt
|
198 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KAEB.txt -O KA/KAEB.txt
|
199 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KAEC.txt -O KA/KAEC.txt
|
200 |
+
mkdir -pv KA && wget http://files.docking.org/2D/KA/KAED.txt -O KA/KAED.txt
|
201 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBCA.txt -O BB/BBCA.txt
|
202 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBCB.txt -O BB/BBCB.txt
|
203 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBCC.txt -O BB/BBCC.txt
|
204 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBCD.txt -O BB/BBCD.txt
|
205 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBEA.txt -O BB/BBEA.txt
|
206 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBEB.txt -O BB/BBEB.txt
|
207 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBEC.txt -O BB/BBEC.txt
|
208 |
+
mkdir -pv BB && wget http://files.docking.org/2D/BB/BBED.txt -O BB/BBED.txt
|
209 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBAA.txt -O CB/CBAA.txt
|
210 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBAB.txt -O CB/CBAB.txt
|
211 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBAC.txt -O CB/CBAC.txt
|
212 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBAD.txt -O CB/CBAD.txt
|
213 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBBA.txt -O CB/CBBA.txt
|
214 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBBB.txt -O CB/CBBB.txt
|
215 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBBC.txt -O CB/CBBC.txt
|
216 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBBD.txt -O CB/CBBD.txt
|
217 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBCA.txt -O CB/CBCA.txt
|
218 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBCB.txt -O CB/CBCB.txt
|
219 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBCC.txt -O CB/CBCC.txt
|
220 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBCD.txt -O CB/CBCD.txt
|
221 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBEA.txt -O CB/CBEA.txt
|
222 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBEB.txt -O CB/CBEB.txt
|
223 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBEC.txt -O CB/CBEC.txt
|
224 |
+
mkdir -pv CB && wget http://files.docking.org/2D/CB/CBED.txt -O CB/CBED.txt
|
225 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBAA.txt -O DB/DBAA.txt
|
226 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBAB.txt -O DB/DBAB.txt
|
227 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBAC.txt -O DB/DBAC.txt
|
228 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBAD.txt -O DB/DBAD.txt
|
229 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBBA.txt -O DB/DBBA.txt
|
230 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBBB.txt -O DB/DBBB.txt
|
231 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBBC.txt -O DB/DBBC.txt
|
232 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBBD.txt -O DB/DBBD.txt
|
233 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBCA.txt -O DB/DBCA.txt
|
234 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBCB.txt -O DB/DBCB.txt
|
235 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBCC.txt -O DB/DBCC.txt
|
236 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBCD.txt -O DB/DBCD.txt
|
237 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBEA.txt -O DB/DBEA.txt
|
238 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBEB.txt -O DB/DBEB.txt
|
239 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBEC.txt -O DB/DBEC.txt
|
240 |
+
mkdir -pv DB && wget http://files.docking.org/2D/DB/DBED.txt -O DB/DBED.txt
|
241 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBAA.txt -O EB/EBAA.txt
|
242 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBAB.txt -O EB/EBAB.txt
|
243 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBAC.txt -O EB/EBAC.txt
|
244 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBAD.txt -O EB/EBAD.txt
|
245 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBBA.txt -O EB/EBBA.txt
|
246 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBBB.txt -O EB/EBBB.txt
|
247 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBBC.txt -O EB/EBBC.txt
|
248 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBBD.txt -O EB/EBBD.txt
|
249 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBCA.txt -O EB/EBCA.txt
|
250 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBCB.txt -O EB/EBCB.txt
|
251 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBCC.txt -O EB/EBCC.txt
|
252 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBCD.txt -O EB/EBCD.txt
|
253 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBEA.txt -O EB/EBEA.txt
|
254 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBEB.txt -O EB/EBEB.txt
|
255 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBEC.txt -O EB/EBEC.txt
|
256 |
+
mkdir -pv EB && wget http://files.docking.org/2D/EB/EBED.txt -O EB/EBED.txt
|
257 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBAA.txt -O FB/FBAA.txt
|
258 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBAB.txt -O FB/FBAB.txt
|
259 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBAC.txt -O FB/FBAC.txt
|
260 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBAD.txt -O FB/FBAD.txt
|
261 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBBA.txt -O FB/FBBA.txt
|
262 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBBB.txt -O FB/FBBB.txt
|
263 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBBC.txt -O FB/FBBC.txt
|
264 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBBD.txt -O FB/FBBD.txt
|
265 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBCA.txt -O FB/FBCA.txt
|
266 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBCB.txt -O FB/FBCB.txt
|
267 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBCC.txt -O FB/FBCC.txt
|
268 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBCD.txt -O FB/FBCD.txt
|
269 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBEA.txt -O FB/FBEA.txt
|
270 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBEB.txt -O FB/FBEB.txt
|
271 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBEC.txt -O FB/FBEC.txt
|
272 |
+
mkdir -pv FB && wget http://files.docking.org/2D/FB/FBED.txt -O FB/FBED.txt
|
273 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBAA.txt -O GB/GBAA.txt
|
274 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBAB.txt -O GB/GBAB.txt
|
275 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBAC.txt -O GB/GBAC.txt
|
276 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBAD.txt -O GB/GBAD.txt
|
277 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBBA.txt -O GB/GBBA.txt
|
278 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBBB.txt -O GB/GBBB.txt
|
279 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBBC.txt -O GB/GBBC.txt
|
280 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBBD.txt -O GB/GBBD.txt
|
281 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBCA.txt -O GB/GBCA.txt
|
282 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBCB.txt -O GB/GBCB.txt
|
283 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBCC.txt -O GB/GBCC.txt
|
284 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBCD.txt -O GB/GBCD.txt
|
285 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBEA.txt -O GB/GBEA.txt
|
286 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBEB.txt -O GB/GBEB.txt
|
287 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBEC.txt -O GB/GBEC.txt
|
288 |
+
mkdir -pv GB && wget http://files.docking.org/2D/GB/GBED.txt -O GB/GBED.txt
|
289 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBAA.txt -O HB/HBAA.txt
|
290 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBAB.txt -O HB/HBAB.txt
|
291 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBAC.txt -O HB/HBAC.txt
|
292 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBAD.txt -O HB/HBAD.txt
|
293 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBBA.txt -O HB/HBBA.txt
|
294 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBBB.txt -O HB/HBBB.txt
|
295 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBBC.txt -O HB/HBBC.txt
|
296 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBBD.txt -O HB/HBBD.txt
|
297 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBCA.txt -O HB/HBCA.txt
|
298 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBCB.txt -O HB/HBCB.txt
|
299 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBCC.txt -O HB/HBCC.txt
|
300 |
+
mkdir -pv HB && wget http://files.docking.org/2D/HB/HBCD.txt -O HB/HBCD.txt
|
data/zinc/zinc_complete/run_download.py
ADDED
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import concurrent.futures
|
2 |
+
import subprocess
|
3 |
+
|
4 |
+
shell_file = "download_zinc.sh"
|
5 |
+
num_parallel = 8
|
6 |
+
|
7 |
+
def execute_command(command):
|
8 |
+
print("Running: ", command)
|
9 |
+
subprocess.run(command, shell=True)
|
10 |
+
|
11 |
+
commands = []
|
12 |
+
with open(shell_file, "r") as file:
|
13 |
+
for line in file:
|
14 |
+
line = line.strip()
|
15 |
+
if line.startswith("mkdir") and "wget" in line:
|
16 |
+
commands.append(line)
|
17 |
+
|
18 |
+
with concurrent.futures.ThreadPoolExecutor() as executor:
|
19 |
+
executor.map(execute_command, commands, chunksize=num_parallel)
|
20 |
+
|
21 |
+
print("Downloads completed")
|
demonstrator.ipynb
ADDED
@@ -0,0 +1,521 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# Demonstrator"
|
8 |
+
]
|
9 |
+
},
|
10 |
+
{
|
11 |
+
"cell_type": "markdown",
|
12 |
+
"metadata": {},
|
13 |
+
"source": [
|
14 |
+
"### Load the model"
|
15 |
+
]
|
16 |
+
},
|
17 |
+
{
|
18 |
+
"cell_type": "code",
|
19 |
+
"execution_count": 5,
|
20 |
+
"metadata": {},
|
21 |
+
"outputs": [
|
22 |
+
{
|
23 |
+
"name": "stderr",
|
24 |
+
"output_type": "stream",
|
25 |
+
"text": [
|
26 |
+
"INFO:sample:Compiling the model...\n"
|
27 |
+
]
|
28 |
+
}
|
29 |
+
],
|
30 |
+
"source": [
|
31 |
+
"import rdkit\n",
|
32 |
+
"from rdkit import Chem\n",
|
33 |
+
"import rdkit.rdBase as rkrb\n",
|
34 |
+
"import rdkit.RDLogger as rkl\n",
|
35 |
+
"import os\n",
|
36 |
+
"import torch \n",
|
37 |
+
"import logging\n",
|
38 |
+
"import numpy as np\n",
|
39 |
+
"from plot_utils import check_metrics\n",
|
40 |
+
"from sample import Sampler\n",
|
41 |
+
"import pandas as pd\n",
|
42 |
+
"\n",
|
43 |
+
"device = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
44 |
+
"\n",
|
45 |
+
"if \"cuda\" in device:\n",
|
46 |
+
" # dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'\n",
|
47 |
+
" dtype = \"float16\" if torch.cuda.is_available() else \"float32\"\n",
|
48 |
+
"else:\n",
|
49 |
+
" dtype = \"float32\"\n",
|
50 |
+
"\n",
|
51 |
+
"logger = rkl.logger()\n",
|
52 |
+
"logger.setLevel(rkl.ERROR)\n",
|
53 |
+
"rkrb.DisableLog(\"rdApp.error\")\n",
|
54 |
+
"\n",
|
55 |
+
"torch.set_num_threads(8)\n",
|
56 |
+
"logging.basicConfig(level=logging.INFO)\n",
|
57 |
+
"logger = logging.getLogger(__name__)\n",
|
58 |
+
"\n",
|
59 |
+
"sampler = Sampler(\n",
|
60 |
+
" load_path=os.path.join(\n",
|
61 |
+
" os.getcwd(), \"out\", \"llama2-M-Full-RSS.pt\"\n",
|
62 |
+
" ),\n",
|
63 |
+
" device=device,\n",
|
64 |
+
" seed=1234,\n",
|
65 |
+
" dtype=dtype,\n",
|
66 |
+
" compile=True,\n",
|
67 |
+
")\n",
|
68 |
+
"\n",
|
69 |
+
" \n",
|
70 |
+
"num_samples = 100\n",
|
71 |
+
"df_comp = pd.read_parquet(os.path.join(os.getcwd(),\"data\",\"OrganiX13.parquet\"))\n",
|
72 |
+
"df_comp = df_comp.sample(n=2_500_000)\n",
|
73 |
+
"comp_context_dict = {c: df_comp[c].to_numpy() for c in [\"logp\", \"sascore\", \"mol_weight\"]} \n",
|
74 |
+
"comp_smiles = df_comp[\"smiles\"]\n",
|
75 |
+
"\n"
|
76 |
+
]
|
77 |
+
},
|
78 |
+
{
|
79 |
+
"cell_type": "code",
|
80 |
+
"execution_count": 6,
|
81 |
+
"metadata": {},
|
82 |
+
"outputs": [
|
83 |
+
{
|
84 |
+
"name": "stderr",
|
85 |
+
"output_type": "stream",
|
86 |
+
"text": [
|
87 |
+
"INFO:root:Wrote file /home/ndobberstein/Projekte/llama2-molgen/chemiscope_gen.json\n"
|
88 |
+
]
|
89 |
+
}
|
90 |
+
],
|
91 |
+
"source": [
|
92 |
+
"from typing import List, Dict\n",
|
93 |
+
"import json\n",
|
94 |
+
"from rdkit.Chem import AllChem\n",
|
95 |
+
"\n",
|
96 |
+
"@torch.no_grad()\n",
|
97 |
+
"def convert_to_chemiscope(smiles_list : List[str], context_dict : Dict[str, List[float]]):\n",
|
98 |
+
" # For more details on the file format: https://chemiscope.org/docs/tutorial/input-reference.html\n",
|
99 |
+
"\n",
|
100 |
+
" structures = []\n",
|
101 |
+
" remove_list = []\n",
|
102 |
+
" for i,smi in enumerate(smiles_list):\n",
|
103 |
+
" mol = Chem.MolFromSmiles(smi)\n",
|
104 |
+
" if mol is None:\n",
|
105 |
+
" logging.info(f\"Mol invalid: {smi} ! Skipping...\")\n",
|
106 |
+
" remove_list.append(i)\n",
|
107 |
+
" continue\n",
|
108 |
+
"\n",
|
109 |
+
" res = AllChem.EmbedMolecule(mol,randomSeed=0xf00d, maxAttempts=20)\n",
|
110 |
+
" # res = AllChem.Compute2DCoords(mol)\n",
|
111 |
+
"\n",
|
112 |
+
" if res != 0:\n",
|
113 |
+
" logging.info(f\"Could not calculate coordinates for {smi}! Skipping..\")\n",
|
114 |
+
" remove_list.append(i)\n",
|
115 |
+
" continue\n",
|
116 |
+
" \n",
|
117 |
+
"\n",
|
118 |
+
" conf = list(mol.GetConformers())[0]\n",
|
119 |
+
" x,y,z = [],[],[]\n",
|
120 |
+
" symbols = []\n",
|
121 |
+
" for atom, coords in zip(mol.GetAtoms(), conf.GetPositions()):\n",
|
122 |
+
" symbols.append(atom.GetSymbol())\n",
|
123 |
+
" x.append(coords[0])\n",
|
124 |
+
" y.append(coords[1])\n",
|
125 |
+
" z.append(coords[2])\n",
|
126 |
+
" \n",
|
127 |
+
" structures.append({\n",
|
128 |
+
" \"size\": len(x),\n",
|
129 |
+
" \"names\": symbols,\n",
|
130 |
+
" \"x\": x,\n",
|
131 |
+
" \"y\": y,\n",
|
132 |
+
" \"z\" : z\n",
|
133 |
+
" })\n",
|
134 |
+
"\n",
|
135 |
+
"\n",
|
136 |
+
"\n",
|
137 |
+
" properties = {}\n",
|
138 |
+
" \n",
|
139 |
+
" for c in context_dict:\n",
|
140 |
+
" properties[c] = {\n",
|
141 |
+
" \"target\": \"structure\",\n",
|
142 |
+
" \"values\": [v for i, v in enumerate(context_dict[c]) if i not in remove_list]\n",
|
143 |
+
" }\n",
|
144 |
+
" \n",
|
145 |
+
"\n",
|
146 |
+
"\n",
|
147 |
+
" \n",
|
148 |
+
" data = {\n",
|
149 |
+
" \"meta\": {\n",
|
150 |
+
" # // the name of the dataset\n",
|
151 |
+
" \"name\": \"Test Dataset\",\n",
|
152 |
+
" # // description of the dataset, OPTIONAL\n",
|
153 |
+
" \"description\": \"This contains data from generated molecules\",\n",
|
154 |
+
" # // authors of the dataset, OPTIONAL\n",
|
155 |
+
" \"authors\": [\"Niklas Dobberstein, [email protected]\"],\n",
|
156 |
+
" # // references for the dataset, OPTIONAL\n",
|
157 |
+
" \"references\": [\n",
|
158 |
+
" \"\",\n",
|
159 |
+
" ],\n",
|
160 |
+
" \n",
|
161 |
+
" },\n",
|
162 |
+
" \"properties\": properties,\n",
|
163 |
+
" \"structures\": structures\n",
|
164 |
+
" }\n",
|
165 |
+
" \n",
|
166 |
+
" out_path = os.path.join(os.getcwd(), \"chemiscope_gen.json\")\n",
|
167 |
+
" with open(out_path, \"w\") as f:\n",
|
168 |
+
" json.dump(data, f)\n",
|
169 |
+
"\n",
|
170 |
+
" logging.info(f\"Wrote file {out_path}\")\n",
|
171 |
+
"\n",
|
172 |
+
"convert_to_chemiscope([\n",
|
173 |
+
" \"CC=O\",\n",
|
174 |
+
" \"s1ccnc1\"\n",
|
175 |
+
"], {\"logp\": [1.0,2.0], \"sascore\": [1.5,-2.0]})"
|
176 |
+
]
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"cell_type": "code",
|
180 |
+
"execution_count": 7,
|
181 |
+
"metadata": {},
|
182 |
+
"outputs": [
|
183 |
+
{
|
184 |
+
"data": {
|
185 |
+
"application/vnd.jupyter.widget-view+json": {
|
186 |
+
"model_id": "8b28a4e692de4bb48fde10a88d9727ba",
|
187 |
+
"version_major": 2,
|
188 |
+
"version_minor": 0
|
189 |
+
},
|
190 |
+
"text/plain": [
|
191 |
+
"HBox(children=(Checkbox(value=False, description='logp'), Checkbox(value=False, description='sascore'), Checkb…"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
"metadata": {},
|
195 |
+
"output_type": "display_data"
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"data": {
|
199 |
+
"application/vnd.jupyter.widget-view+json": {
|
200 |
+
"model_id": "62331a62f2bf4d08a3a202ad277c6d92",
|
201 |
+
"version_major": 2,
|
202 |
+
"version_minor": 0
|
203 |
+
},
|
204 |
+
"text/plain": [
|
205 |
+
"HBox(children=(FloatSlider(value=0.0, description='logp:', max=7.0, min=-4.0, step=0.5), FloatSlider(value=2.0…"
|
206 |
+
]
|
207 |
+
},
|
208 |
+
"metadata": {},
|
209 |
+
"output_type": "display_data"
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"data": {
|
213 |
+
"application/vnd.jupyter.widget-view+json": {
|
214 |
+
"model_id": "2d498af39f4046b0a5bb92080361dfec",
|
215 |
+
"version_major": 2,
|
216 |
+
"version_minor": 0
|
217 |
+
},
|
218 |
+
"text/plain": [
|
219 |
+
"Text(value='', description='Context SMI:')"
|
220 |
+
]
|
221 |
+
},
|
222 |
+
"metadata": {},
|
223 |
+
"output_type": "display_data"
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"data": {
|
227 |
+
"application/vnd.jupyter.widget-view+json": {
|
228 |
+
"model_id": "ed8a755253444e9c83dc27c5f830588b",
|
229 |
+
"version_major": 2,
|
230 |
+
"version_minor": 0
|
231 |
+
},
|
232 |
+
"text/plain": [
|
233 |
+
"FloatSlider(value=0.8, description='Temperature:', max=2.0)"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
"metadata": {},
|
237 |
+
"output_type": "display_data"
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"data": {
|
241 |
+
"application/vnd.jupyter.widget-view+json": {
|
242 |
+
"model_id": "139e7d1e40984101800e2cbb740280b0",
|
243 |
+
"version_major": 2,
|
244 |
+
"version_minor": 0
|
245 |
+
},
|
246 |
+
"text/plain": [
|
247 |
+
"Button(description='Generate', style=ButtonStyle())"
|
248 |
+
]
|
249 |
+
},
|
250 |
+
"metadata": {},
|
251 |
+
"output_type": "display_data"
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"data": {
|
255 |
+
"application/vnd.jupyter.widget-view+json": {
|
256 |
+
"model_id": "4d119a3b477243ac916478a6ec2a55c7",
|
257 |
+
"version_major": 2,
|
258 |
+
"version_minor": 0
|
259 |
+
},
|
260 |
+
"text/plain": [
|
261 |
+
"Output()"
|
262 |
+
]
|
263 |
+
},
|
264 |
+
"metadata": {},
|
265 |
+
"output_type": "display_data"
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"data": {
|
269 |
+
"application/vnd.jupyter.widget-view+json": {
|
270 |
+
"model_id": "dfce28d4f6a3414c838e6542ffb43fc6",
|
271 |
+
"version_major": 2,
|
272 |
+
"version_minor": 0
|
273 |
+
},
|
274 |
+
"text/plain": [
|
275 |
+
"Output()"
|
276 |
+
]
|
277 |
+
},
|
278 |
+
"metadata": {},
|
279 |
+
"output_type": "display_data"
|
280 |
+
}
|
281 |
+
],
|
282 |
+
"source": [
|
283 |
+
"import ipywidgets as widgets\n",
|
284 |
+
"from IPython.display import display, clear_output, HTML\n",
|
285 |
+
"import numpy as np\n",
|
286 |
+
"import torch\n",
|
287 |
+
"import matplotlib.pyplot as plt\n",
|
288 |
+
"from rdkit import Chem\n",
|
289 |
+
"from rdkit.Chem import Draw\n",
|
290 |
+
"import logging\n",
|
291 |
+
"from plot_utils import calc_context_from_smiles\n",
|
292 |
+
"\n",
|
293 |
+
"# Define the context_cols options and create checkboxes for them\n",
|
294 |
+
"context_cols_options = [\"logp\", \"sascore\", \"mol_weight\"]\n",
|
295 |
+
"context_cols_checkboxes = [widgets.Checkbox(description=col, value=False) for col in context_cols_options]\n",
|
296 |
+
"\n",
|
297 |
+
"# Create a text input for context_smi\n",
|
298 |
+
"context_smi_input = widgets.Text(description=\"Context SMI:\", value=\"\")\n",
|
299 |
+
"\n",
|
300 |
+
"# Create sliders for temperature and context_cols values\n",
|
301 |
+
"temperature_slider = widgets.FloatSlider(description=\"Temperature:\", min=0, max=2.0, step=0.1, value=0.8)\n",
|
302 |
+
"\n",
|
303 |
+
"logp_slider = widgets.FloatSlider(description=\"logp:\", min=-4, max=7, step=0.5, value=0.0)\n",
|
304 |
+
"sascore_slider = widgets.FloatSlider(description=\"sascore:\", min=1, max=10, step=0.5, value=2.0)\n",
|
305 |
+
"mol_weight_slider = widgets.FloatSlider(description=\"mol_weight:\", min=0.5, max=10, step=0.5, value=3.0)\n",
|
306 |
+
"\n",
|
307 |
+
"# Create a button to generate the code and display SMILES\n",
|
308 |
+
"generate_button = widgets.Button(description=\"Generate\")\n",
|
309 |
+
"\n",
|
310 |
+
"# Create an output widget for displaying generated information\n",
|
311 |
+
"output = widgets.Output()\n",
|
312 |
+
"\n",
|
313 |
+
"# Create an output widget for displaying the RDKit molecules\n",
|
314 |
+
"molecule_output = widgets.Output()\n",
|
315 |
+
"\n",
|
316 |
+
"@torch.no_grad()\n",
|
317 |
+
"def generate_code(_):\n",
|
318 |
+
" with output:\n",
|
319 |
+
" clear_output(wait=False)\n",
|
320 |
+
" # logging.info(\"Parameters used in generation:\")\n",
|
321 |
+
" \n",
|
322 |
+
" # Get the selected context_cols\n",
|
323 |
+
" selected_context_cols = [col for col, checkbox in zip(context_cols_options, context_cols_checkboxes) if checkbox.value]\n",
|
324 |
+
" # logging.info(f\"Context Cols: {selected_context_cols}\")\n",
|
325 |
+
" \n",
|
326 |
+
" # Get the values of context_smi and temperature from the sliders\n",
|
327 |
+
" context_smi = context_smi_input.value.strip()\n",
|
328 |
+
" temperature = temperature_slider.value\n",
|
329 |
+
" # logging.info(f\"Context Smiles: {context_smi}\")\n",
|
330 |
+
" # logging.info(f\"Temperature: {temperature}\")\n",
|
331 |
+
" \n",
|
332 |
+
" # Get the values of logp, sascore, and mol_weight from the sliders\n",
|
333 |
+
" context_dict = {} if len(selected_context_cols) != 0 else None\n",
|
334 |
+
" for c in selected_context_cols:\n",
|
335 |
+
" if c == \"logp\":\n",
|
336 |
+
" val = logp_slider.value\n",
|
337 |
+
" elif c == \"sascore\":\n",
|
338 |
+
" val = sascore_slider.value\n",
|
339 |
+
" else:\n",
|
340 |
+
" val = mol_weight_slider.value\n",
|
341 |
+
" val = round(val, 2)\n",
|
342 |
+
" context_dict[c] = val*torch.ones((num_samples,),device=device,dtype=torch.float)\n",
|
343 |
+
" # logging.info(f\"{c}: {val}\")\n",
|
344 |
+
" \n",
|
345 |
+
" # Generate SMILES using the provided context\n",
|
346 |
+
" smiles, context = sampler.generate(\n",
|
347 |
+
" context_cols=context_dict,\n",
|
348 |
+
" context_smi=context_smi,\n",
|
349 |
+
" start_smiles=None,\n",
|
350 |
+
" num_samples=num_samples,\n",
|
351 |
+
" max_new_tokens=256,\n",
|
352 |
+
" temperature=temperature,\n",
|
353 |
+
" top_k=25,\n",
|
354 |
+
" total_gen_steps=int(np.ceil(num_samples / 1000)),\n",
|
355 |
+
" return_context=True\n",
|
356 |
+
" )\n",
|
357 |
+
" \n",
|
358 |
+
" with open(os.path.join(os.getcwd(), \"gen_smiles.txt\"), \"w\") as f:\n",
|
359 |
+
" for s in smiles:\n",
|
360 |
+
" f.write(f\"{s}\\n\")\n",
|
361 |
+
" # Display SMILES as RDKit molecules\n",
|
362 |
+
" display_molecules(smiles, context)\n",
|
363 |
+
"\n",
|
364 |
+
"\n",
|
365 |
+
"\n",
|
366 |
+
"def display_molecules(smiles_list, context_dict):\n",
|
367 |
+
" with molecule_output:\n",
|
368 |
+
" clear_output(wait=False)\n",
|
369 |
+
" molecules = [Chem.MolFromSmiles(smiles) for smiles in smiles_list]\n",
|
370 |
+
" \n",
|
371 |
+
" # Convert RDKit molecules to images and store them in a list\n",
|
372 |
+
" images = [Draw.MolToImage(mol) for mol in molecules]\n",
|
373 |
+
" \n",
|
374 |
+
" # Create a subplot grid to display the images\n",
|
375 |
+
" num_images = len(images)\n",
|
376 |
+
" num_cols = 5 # Number of columns in the grid\n",
|
377 |
+
" num_rows = (num_images + num_cols - 1) // num_cols # Calculate the number of rows\n",
|
378 |
+
" \n",
|
379 |
+
" fig, axes = plt.subplots(num_rows, num_cols, figsize=(25, 25))\n",
|
380 |
+
" fig.subplots_adjust(hspace=0.5)\n",
|
381 |
+
" calculated_context = {c:[] for c in context_dict}\n",
|
382 |
+
" for i, ax in enumerate(axes.flat):\n",
|
383 |
+
" if i < num_images:\n",
|
384 |
+
" ax.imshow(images[i])\n",
|
385 |
+
" for j, c in enumerate(context_dict):\n",
|
386 |
+
" smiles = smiles_list[i]\n",
|
387 |
+
" smi_con = round(calc_context_from_smiles([smiles], c)[0],2)\n",
|
388 |
+
" calculated_context[c].append(smi_con)\n",
|
389 |
+
" ax.text(0.5, -0.1 * j , f\"{c}: {context_dict[c][i]} vs {smi_con}\", transform=ax.transAxes, fontsize=10, ha='center')\n",
|
390 |
+
" \n",
|
391 |
+
" ax.axis('off')\n",
|
392 |
+
" else:\n",
|
393 |
+
" fig.delaxes(ax) # Remove empty subplots if there are more rows than images\n",
|
394 |
+
" \n",
|
395 |
+
"\n",
|
396 |
+
" if len(context_dict) >= 2:\n",
|
397 |
+
" convert_to_chemiscope(smiles_list, calculated_context)\n",
|
398 |
+
"\n",
|
399 |
+
" plt.savefig(\"gen_mols.png\")\n",
|
400 |
+
" plt.show()\n",
|
401 |
+
"\n",
|
402 |
+
"# Attach the generate_code function to the button's click event\n",
|
403 |
+
"generate_button.on_click(generate_code)\n",
|
404 |
+
"\n",
|
405 |
+
"# Display the widgets\n",
|
406 |
+
"display(widgets.HBox(context_cols_checkboxes))\n",
|
407 |
+
"display(widgets.HBox((logp_slider, sascore_slider, mol_weight_slider)))\n",
|
408 |
+
"\n",
|
409 |
+
"display(context_smi_input)\n",
|
410 |
+
"display(temperature_slider)\n",
|
411 |
+
"display(generate_button)\n",
|
412 |
+
"display(output)\n",
|
413 |
+
"display(molecule_output)"
|
414 |
+
]
|
415 |
+
},
|
416 |
+
{
|
417 |
+
"cell_type": "code",
|
418 |
+
"execution_count": null,
|
419 |
+
"metadata": {},
|
420 |
+
"outputs": [
|
421 |
+
{
|
422 |
+
"data": {
|
423 |
+
"application/vnd.jupyter.widget-view+json": {
|
424 |
+
"model_id": "ea96e00e0ea8448d97906ec965f04788",
|
425 |
+
"version_major": 2,
|
426 |
+
"version_minor": 0
|
427 |
+
},
|
428 |
+
"text/plain": [
|
429 |
+
"Batch: 0%| | 0/1 [00:00<?, ?it/s]"
|
430 |
+
]
|
431 |
+
},
|
432 |
+
"metadata": {},
|
433 |
+
"output_type": "display_data"
|
434 |
+
},
|
435 |
+
{
|
436 |
+
"data": {
|
437 |
+
"application/vnd.jupyter.widget-view+json": {
|
438 |
+
"model_id": "77ba2d72172846e18572c94bc5b3bd6f",
|
439 |
+
"version_major": 2,
|
440 |
+
"version_minor": 0
|
441 |
+
},
|
442 |
+
"text/plain": [
|
443 |
+
"Generation: 0%| | 0/256 [00:00<?, ?it/s]"
|
444 |
+
]
|
445 |
+
},
|
446 |
+
"metadata": {},
|
447 |
+
"output_type": "display_data"
|
448 |
+
},
|
449 |
+
{
|
450 |
+
"name": "stderr",
|
451 |
+
"output_type": "stream",
|
452 |
+
"text": [
|
453 |
+
"INFO:sample:Number valid generated: 68.0 %\n",
|
454 |
+
"INFO:sample:---------------\n"
|
455 |
+
]
|
456 |
+
}
|
457 |
+
],
|
458 |
+
"source": [
|
459 |
+
"selected_context_cols = [\"logp\", \"sascore\", \"mol_weight\"]\n",
|
460 |
+
"num_samples = 25\n",
|
461 |
+
"context_dict = {} if len(selected_context_cols) != 0 else None\n",
|
462 |
+
"for c in selected_context_cols:\n",
|
463 |
+
" if c == \"logp\":\n",
|
464 |
+
" v = 0.5 * torch.randint(\n",
|
465 |
+
" -8, 14, (num_samples,), device=device, dtype=torch.float\n",
|
466 |
+
" )\n",
|
467 |
+
" context_dict[c] = v.sort()[0]\n",
|
468 |
+
" elif c == \"sascore\":\n",
|
469 |
+
" v = 0.5 * torch.randint(\n",
|
470 |
+
" 1, 20, (num_samples,), device=device, dtype=torch.float\n",
|
471 |
+
" )\n",
|
472 |
+
" context_dict[c] = v.sort()[0]\n",
|
473 |
+
" else:\n",
|
474 |
+
" v = 0.5 * torch.randint(\n",
|
475 |
+
" 1, 20, (num_samples,), device=device, dtype=torch.float\n",
|
476 |
+
" )\n",
|
477 |
+
" \n",
|
478 |
+
" context_dict[c] = v.sort()[0]\n",
|
479 |
+
" # logging.info(f\"{c}: {val}\")\n",
|
480 |
+
"\n",
|
481 |
+
"# Generate SMILES using the provided context\n",
|
482 |
+
"smiles, context = sampler.generate(\n",
|
483 |
+
" context_cols=context_dict,\n",
|
484 |
+
" context_smi=None,\n",
|
485 |
+
" start_smiles=None,\n",
|
486 |
+
" num_samples=num_samples,\n",
|
487 |
+
" max_new_tokens=256,\n",
|
488 |
+
" temperature=0.8,\n",
|
489 |
+
" top_k=25,\n",
|
490 |
+
" total_gen_steps=int(np.ceil(num_samples / 1000)),\n",
|
491 |
+
" return_context=True\n",
|
492 |
+
")\n",
|
493 |
+
"\n",
|
494 |
+
"# Display SMILES as RDKit molecules\n",
|
495 |
+
"display_molecules(smiles, context)\n"
|
496 |
+
]
|
497 |
+
}
|
498 |
+
],
|
499 |
+
"metadata": {
|
500 |
+
"kernelspec": {
|
501 |
+
"display_name": "torch2-bachelor",
|
502 |
+
"language": "python",
|
503 |
+
"name": "python3"
|
504 |
+
},
|
505 |
+
"language_info": {
|
506 |
+
"codemirror_mode": {
|
507 |
+
"name": "ipython",
|
508 |
+
"version": 3
|
509 |
+
},
|
510 |
+
"file_extension": ".py",
|
511 |
+
"mimetype": "text/x-python",
|
512 |
+
"name": "python",
|
513 |
+
"nbconvert_exporter": "python",
|
514 |
+
"pygments_lexer": "ipython3",
|
515 |
+
"version": "3.8.18"
|
516 |
+
},
|
517 |
+
"orig_nbformat": 4
|
518 |
+
},
|
519 |
+
"nbformat": 4,
|
520 |
+
"nbformat_minor": 2
|
521 |
+
}
|
fragment_creator.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from abc import ABC
|
2 |
+
from dataclasses import dataclass
|
3 |
+
from typing import List, Union
|
4 |
+
import numpy as np
|
5 |
+
from rdkit import Chem
|
6 |
+
from rdkit.Chem.BRICS import BRICSDecompose
|
7 |
+
from rdkit.Chem.Recap import RecapDecompose
|
8 |
+
|
9 |
+
import random
|
10 |
+
|
11 |
+
|
12 |
+
@dataclass
|
13 |
+
class Fragment:
|
14 |
+
smiles: Union[str, None]
|
15 |
+
tokens: Union[List[int], None]
|
16 |
+
|
17 |
+
|
18 |
+
class BaseFragmentCreator(ABC):
|
19 |
+
"""
|
20 |
+
Is the base class for all fragment creator and does nothing to the smiles
|
21 |
+
"""
|
22 |
+
|
23 |
+
def __init__(self) -> None:
|
24 |
+
pass
|
25 |
+
|
26 |
+
def create_fragment(self, frag: Fragment) -> Fragment:
|
27 |
+
return ""
|
28 |
+
|
29 |
+
|
30 |
+
# This is the method used in the paper
|
31 |
+
class RandomSubsliceFragmentCreator(BaseFragmentCreator):
|
32 |
+
def __init__(self, max_fragment_size=50) -> None:
|
33 |
+
super().__init__()
|
34 |
+
self.max_fragment_size = max_fragment_size
|
35 |
+
|
36 |
+
def create_fragment(self, frag: Fragment) -> Fragment:
|
37 |
+
"""
|
38 |
+
Creates the random sub slice fragments from the tokens
|
39 |
+
"""
|
40 |
+
tokens = frag.tokens
|
41 |
+
|
42 |
+
startIdx = np.random.randint(0, len(tokens) - 1)
|
43 |
+
|
44 |
+
endIdx = np.random.randint(
|
45 |
+
startIdx + 1, min(len(tokens), startIdx + self.max_fragment_size)
|
46 |
+
)
|
47 |
+
return Fragment(smiles=None, tokens=tokens[startIdx:endIdx])
|
48 |
+
|
49 |
+
|
50 |
+
class BricksFragmentCreator(BaseFragmentCreator):
|
51 |
+
def __init__(self) -> None:
|
52 |
+
super().__init__()
|
53 |
+
|
54 |
+
def create_fragment(self, frag: Fragment) -> Fragment:
|
55 |
+
"""
|
56 |
+
Creates the Bricks fragments and takes one randomly
|
57 |
+
"""
|
58 |
+
smiles = frag.smiles
|
59 |
+
m = Chem.MolFromSmiles(smiles)
|
60 |
+
if m is None:
|
61 |
+
return ""
|
62 |
+
|
63 |
+
res = list(BRICSDecompose(m, minFragmentSize=3))
|
64 |
+
# print(res)
|
65 |
+
return random.choice(res)
|
66 |
+
|
67 |
+
|
68 |
+
class RecapFragmentCreator(BaseFragmentCreator):
|
69 |
+
def __init__(self) -> None:
|
70 |
+
super().__init__()
|
71 |
+
|
72 |
+
def create_fragment(self, frag: Fragment) -> Fragment:
|
73 |
+
"""
|
74 |
+
Creates the Recap fragments and takes one randomly
|
75 |
+
"""
|
76 |
+
smiles = frag.smiles
|
77 |
+
m = Chem.MolFromSmiles(smiles)
|
78 |
+
if m is None:
|
79 |
+
return ""
|
80 |
+
|
81 |
+
res = RecapDecompose(m, minFragmentSize=3).GetAllChildren()
|
82 |
+
# print(res)
|
83 |
+
return random.choice(res)
|
84 |
+
|
85 |
+
|
86 |
+
class MolFragsFragmentCreator(BaseFragmentCreator):
|
87 |
+
def __init__(self) -> None:
|
88 |
+
super().__init__()
|
89 |
+
|
90 |
+
def create_fragment(self, frag: Fragment) -> Fragment:
|
91 |
+
"""
|
92 |
+
Creates the Bricks fragments and takes one randomly
|
93 |
+
"""
|
94 |
+
smiles = frag.smiles
|
95 |
+
m = Chem.MolFromSmiles(smiles)
|
96 |
+
if m is None:
|
97 |
+
return ""
|
98 |
+
|
99 |
+
res = list(Chem.rdmolops.GetMolFrags(m, asMols=True))
|
100 |
+
res = [Chem.MolToSmiles(m) for m in res]
|
101 |
+
# print(res)
|
102 |
+
return random.choice(res)
|
103 |
+
|
104 |
+
|
105 |
+
def fragment_creator_factory(key: Union[str, None]):
|
106 |
+
if key is None:
|
107 |
+
return None
|
108 |
+
|
109 |
+
if key == "mol_frags":
|
110 |
+
return MolFragsFragmentCreator()
|
111 |
+
elif key == "recap":
|
112 |
+
return RecapFragmentCreator()
|
113 |
+
elif key == "bricks":
|
114 |
+
return BricksFragmentCreator()
|
115 |
+
elif key == "rss":
|
116 |
+
return RandomSubsliceFragmentCreator()
|
117 |
+
else:
|
118 |
+
raise ValueError(f"Do not have factory for the given key: {key}")
|
119 |
+
|
120 |
+
|
121 |
+
if __name__ == "__main__":
|
122 |
+
from tokenizer import SmilesTokenizer
|
123 |
+
|
124 |
+
tokenizer = SmilesTokenizer()
|
125 |
+
|
126 |
+
creator = BricksFragmentCreator()
|
127 |
+
# creator = MolFragsFragmentCreator()
|
128 |
+
|
129 |
+
# creator = RecapFragmentCreator()
|
130 |
+
|
131 |
+
frag = creator.create_fragment("CC(=O)NC1=CC=C(C=C1)O")
|
132 |
+
|
133 |
+
print(frag)
|
134 |
+
tokens = tokenizer.encode(frag)
|
135 |
+
print(tokens)
|
136 |
+
print([tokenizer._convert_id_to_token(t) for t in tokens])
|
generate_paper_graphs.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# TODO: Change FULL_PATH_TO_CONDA to the binary where the conda folder is: see https://github.com/conda/conda/issues/8536
|
4 |
+
conda activate FULL_PATH_TO_CONDA/torch2-llamol
|
5 |
+
|
6 |
+
array=( logp sascore mol_weight )
|
7 |
+
# python sample.py --num_samples 20000 --num_samples_per_step 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet"
|
8 |
+
# for i in "${array[@]}"
|
9 |
+
# do
|
10 |
+
# python sample.py --num_samples 10000 --num_samples_per_step 500 --kv_caching --ckpt_path "out/llama2-M-Full-RSS.pt" --context_cols "$i" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet"
|
11 |
+
# done
|
12 |
+
|
13 |
+
# 2 Combinations
|
14 |
+
python sample.py --num_samples 1000 --seed 4321 --kv_caching --ckpt_path "out/llama2-M-Full-RSS.pt" --context_cols logp sascore --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet"
|
15 |
+
python sample.py --num_samples 1000 --seed 4321 --kv_caching --ckpt_path "out/llama2-M-Full-RSS.pt" --context_cols logp mol_weight --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet"
|
16 |
+
python sample.py --num_samples 1000 --seed 4321 --kv_caching --ckpt_path "out/llama2-M-Full-RSS.pt" --context_cols sascore mol_weight --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet"
|
17 |
+
|
18 |
+
# # # All 3
|
19 |
+
# python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --context_cols logp sascore mol_weight --kv_caching --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --seed 4312
|
get_fragment_table.sh
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# TODO: Change FULL_PATH_TO_CONDA to the binary where the conda folder is: see https://github.com/conda/conda/issues/8536
|
4 |
+
conda activate FULL_PATH_TO_CONDA/torch2-llamol
|
5 |
+
|
6 |
+
|
7 |
+
# context_smiles=("c1ccccc1" "s1cccc1" "C1=CSC=C1" "CC1=CSC=C1" "C1=CC=C2C(=C1)C3=CC=CC=C3S2" "CCO" "CC=O" "CC(=O)OC1=CC=CC=C1C(=O)O" "CC(=O)NC1=CC=C(C=C1)O" "CC(C)CC1=CC=C(C=C1)C(C)C(=O)O" "OC(=O)C(C)c1ccc(cc1)CC(C)C" "C1C(=O)NC(=O)NC1=O" "CN1C=NC2=C1C(=O)N(C(=O)N2C)C" "CN1CCC23C4C1CC5=C2C(=C(C=C5)O)OC3C(C=C4)O" "CN1CCC23C4C1CC5=C2C(=C(C=C5)OC)OC3C(=O)CC4")
|
8 |
+
# context_smiles=("CN1CCC23C4C1CC5=C2C(=C(C=C5)O)OC3C(C=C4)O" "CN1CC[C@]23[C@@H]4[C@H]1CC5=C2C(=C(C=C5)O)O[C@H]3[C@H](C=C4)O" "CN1CCC23C4C1CC5=C2C(=C(C=C5)OC)OC3C(=O)CC4" "CN1CC[C@]23[C@@H]4[C@H]1CC5=C2C(=C(C=C5)OC)O[C@H]3C(=O)CC4" )
|
9 |
+
# context_smiles=("C1=CSC=C1" )
|
10 |
+
context_smiles=("C1=CSC=C1" "CC=O" "CC(=O)NC1=CC=C(C=C1)O" "CN1C=NC2=C1C(=O)N(C(=O)N2C)C")
|
11 |
+
for smi in "${context_smiles[@]}"; do
|
12 |
+
# Only fragment generation
|
13 |
+
# output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi")
|
14 |
+
|
15 |
+
# Fragment and LogP
|
16 |
+
# output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "logp" )
|
17 |
+
|
18 |
+
# Fragment and Sascore
|
19 |
+
# output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "sascore" )
|
20 |
+
|
21 |
+
# Fragment and Mol weight
|
22 |
+
# output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "mol_weight" )
|
23 |
+
|
24 |
+
# Multi Fragment Condition
|
25 |
+
|
26 |
+
# Logp + Sascore
|
27 |
+
# output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "logp" "sascore" )
|
28 |
+
|
29 |
+
|
30 |
+
# Logp + Mol Weight
|
31 |
+
# output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "logp" "mol_weight" )
|
32 |
+
|
33 |
+
# Sascore + Mol Weight
|
34 |
+
# output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "sascore" "mol_weight" )
|
35 |
+
|
36 |
+
# Logp + Sascore + Mol Weight
|
37 |
+
output=$(python sample.py --num_samples 1000 --ckpt_path "out/llama2-M-Full-RSS.pt" --max_new_tokens 256 --cmp_dataset_path="data/OrganiX13.parquet" --context_smi "$smi" --context_cols "logp" "sascore" "mol_weight" )
|
38 |
+
|
39 |
+
|
40 |
+
echo "SMI: $smi"
|
41 |
+
echo "----------------------"
|
42 |
+
done
|
model.py
ADDED
@@ -0,0 +1,787 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from __future__ import annotations
|
2 |
+
|
3 |
+
import math
|
4 |
+
import pickle
|
5 |
+
import struct
|
6 |
+
import inspect
|
7 |
+
from dataclasses import dataclass, field
|
8 |
+
from typing import Any, Dict, Optional, Tuple, List, Union
|
9 |
+
|
10 |
+
import numpy as np
|
11 |
+
import torch
|
12 |
+
import torch.nn.functional as F
|
13 |
+
from torch import nn
|
14 |
+
from tqdm.auto import tqdm
|
15 |
+
|
16 |
+
from tokenizer import SmilesTokenizer
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class ModelArgs:
|
21 |
+
dim: int = 4096
|
22 |
+
n_layers: int = 32
|
23 |
+
n_heads: int = 32
|
24 |
+
n_kv_heads: Optional[int] = None
|
25 |
+
vocab_size: int = -1 # defined later by tokenizer
|
26 |
+
multiple_of: int = 256 # make SwiGLU hidden layer size multiple of large power of 2
|
27 |
+
norm_eps: float = 1e-5
|
28 |
+
max_seq_len: int = 2048
|
29 |
+
dropout: float = 0.0
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class ContextArgs:
|
34 |
+
context_keys: List[str] = field(default_factory=list)
|
35 |
+
context_dims: List[int] = field(default_factory=list)
|
36 |
+
|
37 |
+
|
38 |
+
class RMSNorm(torch.nn.Module):
|
39 |
+
def __init__(self, dim: int, eps: float):
|
40 |
+
super().__init__()
|
41 |
+
self.eps = eps
|
42 |
+
self.weight = nn.Parameter(torch.ones(dim))
|
43 |
+
|
44 |
+
def _norm(self, x):
|
45 |
+
return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
|
46 |
+
|
47 |
+
def forward(self, x):
|
48 |
+
output = self._norm(x.float()).type_as(x)
|
49 |
+
return output * self.weight
|
50 |
+
|
51 |
+
|
52 |
+
def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0):
|
53 |
+
freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim))
|
54 |
+
t = torch.arange(end, device=freqs.device) # type: ignore
|
55 |
+
freqs = torch.outer(t, freqs).float() # type: ignore
|
56 |
+
freqs_cos = torch.cos(freqs) # real part
|
57 |
+
freqs_sin = torch.sin(freqs) # imaginary part
|
58 |
+
return freqs_cos, freqs_sin
|
59 |
+
|
60 |
+
|
61 |
+
def reshape_for_broadcast(freqs_cis: torch.Tensor, x: torch.Tensor):
|
62 |
+
ndim = x.ndim
|
63 |
+
assert 0 <= 1 < ndim
|
64 |
+
assert freqs_cis.shape == (x.shape[1], x.shape[-1])
|
65 |
+
shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)]
|
66 |
+
return freqs_cis.view(shape)
|
67 |
+
|
68 |
+
|
69 |
+
def apply_rotary_emb(
|
70 |
+
xq: torch.Tensor, xk: torch.Tensor, freqs_cos: torch.Tensor, freqs_sin: torch.Tensor
|
71 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
72 |
+
# reshape xq and xk to match the complex representation
|
73 |
+
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
|
74 |
+
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
|
75 |
+
|
76 |
+
# reshape freqs_cos and freqs_sin for broadcasting
|
77 |
+
freqs_cos = reshape_for_broadcast(freqs_cos, xq_r)
|
78 |
+
freqs_sin = reshape_for_broadcast(freqs_sin, xq_r)
|
79 |
+
|
80 |
+
# apply rotation using real numbers
|
81 |
+
xq_out_r = xq_r * freqs_cos - xq_i * freqs_sin
|
82 |
+
xq_out_i = xq_r * freqs_sin + xq_i * freqs_cos
|
83 |
+
xk_out_r = xk_r * freqs_cos - xk_i * freqs_sin
|
84 |
+
xk_out_i = xk_r * freqs_sin + xk_i * freqs_cos
|
85 |
+
|
86 |
+
# flatten last two dimensions
|
87 |
+
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
|
88 |
+
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
|
89 |
+
|
90 |
+
return xq_out.type_as(xq), xk_out.type_as(xk)
|
91 |
+
|
92 |
+
|
93 |
+
def repeat_kv(x: torch.Tensor, n_rep: int) -> torch.Tensor:
|
94 |
+
"""torch.repeat_interleave(x, dim=2, repeats=n_rep)"""
|
95 |
+
bs, slen, n_kv_heads, head_dim = x.shape
|
96 |
+
if n_rep == 1:
|
97 |
+
return x
|
98 |
+
return (
|
99 |
+
x[:, :, :, None, :]
|
100 |
+
.expand(bs, slen, n_kv_heads, n_rep, head_dim)
|
101 |
+
.reshape(bs, slen, n_kv_heads * n_rep, head_dim)
|
102 |
+
)
|
103 |
+
|
104 |
+
|
105 |
+
class Attention(nn.Module):
|
106 |
+
def __init__(self, args: ModelArgs):
|
107 |
+
super().__init__()
|
108 |
+
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
109 |
+
model_parallel_size = 1
|
110 |
+
self.n_local_heads = args.n_heads // model_parallel_size
|
111 |
+
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
112 |
+
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
113 |
+
self.head_dim = args.dim // args.n_heads
|
114 |
+
self.wq = nn.Linear(args.dim, args.n_heads * self.head_dim, bias=False)
|
115 |
+
self.wk = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
116 |
+
self.wv = nn.Linear(args.dim, self.n_kv_heads * self.head_dim, bias=False)
|
117 |
+
self.wo = nn.Linear(args.n_heads * self.head_dim, args.dim, bias=False)
|
118 |
+
self.attn_dropout = nn.Dropout(args.dropout)
|
119 |
+
self.resid_dropout = nn.Dropout(args.dropout)
|
120 |
+
self.dropout = args.dropout
|
121 |
+
self.cache_hash = None
|
122 |
+
|
123 |
+
# use flash attention or a manual implementation?
|
124 |
+
self.flash = hasattr(torch.nn.functional, "scaled_dot_product_attention")
|
125 |
+
if not self.flash:
|
126 |
+
print(
|
127 |
+
"WARNING: using slow attention. Flash Attention requires PyTorch >= 2.0"
|
128 |
+
)
|
129 |
+
mask = torch.full((1, 1, args.max_seq_len, args.max_seq_len), float("-inf"))
|
130 |
+
mask = torch.triu(mask, diagonal=1)
|
131 |
+
self.register_buffer("mask", mask)
|
132 |
+
|
133 |
+
def forward(
|
134 |
+
self,
|
135 |
+
x: torch.Tensor,
|
136 |
+
freqs_cos: torch.Tensor,
|
137 |
+
freqs_sin: torch.Tensor,
|
138 |
+
):
|
139 |
+
bsz, seqlen, _ = x.shape
|
140 |
+
|
141 |
+
# QKV
|
142 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
143 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
144 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
145 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
146 |
+
|
147 |
+
# RoPE relative positional embeddings
|
148 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
|
149 |
+
|
150 |
+
# grouped multiquery attention: expand out keys and values
|
151 |
+
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
152 |
+
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
153 |
+
|
154 |
+
# make heads into a batch dimension
|
155 |
+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
156 |
+
xk = xk.transpose(1, 2)
|
157 |
+
xv = xv.transpose(1, 2)
|
158 |
+
|
159 |
+
# flash implementation
|
160 |
+
if self.flash:
|
161 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
162 |
+
xq,
|
163 |
+
xk,
|
164 |
+
xv,
|
165 |
+
attn_mask=None,
|
166 |
+
dropout_p=self.dropout if self.training else 0.0,
|
167 |
+
is_causal=True,
|
168 |
+
)
|
169 |
+
else:
|
170 |
+
# manual implementation
|
171 |
+
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
172 |
+
assert hasattr(self, "mask")
|
173 |
+
scores = (
|
174 |
+
scores + self.mask[:, :, :seqlen, :seqlen]
|
175 |
+
) # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
176 |
+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
177 |
+
scores = self.attn_dropout(scores)
|
178 |
+
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
|
179 |
+
|
180 |
+
# restore time as batch dimension and concat heads
|
181 |
+
output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
182 |
+
|
183 |
+
# final projection into the residual stream
|
184 |
+
output = self.wo(output)
|
185 |
+
output = self.resid_dropout(output)
|
186 |
+
return output
|
187 |
+
|
188 |
+
def forward_with_kvcache(
|
189 |
+
self,
|
190 |
+
x: torch.Tensor,
|
191 |
+
freqs_cos: torch.Tensor,
|
192 |
+
freqs_sin: torch.Tensor,
|
193 |
+
cache_id: int = 1,
|
194 |
+
):
|
195 |
+
bsz, seqlen, _ = x.shape
|
196 |
+
|
197 |
+
original_x = x
|
198 |
+
use_cache = self.cache_hash == cache_id
|
199 |
+
if use_cache:
|
200 |
+
x = x[:, -1, :].unsqueeze(1) # only need the last new token
|
201 |
+
# QKV
|
202 |
+
xq, xk, xv = self.wq(x), self.wk(x), self.wv(x)
|
203 |
+
if use_cache:
|
204 |
+
# comp_xq, comp_xk, comp_xv = self.wq(original_x), self.wk(original_x), self.wv(original_x)
|
205 |
+
# comp_xq = comp_xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
206 |
+
# comp_xk = comp_xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
207 |
+
# comp_xv = comp_xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
208 |
+
|
209 |
+
# # RoPE relative positional embeddings
|
210 |
+
# comp_xq, comp_xk = apply_rotary_emb(comp_xq, comp_xk, freqs_cos, freqs_sin)
|
211 |
+
|
212 |
+
self.k_cache = torch.concat([self.k_cache, xk.clone()], dim=1)
|
213 |
+
self.v_cache = torch.concat([self.v_cache, xv.clone()], dim=1)
|
214 |
+
# print("Before positional xk:", torch.all(self.k_cache == self.wk(original_x)))
|
215 |
+
# print("Before positional xv:", torch.all(self.v_cache == self.wv(original_x)))
|
216 |
+
|
217 |
+
seqlen = self.k_cache.size(1)
|
218 |
+
xk = self.k_cache
|
219 |
+
xv = self.v_cache
|
220 |
+
self.cache_hash = cache_id
|
221 |
+
xq = xq.view(bsz, 1, self.n_local_heads, self.head_dim)
|
222 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
223 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
224 |
+
|
225 |
+
# RoPE relative positional embeddings
|
226 |
+
# xq, xk = apply_rotary_emb(xq, xk[:,-1,:,:].unsqueeze(1), freqs_cos[-1,:].unsqueeze(0), freqs_sin[-1,:].unsqueeze(0))
|
227 |
+
# reshape xq and xk to match the complex representation
|
228 |
+
xq_r, xq_i = xq.float().reshape(xq.shape[:-1] + (-1, 2)).unbind(-1)
|
229 |
+
xk_r, xk_i = xk.float().reshape(xk.shape[:-1] + (-1, 2)).unbind(-1)
|
230 |
+
|
231 |
+
# reshape freqs_cos and freqs_sin for broadcasting
|
232 |
+
q_freq_cos = freqs_cos[-1, :].unsqueeze(0)
|
233 |
+
q_freq_sin = freqs_sin[-1, :].unsqueeze(0)
|
234 |
+
freqs_cos_q = reshape_for_broadcast(q_freq_cos, xq_r)
|
235 |
+
freqs_sin_q = reshape_for_broadcast(q_freq_sin, xq_r)
|
236 |
+
|
237 |
+
freqs_cos_k = reshape_for_broadcast(freqs_cos, xk_r)
|
238 |
+
freqs_sin_k = reshape_for_broadcast(freqs_sin, xk_r)
|
239 |
+
|
240 |
+
# apply rotation using real numbers
|
241 |
+
xq_out_r = xq_r * freqs_cos_q - xq_i * freqs_sin_q
|
242 |
+
xq_out_i = xq_r * freqs_sin_q + xq_i * freqs_cos_q
|
243 |
+
xk_out_r = xk_r * freqs_cos_k - xk_i * freqs_sin_k
|
244 |
+
xk_out_i = xk_r * freqs_sin_k + xk_i * freqs_cos_k
|
245 |
+
|
246 |
+
# flatten last two dimensions
|
247 |
+
xq_out = torch.stack([xq_out_r, xq_out_i], dim=-1).flatten(3)
|
248 |
+
xk_out = torch.stack([xk_out_r, xk_out_i], dim=-1).flatten(3)
|
249 |
+
|
250 |
+
xq, xk = xq_out.type_as(xq), xk_out.type_as(xk)
|
251 |
+
# print(f"Seq len {xk.shape[1]} xq:", torch.allclose(xq , comp_xq[:,-1,:].unsqueeze(1), atol=1e-7), torch.mean(xq - comp_xq[:,-1,:].unsqueeze(1)))
|
252 |
+
# print(f"Seq len {xk.shape[1]} xk:", torch.allclose(xk ,comp_xk, atol=1e-7), torch.mean(xk - comp_xk))
|
253 |
+
# print(f"Seq len {xk.shape[1]} xv:", torch.allclose(xv , comp_xv, atol=1e-7), torch.mean(xv - comp_xv))
|
254 |
+
# print("-"*10)
|
255 |
+
# self.old_x = original_x
|
256 |
+
else:
|
257 |
+
self.k_cache = xk
|
258 |
+
self.v_cache = xv
|
259 |
+
self.old_x = x
|
260 |
+
|
261 |
+
xq = xq.view(bsz, seqlen, self.n_local_heads, self.head_dim)
|
262 |
+
xk = xk.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
263 |
+
xv = xv.view(bsz, seqlen, self.n_local_kv_heads, self.head_dim)
|
264 |
+
|
265 |
+
self.cache_hash = cache_id
|
266 |
+
|
267 |
+
# RoPE relative positional embeddings
|
268 |
+
xq, xk = apply_rotary_emb(xq, xk, freqs_cos, freqs_sin)
|
269 |
+
|
270 |
+
# grouped multiquery attention: expand out keys and values
|
271 |
+
xk = repeat_kv(xk, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
272 |
+
xv = repeat_kv(xv, self.n_rep) # (bs, seqlen, n_local_heads, head_dim)
|
273 |
+
|
274 |
+
# make heads into a batch dimension
|
275 |
+
xq = xq.transpose(1, 2) # (bs, n_local_heads, seqlen, head_dim)
|
276 |
+
xk = xk.transpose(1, 2)
|
277 |
+
xv = xv.transpose(1, 2)
|
278 |
+
|
279 |
+
# flash implementation
|
280 |
+
if self.flash:
|
281 |
+
output = torch.nn.functional.scaled_dot_product_attention(
|
282 |
+
xq,
|
283 |
+
xk,
|
284 |
+
xv,
|
285 |
+
attn_mask=None,
|
286 |
+
dropout_p=self.dropout if self.training else 0.0,
|
287 |
+
# NOTE: VERY IMPORTANT to set is_causal=False, OTHERWISE the KV-Caching just breaks
|
288 |
+
is_causal=False,
|
289 |
+
)
|
290 |
+
else:
|
291 |
+
# manual implementation
|
292 |
+
scores = torch.matmul(xq, xk.transpose(2, 3)) / math.sqrt(self.head_dim)
|
293 |
+
assert hasattr(self, "mask")
|
294 |
+
scores = (
|
295 |
+
scores + self.mask[:, :, :seqlen, :seqlen]
|
296 |
+
) # (bs, n_local_heads, seqlen, cache_len + seqlen)
|
297 |
+
scores = F.softmax(scores.float(), dim=-1).type_as(xq)
|
298 |
+
scores = self.attn_dropout(scores)
|
299 |
+
output = torch.matmul(scores, xv) # (bs, n_local_heads, seqlen, head_dim)
|
300 |
+
|
301 |
+
# restore time as batch dimension and concat heads
|
302 |
+
# if use_cache:
|
303 |
+
# # original_x[:,-1,:] = output.transpose(1, 2).contiguous().view(bsz,-1)
|
304 |
+
# # output = original_x
|
305 |
+
# output = torch.concat( [self.out_cache, output.transpose(1, 2).view(bsz,1,-1)], dim=1).contiguous()
|
306 |
+
# self.out_cache = output
|
307 |
+
# else:
|
308 |
+
# output = output.transpose(1, 2).contiguous().view(bsz, seqlen, -1)
|
309 |
+
# self.out_cache = output
|
310 |
+
|
311 |
+
# NOTE: only work when fed in one token at a time (e.g. seq = 1)
|
312 |
+
output = output.transpose(1, 2).contiguous().view(bsz, x.size(1), -1)
|
313 |
+
|
314 |
+
# final projection into the residual stream
|
315 |
+
output = self.wo(output)
|
316 |
+
output = self.resid_dropout(output)
|
317 |
+
return output
|
318 |
+
|
319 |
+
|
320 |
+
class FeedForward(nn.Module):
|
321 |
+
def __init__(self, dim: int, hidden_dim: int, multiple_of: int, dropout: float):
|
322 |
+
super().__init__()
|
323 |
+
hidden_dim = int(2 * hidden_dim / 3)
|
324 |
+
hidden_dim = multiple_of * ((hidden_dim + multiple_of - 1) // multiple_of)
|
325 |
+
self.w1 = nn.Linear(dim, hidden_dim, bias=False)
|
326 |
+
self.w2 = nn.Linear(hidden_dim, dim, bias=False)
|
327 |
+
self.w3 = nn.Linear(dim, hidden_dim, bias=False)
|
328 |
+
self.dropout = nn.Dropout(dropout)
|
329 |
+
|
330 |
+
def forward(self, x):
|
331 |
+
return self.dropout(self.w2(F.silu(self.w1(x)) * self.w3(x)))
|
332 |
+
|
333 |
+
|
334 |
+
class TransformerBlock(nn.Module):
|
335 |
+
def __init__(self, layer_id: int, args: ModelArgs):
|
336 |
+
super().__init__()
|
337 |
+
self.n_heads = args.n_heads
|
338 |
+
self.dim = args.dim
|
339 |
+
self.head_dim = args.dim // args.n_heads
|
340 |
+
self.attention = Attention(args)
|
341 |
+
self.feed_forward = FeedForward(
|
342 |
+
dim=args.dim,
|
343 |
+
hidden_dim=4 * args.dim,
|
344 |
+
multiple_of=args.multiple_of,
|
345 |
+
dropout=args.dropout,
|
346 |
+
)
|
347 |
+
self.layer_id = layer_id
|
348 |
+
self.attention_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
349 |
+
self.ffn_norm = RMSNorm(args.dim, eps=args.norm_eps)
|
350 |
+
|
351 |
+
def forward(self, x, freqs_cos, freqs_sin):
|
352 |
+
h = x + self.attention.forward(self.attention_norm(x), freqs_cos, freqs_sin)
|
353 |
+
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
354 |
+
return out
|
355 |
+
|
356 |
+
def forward_with_kvcache(self, x, freqs_cos, freqs_sin, cache_id=1):
|
357 |
+
h = x + self.attention.forward_with_kvcache(
|
358 |
+
self.attention_norm(x), freqs_cos, freqs_sin, cache_id=cache_id
|
359 |
+
)
|
360 |
+
out = h + self.feed_forward.forward(self.ffn_norm(h))
|
361 |
+
return out
|
362 |
+
|
363 |
+
|
364 |
+
class Transformer(nn.Module):
|
365 |
+
last_loss: Optional[torch.Tensor]
|
366 |
+
|
367 |
+
def __init__(self, params: ModelArgs, context_params: ContextArgs):
|
368 |
+
super().__init__()
|
369 |
+
self.params = params
|
370 |
+
self.context_params = context_params
|
371 |
+
self.vocab_size = params.vocab_size
|
372 |
+
self.n_layers = params.n_layers
|
373 |
+
|
374 |
+
self.tok_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
375 |
+
|
376 |
+
self.frag_embeddings = nn.Embedding(params.vocab_size, params.dim)
|
377 |
+
self.frag_type_embedding = nn.Embedding(1, params.dim)
|
378 |
+
|
379 |
+
self.context_lookup = {k: i for i, k in enumerate(context_params.context_keys)}
|
380 |
+
self.conditions_type_embeddings = nn.Embedding(
|
381 |
+
len(context_params.context_keys), params.dim
|
382 |
+
)
|
383 |
+
self.conditions_embeddings_lookup = nn.ModuleDict(
|
384 |
+
{
|
385 |
+
k: nn.Sequential(
|
386 |
+
nn.Linear(dim, params.dim, bias=True),
|
387 |
+
)
|
388 |
+
for k, dim in zip(
|
389 |
+
context_params.context_keys, context_params.context_dims
|
390 |
+
)
|
391 |
+
}
|
392 |
+
)
|
393 |
+
|
394 |
+
self.dropout = nn.Dropout(params.dropout)
|
395 |
+
self.layers = torch.nn.ModuleList()
|
396 |
+
for layer_id in range(params.n_layers):
|
397 |
+
self.layers.append(TransformerBlock(layer_id, params))
|
398 |
+
self.norm = RMSNorm(params.dim, eps=params.norm_eps)
|
399 |
+
self.output = nn.Linear(params.dim, params.vocab_size, bias=False)
|
400 |
+
|
401 |
+
# share the unembedding parameters with the embedding parameters
|
402 |
+
self.tok_embeddings.weight = (
|
403 |
+
self.output.weight
|
404 |
+
) # https://paperswithcode.com/method/weight-tying
|
405 |
+
|
406 |
+
# some useful precompute for the RoPE relative positional embeddings
|
407 |
+
freqs_cos, freqs_sin = precompute_freqs_cis(
|
408 |
+
self.params.dim // self.params.n_heads, self.params.max_seq_len
|
409 |
+
)
|
410 |
+
self.register_buffer("freqs_cos", freqs_cos, persistent=False)
|
411 |
+
self.register_buffer("freqs_sin", freqs_sin, persistent=False)
|
412 |
+
|
413 |
+
# init all weights
|
414 |
+
self.apply(self._init_weights)
|
415 |
+
# apply special scaled init to the residual projections, per GPT-2 paper
|
416 |
+
for pn, p in self.named_parameters():
|
417 |
+
if pn.endswith("w3.weight") or pn.endswith("wo.weight"):
|
418 |
+
torch.nn.init.normal_(
|
419 |
+
p, mean=0.0, std=0.02 / math.sqrt(2 * params.n_layers)
|
420 |
+
)
|
421 |
+
|
422 |
+
# Initialize attribute for the loss of the last forward call. This will be set if the forward is called with a targets tensor.
|
423 |
+
self.last_loss = None
|
424 |
+
|
425 |
+
def _init_weights(self, module):
|
426 |
+
if isinstance(module, nn.Linear):
|
427 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
428 |
+
if module.bias is not None:
|
429 |
+
torch.nn.init.zeros_(module.bias)
|
430 |
+
elif isinstance(module, nn.Embedding):
|
431 |
+
torch.nn.init.normal_(module.weight, mean=0.0, std=0.02)
|
432 |
+
|
433 |
+
def forward(
|
434 |
+
self,
|
435 |
+
tokens: torch.Tensor,
|
436 |
+
targets: Optional[torch.Tensor] = None,
|
437 |
+
context: Optional[Dict[str, torch.Tensor]] = None,
|
438 |
+
fragment: Optional[torch.Tensor] = None,
|
439 |
+
) -> torch.Tensor:
|
440 |
+
bsz, seqlen = tokens.shape
|
441 |
+
device = tokens.device
|
442 |
+
|
443 |
+
h = self._add_context_to_seq(tokens, context, fragment, bsz, device)
|
444 |
+
|
445 |
+
context_seq_len = h.shape[1] - seqlen
|
446 |
+
|
447 |
+
bsz, seqlen, _ = h.shape
|
448 |
+
|
449 |
+
freqs_cos = self.freqs_cos[:seqlen]
|
450 |
+
freqs_sin = self.freqs_sin[:seqlen]
|
451 |
+
|
452 |
+
for layer in self.layers:
|
453 |
+
h = layer(h, freqs_cos, freqs_sin)
|
454 |
+
h = self.norm(h)
|
455 |
+
|
456 |
+
h = h[:, context_seq_len:]
|
457 |
+
if targets is not None:
|
458 |
+
# if we are given some desired targets also calculate the loss
|
459 |
+
logits = self.output(h)
|
460 |
+
tmp_last_loss = F.cross_entropy(
|
461 |
+
logits.reshape(-1, logits.size(-1)),
|
462 |
+
targets.reshape(-1),
|
463 |
+
ignore_index=0, # Ignore Pad Tokens
|
464 |
+
)
|
465 |
+
|
466 |
+
# NOTE: This essentially does nothing for the computation,
|
467 |
+
# because we are multiplying the weights by zero.
|
468 |
+
# This *needs* to be done, so that we can train with DDP
|
469 |
+
# As due to the random training process some of the weights are not used in the forward pass
|
470 |
+
# That is unacceptable for the for the c10 backend and the training errors out.
|
471 |
+
# Maybe there is a better fix in the future, see:
|
472 |
+
# https://github.com/pytorch/pytorch/issues/43259
|
473 |
+
ddp_fix = sum(p.sum() for p in self.parameters())
|
474 |
+
zero_sum = ddp_fix * 0.0
|
475 |
+
|
476 |
+
self.last_loss = tmp_last_loss + zero_sum
|
477 |
+
else:
|
478 |
+
# inference-time mini-optimization: only forward the output on the very last position
|
479 |
+
logits = self.output(
|
480 |
+
h[:, [-1], :]
|
481 |
+
) # note: using list [-1] to preserve the time dim
|
482 |
+
self.last_loss = None
|
483 |
+
|
484 |
+
return logits
|
485 |
+
|
486 |
+
def forward_with_kvcache(
|
487 |
+
self,
|
488 |
+
tokens: torch.Tensor,
|
489 |
+
targets: Optional[torch.Tensor] = None,
|
490 |
+
context: Optional[Dict[str, torch.Tensor]] = None,
|
491 |
+
fragment: Optional[torch.Tensor] = None,
|
492 |
+
cache_id: int = 1,
|
493 |
+
pos_seq_len: Optional[int] = None,
|
494 |
+
) -> torch.Tensor:
|
495 |
+
bsz, seqlen = tokens.shape
|
496 |
+
device = tokens.device
|
497 |
+
|
498 |
+
h = self._add_context_to_seq(tokens, context, fragment, bsz, device)
|
499 |
+
|
500 |
+
context_seq_len = h.shape[1] - seqlen
|
501 |
+
|
502 |
+
bsz, seqlen, _ = h.shape
|
503 |
+
if pos_seq_len is None:
|
504 |
+
pos_seq_len = seqlen
|
505 |
+
else:
|
506 |
+
pos_seq_len = max(seqlen, pos_seq_len + context_seq_len)
|
507 |
+
|
508 |
+
freqs_cos = self.freqs_cos[:pos_seq_len]
|
509 |
+
freqs_sin = self.freqs_sin[:pos_seq_len]
|
510 |
+
|
511 |
+
for layer in self.layers:
|
512 |
+
h = layer.forward_with_kvcache(h, freqs_cos, freqs_sin, cache_id=cache_id)
|
513 |
+
h = self.norm(h)
|
514 |
+
|
515 |
+
h = h[:, context_seq_len:]
|
516 |
+
if targets is not None:
|
517 |
+
# if we are given some desired targets also calculate the loss
|
518 |
+
logits = self.output(h)
|
519 |
+
tmp_last_loss = F.cross_entropy(
|
520 |
+
logits.reshape(-1, logits.size(-1)),
|
521 |
+
targets.reshape(-1),
|
522 |
+
ignore_index=0, # Ignore Pad Tokens
|
523 |
+
)
|
524 |
+
|
525 |
+
# NOTE: This essentially does nothing for the computation,
|
526 |
+
# because we are multiplying the weights by zero.
|
527 |
+
# This *needs* to be done, so that we can train with DDP
|
528 |
+
# As due to the random training process some of the weights are not used in the forward pass
|
529 |
+
# That is unacceptable for the for the c10 backend and the training errors out.
|
530 |
+
# Maybe there is a better fix in the future, see:
|
531 |
+
# https://github.com/pytorch/pytorch/issues/43259
|
532 |
+
ddp_fix = sum(p.sum() for p in self.parameters())
|
533 |
+
zero_sum = ddp_fix * 0.0
|
534 |
+
|
535 |
+
self.last_loss = tmp_last_loss + zero_sum
|
536 |
+
else:
|
537 |
+
# inference-time mini-optimization: only forward the output on the very last position
|
538 |
+
logits = self.output(
|
539 |
+
h[:, [-1], :]
|
540 |
+
) # note: using list [-1] to preserve the time dim
|
541 |
+
self.last_loss = None
|
542 |
+
|
543 |
+
return logits
|
544 |
+
|
545 |
+
def _add_context_to_seq(self, tokens, context, fragment, bsz, device):
|
546 |
+
h = self.tok_embeddings(tokens)
|
547 |
+
h = self.dropout(h)
|
548 |
+
|
549 |
+
if fragment is not None:
|
550 |
+
fragment_type_enc = torch.zeros_like(
|
551 |
+
fragment, dtype=torch.long, device=device
|
552 |
+
)
|
553 |
+
|
554 |
+
h = torch.concat(
|
555 |
+
(
|
556 |
+
self.tok_embeddings(fragment)
|
557 |
+
+ self.frag_embeddings(fragment)
|
558 |
+
+ self.frag_type_embedding(fragment_type_enc),
|
559 |
+
h,
|
560 |
+
),
|
561 |
+
dim=1,
|
562 |
+
)
|
563 |
+
|
564 |
+
if context is not None and len(context) != 0:
|
565 |
+
# context is a dictionary with key : context_tensor of shape (batch_size, context_dim)
|
566 |
+
type_ids = []
|
567 |
+
context_vals = []
|
568 |
+
|
569 |
+
for emb_key, context_val in context.items():
|
570 |
+
emb_context_val = self.conditions_embeddings_lookup[emb_key](
|
571 |
+
context_val.unsqueeze(1).to(device)
|
572 |
+
).unsqueeze(1)
|
573 |
+
|
574 |
+
context_vals.append(emb_context_val)
|
575 |
+
type_ids_tensor = torch.tensor(
|
576 |
+
[self.context_lookup[emb_key]], device=device, dtype=torch.long
|
577 |
+
)
|
578 |
+
type_ids.append(type_ids_tensor)
|
579 |
+
|
580 |
+
context_types = (
|
581 |
+
torch.concat(type_ids, dim=0).reshape(-1, 1).expand(-1, bsz).T
|
582 |
+
)
|
583 |
+
# shape(len(context),batch_size, emb_size)
|
584 |
+
context_types = self.conditions_type_embeddings(context_types)
|
585 |
+
|
586 |
+
context_vals = torch.concat(context_vals, dim=1).to(device)
|
587 |
+
|
588 |
+
# SHAPE
|
589 |
+
h = torch.concat([context_vals + context_types, h], dim=1)
|
590 |
+
return h
|
591 |
+
|
592 |
+
def configure_optimizers(self, weight_decay, learning_rate, betas, device_type):
|
593 |
+
# start with all of the candidate parameters
|
594 |
+
param_dict = {pn: p for pn, p in self.named_parameters()}
|
595 |
+
# filter out those that do not require grad
|
596 |
+
param_dict = {pn: p for pn, p in param_dict.items() if p.requires_grad}
|
597 |
+
# create optim groups. Any parameters that is 2D will be weight decayed, otherwise no.
|
598 |
+
# i.e. all weight tensors in matmuls + embeddings decay, all biases and layernorms don't.
|
599 |
+
decay_params = [p for n, p in param_dict.items() if p.dim() >= 2]
|
600 |
+
nodecay_params = [p for n, p in param_dict.items() if p.dim() < 2]
|
601 |
+
optim_groups = [
|
602 |
+
{"params": decay_params, "weight_decay": weight_decay},
|
603 |
+
{"params": nodecay_params, "weight_decay": 0.0},
|
604 |
+
]
|
605 |
+
num_decay_params = sum(p.numel() for p in decay_params)
|
606 |
+
num_nodecay_params = sum(p.numel() for p in nodecay_params)
|
607 |
+
print(
|
608 |
+
f"num decayed parameter tensors: {len(decay_params)}, with {num_decay_params:,} parameters"
|
609 |
+
)
|
610 |
+
print(
|
611 |
+
f"num non-decayed parameter tensors: {len(nodecay_params)}, with {num_nodecay_params:,} parameters"
|
612 |
+
)
|
613 |
+
# Create AdamW optimizer and use the fused version if it is available
|
614 |
+
fused_available = "fused" in inspect.signature(torch.optim.AdamW).parameters
|
615 |
+
use_fused = fused_available and device_type == "cuda"
|
616 |
+
extra_args = dict(fused=True) if use_fused else dict()
|
617 |
+
optimizer = torch.optim.AdamW(
|
618 |
+
optim_groups, lr=learning_rate, betas=betas, **extra_args
|
619 |
+
)
|
620 |
+
print(f"using fused AdamW: {use_fused}")
|
621 |
+
|
622 |
+
return optimizer
|
623 |
+
|
624 |
+
def estimate_mfu(self, fwdbwd_per_iter, dt):
|
625 |
+
"""estimate model flops utilization (MFU) in units of A100 bfloat16 peak FLOPS"""
|
626 |
+
# first estimate the number of flops we do per iteration.
|
627 |
+
# see PaLM paper Appendix B as ref: https://arxiv.org/abs/2204.02311
|
628 |
+
N = sum(p.numel() for p in self.parameters())
|
629 |
+
cfg = self.params
|
630 |
+
L, H, Q, T = cfg.n_layers, cfg.n_heads, cfg.dim // cfg.n_heads, cfg.max_seq_len
|
631 |
+
flops_per_token = 6 * N + 12 * L * H * Q * T
|
632 |
+
flops_per_fwdbwd = flops_per_token * T
|
633 |
+
flops_per_iter = flops_per_fwdbwd * fwdbwd_per_iter
|
634 |
+
# express our flops throughput as ratio of A100 bfloat16 peak flops
|
635 |
+
flops_achieved = flops_per_iter * (1.0 / dt) # per second
|
636 |
+
flops_promised = 312e12 # A100 GPU bfloat16 peak flops is 312 TFLOPS
|
637 |
+
mfu = flops_achieved / flops_promised
|
638 |
+
return mfu
|
639 |
+
|
640 |
+
@torch.inference_mode()
|
641 |
+
def generate(
|
642 |
+
self,
|
643 |
+
tokenizer: SmilesTokenizer,
|
644 |
+
context: Union[torch.Tensor, None] = None,
|
645 |
+
fragments: Union[torch.Tensor, None] = None,
|
646 |
+
max_length: int = 50,
|
647 |
+
num_gen: int = 200,
|
648 |
+
start_smiles: Union[str, None] = None,
|
649 |
+
temperature: float = 1.0,
|
650 |
+
top_k: Union[int, None] = None,
|
651 |
+
device: torch.device = torch.device("cpu"),
|
652 |
+
cache_kv: bool = False,
|
653 |
+
) -> List[str]:
|
654 |
+
batch_size = num_gen
|
655 |
+
if start_smiles is not None:
|
656 |
+
tokenized_start_selfie = tokenizer.encode(start_smiles)[
|
657 |
+
:-1
|
658 |
+
] # remove <eos> token
|
659 |
+
tokenized_start_selfie = torch.tensor(
|
660 |
+
tokenized_start_selfie, device=device, dtype=torch.long
|
661 |
+
).view(-1, 1)
|
662 |
+
tokenized_start_selfie = tokenized_start_selfie.repeat(1, batch_size)
|
663 |
+
|
664 |
+
outputs = tokenized_start_selfie.T
|
665 |
+
else:
|
666 |
+
outputs = (
|
667 |
+
torch.LongTensor([[tokenizer.cls_token_id] * batch_size]).to(device)
|
668 |
+
).T # batch_size
|
669 |
+
self.eval()
|
670 |
+
|
671 |
+
start_len = outputs.shape[1]
|
672 |
+
has_end_idx = np.array([0] * batch_size)
|
673 |
+
cache_id = np.random.randint(0, int(1e10), 1).item()
|
674 |
+
with torch.no_grad():
|
675 |
+
with tqdm(total=max_length, desc="Generation") as pbar:
|
676 |
+
for i in range(start_len, max_length):
|
677 |
+
# trg_tensor = #torch.LongTensor(outputs).to(model.device)
|
678 |
+
if not cache_kv:
|
679 |
+
logits = self(outputs, context=context, fragment=fragments)
|
680 |
+
else:
|
681 |
+
# logits_ = self(outputs, context=context, fragment=fragments)
|
682 |
+
if i == start_len:
|
683 |
+
# When starting pass the whole input, so that "start_smiles" works, then only the newly generated token, because of the cache
|
684 |
+
func_input = outputs
|
685 |
+
else:
|
686 |
+
func_input = outputs[:, -1].unsqueeze(-1)
|
687 |
+
logits = self.forward_with_kvcache(
|
688 |
+
func_input,
|
689 |
+
context=context,
|
690 |
+
fragment=fragments,
|
691 |
+
cache_id=cache_id,
|
692 |
+
pos_seq_len=outputs.size(-1),
|
693 |
+
)
|
694 |
+
|
695 |
+
# raise NotImplementedError("Currently not working / right implemented")
|
696 |
+
# logits = self.forward_with_kvcache(outputs, context=context, fragment=fragments,cache_id = cache_id)
|
697 |
+
|
698 |
+
logits = logits[:, -1, :] # crop to just the final time step
|
699 |
+
if temperature == 0.0:
|
700 |
+
# "sample" the single most likely index
|
701 |
+
_, logits = torch.topk(logits, k=1, dim=-1)
|
702 |
+
else:
|
703 |
+
# pluck the logits at the final step and scale by desired temperature
|
704 |
+
logits = logits / temperature
|
705 |
+
# optionally crop the logits to only the top k options
|
706 |
+
if top_k is not None:
|
707 |
+
v, _ = torch.topk(logits, min(top_k, logits.size(-1)))
|
708 |
+
logits[logits < v[:, [-1]]] = -float("Inf")
|
709 |
+
|
710 |
+
probs = F.softmax(logits, dim=-1)
|
711 |
+
idx_next = torch.multinomial(probs, num_samples=1)
|
712 |
+
|
713 |
+
ended_sentences = idx_next == tokenizer.sep_token_id
|
714 |
+
if torch.count_nonzero(ended_sentences) != 0:
|
715 |
+
indicies = torch.nonzero(ended_sentences)
|
716 |
+
indicies = indicies.cpu().numpy()
|
717 |
+
for end_idx in indicies[:, 0]:
|
718 |
+
if has_end_idx[end_idx] == 0:
|
719 |
+
has_end_idx[end_idx] = i
|
720 |
+
|
721 |
+
# print(has_end_idx)
|
722 |
+
|
723 |
+
if all([idx != 0 for idx in has_end_idx]):
|
724 |
+
break
|
725 |
+
|
726 |
+
# outputs.append(best_guesses)
|
727 |
+
# outputs = torch.row_stack((outputs, idx_next))
|
728 |
+
outputs = torch.cat((outputs, idx_next), dim=1)
|
729 |
+
pbar.update(1)
|
730 |
+
|
731 |
+
out_selfies = []
|
732 |
+
for output, end_idx in zip(outputs.cpu().numpy(), has_end_idx):
|
733 |
+
# Incase of limiting the max_len
|
734 |
+
if end_idx == 0:
|
735 |
+
selfie = [tokenizer._convert_id_to_token(idx) for idx in output[:]]
|
736 |
+
else:
|
737 |
+
selfie = [
|
738 |
+
tokenizer._convert_id_to_token(idx) for idx in output[:end_idx]
|
739 |
+
]
|
740 |
+
selfie = "".join(selfie[1:])
|
741 |
+
out_selfies.append(selfie)
|
742 |
+
|
743 |
+
# for indicies in outputs:
|
744 |
+
# translated_sentence = [tokenizer.idx_to_tokens[idx] for idx in outputs]
|
745 |
+
# remove start token
|
746 |
+
return out_selfies
|
747 |
+
|
748 |
+
@staticmethod
|
749 |
+
def load(path, device: torch.device = torch.device("cpu")) -> Transformer:
|
750 |
+
data = torch.load(path, map_location=device)
|
751 |
+
|
752 |
+
newinstace = Transformer(data["model_params"], data["context_params"])
|
753 |
+
newinstace.load_state_dict(data["state_dict"])
|
754 |
+
return newinstace.to(device)
|
755 |
+
|
756 |
+
def save(self, filepath):
|
757 |
+
torch.save(
|
758 |
+
{
|
759 |
+
"state_dict": self.state_dict(),
|
760 |
+
**dict(model_params=self.params, context_params=self.context_params),
|
761 |
+
},
|
762 |
+
filepath,
|
763 |
+
)
|
764 |
+
|
765 |
+
def getNumberTrainableParams(self) -> int:
|
766 |
+
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|
767 |
+
|
768 |
+
def getNumberParams(self) -> int:
|
769 |
+
return sum(p.numel() for p in self.parameters())
|
770 |
+
|
771 |
+
|
772 |
+
if __name__ == "__main__":
|
773 |
+
m = Transformer(
|
774 |
+
ModelArgs(dim=128, n_layers=8, n_heads=8, vocab_size=512, max_seq_len=1024),
|
775 |
+
context_params=ContextArgs(
|
776 |
+
context_keys=["logp", "sascore", "mol_weight"], context_dims=[1, 1, 1]
|
777 |
+
),
|
778 |
+
)
|
779 |
+
seq = torch.ones((128, 50), dtype=torch.long)
|
780 |
+
frag = torch.ones((128, 10), dtype=torch.long)
|
781 |
+
context = {
|
782 |
+
"logp": torch.ones((128,), dtype=torch.float32),
|
783 |
+
# "sascore": torch.ones((128,), dtype=torch.float32),
|
784 |
+
"mol_weight": torch.ones((128,), dtype=torch.float32),
|
785 |
+
}
|
786 |
+
|
787 |
+
print(m.forward(seq, targets=seq, context=context, fragment=frag))
|
out/llama2-M-Full-RSS.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:83571f8f8936a4eac8ac4541282ff99a3e942c07ee4aaef82abdc2f52e1731ae
|
3 |
+
size 58587134
|
plot_utils.py
ADDED
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import Dict, List, Union
|
2 |
+
import matplotlib.pyplot as plt
|
3 |
+
import seaborn as sns
|
4 |
+
import os
|
5 |
+
import numpy as np
|
6 |
+
import pandas as pd
|
7 |
+
|
8 |
+
from rdkit.Chem import AllChem, Descriptors, RDConfig
|
9 |
+
|
10 |
+
import sys
|
11 |
+
|
12 |
+
sys.path.append(os.path.join(RDConfig.RDContribDir, "SA_Score"))
|
13 |
+
# now you can import sascore!
|
14 |
+
import sascorer
|
15 |
+
from rdkit import Chem
|
16 |
+
import logging
|
17 |
+
|
18 |
+
logger = logging.getLogger(__name__)
|
19 |
+
# plt.rcParams.update({'font.size': 13.1})
|
20 |
+
plt.rcParams.update({"font.size": 12.5})
|
21 |
+
|
22 |
+
COL_TO_DISPLAY_NAME = {
|
23 |
+
"logp": "LogP",
|
24 |
+
"sascore": "SAScore",
|
25 |
+
"mol_weight": "Molecular Weight",
|
26 |
+
}
|
27 |
+
|
28 |
+
|
29 |
+
def calcContextSAScore(smiles: List[str]):
|
30 |
+
sasc = []
|
31 |
+
for smi in smiles:
|
32 |
+
mol = Chem.MolFromSmiles(smi)
|
33 |
+
sa = sascorer.calculateScore(mol)
|
34 |
+
sasc.append(sa)
|
35 |
+
|
36 |
+
return np.array(sasc)
|
37 |
+
|
38 |
+
|
39 |
+
def calcContextLogP(smiles: List[str]):
|
40 |
+
logps = []
|
41 |
+
for smi in smiles:
|
42 |
+
mol = Chem.MolFromSmiles(smi)
|
43 |
+
logp = Descriptors.MolLogP(mol)
|
44 |
+
logps.append(logp)
|
45 |
+
|
46 |
+
return np.array(logps)
|
47 |
+
|
48 |
+
|
49 |
+
def calcContextEnergy(smiles, num_confs=5):
|
50 |
+
contexts = []
|
51 |
+
for smi in smiles:
|
52 |
+
# print("Calculating Energy:",smi)
|
53 |
+
mol = Chem.AddHs(Chem.MolFromSmiles(smi))
|
54 |
+
AllChem.EmbedMultipleConfs(mol, num_confs, numThreads=48)
|
55 |
+
generated_smiles = AllChem.MMFFOptimizeMoleculeConfs(mol, numThreads=48)
|
56 |
+
energies = []
|
57 |
+
for coverged, energy in generated_smiles:
|
58 |
+
if coverged != 0:
|
59 |
+
print("Not converged!", smi)
|
60 |
+
energies.append(energy)
|
61 |
+
|
62 |
+
# print(energy)
|
63 |
+
# kcal/mol
|
64 |
+
mean_en = np.mean(energies)
|
65 |
+
# to hartree
|
66 |
+
mean_en = mean_en * 0.0016
|
67 |
+
contexts.append(mean_en)
|
68 |
+
|
69 |
+
return np.array(contexts)
|
70 |
+
|
71 |
+
|
72 |
+
def calcContextMolWeight(smiles: List[str]):
|
73 |
+
con = []
|
74 |
+
for _, smi in enumerate(smiles):
|
75 |
+
mol = Chem.MolFromSmiles(smi)
|
76 |
+
c = Descriptors.ExactMolWt(mol) / 100
|
77 |
+
con.append(c)
|
78 |
+
|
79 |
+
return np.array(con)
|
80 |
+
|
81 |
+
|
82 |
+
def plot_1D_condition(
|
83 |
+
context_col,
|
84 |
+
save_path,
|
85 |
+
new_context,
|
86 |
+
generated_smiles,
|
87 |
+
temperature,
|
88 |
+
context_dict,
|
89 |
+
context_scaler=None,
|
90 |
+
):
|
91 |
+
for con_col in context_col:
|
92 |
+
save_path = os.path.join(
|
93 |
+
save_path, f"{con_col}_{'-'.join(context_col)}_temp{temperature}"
|
94 |
+
)
|
95 |
+
os.makedirs(save_path, exist_ok=True)
|
96 |
+
|
97 |
+
current_context = new_context[con_col].cpu().detach().numpy()
|
98 |
+
if con_col == "mol_weight":
|
99 |
+
predicted_context = calcContextMolWeight(generated_smiles)
|
100 |
+
elif con_col == "logp":
|
101 |
+
predicted_context = calcContextLogP(generated_smiles)
|
102 |
+
elif con_col == "sascore":
|
103 |
+
predicted_context = calcContextSAScore(generated_smiles)
|
104 |
+
elif con_col == "energy":
|
105 |
+
# TODO: Change to something better
|
106 |
+
predicted_context = calcContextEnergy(generated_smiles)
|
107 |
+
|
108 |
+
if context_scaler is not None:
|
109 |
+
raise NotImplementedError("Not implemented yet")
|
110 |
+
# context_list = context_scaler.inverse_transform(context_list)
|
111 |
+
|
112 |
+
mean_vals_pred = []
|
113 |
+
labels = np.unique(current_context)
|
114 |
+
mse_value = []
|
115 |
+
mad_value = []
|
116 |
+
for label in labels:
|
117 |
+
mask = (current_context == label).reshape(-1)
|
118 |
+
mean_val = np.mean(predicted_context[mask])
|
119 |
+
mean_vals_pred.append(mean_val)
|
120 |
+
mse_value.extend((predicted_context[mask] - label) ** 2)
|
121 |
+
mad_value.extend(abs(predicted_context[mask] - label))
|
122 |
+
|
123 |
+
mse = np.mean(mse_value)
|
124 |
+
mad = np.mean(mad_value)
|
125 |
+
logger.info(f"MSE {mse}")
|
126 |
+
logger.info(f"MAD {mad}")
|
127 |
+
logger.info(f"SD: {np.std(mad_value)}")
|
128 |
+
|
129 |
+
current_context = current_context.reshape(-1)
|
130 |
+
|
131 |
+
# Create a figure and axes
|
132 |
+
fig, ax1 = plt.subplots()
|
133 |
+
|
134 |
+
# Scatter plot
|
135 |
+
ax1.scatter(
|
136 |
+
current_context,
|
137 |
+
predicted_context,
|
138 |
+
label="Ground Truth vs Prediction",
|
139 |
+
c="blue",
|
140 |
+
alpha=0.5,
|
141 |
+
)
|
142 |
+
ax1.plot(
|
143 |
+
np.arange(np.min(current_context), np.max(current_context) + 1),
|
144 |
+
np.arange(np.min(current_context), np.max(current_context) + 1),
|
145 |
+
label="y=x",
|
146 |
+
c="black",
|
147 |
+
)
|
148 |
+
ax1.scatter(labels, mean_vals_pred, label="Mean predicted values", c="red")
|
149 |
+
ax1.set_xlabel("Ground Truth")
|
150 |
+
ax1.set_ylabel("Prediction")
|
151 |
+
|
152 |
+
# Histogram
|
153 |
+
ax2 = ax1.twinx() # Create a twin Axes sharing the x-axis
|
154 |
+
sns.histplot(
|
155 |
+
context_dict[con_col],
|
156 |
+
# bins=200,
|
157 |
+
label="Dataset distribution",
|
158 |
+
alpha=0.5,
|
159 |
+
# kde=True,
|
160 |
+
# element="poly",
|
161 |
+
ax=ax2,
|
162 |
+
)
|
163 |
+
# ax2.hist(
|
164 |
+
# context_dict[con_col],
|
165 |
+
# bins=200,
|
166 |
+
# label="Dataset distribution",
|
167 |
+
# alpha=0.5,
|
168 |
+
# )
|
169 |
+
ax2.set_ylabel("Frequency")
|
170 |
+
|
171 |
+
# Combine legends
|
172 |
+
handles1, labels1 = ax1.get_legend_handles_labels()
|
173 |
+
handles2, labels2 = ax2.get_legend_handles_labels()
|
174 |
+
|
175 |
+
ax1.legend(handles1 + handles2, labels1 + labels2)
|
176 |
+
|
177 |
+
plt.xlim((np.min(current_context), np.max(current_context) + 1))
|
178 |
+
# Set title
|
179 |
+
display_name = COL_TO_DISPLAY_NAME[con_col]
|
180 |
+
plt.title(f"{display_name} - temperature: {temperature} - mse: {round(mse, 4)}")
|
181 |
+
|
182 |
+
out_df = pd.DataFrame(
|
183 |
+
{
|
184 |
+
"smiles": generated_smiles,
|
185 |
+
f"{con_col}": predicted_context.tolist(),
|
186 |
+
f"target_{con_col}": current_context.tolist(),
|
187 |
+
}
|
188 |
+
)
|
189 |
+
out_df.to_csv(os.path.join(save_path, "predictions.csv"), index=False)
|
190 |
+
out_path = os.path.join(save_path, "graph.png")
|
191 |
+
print(f"Saved to {out_path}")
|
192 |
+
plt.savefig(out_path)
|
193 |
+
plt.clf()
|
194 |
+
|
195 |
+
|
196 |
+
def plot_2D_condition(
|
197 |
+
context_col,
|
198 |
+
save_path,
|
199 |
+
new_context,
|
200 |
+
generated_smiles,
|
201 |
+
temperature,
|
202 |
+
label: Union[str, None] = None,
|
203 |
+
):
|
204 |
+
save_path = os.path.join(
|
205 |
+
save_path, f"multicond2_{'-'.join(context_col)}_temp={temperature}"
|
206 |
+
)
|
207 |
+
if label is not None:
|
208 |
+
save_path = os.path.join(save_path, label)
|
209 |
+
|
210 |
+
os.makedirs(save_path, exist_ok=True)
|
211 |
+
delta_dict = {c: [] for c in context_col}
|
212 |
+
predicted_context_dict = {}
|
213 |
+
for con_col in context_col:
|
214 |
+
current_context = new_context[con_col].cpu().numpy()
|
215 |
+
if con_col == "mol_weight":
|
216 |
+
predicted_context = calcContextMolWeight(generated_smiles)
|
217 |
+
elif con_col == "logp":
|
218 |
+
predicted_context = calcContextLogP(generated_smiles)
|
219 |
+
elif con_col == "sascore":
|
220 |
+
predicted_context = calcContextSAScore(generated_smiles)
|
221 |
+
elif con_col == "energy":
|
222 |
+
# TODO: Change to something better
|
223 |
+
predicted_context = calcContextEnergy(generated_smiles)
|
224 |
+
|
225 |
+
predicted_context_dict[con_col] = np.array(predicted_context)
|
226 |
+
delta_dict[con_col] = np.abs(current_context - np.array(predicted_context))
|
227 |
+
|
228 |
+
# Create a DataFrame from delta_dict
|
229 |
+
df = pd.DataFrame(delta_dict)
|
230 |
+
real_values_prop1 = new_context[context_col[0]].cpu().numpy()
|
231 |
+
real_values_prop2 = new_context[context_col[1]].cpu().numpy()
|
232 |
+
# cmap = plt.get_cmap('Blues') # Choose a green color palette from Matplotlib
|
233 |
+
mse_vals_x = []
|
234 |
+
mad_vals_x = []
|
235 |
+
mse_vals_y = []
|
236 |
+
mad_vals_y = []
|
237 |
+
fig = plt.figure()
|
238 |
+
ax = plt.subplot(111)
|
239 |
+
for v1 in np.unique(real_values_prop1):
|
240 |
+
for v2 in np.unique(real_values_prop2):
|
241 |
+
mask = (real_values_prop1 == v1) & (real_values_prop2 == v2)
|
242 |
+
indices = np.nonzero(mask)[0]
|
243 |
+
# print("Indices", len(indices))
|
244 |
+
# Get the color from the color palette based on the v1 value
|
245 |
+
# color = cmap((v1 - np.min(real_values_prop1)) / (np.max(real_values_prop1) - np.min(real_values_prop1)))
|
246 |
+
color = np.random.rand(
|
247 |
+
3,
|
248 |
+
)
|
249 |
+
# # Plot scatter plot with the specified color and label
|
250 |
+
|
251 |
+
x_pred = predicted_context_dict[context_col[0]][indices].ravel()
|
252 |
+
y_pred = predicted_context_dict[context_col[1]][indices].ravel()
|
253 |
+
mse_vals_x.extend((x_pred - v1) ** 2)
|
254 |
+
mad_vals_x.extend(np.abs(x_pred - v1))
|
255 |
+
|
256 |
+
mse_vals_y.extend((y_pred - v2) ** 2)
|
257 |
+
mad_vals_y.extend(np.abs(y_pred - v2))
|
258 |
+
|
259 |
+
ax.scatter(x_pred, y_pred, color=color, alpha=0.5)
|
260 |
+
|
261 |
+
# Plot KDE plot with the specified color
|
262 |
+
# sns.kdeplot(
|
263 |
+
# data=pd.DataFrame(
|
264 |
+
# {
|
265 |
+
# f"x": x_pred,
|
266 |
+
# f"y": y_pred,
|
267 |
+
# }
|
268 |
+
# ),
|
269 |
+
# x=f"x",
|
270 |
+
# y=f"y",
|
271 |
+
# color=color,
|
272 |
+
# fill=False,
|
273 |
+
# bw_adjust=2.25,
|
274 |
+
# # label=f"({v1}, {v2})"
|
275 |
+
# )
|
276 |
+
|
277 |
+
ax.scatter(v1, v2, color=color, label=f"({v1}, {v2})", marker="^", s=20.0)
|
278 |
+
|
279 |
+
mse_x = np.mean(mse_vals_x)
|
280 |
+
mad_x = np.mean(mad_vals_x)
|
281 |
+
mse_y = np.mean(mse_vals_y)
|
282 |
+
mad_y = np.mean(mad_vals_y)
|
283 |
+
|
284 |
+
logger.info(f"MSE {context_col[0]}: {mse_x}")
|
285 |
+
logger.info(f"MAD {context_col[0]}: {mad_x}")
|
286 |
+
logger.info(f"MSE {context_col[1]}: {mse_y}")
|
287 |
+
logger.info(f"MAD {context_col[1]}: {mad_y}")
|
288 |
+
|
289 |
+
file_path = os.path.join(save_path, "metrics.txt")
|
290 |
+
|
291 |
+
with open(file_path, "w") as f:
|
292 |
+
f.write(f"MSE {context_col[0]}: {mse_x} \n")
|
293 |
+
f.write(f"MAD {context_col[0]}: {mad_x} \n")
|
294 |
+
f.write(f"MSE {context_col[1]}: {mse_y} \n")
|
295 |
+
f.write(f"MAD {context_col[1]}: {mad_y} \n")
|
296 |
+
|
297 |
+
ax.set_xlabel(COL_TO_DISPLAY_NAME[context_col[0]])
|
298 |
+
ax.set_ylabel(COL_TO_DISPLAY_NAME[context_col[1]])
|
299 |
+
box = ax.get_position()
|
300 |
+
ax.set_position([box.x0, box.y0, box.width * 0.8, box.height])
|
301 |
+
|
302 |
+
# Put a legend to the right of the current axis
|
303 |
+
ax.legend(loc="center left", bbox_to_anchor=(1, 0.5))
|
304 |
+
ax.set_title("Multi Property Distribution of Generated Molecules")
|
305 |
+
out_path = os.path.join(save_path, "graph.png")
|
306 |
+
logger.info(f"Saved to {out_path}")
|
307 |
+
plt.savefig(out_path)
|
308 |
+
plt.clf()
|
309 |
+
return save_path
|
310 |
+
|
311 |
+
|
312 |
+
def plot_3D_condition(
|
313 |
+
context_col, save_path, new_context, generated_smiles, temperature
|
314 |
+
):
|
315 |
+
save_path = os.path.join(
|
316 |
+
save_path, f"multicond3_{'-'.join(context_col)}_temp={temperature}"
|
317 |
+
)
|
318 |
+
os.makedirs(save_path, exist_ok=True)
|
319 |
+
predicted_context_dict = {}
|
320 |
+
for con_col in context_col:
|
321 |
+
predicted_context = calc_context_from_smiles(generated_smiles, con_col)
|
322 |
+
|
323 |
+
predicted_context_dict[con_col] = np.array(predicted_context)
|
324 |
+
|
325 |
+
real_values_prop1 = new_context[context_col[0]].cpu().numpy()
|
326 |
+
real_values_prop2 = new_context[context_col[1]].cpu().numpy()
|
327 |
+
real_values_prop3 = new_context[context_col[2]].cpu().numpy()
|
328 |
+
# cmap = plt.get_cmap('Blues') # Choose a green color palette from Matplotlib
|
329 |
+
|
330 |
+
mse_vals_x = []
|
331 |
+
mad_vals_x = []
|
332 |
+
mse_vals_y = []
|
333 |
+
mad_vals_y = []
|
334 |
+
mse_vals_z = []
|
335 |
+
mad_vals_z = []
|
336 |
+
|
337 |
+
fig = plt.figure()
|
338 |
+
ax = fig.add_subplot(projection="3d")
|
339 |
+
for v1 in np.unique(real_values_prop1):
|
340 |
+
for v2 in np.unique(real_values_prop2):
|
341 |
+
for v3 in np.unique(real_values_prop3):
|
342 |
+
mask = (
|
343 |
+
(real_values_prop1 == v1)
|
344 |
+
& (real_values_prop2 == v2)
|
345 |
+
& (real_values_prop3 == v3)
|
346 |
+
)
|
347 |
+
indices = np.nonzero(mask)[0]
|
348 |
+
# print("Indices", len(indices))
|
349 |
+
# Get the color from the color palette based on the v1 value
|
350 |
+
# color = cmap((v1 - np.min(real_values_prop1)) / (np.max(real_values_prop1) - np.min(real_values_prop1)))
|
351 |
+
color = np.random.rand(
|
352 |
+
3,
|
353 |
+
)
|
354 |
+
|
355 |
+
x_pred = predicted_context_dict[context_col[0]][indices].ravel()
|
356 |
+
y_pred = predicted_context_dict[context_col[1]][indices].ravel()
|
357 |
+
z_pred = predicted_context_dict[context_col[2]][indices].ravel()
|
358 |
+
|
359 |
+
mse_vals_x.extend((x_pred - v1) ** 2)
|
360 |
+
mad_vals_x.extend(np.abs(x_pred - v1))
|
361 |
+
|
362 |
+
mse_vals_y.extend((y_pred - v2) ** 2)
|
363 |
+
mad_vals_y.extend(np.abs(y_pred - v2))
|
364 |
+
|
365 |
+
mse_vals_z.extend((z_pred - v3) ** 2)
|
366 |
+
mad_vals_z.extend(np.abs(z_pred - v3))
|
367 |
+
|
368 |
+
# # Plot scatter plot with the specified color and label
|
369 |
+
ax.scatter(v1, v2, v3, color=color, label=f"({v1}, {v2}, {v3})", s=20.0)
|
370 |
+
ax.scatter(
|
371 |
+
x_pred,
|
372 |
+
y_pred,
|
373 |
+
z_pred,
|
374 |
+
color=color,
|
375 |
+
)
|
376 |
+
|
377 |
+
mse_x = np.mean(mse_vals_x)
|
378 |
+
mad_x = np.mean(mad_vals_x)
|
379 |
+
mse_y = np.mean(mse_vals_y)
|
380 |
+
mad_y = np.mean(mad_vals_y)
|
381 |
+
mse_z = np.mean(mse_vals_z)
|
382 |
+
mad_z = np.mean(mad_vals_z)
|
383 |
+
|
384 |
+
logger.info(f"MSE {context_col[0]}: {mse_x}")
|
385 |
+
logger.info(f"MAD {context_col[0]}: {mad_x}")
|
386 |
+
logger.info(f"MSE {context_col[1]}: {mse_y}")
|
387 |
+
logger.info(f"MAD {context_col[1]}: {mad_y}")
|
388 |
+
logger.info(f"MSE {context_col[2]}: {mse_z}")
|
389 |
+
logger.info(f"MAD {context_col[2]}: {mad_z}")
|
390 |
+
|
391 |
+
file_path = os.path.join(save_path, "metrics.txt")
|
392 |
+
|
393 |
+
with open(file_path, "w") as f:
|
394 |
+
f.write(f"MSE {context_col[0]}: {mse_x} \n")
|
395 |
+
f.write(f"MAD {context_col[0]}: {mad_x} \n")
|
396 |
+
|
397 |
+
f.write(f"MSE {context_col[1]}: {mse_y} \n")
|
398 |
+
f.write(f"MAD {context_col[1]}: {mad_y} \n")
|
399 |
+
|
400 |
+
f.write(f"MSE {context_col[2]}: {mse_z} \n")
|
401 |
+
f.write(f"MAD {context_col[2]}: {mad_z} \n")
|
402 |
+
|
403 |
+
ax.set_xlabel(COL_TO_DISPLAY_NAME[context_col[0]])
|
404 |
+
ax.set_ylabel(COL_TO_DISPLAY_NAME[context_col[1]])
|
405 |
+
ax.set_zlabel(COL_TO_DISPLAY_NAME[context_col[2]])
|
406 |
+
# plt.legend(
|
407 |
+
# bbox_to_anchor=(1.0, 0.5),
|
408 |
+
# loc="center right",
|
409 |
+
# bbox_transform=plt.gcf().transFigure,
|
410 |
+
# )
|
411 |
+
# plt.subplots_adjust(left=0.05, bottom=0.1, right=0.8)
|
412 |
+
plt.legend(
|
413 |
+
bbox_to_anchor=(1.035, 0.5),
|
414 |
+
loc="center right",
|
415 |
+
bbox_transform=plt.gcf().transFigure,
|
416 |
+
)
|
417 |
+
plt.subplots_adjust(left=0.05, bottom=0.1, right=0.775)
|
418 |
+
|
419 |
+
plt.title("Multi Property Distribution of Generated Molecules")
|
420 |
+
out_path = os.path.join(save_path, "graph.png")
|
421 |
+
print(f"Saved to {out_path}")
|
422 |
+
plt.savefig(out_path)
|
423 |
+
plt.clf()
|
424 |
+
|
425 |
+
return save_path
|
426 |
+
|
427 |
+
|
428 |
+
def calc_context_from_smiles(generated_smiles, con_col):
|
429 |
+
if con_col == "mol_weight":
|
430 |
+
predicted_context = calcContextMolWeight(generated_smiles)
|
431 |
+
elif con_col == "logp":
|
432 |
+
predicted_context = calcContextLogP(generated_smiles)
|
433 |
+
elif con_col == "sascore":
|
434 |
+
predicted_context = calcContextSAScore(generated_smiles)
|
435 |
+
elif con_col == "energy":
|
436 |
+
# TODO: Change to something better
|
437 |
+
predicted_context = calcContextEnergy(generated_smiles)
|
438 |
+
return predicted_context
|
439 |
+
|
440 |
+
|
441 |
+
def plot_unconditional(
|
442 |
+
out_path: str = os.getcwd(),
|
443 |
+
smiles: List[str] = [],
|
444 |
+
temperature: float = 0.8,
|
445 |
+
cmp_context_dict: Union[Dict[str, np.array], None] = None,
|
446 |
+
context_cols: List[str] = ["logp", "sascore", "mol_weight"],
|
447 |
+
):
|
448 |
+
out_path = os.path.join(out_path, "unconditional")
|
449 |
+
os.makedirs(out_path, exist_ok=True)
|
450 |
+
|
451 |
+
for c in context_cols:
|
452 |
+
plt.clf()
|
453 |
+
|
454 |
+
context_cal = calc_context_from_smiles(smiles, c)
|
455 |
+
|
456 |
+
if cmp_context_dict is not None:
|
457 |
+
sns.histplot(
|
458 |
+
cmp_context_dict[c],
|
459 |
+
stat="density",
|
460 |
+
label="Dataset Distribution",
|
461 |
+
alpha=0.75,
|
462 |
+
color="blue",
|
463 |
+
)
|
464 |
+
sns.histplot(
|
465 |
+
context_cal,
|
466 |
+
stat="density",
|
467 |
+
label="Generated Molecules Distribution",
|
468 |
+
alpha=0.5,
|
469 |
+
color="orange",
|
470 |
+
)
|
471 |
+
|
472 |
+
if c == "logp":
|
473 |
+
plt.xlim((-6, 8))
|
474 |
+
else:
|
475 |
+
plt.xlim((0, 10))
|
476 |
+
|
477 |
+
plt.xlabel(COL_TO_DISPLAY_NAME[c])
|
478 |
+
plt.title(
|
479 |
+
f"Unconditional Distribution {COL_TO_DISPLAY_NAME[c]} \nwith Temperature {temperature}"
|
480 |
+
)
|
481 |
+
plt.legend()
|
482 |
+
|
483 |
+
out_file = os.path.join(out_path, f"unc_{c}_temp={temperature}.png")
|
484 |
+
plt.savefig(out_file)
|
485 |
+
logger.info(f"Saved Unconditional to {out_file}")
|
486 |
+
|
487 |
+
|
488 |
+
def novelty(gen, train):
|
489 |
+
gen_smiles_set = set(gen) - {None}
|
490 |
+
train_set = set(train)
|
491 |
+
return len(gen_smiles_set - train_set) / len(gen_smiles_set)
|
492 |
+
|
493 |
+
|
494 |
+
def unique_at(gen, k=1000):
|
495 |
+
gen = gen[:k]
|
496 |
+
|
497 |
+
return len(set(gen)) / len(gen)
|
498 |
+
|
499 |
+
|
500 |
+
def check_metrics(generated_smiles: List[str], dataset_smiles: List[str]):
|
501 |
+
len_before = len(generated_smiles)
|
502 |
+
generated_smiles = [g for g in generated_smiles if g is not None]
|
503 |
+
len_after = len(generated_smiles)
|
504 |
+
|
505 |
+
novel = novelty(generated_smiles, dataset_smiles)
|
506 |
+
unique_at_1k = unique_at(generated_smiles, k=1000)
|
507 |
+
unique_at_10k = unique_at(generated_smiles, k=10000)
|
508 |
+
return dict(
|
509 |
+
novelty=novel,
|
510 |
+
unique_at_1k=unique_at_1k,
|
511 |
+
unique_at_10k=unique_at_10k,
|
512 |
+
validity=len_after / float(len_before),
|
513 |
+
)
|
preprocess_dataset.py
ADDED
@@ -0,0 +1,370 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import json
|
3 |
+
import os
|
4 |
+
import pickle
|
5 |
+
import random
|
6 |
+
from functools import partial
|
7 |
+
|
8 |
+
import pandas as pd
|
9 |
+
import numpy as np
|
10 |
+
import requests
|
11 |
+
import torch
|
12 |
+
import torch.distributed as dist
|
13 |
+
from tqdm import tqdm
|
14 |
+
import multiprocessing
|
15 |
+
from multiprocessing import Pool
|
16 |
+
from fragment_creator import BaseFragmentCreator, BricksFragmentCreator, Fragment
|
17 |
+
from tokenizer import SmilesTokenizer
|
18 |
+
from torch.utils.data.distributed import DistributedSampler
|
19 |
+
from rdkit import Chem
|
20 |
+
from rdkit.Chem.Scaffolds.MurckoScaffold import MurckoScaffoldSmiles
|
21 |
+
from tqdm.contrib.concurrent import process_map, thread_map
|
22 |
+
from typing import List
|
23 |
+
import swifter
|
24 |
+
|
25 |
+
DATA_CACHE_DIR = "data"
|
26 |
+
|
27 |
+
|
28 |
+
def _tokenize_smiles(
|
29 |
+
smi: List[str],
|
30 |
+
tokenizer: SmilesTokenizer = None,
|
31 |
+
max_smiles_len=256,
|
32 |
+
log_output=True,
|
33 |
+
):
|
34 |
+
# try:
|
35 |
+
tokens = tokenizer.encode(smi)
|
36 |
+
if len(tokens) > max_smiles_len:
|
37 |
+
if log_output:
|
38 |
+
print(f"Removing to long {smi} with smiles len of {len(tokens)} ")
|
39 |
+
return None
|
40 |
+
|
41 |
+
return tokens
|
42 |
+
|
43 |
+
# except Exception as e:
|
44 |
+
# print(e)
|
45 |
+
# return None
|
46 |
+
|
47 |
+
|
48 |
+
def _tokenize_scaffolds(smi: str, tokenizer=None, max_smiles_len=256, log_output=True):
|
49 |
+
# try:
|
50 |
+
|
51 |
+
smi = MurckoScaffoldSmiles(smi)
|
52 |
+
tokens = tokenizer.encode(smi)
|
53 |
+
tokens = tokens[1:-1] # remove [SEP] and [CLS] tokens
|
54 |
+
if len(tokens) > max_smiles_len:
|
55 |
+
if log_output:
|
56 |
+
print(f"Removing to long {smi} with smiles len of {len(tokens)} ")
|
57 |
+
return None
|
58 |
+
|
59 |
+
return tokens
|
60 |
+
|
61 |
+
# except Exception as e:
|
62 |
+
# print(e)
|
63 |
+
# return None
|
64 |
+
|
65 |
+
|
66 |
+
def pad_batch(src, pad_idx):
|
67 |
+
max_len = max([len(d) for d in src])
|
68 |
+
# src = [d["src_input_ids"] for d in data]
|
69 |
+
padded_src = np.ones([len(src), max_len]) * pad_idx
|
70 |
+
|
71 |
+
for i, j in enumerate(src):
|
72 |
+
padded_src[i][0 : len(j)] = j
|
73 |
+
|
74 |
+
# try to predict the next token from the previouse tokens
|
75 |
+
# essentially reconstructing the src sentence from the embeddings and the previous sentence
|
76 |
+
padded_src = padded_src.T
|
77 |
+
return padded_src
|
78 |
+
|
79 |
+
|
80 |
+
def pretokenize(
|
81 |
+
data_file=os.path.join(
|
82 |
+
DATA_CACHE_DIR, "FULL_combined_zinc_pubchemqc_qm9_pc9_reddb_chembl.parquet"
|
83 |
+
),
|
84 |
+
tokenizer=SmilesTokenizer(),
|
85 |
+
limit=None,
|
86 |
+
context=["logp", "sascore", "mol_weight"],
|
87 |
+
out_name: str = "processed_dataset",
|
88 |
+
remove_nan_context_rows: bool = False,
|
89 |
+
):
|
90 |
+
df = pd.read_parquet(data_file)
|
91 |
+
|
92 |
+
if limit is not None:
|
93 |
+
# smiles_list = df.smiles[:limit]
|
94 |
+
df = df.sample(n=limit) # df[:limit]
|
95 |
+
# NOTE: Set here if necessary, but for memory efficiency not duplicating millions of smiles
|
96 |
+
# smiles_list = df.smiles
|
97 |
+
else:
|
98 |
+
# shuffle the rows
|
99 |
+
df = df.sample(frac=1.0)
|
100 |
+
|
101 |
+
cpu_count = (
|
102 |
+
multiprocessing.cpu_count()
|
103 |
+
) # min(int(multiprocessing.cpu_count() * 0.8), 8)
|
104 |
+
print(f"Running on {cpu_count} CPUs ")
|
105 |
+
|
106 |
+
tqdm.pandas()
|
107 |
+
|
108 |
+
df["scaffolds"] = df["smiles"].progress_map(lambda s: None if "." in s else s)
|
109 |
+
df["smiles"] = df["scaffolds"].copy()
|
110 |
+
orig_len = len(df)
|
111 |
+
if context is not None:
|
112 |
+
if df.get("origin") is not None:
|
113 |
+
origins = df["origin"].unique()
|
114 |
+
origin_dics = {}
|
115 |
+
for i, o in enumerate(origins):
|
116 |
+
df.loc[df["origin"] == o, "origin"] = i
|
117 |
+
origin_dics[o] = i
|
118 |
+
df["origin"] = df["origin"].astype(float)
|
119 |
+
with open(
|
120 |
+
os.path.join(
|
121 |
+
DATA_CACHE_DIR, os.path.basename(data_file) + "_origins.json"
|
122 |
+
),
|
123 |
+
"w",
|
124 |
+
) as f:
|
125 |
+
json.dump(origin_dics, f)
|
126 |
+
|
127 |
+
mask = (
|
128 |
+
~df["smiles"].isna()
|
129 |
+
& (
|
130 |
+
(~df[context].isna()).all(axis=1)
|
131 |
+
if remove_nan_context_rows
|
132 |
+
else np.ones(len(df["smiles"]), dtype=bool)
|
133 |
+
)
|
134 |
+
& ~df["scaffolds"].isna()
|
135 |
+
)
|
136 |
+
else:
|
137 |
+
mask = ~df["smiles"].isna()
|
138 |
+
error_count = np.count_nonzero(~mask)
|
139 |
+
df = df[mask]
|
140 |
+
# print("HELLO")
|
141 |
+
# print("***"*10)
|
142 |
+
|
143 |
+
# tokenizer.batch_encode_plus()
|
144 |
+
|
145 |
+
# df["scaffolds"] = df["scaffolds"].swifter.apply(
|
146 |
+
# partial(_tokenize_scaffolds, tokenizer=tokenizer, log_output=False)
|
147 |
+
# )
|
148 |
+
# df["scaffolds"] = df["scaffolds"].swifter.apply(
|
149 |
+
# partial(_tokenize_scaffolds, tokenizer=tokenizer, log_output=False)
|
150 |
+
# )
|
151 |
+
df["tokens"] = df["smiles"].swifter.apply(
|
152 |
+
partial(_tokenize_smiles, tokenizer=tokenizer, log_output=False)
|
153 |
+
)
|
154 |
+
df["scaffolds"] = df["tokens"].copy()
|
155 |
+
|
156 |
+
mask = ~df["tokens"].isna() & ~df["scaffolds"].isna()
|
157 |
+
df = df[mask]
|
158 |
+
error_count += np.count_nonzero(~mask)
|
159 |
+
|
160 |
+
# Shuffle the data
|
161 |
+
df = df.sample(frac=1).reset_index(drop=True)
|
162 |
+
# with Pool(cpu_count) as p:
|
163 |
+
# df["scaffolds"] = list(
|
164 |
+
|
165 |
+
# p.map(partial( _tokenize_scaffolds ,tokenizer=tokenizer, log_output=False), tqdm(df.smiles.to_numpy(),total=len(df)), chunksize=1000),
|
166 |
+
|
167 |
+
# )
|
168 |
+
|
169 |
+
# df["smiles"] = list(
|
170 |
+
# p.map(partial( _tokenize_smiles ,tokenizer=tokenizer, log_output=False), tqdm(df.smiles.to_numpy(),total=len(df)), chunksize=1000),
|
171 |
+
# )
|
172 |
+
|
173 |
+
if context is not None:
|
174 |
+
context_list = df[context].to_numpy()
|
175 |
+
context_dict = {k: context_list[:, i] for i, k in enumerate(context)}
|
176 |
+
else:
|
177 |
+
context_dict = {}
|
178 |
+
|
179 |
+
print(f"Error count: {error_count} / {orig_len} = {error_count/orig_len}")
|
180 |
+
|
181 |
+
cache_path = os.path.join(os.path.dirname(__file__), ".cache")
|
182 |
+
os.makedirs(cache_path, exist_ok=True)
|
183 |
+
out_path = os.path.join(cache_path, f"{out_name}_{limit}.pkl")
|
184 |
+
with open(out_path, "wb") as f:
|
185 |
+
pickle.dump(
|
186 |
+
{
|
187 |
+
"tokens": df["tokens"].tolist(),
|
188 |
+
"smiles": df["smiles"].tolist(),
|
189 |
+
"scaf": df["scaffolds"].tolist(),
|
190 |
+
**context_dict,
|
191 |
+
},
|
192 |
+
f,
|
193 |
+
)
|
194 |
+
print(f"Saved to {out_path}")
|
195 |
+
print("Done.")
|
196 |
+
|
197 |
+
|
198 |
+
class PretokDataset(torch.utils.data.Dataset):
|
199 |
+
"""Loads pretokenized example from disk and returns them as PyTorch tensors."""
|
200 |
+
|
201 |
+
def __init__(self, split, pad_token_id, dataset="processed_dataset.pkl"):
|
202 |
+
super().__init__()
|
203 |
+
self.split = split
|
204 |
+
self.dataset = dataset
|
205 |
+
self.pad_token_id = pad_token_id
|
206 |
+
cache_path = os.path.join(os.path.dirname(__file__), ".cache")
|
207 |
+
with open(os.path.join(cache_path, self.dataset), "rb") as f:
|
208 |
+
self.data_dict = pickle.load(f)
|
209 |
+
|
210 |
+
# split out 10% of the data for validation
|
211 |
+
split_ix = int(len(self.data_dict["tokens"]) * 0.9)
|
212 |
+
if self.split == "train":
|
213 |
+
self.data_dict = {k: self.data_dict[k][:split_ix] for k in self.data_dict}
|
214 |
+
elif self.split == "val":
|
215 |
+
self.data_dict = {k: self.data_dict[k][split_ix:] for k in self.data_dict}
|
216 |
+
else:
|
217 |
+
raise RuntimeError(f"Could not find split for: self.split={self.split}")
|
218 |
+
|
219 |
+
def __len__(self):
|
220 |
+
return len(self.data_dict["tokens"])
|
221 |
+
|
222 |
+
def __getitem__(self, idx):
|
223 |
+
m = self.data_dict
|
224 |
+
|
225 |
+
start = idx
|
226 |
+
end = idx + 1
|
227 |
+
|
228 |
+
# calling .astype will copy the data into a new numpy array, now in RAM
|
229 |
+
padded_tokens = pad_batch(m["tokens"][start:end], self.pad_token_id)
|
230 |
+
chunk = torch.from_numpy((padded_tokens).astype(np.int64))
|
231 |
+
|
232 |
+
padded_scaffolds = torch.from_numpy(
|
233 |
+
pad_batch(m["scaf"][start:end], self.pad_token_id).astype(np.int64)
|
234 |
+
)
|
235 |
+
|
236 |
+
item = {
|
237 |
+
"seq": chunk,
|
238 |
+
"scaf": padded_scaffolds,
|
239 |
+
"smiles": m["smiles"][start:end],
|
240 |
+
**{
|
241 |
+
k: torch.tensor(m[k][start:end], dtype=torch.float32)
|
242 |
+
for k in m
|
243 |
+
if k != "scaf" and k != "tokens" and k != "smiles"
|
244 |
+
},
|
245 |
+
}
|
246 |
+
|
247 |
+
return item
|
248 |
+
|
249 |
+
|
250 |
+
def padding_collate_fn(
|
251 |
+
data, tokenizer: SmilesTokenizer, fragment_creator: BaseFragmentCreator
|
252 |
+
):
|
253 |
+
# data = list of dicts
|
254 |
+
pad_idx = tokenizer.pad_token_id
|
255 |
+
|
256 |
+
src = [d["seq"] for d in data]
|
257 |
+
|
258 |
+
max_len = max([len(d) for d in src])
|
259 |
+
padded_src = np.ones([len(src), max_len]) * pad_idx
|
260 |
+
for i, j in enumerate(src):
|
261 |
+
padded_src[i][0 : len(j)] = j.ravel()
|
262 |
+
|
263 |
+
if fragment_creator is None:
|
264 |
+
smiles_context = [d["scaf"] for d in data]
|
265 |
+
else:
|
266 |
+
# Remove start and end token after tokenization with [1:-1 ]
|
267 |
+
smiles_context = []
|
268 |
+
for d in data:
|
269 |
+
s = d["smiles"][0]
|
270 |
+
tokens = d["seq"]
|
271 |
+
frag = fragment_creator.create_fragment(Fragment(smiles=s, tokens=tokens))
|
272 |
+
if frag.tokens is not None:
|
273 |
+
smiles_context.append(frag.tokens)
|
274 |
+
else:
|
275 |
+
smiles_context.append(
|
276 |
+
torch.tensor(
|
277 |
+
tokenizer.encode(frag.smiles)[1:-1],
|
278 |
+
dtype=torch.long,
|
279 |
+
device=tokens.device,
|
280 |
+
)
|
281 |
+
)
|
282 |
+
|
283 |
+
max_len_ctx = max([len(d) for d in smiles_context])
|
284 |
+
padded_smiles_context = np.ones([len(smiles_context), max_len_ctx]) * pad_idx
|
285 |
+
for i, j in enumerate(smiles_context):
|
286 |
+
padded_smiles_context[i][0 : len(j)] = j.ravel()
|
287 |
+
# try to predict the next token from the previouse tokens
|
288 |
+
# essentially reconstructing the src sentence from the embeddings and the previous sentence
|
289 |
+
padded_src = padded_src.T
|
290 |
+
|
291 |
+
original_context_keys = [
|
292 |
+
k for k in data[0].keys() if k != "seq" and k != "scaf" and k != "smiles"
|
293 |
+
]
|
294 |
+
context_out_dict = {k: [] for k in original_context_keys}
|
295 |
+
|
296 |
+
for k in original_context_keys:
|
297 |
+
val_list = []
|
298 |
+
for d in data:
|
299 |
+
val_list.append(d[k])
|
300 |
+
|
301 |
+
context_out_dict[k] = torch.concat(val_list, dim=0)
|
302 |
+
|
303 |
+
return {
|
304 |
+
"src": torch.tensor(padded_src, dtype=torch.long), # for (seq_len, batch_size)
|
305 |
+
"fragment": torch.tensor(padded_smiles_context.T, dtype=torch.long),
|
306 |
+
"context": context_out_dict,
|
307 |
+
}
|
308 |
+
|
309 |
+
|
310 |
+
class SmilesTask:
|
311 |
+
@staticmethod
|
312 |
+
def iter_batches(
|
313 |
+
split,
|
314 |
+
batch_size,
|
315 |
+
device,
|
316 |
+
context_keys: List[str],
|
317 |
+
num_workers=0,
|
318 |
+
dataset="processed_dataset.pkl",
|
319 |
+
fragment_creator: BaseFragmentCreator = BricksFragmentCreator(),
|
320 |
+
):
|
321 |
+
tokenizer = SmilesTokenizer()
|
322 |
+
ds = PretokDataset(split, tokenizer.pad_token_id, dataset=dataset)
|
323 |
+
is_ddp = int(os.environ.get("RANK", -1)) != -1
|
324 |
+
dl = torch.utils.data.DataLoader(
|
325 |
+
ds,
|
326 |
+
batch_size=batch_size,
|
327 |
+
pin_memory=True,
|
328 |
+
num_workers=num_workers,
|
329 |
+
shuffle=False,
|
330 |
+
sampler=DistributedSampler(ds) if is_ddp else None,
|
331 |
+
collate_fn=lambda batch: padding_collate_fn(
|
332 |
+
batch, tokenizer, fragment_creator
|
333 |
+
),
|
334 |
+
)
|
335 |
+
|
336 |
+
for data in dl:
|
337 |
+
data["src"] = data["src"].to(device, non_blocking=True)
|
338 |
+
data["tgt"] = data["src"].to(device, non_blocking=True)
|
339 |
+
|
340 |
+
data["src"] = data["src"][:-1, :].T # batch_size, seq_len
|
341 |
+
data["tgt"] = data["tgt"][1:, :].T # batch_size, seq_len
|
342 |
+
|
343 |
+
data["fragment"] = (
|
344 |
+
data["fragment"].to(device, non_blocking=True).T
|
345 |
+
) # batch_size, seq_len
|
346 |
+
keys = list(data["context"].keys())
|
347 |
+
for d in keys:
|
348 |
+
if d not in context_keys:
|
349 |
+
del data["context"][d]
|
350 |
+
else:
|
351 |
+
data["context"][d] = data["context"][d].to(
|
352 |
+
device, non_blocking=True
|
353 |
+
)
|
354 |
+
|
355 |
+
yield data
|
356 |
+
|
357 |
+
|
358 |
+
if __name__ == "__main__":
|
359 |
+
|
360 |
+
pretokenize(
|
361 |
+
data_file=os.path.join(
|
362 |
+
DATA_CACHE_DIR,
|
363 |
+
"OrganiX13.parquet",
|
364 |
+
),
|
365 |
+
limit=None, # Set how many molecules should be processed, if None all molecules will be processed,
|
366 |
+
context=["logp", "sascore", "mol_weight"],
|
367 |
+
out_name="processed_dataset",
|
368 |
+
remove_nan_context_rows=False,
|
369 |
+
)
|
370 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy==1.23.5
|
2 |
+
pytest==7.4.0
|
3 |
+
Requests==2.31.0
|
4 |
+
sentencepiece==0.1.99
|
5 |
+
tiktoken==0.3.3
|
6 |
+
torch==2.0.1
|
7 |
+
tqdm==4.64.1
|
8 |
+
wandb==0.15.5
|
sample.py
ADDED
@@ -0,0 +1,616 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from contextlib import nullcontext
|
3 |
+
import sys
|
4 |
+
import time
|
5 |
+
import pandas as pd
|
6 |
+
import torch
|
7 |
+
from tqdm.auto import tqdm
|
8 |
+
|
9 |
+
# from tqdm.notebook import tqdm
|
10 |
+
from model import Transformer
|
11 |
+
from plot_utils import (
|
12 |
+
check_metrics,
|
13 |
+
plot_1D_condition,
|
14 |
+
plot_2D_condition,
|
15 |
+
plot_3D_condition,
|
16 |
+
plot_unconditional,
|
17 |
+
)
|
18 |
+
from tokenizer import SmilesTokenizer
|
19 |
+
import numpy as np
|
20 |
+
from typing import Dict, List, Tuple, Union
|
21 |
+
import re
|
22 |
+
|
23 |
+
from rdkit import Chem
|
24 |
+
from rdkit import DataStructs
|
25 |
+
from rdkit.Chem.Fingerprints import FingerprintMols
|
26 |
+
|
27 |
+
import logging
|
28 |
+
|
29 |
+
logger = logging.getLogger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
class Sampler:
|
33 |
+
def __init__(
|
34 |
+
self,
|
35 |
+
load_path: str,
|
36 |
+
device: str = "cpu",
|
37 |
+
seed: int = 1337,
|
38 |
+
dtype: str = "float16",
|
39 |
+
compile: bool = True,
|
40 |
+
quantize: bool = False,
|
41 |
+
) -> None:
|
42 |
+
self.load_path = load_path
|
43 |
+
self.device = device
|
44 |
+
self.dtype = dtype
|
45 |
+
self.compile = compile
|
46 |
+
self.quantize = quantize
|
47 |
+
self.seed = seed
|
48 |
+
self._init_model()
|
49 |
+
|
50 |
+
def _init_model(self):
|
51 |
+
np.random.seed(self.seed)
|
52 |
+
torch.cuda.manual_seed(self.seed)
|
53 |
+
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
54 |
+
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
55 |
+
self.device_type = (
|
56 |
+
"cuda" if "cuda" in self.device else "cpu"
|
57 |
+
) # for later use in torch.autocast
|
58 |
+
ptdtype = {
|
59 |
+
"float32": torch.float32,
|
60 |
+
"bfloat16": torch.bfloat16,
|
61 |
+
"float16": torch.float16,
|
62 |
+
}[self.dtype]
|
63 |
+
self.ptdtype = ptdtype
|
64 |
+
|
65 |
+
self.ctx = self._autocast()
|
66 |
+
# init from a model saved in a specific directory
|
67 |
+
# ckpt_path = os.path.join(out_dir, "ckpt_full_dim=256.pt")
|
68 |
+
self.model = Transformer.load(self.load_path, device=self.device)
|
69 |
+
|
70 |
+
self.model.eval()
|
71 |
+
if self.quantize:
|
72 |
+
raise NotImplementedError("Not properly implemented for CPU / GPU")
|
73 |
+
self.model = torch.ao.quantization.quantize_dynamic(
|
74 |
+
self.model, # the original model
|
75 |
+
{torch.nn.Linear}, # a set of layers to dynamically quantize
|
76 |
+
dtype=torch.qint8,
|
77 |
+
)
|
78 |
+
|
79 |
+
if self.compile:
|
80 |
+
logger.info("Compiling the model...")
|
81 |
+
self.model = torch.compile(self.model) # requires PyTorch 2.0 (optional)
|
82 |
+
|
83 |
+
self.model = self.model.to(self.device)
|
84 |
+
# load the tokenizer
|
85 |
+
self.tokenizer = SmilesTokenizer()
|
86 |
+
|
87 |
+
def get_context(
|
88 |
+
self,
|
89 |
+
context_col: List[str],
|
90 |
+
context_smi: str,
|
91 |
+
num_examples: int = 50,
|
92 |
+
):
|
93 |
+
"""
|
94 |
+
Returns a dictionary in the form of
|
95 |
+
{
|
96 |
+
"fragment": torch.tensor,
|
97 |
+
"context": {
|
98 |
+
"logp": torch.tensor,
|
99 |
+
"sascore": torch.tensor,
|
100 |
+
"mol_weight": torch.tensor
|
101 |
+
}
|
102 |
+
}
|
103 |
+
|
104 |
+
|
105 |
+
When context_smi is set to a string, then the "fragment" field is populated.
|
106 |
+
All of the properties listed in the context_col list is set to the keys and the values are set to a resonable range for each property.
|
107 |
+
|
108 |
+
num_examples indicates how many values are sampled for each property.
|
109 |
+
"""
|
110 |
+
output_dict = {"context": {}, "fragment": None}
|
111 |
+
|
112 |
+
if context_smi is not None:
|
113 |
+
logger.debug(
|
114 |
+
f"context_smiles: {context_smi}",
|
115 |
+
)
|
116 |
+
# NOTE: Remove beginning [CLS] and end token [SEP]
|
117 |
+
incorporate_selfie = self.tokenizer.encode(context_smi)[1:-1]
|
118 |
+
|
119 |
+
context = torch.tensor(
|
120 |
+
[incorporate_selfie] * num_examples,
|
121 |
+
dtype=torch.long,
|
122 |
+
device=self.device,
|
123 |
+
)
|
124 |
+
|
125 |
+
output_dict["fragment"] = context
|
126 |
+
|
127 |
+
if context_col is None:
|
128 |
+
return output_dict
|
129 |
+
|
130 |
+
if "logp" in context_col:
|
131 |
+
# context = 0.5 * torch.randint(
|
132 |
+
# -8, 14, (num_examples,), device=self.device, dtype=torch.float
|
133 |
+
# )
|
134 |
+
# context = 0.5 * torch.randint(
|
135 |
+
# -6, 6, (num_examples, 1), device=device, dtype=torch.float
|
136 |
+
# )
|
137 |
+
context = torch.tensor(
|
138 |
+
np.random.choice([-2, 0, 2], (num_examples,)),
|
139 |
+
device=self.device,
|
140 |
+
dtype=self.ptdtype,
|
141 |
+
)
|
142 |
+
# context = 2.0 * torch.ones(
|
143 |
+
# (num_examples,1), device=device, dtype=torch.float
|
144 |
+
# )
|
145 |
+
# context = -2.0*torch.ones((num_examples,2),device=device,dtype=torch.float)
|
146 |
+
# context, _ = torch.sort(context, 0)
|
147 |
+
output_dict["context"]["logp"] = context
|
148 |
+
|
149 |
+
if "energy" in context_col:
|
150 |
+
context = 0.1 * torch.randint(
|
151 |
+
-15, 15, (num_examples,), device=self.device, dtype=torch.float
|
152 |
+
)
|
153 |
+
# context = -2.0*torch.ones((num_examples,2),device=device,dtype=torch.float)
|
154 |
+
context, _ = torch.sort(context, 0)
|
155 |
+
output_dict["context"]["energy"] = context
|
156 |
+
|
157 |
+
if "sascore" in context_col:
|
158 |
+
# context = 0.5 * torch.randint(
|
159 |
+
# 2, 20, (num_examples, ), device=self.device, dtype=torch.float
|
160 |
+
# )
|
161 |
+
context = torch.tensor(
|
162 |
+
np.random.choice([2, 3, 4], (num_examples,)),
|
163 |
+
device=self.device,
|
164 |
+
dtype=torch.float,
|
165 |
+
)
|
166 |
+
# context = 0.5 * torch.randint(
|
167 |
+
# 4, 8, (num_examples, 1), device=device, dtype=torch.float
|
168 |
+
# )
|
169 |
+
# context = 2.0*torch.ones((num_examples,1),device=device,dtype=torch.float)
|
170 |
+
# context, _ = torch.sort(context, 0)
|
171 |
+
output_dict["context"]["sascore"] = context
|
172 |
+
|
173 |
+
if "mol_weight" in context_col:
|
174 |
+
# context = 0.5 * torch.randint(
|
175 |
+
# 2, 20, (num_examples,), device=self.device, dtype=torch.float
|
176 |
+
# )
|
177 |
+
context = torch.tensor(
|
178 |
+
np.random.choice([2.0, 3.0, 4.0], (num_examples,)),
|
179 |
+
device=self.device,
|
180 |
+
dtype=torch.float,
|
181 |
+
)
|
182 |
+
|
183 |
+
# context = 0.5 * torch.randint(
|
184 |
+
# 2, 20, (num_examples, 1), device=device, dtype=torch.float
|
185 |
+
# )
|
186 |
+
# context = 2.5*torch.ones((num_examples,1),device=device,dtype=torch.float)
|
187 |
+
# context, _ = torch.sort(context, 0)
|
188 |
+
output_dict["context"]["mol_weight"] = context
|
189 |
+
|
190 |
+
return output_dict
|
191 |
+
|
192 |
+
def _autocast(self):
|
193 |
+
if "cuda" in self.device:
|
194 |
+
if self.dtype == "bfloat16" and torch.cuda.is_bf16_supported():
|
195 |
+
return torch.cuda.amp.autocast(dtype=torch.bfloat16)
|
196 |
+
elif self.dtype == "float16":
|
197 |
+
return torch.cuda.amp.autocast(dtype=torch.float16)
|
198 |
+
else:
|
199 |
+
return torch.cuda.amp.autocast(dtype=torch.float32)
|
200 |
+
else: # cpu
|
201 |
+
return nullcontext()
|
202 |
+
|
203 |
+
@torch.no_grad()
|
204 |
+
def generate(
|
205 |
+
self,
|
206 |
+
context_cols: Union[List[str], None, Dict[str, torch.Tensor]] = None,
|
207 |
+
context_smi: Union[str, None] = None,
|
208 |
+
start_smiles: Union[str, None] = None,
|
209 |
+
num_samples: int = 50,
|
210 |
+
max_new_tokens: int = 256,
|
211 |
+
temperature: float = 1.0,
|
212 |
+
top_k: Union[int, None] = None,
|
213 |
+
return_context: bool = False,
|
214 |
+
total_gen_steps: int = 1,
|
215 |
+
use_kv_cache: bool = False,
|
216 |
+
) -> Union[List[str], Tuple[List[str], List[float]]]:
|
217 |
+
"""
|
218 |
+
Generates a list of SMILES. With the default options it would generate them unconditionally.
|
219 |
+
Params:
|
220 |
+
- context_cols : When a list the context is randomly sampled from the get_context method, when given a dictionary the
|
221 |
+
context values are taken from the dictionary instead.
|
222 |
+
- context_smi : Further conditioning by the usage of a molecular fragment
|
223 |
+
. start_smiles : Can be used to start the SMILES with a specific string, the model then generates the next tokens including that start sequence.
|
224 |
+
- num_samples : Controlls how many SMILES in total will be generated be the model.
|
225 |
+
- max_new_tokens : Controlls the maximum length of each SMILES (in tokens) that is generated.
|
226 |
+
- temperature: Controlls the randomness of the model. A temperature = 1.0 means it is the trained distribution. A temperature < 1 is more deterministic and temperature > 1 is more random
|
227 |
+
- top_k : Clamps the probability distribution to the top k tokens. From these the next token is then sampled from.
|
228 |
+
- return_context : Whether the context that was given to the model should be returned.
|
229 |
+
- total_gen_steps : In how many sub steps the generation should be split up to. Useful when generation 10k + SMILES and wanting to chunk these into for example 10 * 1k generations with total_gen_steps = 10.
|
230 |
+
- use_kv_cache: Runs the generation using kv-caching. It is faster, but takes more memory.
|
231 |
+
"""
|
232 |
+
|
233 |
+
with self.ctx:
|
234 |
+
gens_per_step = num_samples // total_gen_steps
|
235 |
+
|
236 |
+
logger.debug(f"Gens per Step: {gens_per_step}")
|
237 |
+
context = None # {"context": None, "fragment" : None}
|
238 |
+
out_smiles = []
|
239 |
+
with tqdm(total=total_gen_steps, desc="Batch") as pbar:
|
240 |
+
for i in range(total_gen_steps):
|
241 |
+
if isinstance(context_cols, dict):
|
242 |
+
# TODO: Test if same length
|
243 |
+
cd = {
|
244 |
+
c: context_cols[c][
|
245 |
+
i * gens_per_step : (i + 1) * gens_per_step
|
246 |
+
]
|
247 |
+
for c in context_cols.keys()
|
248 |
+
}
|
249 |
+
|
250 |
+
context_dict = {"context": cd, "fragment": None}
|
251 |
+
if context_smi is not None:
|
252 |
+
logger.debug(
|
253 |
+
f"context_smiles: {context_smi}",
|
254 |
+
)
|
255 |
+
# NOTE: Remove beginning [CLS] and end token [SEP]
|
256 |
+
incorporate_selfie = self.tokenizer.encode(context_smi)[
|
257 |
+
1:-1
|
258 |
+
]
|
259 |
+
|
260 |
+
context_tensor = torch.tensor(
|
261 |
+
[incorporate_selfie] * gens_per_step,
|
262 |
+
dtype=torch.long,
|
263 |
+
device=self.device,
|
264 |
+
)
|
265 |
+
|
266 |
+
context_dict["fragment"] = context_tensor
|
267 |
+
context_cols = list(context_cols.keys())
|
268 |
+
|
269 |
+
else:
|
270 |
+
context_dict = self.get_context(
|
271 |
+
context_cols, context_smi, num_examples=gens_per_step
|
272 |
+
)
|
273 |
+
|
274 |
+
# for k in range(num_samples):
|
275 |
+
y = self.model.generate(
|
276 |
+
self.tokenizer,
|
277 |
+
context=context_dict["context"],
|
278 |
+
fragments=context_dict["fragment"],
|
279 |
+
start_smiles=start_smiles,
|
280 |
+
num_gen=gens_per_step,
|
281 |
+
temperature=temperature,
|
282 |
+
top_k=top_k,
|
283 |
+
max_length=max_new_tokens,
|
284 |
+
device=self.device,
|
285 |
+
cache_kv=use_kv_cache,
|
286 |
+
)
|
287 |
+
|
288 |
+
new_context = {k: [] for k in context_dict["context"]}
|
289 |
+
for i, sample in enumerate(y):
|
290 |
+
# print(sample)
|
291 |
+
mol = Chem.MolFromSmiles(sample)
|
292 |
+
if mol is not None:
|
293 |
+
out_smiles.append(sample)
|
294 |
+
for k in new_context:
|
295 |
+
new_context[k].append(
|
296 |
+
context_dict["context"][k][i].unsqueeze(-1)
|
297 |
+
)
|
298 |
+
|
299 |
+
for k in new_context:
|
300 |
+
new_context[k] = torch.concat(new_context[k], dim=0)
|
301 |
+
|
302 |
+
if context is None:
|
303 |
+
context = new_context
|
304 |
+
else:
|
305 |
+
for k in context:
|
306 |
+
context[k] = torch.concat(
|
307 |
+
[context[k], new_context[k]], dim=0
|
308 |
+
)
|
309 |
+
|
310 |
+
pbar.update(1)
|
311 |
+
|
312 |
+
logger.info(
|
313 |
+
f"Number valid generated: {len(out_smiles) / num_samples * 100} %"
|
314 |
+
)
|
315 |
+
logger.info("---------------")
|
316 |
+
|
317 |
+
if return_context:
|
318 |
+
return (out_smiles, context)
|
319 |
+
|
320 |
+
else:
|
321 |
+
return out_smiles
|
322 |
+
|
323 |
+
@torch.no_grad()
|
324 |
+
def generate_with_evaluation(
|
325 |
+
self,
|
326 |
+
context_cols: Union[List[str], None] = None,
|
327 |
+
context_smi: Union[str, None] = None,
|
328 |
+
start_smiles: Union[str, None] = None,
|
329 |
+
num_samples: int = 50,
|
330 |
+
max_new_tokens: int = 256,
|
331 |
+
temperature: float = 1.0,
|
332 |
+
top_k: Union[int, None] = None,
|
333 |
+
cmp_context_dict: Union[Dict[str, torch.Tensor], None] = None,
|
334 |
+
total_gen_steps: int = 1,
|
335 |
+
use_kv_cache: bool = False,
|
336 |
+
):
|
337 |
+
out_smiles, new_context = self.generate(
|
338 |
+
context_cols=context_cols,
|
339 |
+
context_smi=context_smi,
|
340 |
+
start_smiles=start_smiles,
|
341 |
+
num_samples=num_samples,
|
342 |
+
max_new_tokens=max_new_tokens,
|
343 |
+
temperature=temperature,
|
344 |
+
top_k=top_k,
|
345 |
+
return_context=True,
|
346 |
+
total_gen_steps=total_gen_steps,
|
347 |
+
use_kv_cache=use_kv_cache,
|
348 |
+
)
|
349 |
+
|
350 |
+
out_dir = os.path.dirname(self.load_path)
|
351 |
+
|
352 |
+
if context_cols is not None:
|
353 |
+
if len(context_cols) == 1:
|
354 |
+
plot_1D_condition(
|
355 |
+
context_cols,
|
356 |
+
os.path.join(out_dir, "plots"),
|
357 |
+
new_context,
|
358 |
+
out_smiles,
|
359 |
+
temperature,
|
360 |
+
cmp_context_dict,
|
361 |
+
context_scaler=None,
|
362 |
+
)
|
363 |
+
|
364 |
+
elif len(context_cols) == 2:
|
365 |
+
plot_2D_condition(
|
366 |
+
context_cols,
|
367 |
+
os.path.join(out_dir, "plots"),
|
368 |
+
new_context,
|
369 |
+
out_smiles,
|
370 |
+
temperature,
|
371 |
+
label=context_smi,
|
372 |
+
)
|
373 |
+
|
374 |
+
elif len(context_cols) == 3:
|
375 |
+
plot_3D_condition(
|
376 |
+
context_cols,
|
377 |
+
os.path.join(out_dir, "plots"),
|
378 |
+
new_context,
|
379 |
+
out_smiles,
|
380 |
+
temperature,
|
381 |
+
)
|
382 |
+
|
383 |
+
else:
|
384 |
+
raise NotImplementedError(
|
385 |
+
"Currently not implemented for len(context_col) > 3"
|
386 |
+
)
|
387 |
+
|
388 |
+
else:
|
389 |
+
# Unconditional Case
|
390 |
+
plot_unconditional(
|
391 |
+
out_path=os.path.join(out_dir, "plots"),
|
392 |
+
smiles=out_smiles,
|
393 |
+
temperature=temperature,
|
394 |
+
cmp_context_dict=cmp_context_dict,
|
395 |
+
)
|
396 |
+
|
397 |
+
if context_smi is not None:
|
398 |
+
pattern = r"\[\d+\*\]"
|
399 |
+
# replace [14*] etc
|
400 |
+
context_smi = re.sub(pattern, "", context_smi)
|
401 |
+
|
402 |
+
context_mol = Chem.MolFromSmiles(context_smi)
|
403 |
+
context_smarts = Chem.MolToSmarts(context_mol)
|
404 |
+
|
405 |
+
pattern = r"(?<!\[)([:-=#])(?!\])(?![^\[]*?\])"
|
406 |
+
|
407 |
+
context_smarts = re.sub(pattern, "~", context_smarts)
|
408 |
+
logger.info(f"context_smarts {context_smarts}")
|
409 |
+
out_mols = [Chem.MolFromSmiles(smi) for smi in out_smiles]
|
410 |
+
|
411 |
+
context_fingerprint = FingerprintMols.FingerprintMol(context_mol)
|
412 |
+
out_fingerprints = [FingerprintMols.FingerprintMol(fi) for fi in out_mols]
|
413 |
+
all_sim = []
|
414 |
+
all_sub = []
|
415 |
+
for out_fing, out_mol in zip(out_fingerprints, out_mols):
|
416 |
+
similarity = DataStructs.TanimotoSimilarity(
|
417 |
+
context_fingerprint, out_fing
|
418 |
+
)
|
419 |
+
|
420 |
+
has_sub = out_mol.HasSubstructMatch(Chem.MolFromSmarts(context_smarts))
|
421 |
+
all_sub.append(has_sub)
|
422 |
+
all_sim.append(similarity)
|
423 |
+
|
424 |
+
# print(similarity,has_sub)
|
425 |
+
logger.info(f"Mean sim {np.mean(all_sim)}")
|
426 |
+
logger.info(
|
427 |
+
f"Has Sub: {np.count_nonzero(all_sub)} or {round(np.count_nonzero(all_sub) / len(all_sub) * 100, 4)} %"
|
428 |
+
)
|
429 |
+
|
430 |
+
return out_smiles, new_context
|
431 |
+
|
432 |
+
|
433 |
+
if __name__ == "__main__":
|
434 |
+
import argparse
|
435 |
+
import rdkit.rdBase as rkrb
|
436 |
+
import rdkit.RDLogger as rkl
|
437 |
+
|
438 |
+
logger = rkl.logger()
|
439 |
+
logger.setLevel(rkl.ERROR)
|
440 |
+
rkrb.DisableLog("rdApp.error")
|
441 |
+
|
442 |
+
torch.set_num_threads(8)
|
443 |
+
logging.basicConfig(level=logging.INFO)
|
444 |
+
logger = logging.getLogger(__name__)
|
445 |
+
|
446 |
+
parser = argparse.ArgumentParser(
|
447 |
+
description="Generate SMILES strings using a trained model."
|
448 |
+
)
|
449 |
+
# parser.add_argument('--context_cols', type=str, nargs='+', default=None)
|
450 |
+
parser.add_argument(
|
451 |
+
"--context_cols",
|
452 |
+
type=str,
|
453 |
+
nargs="+",
|
454 |
+
default=None,
|
455 |
+
help="The given conditions are sampled from a fixed interval and given to the modeĺ.",
|
456 |
+
)
|
457 |
+
parser.add_argument(
|
458 |
+
"--context_smi",
|
459 |
+
type=str,
|
460 |
+
default=None,
|
461 |
+
help="This SMILES is given as context to the model and should be integrated in the generated molecules.",
|
462 |
+
)
|
463 |
+
parser.add_argument(
|
464 |
+
"--start_smiles",
|
465 |
+
type=str,
|
466 |
+
default=None,
|
467 |
+
help="This SMILES is placed at the front of each sample, from which on the generation continues.",
|
468 |
+
)
|
469 |
+
parser.add_argument(
|
470 |
+
"--ckpt_path",
|
471 |
+
type=str,
|
472 |
+
default=os.path.join(os.path.dirname(__file__), "out", "llama2-M-Full-RSS.pt"),
|
473 |
+
help="Which model should be used in the generation",
|
474 |
+
)
|
475 |
+
parser.add_argument(
|
476 |
+
"--num_samples",
|
477 |
+
type=int,
|
478 |
+
default=50,
|
479 |
+
help="Controls how many samples should be generated",
|
480 |
+
)
|
481 |
+
parser.add_argument(
|
482 |
+
"--num_samples_per_step",
|
483 |
+
type=int,
|
484 |
+
default=1000,
|
485 |
+
help="Works in conjunction with num_samples, by splitting the total into num_samples_per_step jobs. When num_samples > num_samples_per_step then it is split up into multiple seperate generation steps.",
|
486 |
+
)
|
487 |
+
|
488 |
+
parser.add_argument(
|
489 |
+
"--max_new_tokens",
|
490 |
+
type=int,
|
491 |
+
default=256,
|
492 |
+
help="Sets how many tokens should be generated from the model. We only trained with a max size of 256, but it is possible to generate longer molecules. However, these might be worse in quality.",
|
493 |
+
)
|
494 |
+
parser.add_argument(
|
495 |
+
"--temperature",
|
496 |
+
type=float,
|
497 |
+
default=0.8,
|
498 |
+
help="Sets the randomness of the generation - A temperature of 0 would be deterministic and a temperature of > 1 is more random.",
|
499 |
+
)
|
500 |
+
parser.add_argument(
|
501 |
+
"--top_k",
|
502 |
+
type=int,
|
503 |
+
default=None,
|
504 |
+
help="The top_k of the sampling. Per default it is None, but can be set to an integer to have a more focused generation.",
|
505 |
+
)
|
506 |
+
parser.add_argument(
|
507 |
+
"--seed",
|
508 |
+
type=int,
|
509 |
+
default=1234,
|
510 |
+
help="Random number generator seed, to make sampling consistent.",
|
511 |
+
)
|
512 |
+
parser.add_argument(
|
513 |
+
"--cmp_dataset_path",
|
514 |
+
type=str,
|
515 |
+
default=None,
|
516 |
+
help="A dataset in parquet or csv format to be used in the sample plots and to compute the metrics such as the novelty.",
|
517 |
+
)
|
518 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
519 |
+
parser.add_argument(
|
520 |
+
"--device",
|
521 |
+
type=str,
|
522 |
+
default=device,
|
523 |
+
help="Change the device the model and generation is run on",
|
524 |
+
)
|
525 |
+
|
526 |
+
if "cuda" in device:
|
527 |
+
# dtype = 'bfloat16' if torch.cuda.is_available() and torch.cuda.is_bf16_supported() else 'float16'
|
528 |
+
dtype = "float16" if torch.cuda.is_available() else "float32"
|
529 |
+
else:
|
530 |
+
dtype = "float32"
|
531 |
+
|
532 |
+
parser.add_argument(
|
533 |
+
"--dtype",
|
534 |
+
type=str,
|
535 |
+
default=dtype,
|
536 |
+
help="Change the datatype of the computation. Per default it is float32 on CPU and float16 on GPU",
|
537 |
+
)
|
538 |
+
parser.add_argument(
|
539 |
+
"--compile",
|
540 |
+
type=bool,
|
541 |
+
default=True,
|
542 |
+
help="Use torch.compile to compile the model. Only works on torch>=2.0, but should make the inference faster.",
|
543 |
+
)
|
544 |
+
parser.add_argument(
|
545 |
+
"--quantize",
|
546 |
+
type=bool,
|
547 |
+
default=False,
|
548 |
+
help="(CURRENTLY NOT WORKING) Enable quantization to in8.",
|
549 |
+
)
|
550 |
+
parser.add_argument(
|
551 |
+
"--kv_caching",
|
552 |
+
action="store_true",
|
553 |
+
default=False,
|
554 |
+
help="Makes the attention mechanism linear, because the old keys and values are cached. The drawback is higher memory consumption.",
|
555 |
+
)
|
556 |
+
args = parser.parse_args()
|
557 |
+
|
558 |
+
logger.info("Sampling with the following parameters:")
|
559 |
+
logger.info(f"Checkpoint: {args.ckpt_path}")
|
560 |
+
logger.info(f"Context columns: {args.context_cols}")
|
561 |
+
logger.info(f"Context SMILES: {args.context_smi}")
|
562 |
+
logger.info(f"Start SMILES: {args.start_smiles}")
|
563 |
+
logger.info(f"Number of samples: {args.num_samples}")
|
564 |
+
logger.info(f"Max new tokens: {args.max_new_tokens}")
|
565 |
+
logger.info(f"Temperature: {args.temperature}")
|
566 |
+
logger.info(f"Top k: {args.top_k}")
|
567 |
+
logger.info(f"Seed: {args.seed}")
|
568 |
+
logger.info(f"Device: {args.device}")
|
569 |
+
logger.info(f"Data type: {args.dtype}")
|
570 |
+
logger.info(f"Compile: {args.compile}")
|
571 |
+
logger.info(f"Comparison dataset path: {args.cmp_dataset_path}")
|
572 |
+
logger.info(f"Quantize: {args.quantize}")
|
573 |
+
logger.info(f"Key Value Caching Enabled: {args.kv_caching}")
|
574 |
+
|
575 |
+
sampler = Sampler(
|
576 |
+
load_path=os.path.join(os.path.dirname(__file__), args.ckpt_path),
|
577 |
+
device=args.device,
|
578 |
+
seed=args.seed,
|
579 |
+
dtype=args.dtype,
|
580 |
+
compile=args.compile,
|
581 |
+
quantize=args.quantize,
|
582 |
+
)
|
583 |
+
|
584 |
+
comp_context_dict = None
|
585 |
+
comp_smiles = None
|
586 |
+
if args.cmp_dataset_path is not None:
|
587 |
+
df_comp = pd.read_parquet(args.cmp_dataset_path)
|
588 |
+
df_comp = df_comp.sample(n=2_500_000)
|
589 |
+
comp_context_dict = {
|
590 |
+
c: df_comp[c].to_numpy() for c in ["logp", "sascore", "mol_weight"]
|
591 |
+
}
|
592 |
+
comp_smiles = df_comp["smiles"]
|
593 |
+
|
594 |
+
measure_time = True
|
595 |
+
start_time = time.time()
|
596 |
+
smiles, context = sampler.generate_with_evaluation(
|
597 |
+
context_cols=args.context_cols,
|
598 |
+
context_smi=args.context_smi,
|
599 |
+
start_smiles=args.start_smiles,
|
600 |
+
num_samples=args.num_samples,
|
601 |
+
max_new_tokens=args.max_new_tokens,
|
602 |
+
temperature=args.temperature,
|
603 |
+
top_k=args.top_k,
|
604 |
+
cmp_context_dict=comp_context_dict,
|
605 |
+
total_gen_steps=int(np.ceil(args.num_samples / args.num_samples_per_step)),
|
606 |
+
use_kv_cache=args.kv_caching,
|
607 |
+
)
|
608 |
+
end_time = time.time()
|
609 |
+
if measure_time:
|
610 |
+
logger.info(f"Generation took: {end_time - start_time} sec")
|
611 |
+
if comp_smiles is not None:
|
612 |
+
res_metrics = check_metrics(smiles, comp_smiles)
|
613 |
+
logger.info(f"Metrics: {res_metrics}")
|
614 |
+
logger.info("Generated Molecules:")
|
615 |
+
for s in smiles:
|
616 |
+
print(s)
|
tokenizer.py
ADDED
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Requriments - transformers, tokenizers
|
2 |
+
# Right now, the Smiles Tokenizer uses an exiesting vocab file from rxnfp that is fairly comprehensive and from the USPTO dataset.
|
3 |
+
# The vocab may be expanded in the near future
|
4 |
+
|
5 |
+
# Code taken from here: https://github.com/deepchem/deepchem/blob/2.4.0/deepchem/feat/smiles_tokenizer.py#L39-L282
|
6 |
+
import collections
|
7 |
+
import os
|
8 |
+
import re
|
9 |
+
import pkg_resources
|
10 |
+
from typing import List
|
11 |
+
from transformers import BertTokenizer
|
12 |
+
from logging import getLogger
|
13 |
+
|
14 |
+
logger = getLogger(__name__)
|
15 |
+
"""
|
16 |
+
SMI_REGEX_PATTERN: str
|
17 |
+
SMILES regex pattern for tokenization. Designed by Schwaller et. al.
|
18 |
+
|
19 |
+
References
|
20 |
+
|
21 |
+
.. [1] Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee
|
22 |
+
ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction
|
23 |
+
1572-1583 DOI: 10.1021/acscentsci.9b00576
|
24 |
+
|
25 |
+
"""
|
26 |
+
|
27 |
+
SMI_REGEX_PATTERN = r"""(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"""
|
28 |
+
|
29 |
+
# add vocab_file dict
|
30 |
+
VOCAB_FILES_NAMES = {"vocab_file": "vocab.txt"}
|
31 |
+
|
32 |
+
|
33 |
+
def get_default_tokenizer():
|
34 |
+
default_vocab_path = pkg_resources.resource_filename(
|
35 |
+
"deepchem", "feat/tests/vocab.txt"
|
36 |
+
)
|
37 |
+
return SmilesTokenizer(default_vocab_path)
|
38 |
+
|
39 |
+
|
40 |
+
class SmilesTokenizer(BertTokenizer):
|
41 |
+
"""
|
42 |
+
Creates the SmilesTokenizer class. The tokenizer heavily inherits from the BertTokenizer
|
43 |
+
implementation found in Huggingface's transformers library. It runs a WordPiece tokenization
|
44 |
+
algorithm over SMILES strings using the tokenisation SMILES regex developed by Schwaller et. al.
|
45 |
+
|
46 |
+
Please see https://github.com/huggingface/transformers
|
47 |
+
and https://github.com/rxn4chemistry/rxnfp for more details.
|
48 |
+
|
49 |
+
Examples
|
50 |
+
--------
|
51 |
+
>>> from deepchem.feat.smiles_tokenizer import SmilesTokenizer
|
52 |
+
>>> current_dir = os.path.dirname(os.path.realpath(__file__))
|
53 |
+
>>> vocab_path = os.path.join(current_dir, 'tests/data', 'vocab.txt')
|
54 |
+
>>> tokenizer = SmilesTokenizer(vocab_path)
|
55 |
+
>>> print(tokenizer.encode("CC(=O)OC1=CC=CC=C1C(=O)O"))
|
56 |
+
[12, 16, 16, 17, 22, 19, 18, 19, 16, 20, 22, 16, 16, 22, 16, 16, 22, 16, 20, 16, 17, 22, 19, 18, 19, 13]
|
57 |
+
|
58 |
+
|
59 |
+
References
|
60 |
+
----------
|
61 |
+
.. [1] Schwaller, Philippe; Probst, Daniel; Vaucher, Alain C.; Nair, Vishnu H; Kreutter, David;
|
62 |
+
Laino, Teodoro; et al. (2019): Mapping the Space of Chemical Reactions using Attention-Based Neural
|
63 |
+
Networks. ChemRxiv. Preprint. https://doi.org/10.26434/chemrxiv.9897365.v3
|
64 |
+
|
65 |
+
Notes
|
66 |
+
----
|
67 |
+
This class requires huggingface's transformers and tokenizers libraries to be installed.
|
68 |
+
"""
|
69 |
+
|
70 |
+
vocab_files_names = VOCAB_FILES_NAMES
|
71 |
+
|
72 |
+
def __init__(
|
73 |
+
self,
|
74 |
+
# unk_token="[UNK]",
|
75 |
+
# sep_token="[SEP]",
|
76 |
+
# pad_token="[PAD]",
|
77 |
+
# cls_token="[CLS]",
|
78 |
+
# mask_token="[MASK]",
|
79 |
+
**kwargs
|
80 |
+
):
|
81 |
+
"""Constructs a SmilesTokenizer.
|
82 |
+
|
83 |
+
Parameters
|
84 |
+
----------
|
85 |
+
vocab_file: str
|
86 |
+
Path to a SMILES character per line vocabulary file.
|
87 |
+
Default vocab file is found in deepchem/feat/tests/data/vocab.txt
|
88 |
+
"""
|
89 |
+
|
90 |
+
vocab_file = os.path.join(os.path.dirname(__file__), "data", "vocab.txt")
|
91 |
+
|
92 |
+
super().__init__(vocab_file, **kwargs)
|
93 |
+
|
94 |
+
self.sos = "[SOS]"
|
95 |
+
self.eos = "[EOS]"
|
96 |
+
|
97 |
+
if not os.path.isfile(vocab_file):
|
98 |
+
raise ValueError("Can't find a vocab file at path '{}'.".format(vocab_file))
|
99 |
+
self.vocab = load_vocab(vocab_file)
|
100 |
+
self.highest_unused_index = max(
|
101 |
+
[i for i, v in enumerate(self.vocab.keys()) if v.startswith("[unused")]
|
102 |
+
)
|
103 |
+
self.ids_to_tokens = collections.OrderedDict(
|
104 |
+
[(ids, tok) for tok, ids in self.vocab.items()]
|
105 |
+
)
|
106 |
+
self.basic_tokenizer = BasicSmilesTokenizer()
|
107 |
+
|
108 |
+
@property
|
109 |
+
def vocab_size(self):
|
110 |
+
return len(self.vocab)
|
111 |
+
|
112 |
+
@property
|
113 |
+
def vocab_list(self):
|
114 |
+
return list(self.vocab.keys())
|
115 |
+
|
116 |
+
def _tokenize(self, text: str):
|
117 |
+
"""
|
118 |
+
Tokenize a string into a list of tokens.
|
119 |
+
|
120 |
+
Parameters
|
121 |
+
----------
|
122 |
+
text: str
|
123 |
+
Input string sequence to be tokenized.
|
124 |
+
"""
|
125 |
+
|
126 |
+
split_tokens = [token for token in self.basic_tokenizer.tokenize(text)]
|
127 |
+
return split_tokens
|
128 |
+
|
129 |
+
def _convert_token_to_id(self, token):
|
130 |
+
"""
|
131 |
+
Converts a token (str/unicode) in an id using the vocab.
|
132 |
+
|
133 |
+
Parameters
|
134 |
+
----------
|
135 |
+
token: str
|
136 |
+
String token from a larger sequence to be converted to a numerical id.
|
137 |
+
"""
|
138 |
+
|
139 |
+
return self.vocab.get(token, self.vocab.get(self.unk_token))
|
140 |
+
|
141 |
+
def _convert_id_to_token(self, index):
|
142 |
+
"""
|
143 |
+
Converts an index (integer) in a token (string/unicode) using the vocab.
|
144 |
+
|
145 |
+
Parameters
|
146 |
+
----------
|
147 |
+
index: int
|
148 |
+
Integer index to be converted back to a string-based token as part of a larger sequence.
|
149 |
+
"""
|
150 |
+
|
151 |
+
return self.ids_to_tokens.get(index, self.unk_token)
|
152 |
+
|
153 |
+
def convert_tokens_to_string(self, tokens: List[str]):
|
154 |
+
"""Converts a sequence of tokens (string) in a single string.
|
155 |
+
|
156 |
+
Parameters
|
157 |
+
----------
|
158 |
+
tokens: List[str]
|
159 |
+
List of tokens for a given string sequence.
|
160 |
+
|
161 |
+
Returns
|
162 |
+
-------
|
163 |
+
out_string: str
|
164 |
+
Single string from combined tokens.
|
165 |
+
"""
|
166 |
+
|
167 |
+
out_string: str = " ".join(tokens).replace(" ##", "").strip()
|
168 |
+
return out_string
|
169 |
+
|
170 |
+
def add_special_tokens_ids_single_sequence(self, token_ids: List[int]):
|
171 |
+
"""
|
172 |
+
Adds special tokens to the a sequence for sequence classification tasks.
|
173 |
+
A BERT sequence has the following format: [CLS] X [SEP]
|
174 |
+
|
175 |
+
Parameters
|
176 |
+
----------
|
177 |
+
|
178 |
+
token_ids: list[int]
|
179 |
+
list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
|
180 |
+
"""
|
181 |
+
|
182 |
+
return [self.cls_token_id] + token_ids + [self.sep_token_id]
|
183 |
+
|
184 |
+
def add_special_tokens_single_sequence(self, tokens: List[str]):
|
185 |
+
"""
|
186 |
+
Adds special tokens to the a sequence for sequence classification tasks.
|
187 |
+
A BERT sequence has the following format: [CLS] X [SEP]
|
188 |
+
|
189 |
+
Parameters
|
190 |
+
----------
|
191 |
+
tokens: List[str]
|
192 |
+
List of tokens for a given string sequence.
|
193 |
+
|
194 |
+
"""
|
195 |
+
return [self.cls_token] + tokens + [self.sep_token]
|
196 |
+
|
197 |
+
def add_special_tokens_ids_sequence_pair(
|
198 |
+
self, token_ids_0: List[int], token_ids_1: List[int]
|
199 |
+
) -> List[int]:
|
200 |
+
"""
|
201 |
+
Adds special tokens to a sequence pair for sequence classification tasks.
|
202 |
+
A BERT sequence pair has the following format: [CLS] A [SEP] B [SEP]
|
203 |
+
|
204 |
+
Parameters
|
205 |
+
----------
|
206 |
+
token_ids_0: List[int]
|
207 |
+
List of ids for the first string sequence in the sequence pair (A).
|
208 |
+
|
209 |
+
token_ids_1: List[int]
|
210 |
+
List of tokens for the second string sequence in the sequence pair (B).
|
211 |
+
"""
|
212 |
+
|
213 |
+
sep = [self.sep_token_id]
|
214 |
+
cls = [self.cls_token_id]
|
215 |
+
|
216 |
+
return cls + token_ids_0 + sep + token_ids_1 + sep
|
217 |
+
|
218 |
+
def add_padding_tokens(
|
219 |
+
self, token_ids: List[int], length: int, right: bool = True
|
220 |
+
) -> List[int]:
|
221 |
+
"""
|
222 |
+
Adds padding tokens to return a sequence of length max_length.
|
223 |
+
By default padding tokens are added to the right of the sequence.
|
224 |
+
|
225 |
+
Parameters
|
226 |
+
----------
|
227 |
+
token_ids: list[int]
|
228 |
+
list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
|
229 |
+
|
230 |
+
length: int
|
231 |
+
|
232 |
+
right: bool (True by default)
|
233 |
+
|
234 |
+
Returns
|
235 |
+
----------
|
236 |
+
token_ids :
|
237 |
+
list of tokenized input ids. Can be obtained using the encode or encode_plus methods.
|
238 |
+
|
239 |
+
padding: int
|
240 |
+
Integer to be added as padding token
|
241 |
+
|
242 |
+
"""
|
243 |
+
padding = [self.pad_token_id] * (length - len(token_ids))
|
244 |
+
|
245 |
+
if right:
|
246 |
+
return token_ids + padding
|
247 |
+
else:
|
248 |
+
return padding + token_ids
|
249 |
+
|
250 |
+
def save_vocabulary(
|
251 |
+
self, vocab_path: str
|
252 |
+
): # -> tuple[str]: doctest issue raised with this return type annotation
|
253 |
+
"""
|
254 |
+
Save the tokenizer vocabulary to a file.
|
255 |
+
|
256 |
+
Parameters
|
257 |
+
----------
|
258 |
+
vocab_path: obj: str
|
259 |
+
The directory in which to save the SMILES character per line vocabulary file.
|
260 |
+
Default vocab file is found in deepchem/feat/tests/data/vocab.txt
|
261 |
+
|
262 |
+
Returns
|
263 |
+
----------
|
264 |
+
vocab_file: :obj:`Tuple(str)`:
|
265 |
+
Paths to the files saved.
|
266 |
+
typle with string to a SMILES character per line vocabulary file.
|
267 |
+
Default vocab file is found in deepchem/feat/tests/data/vocab.txt
|
268 |
+
|
269 |
+
"""
|
270 |
+
index = 0
|
271 |
+
if os.path.isdir(vocab_path):
|
272 |
+
vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
|
273 |
+
else:
|
274 |
+
vocab_file = vocab_path
|
275 |
+
with open(vocab_file, "w", encoding="utf-8") as writer:
|
276 |
+
for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
|
277 |
+
if index != token_index:
|
278 |
+
logger.warning(
|
279 |
+
"Saving vocabulary to {}: vocabulary indices are not consecutive."
|
280 |
+
" Please check that the vocabulary is not corrupted!".format(
|
281 |
+
vocab_file
|
282 |
+
)
|
283 |
+
)
|
284 |
+
index = token_index
|
285 |
+
writer.write(token + "\n")
|
286 |
+
index += 1
|
287 |
+
return (vocab_file,)
|
288 |
+
|
289 |
+
|
290 |
+
class BasicSmilesTokenizer(object):
|
291 |
+
"""
|
292 |
+
|
293 |
+
Run basic SMILES tokenization using a regex pattern developed by Schwaller et. al. This tokenizer is to be used
|
294 |
+
when a tokenizer that does not require the transformers library by HuggingFace is required.
|
295 |
+
|
296 |
+
Examples
|
297 |
+
--------
|
298 |
+
>>> from deepchem.feat.smiles_tokenizer import BasicSmilesTokenizer
|
299 |
+
>>> tokenizer = BasicSmilesTokenizer()
|
300 |
+
>>> print(tokenizer.tokenize("CC(=O)OC1=CC=CC=C1C(=O)O"))
|
301 |
+
['C', 'C', '(', '=', 'O', ')', 'O', 'C', '1', '=', 'C', 'C', '=', 'C', 'C', '=', 'C', '1', 'C', '(', '=', 'O', ')', 'O']
|
302 |
+
|
303 |
+
|
304 |
+
References
|
305 |
+
----------
|
306 |
+
.. [1] Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee
|
307 |
+
ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction
|
308 |
+
1572-1583 DOI: 10.1021/acscentsci.9b00576
|
309 |
+
|
310 |
+
"""
|
311 |
+
|
312 |
+
def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN):
|
313 |
+
"""Constructs a BasicSMILESTokenizer.
|
314 |
+
Parameters
|
315 |
+
----------
|
316 |
+
|
317 |
+
regex: string
|
318 |
+
SMILES token regex
|
319 |
+
|
320 |
+
"""
|
321 |
+
self.regex_pattern = regex_pattern
|
322 |
+
self.regex = re.compile(self.regex_pattern)
|
323 |
+
|
324 |
+
def tokenize(self, text):
|
325 |
+
"""Basic Tokenization of a SMILES."""
|
326 |
+
tokens = [token for token in self.regex.findall(text)]
|
327 |
+
return tokens
|
328 |
+
|
329 |
+
|
330 |
+
def load_vocab(vocab_file):
|
331 |
+
"""Loads a vocabulary file into a dictionary."""
|
332 |
+
vocab = collections.OrderedDict()
|
333 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
334 |
+
tokens = reader.readlines()
|
335 |
+
for index, token in enumerate(tokens):
|
336 |
+
token = token.rstrip("\n")
|
337 |
+
vocab[token] = index
|
338 |
+
return vocab
|
339 |
+
|
340 |
+
|
341 |
+
class BasicSmilesTokenizer(object):
|
342 |
+
"""
|
343 |
+
|
344 |
+
Run basic SMILES tokenization using a regex pattern developed by Schwaller et. al. This tokenizer is to be used
|
345 |
+
when a tokenizer that does not require the transformers library by HuggingFace is required.
|
346 |
+
|
347 |
+
Examples
|
348 |
+
--------
|
349 |
+
>>> from deepchem.feat.smiles_tokenizer import BasicSmilesTokenizer
|
350 |
+
>>> tokenizer = BasicSmilesTokenizer()
|
351 |
+
>>> print(tokenizer.tokenize("CC(=O)OC1=CC=CC=C1C(=O)O"))
|
352 |
+
['C', 'C', '(', '=', 'O', ')', 'O', 'C', '1', '=', 'C', 'C', '=', 'C', 'C', '=', 'C', '1', 'C', '(', '=', 'O', ')', 'O']
|
353 |
+
|
354 |
+
|
355 |
+
References
|
356 |
+
----------
|
357 |
+
.. [1] Philippe Schwaller, Teodoro Laino, Théophile Gaudin, Peter Bolgar, Christopher A. Hunter, Costas Bekas, and Alpha A. Lee
|
358 |
+
ACS Central Science 2019 5 (9): Molecular Transformer: A Model for Uncertainty-Calibrated Chemical Reaction Prediction
|
359 |
+
1572-1583 DOI: 10.1021/acscentsci.9b00576
|
360 |
+
|
361 |
+
"""
|
362 |
+
|
363 |
+
def __init__(self, regex_pattern: str = SMI_REGEX_PATTERN):
|
364 |
+
"""Constructs a BasicSMILESTokenizer.
|
365 |
+
Parameters
|
366 |
+
----------
|
367 |
+
|
368 |
+
regex: string
|
369 |
+
SMILES token regex
|
370 |
+
|
371 |
+
"""
|
372 |
+
self.regex_pattern = regex_pattern
|
373 |
+
self.regex = re.compile(self.regex_pattern)
|
374 |
+
|
375 |
+
def tokenize(self, text):
|
376 |
+
"""Basic Tokenization of a SMILES."""
|
377 |
+
tokens = [token for token in self.regex.findall(text)]
|
378 |
+
return tokens
|
379 |
+
|
380 |
+
|
381 |
+
def load_vocab(vocab_file):
|
382 |
+
"""Loads a vocabulary file into a dictionary."""
|
383 |
+
vocab = collections.OrderedDict()
|
384 |
+
with open(vocab_file, "r", encoding="utf-8") as reader:
|
385 |
+
tokens = reader.readlines()
|
386 |
+
for index, token in enumerate(tokens):
|
387 |
+
token = token.rstrip("\n")
|
388 |
+
vocab[token] = index
|
389 |
+
return vocab
|
390 |
+
|
391 |
+
|
392 |
+
if __name__ == "__main__":
|
393 |
+
current_dir = os.path.dirname(os.path.realpath(__file__))
|
394 |
+
vocab_path = os.path.join(current_dir, "tests/data", "vocab.txt")
|
395 |
+
tokenizer = SmilesTokenizer()
|
396 |
+
|
397 |
+
tokens = tokenizer.encode(
|
398 |
+
"CN1CC[C@]23[C@@H]4[C@H]1CC5=C2C(=C(C=C5)O)O[C@H]3[C@H](C=C4)O"
|
399 |
+
)
|
400 |
+
print([tokenizer._convert_id_to_token(t) for t in tokens])
|
401 |
+
|
402 |
+
enc = tokenizer.encode("CC=O")
|
403 |
+
print(enc)
|
404 |
+
print(tokenizer.decode(enc, skip_special_tokens=True).replace(" ", ""))
|
torch2-env.yaml
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
name: torch2-llamol
|
2 |
+
channels:
|
3 |
+
- pytorch
|
4 |
+
- nvidia
|
5 |
+
- conda-forge
|
6 |
+
- defaults
|
7 |
+
dependencies:
|
8 |
+
- python=3.8
|
9 |
+
- torchaudio
|
10 |
+
- pytorch
|
11 |
+
- torchvision
|
12 |
+
- pytorch-cuda
|
13 |
+
- rdkit
|
14 |
+
- ca-certificates
|
15 |
+
- certifi
|
16 |
+
- openssl
|
17 |
+
- openbabel
|
18 |
+
- ipykernel
|
19 |
+
pip:
|
20 |
+
- tqdm
|
21 |
+
- transformers
|
22 |
+
- pandas
|
23 |
+
- matplotlib
|
24 |
+
- seaborn
|
25 |
+
- hydra-core
|
26 |
+
- swifter
|
27 |
+
- pyarrow
|
28 |
+
- ipywidgets
|
29 |
+
- dask
|
train.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from trainer import (
|
2 |
+
IOConfig,
|
3 |
+
LoaderConfig,
|
4 |
+
Trainer,
|
5 |
+
TrainerArgs,
|
6 |
+
ModelArgs,
|
7 |
+
ContextArgs,
|
8 |
+
OptimizerConfig,
|
9 |
+
)
|
10 |
+
from torch.distributed.elastic.multiprocessing.errors import record
|
11 |
+
|
12 |
+
import hydra
|
13 |
+
from omegaconf import DictConfig, OmegaConf
|
14 |
+
import logging
|
15 |
+
import sys
|
16 |
+
import os
|
17 |
+
import torch
|
18 |
+
|
19 |
+
|
20 |
+
def setup_logger(run_name: str, log_path: str):
|
21 |
+
ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
22 |
+
if ddp:
|
23 |
+
ddp_rank = int(os.environ["RANK"])
|
24 |
+
ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
25 |
+
ddp_world_size = int(os.environ["WORLD_SIZE"])
|
26 |
+
|
27 |
+
formatter = logging.Formatter(
|
28 |
+
f"[%(levelname)s] DDP[{ddp_rank},{ddp_local_rank},{ddp_world_size}] %(asctime)s - [%(filename)s:%(lineno)d]: %(message)s",
|
29 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
30 |
+
)
|
31 |
+
else:
|
32 |
+
formatter = logging.Formatter(
|
33 |
+
r"[%(levelname)s] %(asctime)s - [%(filename)s:%(lineno)d]: %(message)s",
|
34 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
35 |
+
)
|
36 |
+
|
37 |
+
stream_handler = logging.StreamHandler(sys.stdout)
|
38 |
+
stream_handler.setFormatter(formatter)
|
39 |
+
|
40 |
+
os.makedirs(log_path, exist_ok=True)
|
41 |
+
file_handler = logging.FileHandler(os.path.join(log_path, f"train_{run_name}.log"))
|
42 |
+
file_handler.setFormatter(formatter)
|
43 |
+
|
44 |
+
logging.basicConfig(level=logging.INFO, handlers=[stream_handler, file_handler])
|
45 |
+
|
46 |
+
return logging.getLogger()
|
47 |
+
|
48 |
+
|
49 |
+
@record
|
50 |
+
@hydra.main(version_base=None, config_path="config", config_name="config")
|
51 |
+
def main(cfg: DictConfig) -> None:
|
52 |
+
logger = setup_logger(
|
53 |
+
cfg.get("run_name", "default"), cfg.get("io", {"out_dir": "out"})["out_dir"]
|
54 |
+
)
|
55 |
+
|
56 |
+
logger.info("Using config")
|
57 |
+
logger.info(cfg)
|
58 |
+
|
59 |
+
cfg = cfg["train"]
|
60 |
+
io_conf = IOConfig(**cfg.get("io", {}))
|
61 |
+
loader_conf = LoaderConfig(**cfg.get("loader", {}))
|
62 |
+
model_args = ModelArgs(**cfg.get("model", {}))
|
63 |
+
ctx_args = ContextArgs(**cfg.get("context", {}))
|
64 |
+
optmizer_conf = OptimizerConfig(**cfg.get("optimizer", {}))
|
65 |
+
train_args = TrainerArgs(
|
66 |
+
io_conf=io_conf,
|
67 |
+
loader_conf=loader_conf,
|
68 |
+
model_conf=model_args,
|
69 |
+
context_conf=ctx_args,
|
70 |
+
optimizer_conf=optmizer_conf,
|
71 |
+
run_name=cfg.get("label", "train_run"),
|
72 |
+
)
|
73 |
+
|
74 |
+
# When training on cpu / testing to not max out all cpu cores
|
75 |
+
torch.set_num_threads(8)
|
76 |
+
|
77 |
+
trainer = Trainer(
|
78 |
+
train_args=train_args,
|
79 |
+
dtype=cfg.get("dtype", "float16"),
|
80 |
+
compile=cfg.get("compile", False),
|
81 |
+
)
|
82 |
+
should_profile = cfg.get("profile", False)
|
83 |
+
|
84 |
+
if should_profile:
|
85 |
+
with torch.profiler.profile(
|
86 |
+
activities=[
|
87 |
+
torch.profiler.ProfilerActivity.CPU,
|
88 |
+
torch.profiler.ProfilerActivity.CUDA,
|
89 |
+
]
|
90 |
+
) as p:
|
91 |
+
trainer.train()
|
92 |
+
|
93 |
+
print(p.key_averages().table(sort_by="self_cuda_time_total", row_limit=-1))
|
94 |
+
|
95 |
+
else:
|
96 |
+
trainer.train()
|
97 |
+
|
98 |
+
|
99 |
+
if __name__ == "__main__":
|
100 |
+
# python train.py train=llama2-M-Full train.model.dim=1024
|
101 |
+
main()
|
trainLLamaMol.sh
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --mem=32gb # Total memory limit
|
3 |
+
#SBATCH --nodes=1
|
4 |
+
#SBATCH --ntasks-per-node=1
|
5 |
+
#SBATCH --cpus-per-task=2
|
6 |
+
#SBATCH --partition=<YOUR PARTITION>
|
7 |
+
#SBATCH --gres=gpu:a100:1
|
8 |
+
#SBATCH --time=2-00:00:00 # Time limit 2-hrs:min:sec days
|
9 |
+
|
10 |
+
export CUDA_VISIBLE_DEVICES=0
|
11 |
+
|
12 |
+
# TODO: Change FULL_PATH_TO_CONDA to the binary where the conda folder is: see https://github.com/conda/conda/issues/8536
|
13 |
+
conda activate FULL_PATH_TO_CONDA/torch2-llamol
|
14 |
+
module load CUDA/11.7.0
|
15 |
+
module load GCC/7.1.0-2.28
|
16 |
+
|
17 |
+
cd ~/llama2-mol
|
18 |
+
|
19 |
+
srun python train.py train=llama2-M-Full-RSS > "train_runs/run_$SLURM_JOB_ID.out"
|
trainLLamaMolDDPSingleNode.sh
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
#SBATCH --mem=32gb # Total memory limit
|
3 |
+
#SBATCH --nodes=1
|
4 |
+
#SBATCH --ntasks-per-node=<HOW MANY GPUS>
|
5 |
+
#SBATCH --cpus-per-task=2
|
6 |
+
#SBATCH --partition=<YOUR PARTITION>
|
7 |
+
#SBATCH --gres=gpu:a100:<HOW MANY GPUS>
|
8 |
+
#SBATCH --time=2-00:00:00 # Time limit 2-hrs:min:sec days
|
9 |
+
|
10 |
+
export WORLD_SIZE=2
|
11 |
+
export OMP_NUM_THREADS=8
|
12 |
+
### get the first node name as master address - customized for vgg slurm
|
13 |
+
### e.g. master(gnodee[2-5],gnoded1) == gnodee2
|
14 |
+
echo "NODELIST="${SLURM_NODELIST}
|
15 |
+
master_addr=$(scontrol show hostnames "$SLURM_JOB_NODELIST" | head -n 1)
|
16 |
+
PORT=54357
|
17 |
+
export MASTER_ADDR="$master_addr:$PORT"
|
18 |
+
|
19 |
+
|
20 |
+
# TODO: Change FULL_PATH_TO_CONDA to the binary where the conda folder is: see https://github.com/conda/conda/issues/8536
|
21 |
+
conda activate FULL_PATH_TO_CONDA/torch2-llamol
|
22 |
+
module load CUDA/11.7.0
|
23 |
+
module load GCC/8.3.0
|
24 |
+
|
25 |
+
# TODO: Change this to the folder you cloned the repo in
|
26 |
+
cd ~/llamol
|
27 |
+
|
28 |
+
srun torchrun --standalone --max_restarts=3 --nnodes=1 --nproc_per_node=2 --rdzv-id=$SLURM_JOB_ID --rdzv-backend=c10d --rdzv-endpoint="$master_addr:$PORT" train.py train=llama2-M-Full > "train_runs/run_$SLURM_JOB_ID.out"
|
trainer.py
ADDED
@@ -0,0 +1,513 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from dataclasses import dataclass
|
2 |
+
from typing import Any, Dict, Optional, Tuple, List, Union
|
3 |
+
from fragment_creator import fragment_creator_factory
|
4 |
+
|
5 |
+
from model import ContextArgs, ModelArgs
|
6 |
+
from tqdm import tqdm
|
7 |
+
import math
|
8 |
+
import os
|
9 |
+
import time
|
10 |
+
from contextlib import nullcontext
|
11 |
+
from datetime import datetime
|
12 |
+
from functools import partial
|
13 |
+
|
14 |
+
import torch
|
15 |
+
import numpy as np
|
16 |
+
from model import ContextArgs, Transformer, ModelArgs
|
17 |
+
from torch.distributed import destroy_process_group, init_process_group
|
18 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
19 |
+
|
20 |
+
from preprocess_dataset import SmilesTask
|
21 |
+
from tokenizer import SmilesTokenizer
|
22 |
+
|
23 |
+
import logging
|
24 |
+
|
25 |
+
logger = logging.getLogger(__name__)
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class IOConfig:
|
30 |
+
# I/O
|
31 |
+
out_dir: str = "out"
|
32 |
+
eval_interval: int = 500
|
33 |
+
log_interval: int = 10
|
34 |
+
eval_iters: int = 25
|
35 |
+
eval_only: bool = False # if True, script exits right after the first eval
|
36 |
+
always_save_checkpoint: bool = (
|
37 |
+
False # if True, always save a checkpoint after each eval
|
38 |
+
)
|
39 |
+
init_from: str = "scratch" # 'scratch' or 'resume'
|
40 |
+
resume_when_snapshot_available: bool = True
|
41 |
+
|
42 |
+
|
43 |
+
@dataclass
|
44 |
+
class LoaderConfig:
|
45 |
+
# data
|
46 |
+
batch_size: int = (
|
47 |
+
384 # if gradient_accumulation_steps > 1, this is the micro-batch size
|
48 |
+
)
|
49 |
+
max_seq_len: int = 768
|
50 |
+
dataset: str = "smiles"
|
51 |
+
processed_dataset_ckpt: str = "processed_dataset_None.pkl"
|
52 |
+
fragment_creator: Union[str, None] = None
|
53 |
+
|
54 |
+
|
55 |
+
# dim = 256
|
56 |
+
# n_layers = 8
|
57 |
+
# n_heads = 8
|
58 |
+
# multiple_of = 128
|
59 |
+
# dropout = 0.1
|
60 |
+
|
61 |
+
|
62 |
+
@dataclass
|
63 |
+
class OptimizerConfig:
|
64 |
+
# adamw optimizer
|
65 |
+
gradient_accumulation_steps: int = 4 # used to simulate larger batch sizes
|
66 |
+
learning_rate: float = 1e-4 # max learning rate
|
67 |
+
max_iters: int = 100000 # total number of training iterations
|
68 |
+
weight_decay: float = 1e-1
|
69 |
+
beta1: float = 0.9
|
70 |
+
beta2: float = 0.95
|
71 |
+
grad_clip: float = 1.0 # clip gradients at this value, or disable if == 0.0
|
72 |
+
# learning rate decay settings
|
73 |
+
decay_lr: bool = True # whether to decay the learning rate
|
74 |
+
warmup_iters: int = 1000 # how many steps to warm up for
|
75 |
+
|
76 |
+
lr_decay_iters: int = 100000 # should be ~= max_iters per Chinchilla
|
77 |
+
min_lr: float = (
|
78 |
+
0.0 # minimum learning rate, should be ~= learning_rate/10 per Chinchilla
|
79 |
+
)
|
80 |
+
|
81 |
+
|
82 |
+
@dataclass
|
83 |
+
class TrainerArgs:
|
84 |
+
# Input / Output
|
85 |
+
io_conf: IOConfig
|
86 |
+
|
87 |
+
# Loader Configs
|
88 |
+
loader_conf: LoaderConfig
|
89 |
+
|
90 |
+
# Transformer Args
|
91 |
+
model_conf: ModelArgs
|
92 |
+
context_conf: ContextArgs
|
93 |
+
|
94 |
+
# Optimizer
|
95 |
+
optimizer_conf: OptimizerConfig
|
96 |
+
|
97 |
+
run_name: str
|
98 |
+
|
99 |
+
|
100 |
+
class Trainer:
|
101 |
+
def __init__(
|
102 |
+
self, train_args: TrainerArgs, dtype: str = "float16", compile: bool = False
|
103 |
+
) -> None:
|
104 |
+
self.train_conf = train_args
|
105 |
+
self.dtype = dtype
|
106 |
+
self.compile = compile
|
107 |
+
# system
|
108 |
+
self.run_name = train_args.run_name
|
109 |
+
self.device = (
|
110 |
+
"cuda:0" if torch.cuda.is_available() else "cpu"
|
111 |
+
) # "cuda" # examples: 'cpu', 'cuda', 'cuda:0', 'cuda:1' etc., or try 'mps' on macbooks
|
112 |
+
|
113 |
+
self.CKPT_PT = f"{self.run_name}.pt"
|
114 |
+
self.SNAPSHOT_PT = f"snapshot_{self.run_name}.pt"
|
115 |
+
|
116 |
+
def _init_ddp_if_possible(self):
|
117 |
+
# various inits, derived attributes, I/O setup
|
118 |
+
self.ddp = int(os.environ.get("RANK", -1)) != -1 # is this a ddp run?
|
119 |
+
if self.ddp:
|
120 |
+
logger.info(f"Using ddp!")
|
121 |
+
init_process_group(backend="nccl")
|
122 |
+
self.ddp_rank = int(os.environ["RANK"])
|
123 |
+
self.ddp_local_rank = int(os.environ["LOCAL_RANK"])
|
124 |
+
self.ddp_world_size = int(os.environ["WORLD_SIZE"])
|
125 |
+
logger.info(f"{self.ddp_rank}, {self.ddp_local_rank},{self.ddp_world_size}")
|
126 |
+
|
127 |
+
self.device = f"cuda:{self.ddp_local_rank}"
|
128 |
+
torch.cuda.set_device(self.device)
|
129 |
+
self.master_process = (
|
130 |
+
self.ddp_rank == 0
|
131 |
+
) # this process will do logging, checkpointing etc.
|
132 |
+
|
133 |
+
logger.info(f"Is master process {self.device}? {self.master_process}")
|
134 |
+
self.seed_offset = self.ddp_rank # each process gets a different seed
|
135 |
+
# world_size number of processes will be training simultaneously, so we can scale
|
136 |
+
# down the desired gradient accumulation iterations per process proportionally
|
137 |
+
assert (
|
138 |
+
self.train_conf.optimizer_conf.gradient_accumulation_steps
|
139 |
+
% self.ddp_world_size
|
140 |
+
== 0
|
141 |
+
)
|
142 |
+
self.train_conf.optimizer_conf.gradient_accumulation_steps //= (
|
143 |
+
self.ddp_world_size
|
144 |
+
)
|
145 |
+
else:
|
146 |
+
# if not ddp, we are running on a single gpu, and one process
|
147 |
+
self.master_process = True
|
148 |
+
self.seed_offset = 0
|
149 |
+
self.ddp_world_size = 1
|
150 |
+
|
151 |
+
def _init_train(self):
|
152 |
+
self.tokens_per_iter = (
|
153 |
+
self.train_conf.optimizer_conf.gradient_accumulation_steps
|
154 |
+
* self.ddp_world_size
|
155 |
+
* self.train_conf.loader_conf.batch_size
|
156 |
+
* self.train_conf.loader_conf.max_seq_len
|
157 |
+
)
|
158 |
+
if self.master_process:
|
159 |
+
logger.info(f"tokens per iteration will be: {self.tokens_per_iter:,}")
|
160 |
+
logger.info(
|
161 |
+
f"breaks down as: {self.train_conf.optimizer_conf.gradient_accumulation_steps} grad accum steps * {self.ddp_world_size} processes * {self.train_conf.loader_conf.batch_size} batch size * {self.train_conf.loader_conf.max_seq_len } max seq len"
|
162 |
+
)
|
163 |
+
|
164 |
+
if self.master_process:
|
165 |
+
os.makedirs(self.train_conf.io_conf.out_dir, exist_ok=True)
|
166 |
+
|
167 |
+
torch.manual_seed(1337 + self.seed_offset)
|
168 |
+
np.random.seed(1337 + self.seed_offset)
|
169 |
+
torch.backends.cuda.matmul.allow_tf32 = True # allow tf32 on matmul
|
170 |
+
torch.backends.cudnn.allow_tf32 = True # allow tf32 on cudnn
|
171 |
+
self.device_type = (
|
172 |
+
"cuda" if "cuda" in self.device else "cpu"
|
173 |
+
) # for later use in torch.autocast
|
174 |
+
# note: float16 data type will automatically use a GradScaler
|
175 |
+
ptdtype = {
|
176 |
+
"float32": torch.float32,
|
177 |
+
"bfloat16": torch.bfloat16,
|
178 |
+
"float16": torch.float16,
|
179 |
+
}[self.dtype]
|
180 |
+
self.ctx = (
|
181 |
+
nullcontext()
|
182 |
+
if self.device_type == "cpu"
|
183 |
+
else torch.amp.autocast(device_type=self.device_type, dtype=ptdtype)
|
184 |
+
)
|
185 |
+
# task-specific setup
|
186 |
+
task = {"smiles": SmilesTask}[self.train_conf.loader_conf.dataset]
|
187 |
+
self.iter_batches = partial(
|
188 |
+
task.iter_batches,
|
189 |
+
batch_size=self.train_conf.loader_conf.batch_size,
|
190 |
+
device=self.device,
|
191 |
+
context_keys=self.train_conf.context_conf.context_keys,
|
192 |
+
num_workers=0,
|
193 |
+
dataset=self.train_conf.loader_conf.processed_dataset_ckpt,
|
194 |
+
fragment_creator=fragment_creator_factory(
|
195 |
+
self.train_conf.loader_conf.fragment_creator
|
196 |
+
),
|
197 |
+
)
|
198 |
+
# init these up here, can override if init_from='resume' (i.e. from a checkpoint)
|
199 |
+
self.iter_num = 0
|
200 |
+
self.best_val_loss = 1e9
|
201 |
+
self.epoch = 1
|
202 |
+
|
203 |
+
self.tokenizer = SmilesTokenizer()
|
204 |
+
|
205 |
+
has_resumed = False
|
206 |
+
if (
|
207 |
+
self.train_conf.io_conf.init_from == "resume"
|
208 |
+
or self.train_conf.io_conf.resume_when_snapshot_available
|
209 |
+
):
|
210 |
+
snapshot_path = os.path.join(
|
211 |
+
self.train_conf.io_conf.out_dir, self.SNAPSHOT_PT
|
212 |
+
)
|
213 |
+
if os.path.exists(snapshot_path):
|
214 |
+
has_resumed = True
|
215 |
+
logger.info(f"Resuming training from {self.train_conf.io_conf.out_dir}")
|
216 |
+
# resume training from a checkpoint.
|
217 |
+
ckpt_path = os.path.join(self.train_conf.io_conf.out_dir, self.CKPT_PT)
|
218 |
+
self.model = Transformer.load(ckpt_path, device=self.device)
|
219 |
+
snapshot = torch.load(snapshot_path, map_location=self.device)
|
220 |
+
self.iter_num = snapshot["iter_num"]
|
221 |
+
self.best_val_loss = snapshot["best_val_loss"]
|
222 |
+
self.epoch = snapshot["epoch"]
|
223 |
+
|
224 |
+
if self.train_conf.io_conf.init_from == "scratch" and not has_resumed:
|
225 |
+
# init a new model from scratch
|
226 |
+
logger.info("Initializing a new model from scratch")
|
227 |
+
logger.info(self.device)
|
228 |
+
|
229 |
+
model_conf = self.train_conf.model_conf
|
230 |
+
model_conf.vocab_size = self.tokenizer.vocab_size
|
231 |
+
|
232 |
+
self.model = Transformer(model_conf, self.train_conf.context_conf).to(
|
233 |
+
self.device
|
234 |
+
)
|
235 |
+
logger.info(
|
236 |
+
f"Number of params: {self.model.getNumberParams()} Number Trainable Params: {self.model.getNumberTrainableParams()}"
|
237 |
+
)
|
238 |
+
|
239 |
+
# else:
|
240 |
+
# raise ValueError(
|
241 |
+
# f"Could not find option: {self.train_conf.io_conf.init_from}. Use either 'scratch' or 'resume'"
|
242 |
+
# )
|
243 |
+
|
244 |
+
self.model = self.model.to(self.device)
|
245 |
+
|
246 |
+
# initialize a GradScaler. If enabled=False scaler is a no-op
|
247 |
+
self.scaler = torch.cuda.amp.GradScaler(enabled=(self.dtype == "float16"))
|
248 |
+
|
249 |
+
# optimizer
|
250 |
+
self.optimizer = self.model.configure_optimizers(
|
251 |
+
self.train_conf.optimizer_conf.weight_decay,
|
252 |
+
self.train_conf.optimizer_conf.learning_rate,
|
253 |
+
(
|
254 |
+
self.train_conf.optimizer_conf.beta1,
|
255 |
+
self.train_conf.optimizer_conf.beta2,
|
256 |
+
),
|
257 |
+
self.device_type,
|
258 |
+
)
|
259 |
+
|
260 |
+
if (
|
261 |
+
self.train_conf.io_conf.init_from == "resume"
|
262 |
+
and "optimizer_state" in snapshot
|
263 |
+
):
|
264 |
+
logger.info("Loading optimizer state from snapshot")
|
265 |
+
self.optimizer.load_state_dict(snapshot["optimizer_state"])
|
266 |
+
snapshot = None # free up memory
|
267 |
+
|
268 |
+
# compile the model
|
269 |
+
if self.compile:
|
270 |
+
logger.info("compiling the model... (takes a ~minute)")
|
271 |
+
self.unoptimized_model = self.model
|
272 |
+
# NOTE: This is REALLY REALLY slow in our case, as the shapes are different in each epoch.
|
273 |
+
# So it recompiles every batch ._.
|
274 |
+
self.model = torch.compile(
|
275 |
+
self.model, dynamic=False
|
276 |
+
) # requires PyTorch 2.0
|
277 |
+
|
278 |
+
# wrap model into DDP container
|
279 |
+
if self.ddp:
|
280 |
+
# Ignore the `freqs_cis` buffer so that DDP does not broadcast it at
|
281 |
+
# construction time since NCCL does not support `ComplexFloat`
|
282 |
+
prefix = "_orig_mod." if compile else ""
|
283 |
+
self.model._ddp_params_and_buffers_to_ignore = {prefix + "freqs_cis"}
|
284 |
+
self.model = DDP(self.model, device_ids=[self.ddp_local_rank])
|
285 |
+
|
286 |
+
# helps estimate an arbitrarily accurate loss over either split using many batches
|
287 |
+
@torch.no_grad()
|
288 |
+
def estimate_loss(self):
|
289 |
+
out = {}
|
290 |
+
self.model.eval()
|
291 |
+
for split in ["train", "val"]:
|
292 |
+
batch_iter = self.iter_batches(split)
|
293 |
+
losses = torch.zeros(self.train_conf.io_conf.eval_iters) # keep on CPU
|
294 |
+
for k in tqdm(
|
295 |
+
range(self.train_conf.io_conf.eval_iters),
|
296 |
+
total=self.train_conf.io_conf.eval_iters,
|
297 |
+
desc="Eval",
|
298 |
+
):
|
299 |
+
try:
|
300 |
+
X = next(batch_iter)
|
301 |
+
with self.ctx:
|
302 |
+
# logger.info(model)
|
303 |
+
# logger.info(X["src"].device)
|
304 |
+
|
305 |
+
logits = self.model(
|
306 |
+
X["src"],
|
307 |
+
targets=X["tgt"],
|
308 |
+
context=X["context"],
|
309 |
+
fragment=X["fragment"],
|
310 |
+
)
|
311 |
+
|
312 |
+
loss = self.raw_model.last_loss
|
313 |
+
losses[k] = loss.item()
|
314 |
+
except StopIteration:
|
315 |
+
logger.info("Early Eval Stop")
|
316 |
+
|
317 |
+
out[split] = losses.mean()
|
318 |
+
self.model.train()
|
319 |
+
return out
|
320 |
+
|
321 |
+
# learning rate decay scheduler (cosine with warmup)
|
322 |
+
def get_lr(self, it: int):
|
323 |
+
warmup_iters = self.train_conf.optimizer_conf.warmup_iters
|
324 |
+
learning_rate = self.train_conf.optimizer_conf.learning_rate
|
325 |
+
lr_decay_iters = self.train_conf.optimizer_conf.lr_decay_iters
|
326 |
+
min_lr = self.train_conf.optimizer_conf.min_lr
|
327 |
+
|
328 |
+
# 1) linear warmup for warmup_iters steps
|
329 |
+
if it < warmup_iters:
|
330 |
+
return learning_rate * it / warmup_iters
|
331 |
+
# 2) if it > lr_decay_iters, return min learning rate
|
332 |
+
if it > lr_decay_iters:
|
333 |
+
return min_lr
|
334 |
+
# 3) in between, use cosine decay down to min learning rate
|
335 |
+
decay_ratio = (it - warmup_iters) / (lr_decay_iters - warmup_iters)
|
336 |
+
assert 0 <= decay_ratio <= 1
|
337 |
+
coeff = 0.5 * (1.0 + math.cos(math.pi * decay_ratio)) # coeff ranges 0..1
|
338 |
+
return min_lr + coeff * (learning_rate - min_lr)
|
339 |
+
|
340 |
+
def train(self):
|
341 |
+
self._init_ddp_if_possible()
|
342 |
+
self._init_train()
|
343 |
+
|
344 |
+
# training loop
|
345 |
+
train_batch_iter = self.iter_batches("train")
|
346 |
+
X = next(train_batch_iter) # fetch the very first batch
|
347 |
+
t0 = time.time()
|
348 |
+
local_iter_num = 0 # number of iterations in the lifetime of this process
|
349 |
+
self.raw_model = (
|
350 |
+
self.model.module if self.ddp else self.model
|
351 |
+
) # unwrap DDP container if needed
|
352 |
+
running_mfu = -1.0
|
353 |
+
|
354 |
+
gradient_accumulation_steps = (
|
355 |
+
self.train_conf.optimizer_conf.gradient_accumulation_steps
|
356 |
+
)
|
357 |
+
while True:
|
358 |
+
# determine and set the learning rate for this iteration
|
359 |
+
lr = (
|
360 |
+
self.get_lr(self.iter_num)
|
361 |
+
if self.train_conf.optimizer_conf.decay_lr
|
362 |
+
else self.train_conf.optimizer_conf.learning_rate
|
363 |
+
)
|
364 |
+
for param_group in self.optimizer.param_groups:
|
365 |
+
param_group["lr"] = lr
|
366 |
+
|
367 |
+
# evaluate the loss on train/val sets and write checkpoints
|
368 |
+
if (
|
369 |
+
self.iter_num % self.train_conf.io_conf.eval_interval == 0
|
370 |
+
and self.master_process
|
371 |
+
and self.iter_num != 0
|
372 |
+
):
|
373 |
+
logger.info(
|
374 |
+
f"Estimating loss for master_process({self.master_process}) on iter {self.iter_num}"
|
375 |
+
)
|
376 |
+
losses = self.estimate_loss()
|
377 |
+
logger.info(
|
378 |
+
f"step {self.iter_num}: train loss {losses['train']:.4f}, val loss {losses['val']:.4f}"
|
379 |
+
)
|
380 |
+
log_dict = {
|
381 |
+
"iter": self.iter_num,
|
382 |
+
"tokens": self.iter_num * self.tokens_per_iter,
|
383 |
+
"loss/train": losses["train"],
|
384 |
+
"loss/val": losses["val"],
|
385 |
+
"lr": lr,
|
386 |
+
"mfu": running_mfu * 100, # convert to percentage
|
387 |
+
}
|
388 |
+
logger.info(f"{log_dict}")
|
389 |
+
|
390 |
+
if (
|
391 |
+
losses["val"] < self.best_val_loss
|
392 |
+
or self.train_conf.io_conf.always_save_checkpoint
|
393 |
+
):
|
394 |
+
self.best_val_loss = losses["val"]
|
395 |
+
if self.iter_num > 0:
|
396 |
+
logger.info(
|
397 |
+
f"saving checkpoint to {self.train_conf.io_conf.out_dir}"
|
398 |
+
)
|
399 |
+
self.raw_model.save(
|
400 |
+
os.path.join(self.train_conf.io_conf.out_dir, self.CKPT_PT)
|
401 |
+
)
|
402 |
+
|
403 |
+
torch.save(
|
404 |
+
{
|
405 |
+
"iter_num": self.iter_num,
|
406 |
+
"epoch": self.epoch,
|
407 |
+
"best_val_loss": self.best_val_loss,
|
408 |
+
"optimizer_state": self.optimizer.state_dict(),
|
409 |
+
},
|
410 |
+
os.path.join(
|
411 |
+
self.train_conf.io_conf.out_dir, self.SNAPSHOT_PT
|
412 |
+
),
|
413 |
+
)
|
414 |
+
|
415 |
+
if self.iter_num == 0 and self.train_conf.io_conf.eval_only:
|
416 |
+
break
|
417 |
+
|
418 |
+
# forward backward update, with optional gradient accumulation to simulate larger batch size
|
419 |
+
# and using the GradScaler if data type is float16
|
420 |
+
for micro_step in range(gradient_accumulation_steps):
|
421 |
+
if self.ddp:
|
422 |
+
# in DDP training we only need to sync gradients at the last micro step.
|
423 |
+
# the official way to do this is with model.no_sync() context manager, but
|
424 |
+
# I really dislike that this bloats the code and forces us to repeat code
|
425 |
+
# looking at the source of that context manager, it just toggles this variable
|
426 |
+
self.model.require_backward_grad_sync = (
|
427 |
+
micro_step == gradient_accumulation_steps - 1
|
428 |
+
)
|
429 |
+
with self.ctx:
|
430 |
+
context = X["context"]
|
431 |
+
|
432 |
+
fragment = X["fragment"]
|
433 |
+
|
434 |
+
# SCL (Stochastic context learning) algorithm
|
435 |
+
if np.random.random() < 0.15 or fragment is None:
|
436 |
+
fragment = None
|
437 |
+
|
438 |
+
# NOTE: random delete one context or more context columns
|
439 |
+
current_context_keys = list(context.keys())
|
440 |
+
for k in current_context_keys:
|
441 |
+
if np.random.random() < 0.15:
|
442 |
+
del context[k]
|
443 |
+
|
444 |
+
logits = self.model(
|
445 |
+
X["src"], targets=X["tgt"], context=context, fragment=fragment
|
446 |
+
)
|
447 |
+
loss = self.raw_model.last_loss
|
448 |
+
loss = loss / gradient_accumulation_steps
|
449 |
+
# immediately async prefetch next batch while model is doing the forward pass on the GPU
|
450 |
+
try:
|
451 |
+
X = next(train_batch_iter)
|
452 |
+
|
453 |
+
except StopIteration:
|
454 |
+
# StopIteration is thrown if dataset ends
|
455 |
+
# reinitialize data loader
|
456 |
+
logger.info(f"Done Epoch {self.epoch}")
|
457 |
+
train_batch_iter = self.iter_batches("train")
|
458 |
+
X = next(train_batch_iter)
|
459 |
+
self.epoch += 1
|
460 |
+
|
461 |
+
# backward pass, with gradient scaling if training in fp16
|
462 |
+
self.scaler.scale(loss).backward()
|
463 |
+
# logger.info(loss)
|
464 |
+
# clip the gradient
|
465 |
+
if self.train_conf.optimizer_conf.grad_clip != 0.0:
|
466 |
+
self.scaler.unscale_(self.optimizer)
|
467 |
+
torch.nn.utils.clip_grad_norm_(
|
468 |
+
self.model.parameters(), self.train_conf.optimizer_conf.grad_clip
|
469 |
+
)
|
470 |
+
# step the optimizer and scaler if training in fp16
|
471 |
+
self.scaler.step(self.optimizer)
|
472 |
+
self.scaler.update()
|
473 |
+
# flush the gradients as soon as we can, no need for this memory anymore
|
474 |
+
self.optimizer.zero_grad(set_to_none=True)
|
475 |
+
|
476 |
+
# timing and logging
|
477 |
+
t1 = time.time()
|
478 |
+
dt = t1 - t0
|
479 |
+
t0 = t1
|
480 |
+
|
481 |
+
if (
|
482 |
+
self.iter_num % self.train_conf.io_conf.log_interval == 0
|
483 |
+
and self.master_process
|
484 |
+
):
|
485 |
+
# get loss as float, scale up due to the divide above. note: this is a CPU-GPU sync point
|
486 |
+
lossf = loss.item() * gradient_accumulation_steps
|
487 |
+
if local_iter_num >= 5: # let the training loop settle a bit
|
488 |
+
mfu = self.raw_model.estimate_mfu(
|
489 |
+
self.train_conf.loader_conf.batch_size
|
490 |
+
* gradient_accumulation_steps,
|
491 |
+
dt,
|
492 |
+
)
|
493 |
+
running_mfu = (
|
494 |
+
mfu if running_mfu == -1.0 else 0.9 * running_mfu + 0.1 * mfu
|
495 |
+
)
|
496 |
+
logger.info(
|
497 |
+
f"{self.iter_num} | loss {lossf:.4f} | lr {lr:e} | {dt*1000:.2f}ms | mfu {running_mfu*100:.2f}%"
|
498 |
+
)
|
499 |
+
self.iter_num += 1
|
500 |
+
local_iter_num += 1
|
501 |
+
|
502 |
+
# termination conditions
|
503 |
+
|
504 |
+
if self.iter_num > self.train_conf.optimizer_conf.max_iters:
|
505 |
+
logger.info("Done with training iters!")
|
506 |
+
break
|
507 |
+
|
508 |
+
if self.ddp:
|
509 |
+
destroy_process_group()
|
510 |
+
|
511 |
+
|
512 |
+
if __name__ == "__main__":
|
513 |
+
pass
|