msun415 commited on
Commit
13362e2
·
verified ·
1 Parent(s): c7c8ad4

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +4 -35
  2. .gitignore +172 -0
  3. LICENSE +201 -0
  4. README.md +184 -6
  5. assets/llamole.png +3 -0
  6. assets/ui_example.png +3 -0
  7. config/generate/llama_drug.yaml +30 -0
  8. config/generate/llama_material.yaml +30 -0
  9. config/generate/mistral_drug.yaml +30 -0
  10. config/generate/mistral_material.yaml +30 -0
  11. config/generate/qwen_drug.yaml +30 -0
  12. config/generate/qwen_material.yaml +30 -0
  13. config/train/llama_lora.yaml +49 -0
  14. config/train/mistral_lora.yaml +48 -0
  15. config/train/qwen_lora.yaml +48 -0
  16. data/dataset_info.json +20 -0
  17. data/molqa_drug_examples.json +75 -0
  18. data/molqa_material_examples.json +85 -0
  19. data/molqa_train_examples.json +62 -0
  20. data/property_ranges.json +42 -0
  21. install_environment.sh +85 -0
  22. launch.py +238 -0
  23. main.py +66 -0
  24. requirements.txt +30 -0
  25. src/__init__.py +0 -0
  26. src/data/__init__.py +29 -0
  27. src/data/aligner.py +233 -0
  28. src/data/collator.py +165 -0
  29. src/data/data_utils.py +82 -0
  30. src/data/formatter.py +139 -0
  31. src/data/loader.py +149 -0
  32. src/data/parser.py +113 -0
  33. src/data/processors/__init__.py +0 -0
  34. src/data/processors/mmsupervised.py +335 -0
  35. src/data/template.py +889 -0
  36. src/data/tool_utils.py +140 -0
  37. src/eval/__init__.py +0 -0
  38. src/eval/dataset.py +78 -0
  39. src/eval/workflow.py +219 -0
  40. src/extras/__init__.py +0 -0
  41. src/extras/constants.py +1550 -0
  42. src/extras/env.py +75 -0
  43. src/extras/logging.py +82 -0
  44. src/extras/misc.py +215 -0
  45. src/extras/packages.py +83 -0
  46. src/extras/ploting.py +101 -0
  47. src/hparams/__init__.py +31 -0
  48. src/hparams/data_args.py +110 -0
  49. src/hparams/evaluation_args.py +62 -0
  50. src/hparams/finetuning_args.py +402 -0
.gitattributes CHANGED
@@ -1,35 +1,4 @@
1
- *.7z filter=lfs diff=lfs merge=lfs -text
2
- *.arrow filter=lfs diff=lfs merge=lfs -text
3
- *.bin filter=lfs diff=lfs merge=lfs -text
4
- *.bz2 filter=lfs diff=lfs merge=lfs -text
5
- *.ckpt filter=lfs diff=lfs merge=lfs -text
6
- *.ftz filter=lfs diff=lfs merge=lfs -text
7
- *.gz filter=lfs diff=lfs merge=lfs -text
8
- *.h5 filter=lfs diff=lfs merge=lfs -text
9
- *.joblib filter=lfs diff=lfs merge=lfs -text
10
- *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
- *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
- *.model filter=lfs diff=lfs merge=lfs -text
13
- *.msgpack filter=lfs diff=lfs merge=lfs -text
14
- *.npy filter=lfs diff=lfs merge=lfs -text
15
- *.npz filter=lfs diff=lfs merge=lfs -text
16
- *.onnx filter=lfs diff=lfs merge=lfs -text
17
- *.ot filter=lfs diff=lfs merge=lfs -text
18
- *.parquet filter=lfs diff=lfs merge=lfs -text
19
- *.pb filter=lfs diff=lfs merge=lfs -text
20
- *.pickle filter=lfs diff=lfs merge=lfs -text
21
- *.pkl filter=lfs diff=lfs merge=lfs -text
22
- *.pt filter=lfs diff=lfs merge=lfs -text
23
- *.pth filter=lfs diff=lfs merge=lfs -text
24
- *.rar filter=lfs diff=lfs merge=lfs -text
25
- *.safetensors filter=lfs diff=lfs merge=lfs -text
26
- saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
- *.tar.* filter=lfs diff=lfs merge=lfs -text
28
- *.tar filter=lfs diff=lfs merge=lfs -text
29
- *.tflite filter=lfs diff=lfs merge=lfs -text
30
- *.tgz filter=lfs diff=lfs merge=lfs -text
31
- *.wasm filter=lfs diff=lfs merge=lfs -text
32
- *.xz filter=lfs diff=lfs merge=lfs -text
33
- *.zip filter=lfs diff=lfs merge=lfs -text
34
- *.zst filter=lfs diff=lfs merge=lfs -text
35
- *tfevents* filter=lfs diff=lfs merge=lfs -text
 
1
+ # Auto detect text files and perform LF normalization
2
+ * text=auto
3
+ assets/llamole.png filter=lfs diff=lfs merge=lfs -text
4
+ assets/ui_example.png filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
.gitignore ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+
74
+ # PyBuilder
75
+ .pybuilder/
76
+ target/
77
+
78
+ # Jupyter Notebook
79
+ .ipynb_checkpoints
80
+
81
+ # IPython
82
+ profile_default/
83
+ ipython_config.py
84
+
85
+ # pyenv
86
+ # For a library or package, you might want to ignore these files since the code is
87
+ # intended to run in multiple environments; otherwise, check them in:
88
+ # .python-version
89
+
90
+ # pipenv
91
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
92
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
93
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
94
+ # install all needed dependencies.
95
+ #Pipfile.lock
96
+
97
+ # poetry
98
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
99
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
100
+ # commonly ignored for libraries.
101
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
102
+ #poetry.lock
103
+
104
+ # pdm
105
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
106
+ #pdm.lock
107
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
108
+ # in version control.
109
+ # https://pdm.fming.dev/#use-with-ide
110
+ .pdm.toml
111
+
112
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
113
+ __pypackages__/
114
+
115
+ # Celery stuff
116
+ celerybeat-schedule
117
+ celerybeat.pid
118
+
119
+ # SageMath parsed files
120
+ *.sage.py
121
+
122
+ # Environments
123
+ .env
124
+ .venv
125
+ env/
126
+ venv/
127
+ ENV/
128
+ env.bak/
129
+ venv.bak/
130
+
131
+ # Spyder project settings
132
+ .spyderproject
133
+ .spyproject
134
+
135
+ # Rope project settings
136
+ .ropeproject
137
+
138
+ # mkdocs documentation
139
+ /site
140
+
141
+ # mypy
142
+ .mypy_cache/
143
+ .dmypy.json
144
+ dmypy.json
145
+
146
+ # Pyre type checker
147
+ .pyre/
148
+
149
+ # pytype static type analyzer
150
+ .pytype/
151
+
152
+ # Cython debug symbols
153
+ cython_debug/
154
+
155
+ # PyCharm
156
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
157
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
158
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
159
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
160
+ .idea/
161
+
162
+ # custom .gitignore
163
+ cache/
164
+ output/
165
+ wandb/
166
+ # Ignore everything in data directory
167
+ run_script*
168
+
169
+ saves/*
170
+ !saves/README.md
171
+
172
+ .gradio
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README.md CHANGED
@@ -1,12 +1,190 @@
1
  ---
2
  title: Llamole
3
- emoji: 🐢
4
- colorFrom: gray
5
- colorTo: purple
6
  sdk: gradio
7
  sdk_version: 5.15.0
8
- app_file: app.py
9
- pinned: false
10
  ---
 
11
 
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  ---
2
  title: Llamole
3
+ app_file: launch.py
 
 
4
  sdk: gradio
5
  sdk_version: 5.15.0
 
 
6
  ---
7
+ <!-- # 🧬 Llamole: Multimodal Large Language Models for Inverse Molecular Design with Retrosynthetic Planning -->
8
 
9
+ <div align="center">
10
+
11
+ # 🔬 🧪 Llamole: Multimodal Large Language Models for Inverse Molecular Design with Retrosynthetic Planning
12
+
13
+ [![](https://img.shields.io/badge/Paper-pink?style=plastic&logo=arXiv)](https://arxiv.org/abs/2410.04223)
14
+ [![](https://img.shields.io/badge/GitHub-blue?style=plastic&logo=github)](https://github.com/liugangcode/Llamole)
15
+ [![](https://img.shields.io/badge/Huggingface-orange?style=plastic&logo=huggingface)](https://huggingface.co/collections/liuganghuggingface/llamole-collection-67073a2e76b47d5fea909434)
16
+
17
+ </div>
18
+
19
+ ## 🌟 Introduction
20
+
21
+ **Llamole** is a multimodal Large Language Model (LLM) that integrates a base LLM with the Graph Diffusion Transformer and Graph Neural Networks for multi-conditional molecular generation and multi-step reaction inference within texts.
22
+
23
+ 📄 **[Paper](https://arxiv.org/abs/2410.04223)**: *Multimodal Large Language Models for Inverse Molecular Design with Retrosynthetic Planning*
24
+
25
+ <div align="center">
26
+
27
+ <details>
28
+ <summary>🔍 Abstract</summary>
29
+ <p align="left">
30
+ While large language models (LLMs) have integrated images, adapting them to graphs remains challenging, limiting their applications in materials and drug design. This difficulty stems from the need for coherent autoregressive generation across texts and graphs. To address this, we introduce Llamole, the first multimodal LLM capable of interleaved text and graph generation, enabling molecular inverse design with retrosynthetic planning. Llamole integrates a base LLM with the Graph Diffusion Transformer and Graph Neural Networks for multi-conditional molecular generation and reaction inference within texts, while the LLM, with enhanced molecular understanding, flexibly controls activation among the different graph modules. Additionally, Llamole integrates A* search with LLM-based cost functions for efficient retrosynthetic planning. We create benchmarking datasets and conduct extensive experiments to evaluate Llamole against in-context learning and supervised fine-tuning. Llamole significantly outperforms 14 adapted LLMs across 12 metrics for controllable molecular design and retrosynthetic planning.
31
+ </p>
32
+ </details>
33
+
34
+ <img src="assets/llamole.png" alt="Llamole Design" width="80%"/>
35
+
36
+ </div>
37
+
38
+ ---
39
+
40
+ ## 🛠️ Environment Setup
41
+
42
+ Initialize the environment by following these steps:
43
+
44
+ ```bash
45
+ conda create --name llamole python=3.11 -y
46
+ conda activate llamole
47
+ ./install_environment.sh
48
+ ```
49
+
50
+ Alternatively, you can install all required dependencies using the `requirements.sh` script.
51
+
52
+ ---
53
+
54
+ ## 🚀 Model Usage
55
+
56
+ ### 🔧 Requirements
57
+
58
+ - **Hardware**: A single V100 or A6000 GPU for inference.
59
+ - **Configuration Files**:
60
+ - `config/train/{model}_lora.yaml`
61
+ - `config/generate/{model}_{task}.yaml`
62
+
63
+ ### 📥 Automatic Model Download
64
+
65
+ On the first run, the necessary models will be automatically downloaded, including:
66
+
67
+ 1. **Base LLMs** (Please ensure you have access to the model):
68
+ - [Llama-3.1-8b-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B)
69
+ - [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct)
70
+ - [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)
71
+
72
+ 2. **Pretrained Graph Models**:
73
+ - **Graph Decoder**: [Graph Diffusion Transformer](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphDiT)
74
+ - **Graph Encoder**: [GIN-based Encoder](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphEncoder)
75
+ - **Graph Predictor**: [GIN-based Predictor](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GNNPredictor)
76
+
77
+ 3. **Adapters and Connectors** for integrating the base LLM with pretrained graph models.
78
+
79
+ ### 🗂️ Manual Model Download
80
+
81
+ If you prefer to download the models manually, refer to [![](https://img.shields.io/badge/Huggingface-orange?style=plastic&logo=huggingface)](https://huggingface.co/collections/liuganghuggingface/llamole-collection-67073a2e76b47d5fea909434) and place them in the following directories:
82
+
83
+ - `saves/graph_decoder`
84
+ - `saves/graph_encoder`
85
+ - `saves/graph_predictor`
86
+ - `saves/{model_name}-Adapter`
87
+
88
+ ---
89
+
90
+ ### 🕸️ Gradio-based Web UI
91
+
92
+ Launch the web interface using Gradio:
93
+
94
+ ```bash
95
+ python launch.py
96
+ ```
97
+
98
+ The default base LLM is **Qwen2-7B-Instruct**. If you wish to change this, please modify the `args_dict` variable accordingly. Upon launch, the web UI will appear as shown below:
99
+
100
+ <div align="center">
101
+ <img src="assets/ui_example.png" alt="Web UI Example" width="60%">
102
+ </div>
103
+
104
+
105
+ ---
106
+
107
+ ### 💻 Command Line Usage
108
+
109
+ For command-line evaluation, specify the path to the configuration file:
110
+
111
+ ```bash
112
+ python main.py eval config/generate/qwen_material.yaml
113
+ ```
114
+
115
+ You can modify the configuration files to suit your custom datasets.
116
+
117
+ **Note**: Examples of training and evaluation datasets are available in the `data` folder. For more details, refer to `data/dataset_info.json`. To test generation on all MolQA questions, first download the dataset by running:
118
+
119
+ ```bash
120
+ python main.py download_data
121
+ ```
122
+
123
+ Then, update the configuration files to point to the downloaded dataset based on the names from `data/dataset_info.json`.
124
+
125
+ ---
126
+
127
+ ## 📚 Supervised Fine-Tuning
128
+
129
+ The codebase supports multimodal graph-text supervised fine-tuning. Follow these steps:
130
+
131
+ 1. **Download MolQA Training Data**:
132
+
133
+ ```bash
134
+ python main.py download_data
135
+ ```
136
+ Then you may need to modify the configuration files in the `config` folder to point to the downloaded training data. Skipping this step and directly using the command from step 2 will result in training only on the example training set.
137
+
138
+ 2. **Run Fine-Tuning**:
139
+
140
+ ```bash
141
+ python main.py train config/train/mistral_lora.yaml
142
+ ```
143
+
144
+ During the first run, pretrained graph models will be downloaded in the `saves` folder. Modify the configuration files as needed for your setup. An 80G A100 GPU is recommended for supervised fine-tuning.
145
+
146
+ ---
147
+ ## 📖 Citation
148
+
149
+ If you find this repository useful, please cite our paper:
150
+
151
+ ```
152
+ @misc{liu2024llamole,
153
+ title={Multimodal Large Language Models for Inverse Molecular Design with Retrosynthetic Planning},
154
+ author={Gang Liu and Michael Sun and Wojciech Matusik and Meng Jiang and Jie Chen},
155
+ year={2024},
156
+ eprint={2410.04223},
157
+ archivePrefix={arXiv},
158
+ primaryClass={cs.LG},
159
+ url={https://arxiv.org/abs/2410.04223},
160
+ }
161
+ ```
162
+
163
+ ```
164
+ @article{liu2024graphdit,
165
+ title={Graph Diffusion Transformers for Multi-Conditional Molecular Generation},
166
+ author={Liu, Gang and Xu, Jiaxin and Luo, Tengfei and Jiang, Meng},
167
+ journal={Thirty-Eighth Annual Conference on Neural Information Processing Systems},
168
+ year={2024}
169
+ }
170
+ ```
171
+
172
+ ---
173
+
174
+ ## 📄 Acknowledgments
175
+
176
+ This codebase is built upon **[Llama-Factory](https://github.com/hiyouga/LLaMA-Factory)**. We extend our gratitude for their open-source contributions.
177
+
178
+ ---
179
+
180
+ ## 📂 Additional Resources
181
+
182
+ 🔗 **Huggingface Models**: Llamole is developed with three variants (adapters) and three pretrained graph modules (encoder, decoder, predictor):
183
+ - **Base LLM Variant 1**: [Llama-3.1-8b-Instruct](https://huggingface.co/liuganghuggingface/Llamole-Llama-3.1-8B-Instruct-Adapter)
184
+ - **Base LLM Variant 2**: [Qwen2-7B-Instruct](https://huggingface.co/liuganghuggingface/Llamole-Qwen2-7B-Instruct-Adapter)
185
+ - **Base LLM Variant 3**: [Mistral-7B-Instruct-v0.3](https://huggingface.co/liuganghuggingface/Llamole-Mistral-7B-Instruct-v0.3-Adapter)
186
+ - **Pretrained Graph Decoder** for multi-conditional molecular generation: [Graph Diffusion Transformer](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphDiT)
187
+ - **Pretrained Graph Predictor** for one-step reaction prediction: [GNN Predictor](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GNNPredictor)
188
+ - **Pretrained Graph Encoder** for enhanced molecule understanding: [Graph Encoder](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphEncoder)
189
+
190
+ ---
assets/llamole.png ADDED

Git LFS Details

  • SHA256: 8f98ad05d897399cd4db1c618f78509e19b0b832f3aa2d938121bbc5a6e77166
  • Pointer size: 131 Bytes
  • Size of remote file: 771 kB
assets/ui_example.png ADDED

Git LFS Details

  • SHA256: 36a5f766fa0f9da927fea34d046939d88fc07ddbdaa3f35244f5bf6ea873b246
  • Pointer size: 131 Bytes
  • Size of remote file: 151 kB
config/generate/llama_drug.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct
3
+ new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
4
+ graph_decoder_path: saves/graph_decoder
5
+ graph_encoder_path: saves/graph_encoder
6
+ graph_predictor_path: saves/graph_predictor
7
+ adapter_name_or_path: saves/Llama-3.1-8B-Instruct-Adapter
8
+ graph_lm_connector_path: saves/Llama-3.1-8B-Instruct-Adapter/connector
9
+ ### generation
10
+ max_length: 512
11
+ max_new_tokens: 128
12
+ temperature: 0.6
13
+ top_p: 0.9
14
+ ### method
15
+ finetuning_type: lora
16
+ do_train: false
17
+ flash_attn: disabled
18
+ learned_query_size: 8
19
+ ### dataset
20
+ dataset: molqa_drug_examples
21
+ template: llama3
22
+ cutoff_len: 128
23
+ overwrite_cache: true
24
+ preprocessing_num_workers: 16
25
+ output_dir: null
26
+ bf16: true
27
+ pure_bf16: true
28
+ ### eval
29
+ per_device_eval_batch_size: 6
30
+ report_to: 'none'
config/generate/llama_material.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct
3
+ new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
4
+ graph_decoder_path: saves/graph_decoder
5
+ graph_encoder_path: saves/graph_encoder
6
+ graph_predictor_path: saves/graph_predictor
7
+ adapter_name_or_path: saves/Llama-3.1-8B-Instruct-Adapter
8
+ graph_lm_connector_path: saves/Llama-3.1-8B-Instruct-Adapter/connector
9
+ ### generation
10
+ max_length: 512
11
+ max_new_tokens: 128
12
+ temperature: 0.6
13
+ top_p: 0.9
14
+ ### method
15
+ finetuning_type: lora
16
+ do_train: false
17
+ flash_attn: disabled
18
+ learned_query_size: 8
19
+ ### dataset
20
+ dataset: molqa_material_examples
21
+ template: llama3
22
+ cutoff_len: 128
23
+ overwrite_cache: true
24
+ preprocessing_num_workers: 16
25
+ output_dir: null
26
+ bf16: true
27
+ pure_bf16: true
28
+ ### eval
29
+ per_device_eval_batch_size: 6
30
+ report_to: 'none'
config/generate/mistral_drug.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: mistralai/Mistral-7B-Instruct-v0.3
3
+ new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
4
+ graph_decoder_path: saves/graph_decoder
5
+ graph_encoder_path: saves/graph_encoder
6
+ graph_predictor_path: saves/graph_predictor
7
+ adapter_name_or_path: saves/Mistral-7B-Instruct-v0.3-Adapter
8
+ graph_lm_connector_path: saves/Mistral-7B-Instruct-v0.3-Adapter/connector
9
+ ### generation
10
+ max_length: 512
11
+ max_new_tokens: 128
12
+ temperature: 0.6
13
+ top_p: 0.9
14
+ ### method
15
+ finetuning_type: lora
16
+ do_train: false
17
+ flash_attn: disabled
18
+ learned_query_size: 8
19
+ ### dataset
20
+ dataset: molqa_drug_examples
21
+ template: mistral
22
+ cutoff_len: 128
23
+ overwrite_cache: true
24
+ preprocessing_num_workers: 16
25
+ output_dir: null
26
+ bf16: true
27
+ pure_bf16: true
28
+ ### eval
29
+ per_device_eval_batch_size: 6
30
+ report_to: 'none'
config/generate/mistral_material.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: mistralai/Mistral-7B-Instruct-v0.3
3
+ new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
4
+ graph_decoder_path: saves/graph_decoder
5
+ graph_encoder_path: saves/graph_encoder
6
+ graph_predictor_path: saves/graph_predictor
7
+ adapter_name_or_path: saves/Mistral-7B-Instruct-v0.3-Adapter
8
+ graph_lm_connector_path: saves/Mistral-7B-Instruct-v0.3-Adapter/connector
9
+ ### generation
10
+ max_length: 512
11
+ max_new_tokens: 128
12
+ temperature: 0.6
13
+ top_p: 0.9
14
+ ### method
15
+ finetuning_type: lora
16
+ do_train: false
17
+ flash_attn: disabled
18
+ learned_query_size: 8
19
+ ### dataset
20
+ dataset: molqa_material_examples
21
+ template: mistral
22
+ cutoff_len: 128
23
+ overwrite_cache: true
24
+ preprocessing_num_workers: 16
25
+ output_dir: null
26
+ bf16: true
27
+ pure_bf16: true
28
+ ### eval
29
+ per_device_eval_batch_size: 6
30
+ report_to: 'none'
config/generate/qwen_drug.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: Qwen/Qwen2-7B-Instruct
3
+ new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
4
+ graph_decoder_path: saves/graph_decoder
5
+ graph_encoder_path: saves/graph_encoder
6
+ graph_predictor_path: saves/graph_predictor
7
+ adapter_name_or_path: saves/Qwen2-7B-Instruct-Adapter
8
+ graph_lm_connector_path: saves/Qwen2-7B-Instruct-Adapter/connector
9
+ ### generation
10
+ max_length: 512
11
+ max_new_tokens: 128
12
+ temperature: 0.6
13
+ top_p: 0.9
14
+ ### method
15
+ finetuning_type: lora
16
+ do_train: false
17
+ flash_attn: disabled
18
+ learned_query_size: 8
19
+ ### dataset
20
+ dataset: molqa_drug_examples
21
+ template: qwen
22
+ cutoff_len: 128
23
+ overwrite_cache: true
24
+ preprocessing_num_workers: 16
25
+ output_dir: null
26
+ bf16: true
27
+ pure_bf16: true
28
+ ### eval
29
+ per_device_eval_batch_size: 6
30
+ report_to: 'none'
config/generate/qwen_material.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: Qwen/Qwen2-7B-Instruct
3
+ new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
4
+ graph_decoder_path: saves/graph_decoder
5
+ graph_encoder_path: saves/graph_encoder
6
+ graph_predictor_path: saves/graph_predictor
7
+ adapter_name_or_path: saves/Qwen2-7B-Instruct-Adapter
8
+ graph_lm_connector_path: saves/Qwen2-7B-Instruct-Adapter/connector
9
+ ### generation
10
+ max_length: 512
11
+ max_new_tokens: 128
12
+ temperature: 0.6
13
+ top_p: 0.9
14
+ ### method
15
+ finetuning_type: lora
16
+ do_train: false
17
+ flash_attn: disabled
18
+ learned_query_size: 8
19
+ ### dataset
20
+ dataset: molqa_material_examples
21
+ template: qwen
22
+ cutoff_len: 128
23
+ overwrite_cache: true
24
+ preprocessing_num_workers: 16
25
+ output_dir: null
26
+ bf16: true
27
+ pure_bf16: true
28
+ ### eval
29
+ per_device_eval_batch_size: 6
30
+ report_to: 'none'
config/train/llama_lora.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct
3
+ new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
4
+ graph_decoder_path: saves/graph_decoder
5
+ graph_encoder_path: saves/graph_encoder
6
+ graph_predictor_path: saves/graph_predictor
7
+ graph_lm_connector_path: null #specify when resume
8
+ ### method
9
+ stage: mmsft
10
+ do_train: true
11
+ finetuning_type: lora
12
+ lora_target: all
13
+ flash_attn: disabled
14
+ learned_query_size: 8
15
+ ### dataset
16
+ dataset: molqa_train_examples
17
+ template: llama3
18
+ cutoff_len: 2048
19
+ overwrite_cache: true
20
+ preprocessing_num_workers: 16
21
+ ### output
22
+ output_dir: saves/Llama-3.1-8B-Instruct-Adapter
23
+ logging_steps: 10
24
+ save_steps: 2000
25
+ plot_loss: true
26
+ overwrite_output_dir: true
27
+ ### to resume
28
+ # overwrite_output_dir: false
29
+ ### train
30
+ per_device_train_batch_size: 10
31
+ gradient_accumulation_steps: 2
32
+ learning_rate: 1.0e-4
33
+ num_train_epochs: 4.0
34
+ lr_scheduler_type: cosine
35
+ warmup_ratio: 0.1
36
+ ddp_timeout: 180000000
37
+ bf16: true
38
+ pure_bf16: true
39
+ ### train loss
40
+ loss_weight_retro: 1
41
+ loss_weight_design: 1
42
+ loss_weight_lm: 1
43
+ ### eval
44
+ val_size: 0.1
45
+ per_device_eval_batch_size: 6
46
+ eval_strategy: steps
47
+ eval_steps: 2000
48
+ ### specify if connected to wandb
49
+ report_to: 'none'
config/train/mistral_lora.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: mistralai/Mistral-7B-Instruct-v0.3
3
+ new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
4
+ graph_decoder_path: saves/graph_decoder
5
+ graph_encoder_path: saves/graph_encoder
6
+ graph_predictor_path: saves/graph_predictor
7
+ graph_lm_connector_path: null #specify when resume
8
+ ### method
9
+ stage: mmsft
10
+ do_train: true
11
+ finetuning_type: lora
12
+ lora_target: all
13
+ flash_attn: disabled
14
+ learned_query_size: 8
15
+ ### dataset
16
+ dataset: molqa_train_examples
17
+ template: mistral
18
+ cutoff_len: 2048
19
+ overwrite_cache: true
20
+ preprocessing_num_workers: 16
21
+ ### output
22
+ output_dir: saves/Mistral-7B-Instruct-v0.3-Adapter
23
+ logging_steps: 10
24
+ save_steps: 2000
25
+ plot_loss: true
26
+ overwrite_output_dir: true
27
+ ### to resume
28
+ # overwrite_output_dir: false
29
+ ### train
30
+ per_device_train_batch_size: 6
31
+ gradient_accumulation_steps: 2
32
+ learning_rate: 1.0e-4
33
+ num_train_epochs: 4.0
34
+ lr_scheduler_type: cosine
35
+ warmup_ratio: 0.1
36
+ ddp_timeout: 180000000
37
+ bf16: true
38
+ pure_bf16: true
39
+ loss_weight_retro: 1
40
+ loss_weight_design: 1
41
+ loss_weight_lm: 1
42
+ ### eval
43
+ val_size: 0.1
44
+ per_device_eval_batch_size: 6
45
+ eval_strategy: steps
46
+ eval_steps: 2000
47
+ ### specify if connected to wandb
48
+ report_to: 'none'
config/train/qwen_lora.yaml ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### model
2
+ model_name_or_path: Qwen/Qwen2-7B-Instruct
3
+ new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
4
+ graph_decoder_path: saves/graph_decoder
5
+ graph_encoder_path: saves/graph_encoder
6
+ graph_predictor_path: saves/graph_predictor
7
+ graph_lm_connector_path: null #specify when resume
8
+ ### method
9
+ stage: mmsft
10
+ do_train: true
11
+ finetuning_type: lora
12
+ lora_target: all
13
+ flash_attn: disabled
14
+ learned_query_size: 8
15
+ ### dataset
16
+ dataset: molqa_train_examples
17
+ template: qwen
18
+ cutoff_len: 2048
19
+ overwrite_cache: true
20
+ preprocessing_num_workers: 16
21
+ ### output
22
+ output_dir: saves/Qwen2-7B-Instruct-Adapter
23
+ logging_steps: 10
24
+ save_steps: 2000
25
+ plot_loss: true
26
+ overwrite_output_dir: true
27
+ ### to resume
28
+ # overwrite_output_dir: false
29
+ ### train
30
+ per_device_train_batch_size: 6
31
+ gradient_accumulation_steps: 2
32
+ learning_rate: 1.0e-4
33
+ num_train_epochs: 4.0
34
+ lr_scheduler_type: cosine
35
+ warmup_ratio: 0.1
36
+ ddp_timeout: 180000000
37
+ bf16: true
38
+ pure_bf16: true
39
+ loss_weight_retro: 1
40
+ loss_weight_design: 1
41
+ loss_weight_lm: 1
42
+ ### eval
43
+ val_size: 0.1
44
+ per_device_eval_batch_size: 6
45
+ eval_strategy: steps
46
+ eval_steps: 2000
47
+ ### specify if connected to wandb
48
+ report_to: 'none'
data/dataset_info.json ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "molqa_train_examples": {
3
+ "file_name": "molqa_train_examples.json"
4
+ },
5
+ "molqa_material_examples": {
6
+ "file_name": "molqa_material_examples.json"
7
+ },
8
+ "molqa_drug_examples": {
9
+ "file_name": "molqa_drug_examples.json"
10
+ },
11
+ "molqa_train": {
12
+ "file_name": "molqa_train.json"
13
+ },
14
+ "molqa_material": {
15
+ "file_name": "molqa_material.json"
16
+ },
17
+ "molqa_drug": {
18
+ "file_name": "molqa_drug.json"
19
+ }
20
+ }
data/molqa_drug_examples.json ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "Can you design and synthesize a molecule that inhibits both Human Immunodeficiency Virus and Beta-Secretase 1, with a molecular weight of around 335 and two aromatic rings connected to aliphatic chains?",
4
+ "input": "",
5
+ "output": "To satisfy the requirements: The molecule has a moderately complex structure, consisting of aromatic rings connected to aliphatic chains, and featuring carboxyl and hydroxyl functional groups. Therefore, the designed molecule is: <design_start><design_end><mol_start>CC(=O)OCCC(C)Oc1ccc(Cl)cc1Oc1ccccc1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CC(=O)OCCC(C)Oc1ccc(Cl)cc1Oc1ccccc1<mol_end>, follow these procedures: A solution of (R)-acetic acid 3-(toluene-4-sulfonyloxy)-butyl ester (5.89 g, 21 mmol) and 4-chloro-2-phenoxy-phenol (5.0 g, 23 mmol) in DMF (50 mL) is treated with cesium carbonate (7.4 g, 23 mmol). The solution is heated to 60\u00b0 C. and stirred overnight. The reaction is cooled and quenched with 1N HCl. The solution is partitioned in EtOAc and water. The organic is separated, washed with brine, and dried over sodium sulfate. The organic is filtered, and the solvent is removed to afford acetic acid 3-(4-chloro-2-phenoxy-phenoxy)-butyl ester, which is then diluted in methanol (100 mL) and treated with potassium carbonate (5.68 g, 40 mmol). The reaction is stirred for 2 hours at rt. The reaction is then partitioned in EtOAc and water. The organic layer is separated, washed with brine, and dried over sodium sulfate. The organic is filtered and the solvent is removed to afford the crude product. The crude is purified by silica gel column chromatography using 1/1 hexanes/EtOAc to elute the pure product. The solvent is removed to afford 4.35 g (72%) of the desired product. 1H NMR (400 MHz, CDCl3); MS (ES+) m/z mass calcd for C16H17ClO3 292, found 293 (M+1, 100%). The applied reaction is: <retro_start><retro_end>CC(=O)OCCC(C)Oc1ccc(Cl)cc1Oc1ccccc1>>CC(=O)OCCC(C)OS(=O)(=O)c1ccc(C)cc1.O=C(O)O.Oc1ccc(Cl)cc1Oc1ccccc1.[Cs+].[Cs+], which requires the reactants: CC(=O)OCCC(C)OS(=O)(=O)c1ccc(C)cc1 (available), O=C(O)O (available), Oc1ccc(Cl)cc1Oc1ccccc1 (not available, needs to be synthesized), [Cs+] (available), [Cs+] (available). This is step 2 in the retrosynthesis process. To synthesize <mol_start>Oc1ccc(Cl)cc1Oc1ccccc1<mol_end>, follow these procedures: A \u221240\u00b0 C. solution of 4-chloro-2-phenoxy-1-methoxy-benzene (7.16 g, 30.5 mmol) in dry CH2Cl2 (70 mL) is treated dropwise with borontribromide (22.9 g, 91.5 mmol) and then warmed to 0\u00b0 C. and stirred for 3 h under N2. The reaction is diluted with Et2O and quenched with water. The organic layer is dried (Na2SO4), and the solvent is removed in vacuo to afford 7.11 g (100%) of the title compound. Rf=0.30 (4/1 hexanes/acetone). 1H NMR (400 MHz, CDCl3); MS (ES\u2212) m/z mass calcd for C12H9O2Cl 220, found 219 (M\u22121, 100%). The applied reaction is: <retro_start><retro_end>Oc1ccc(Cl)cc1Oc1ccccc1>>COc1ccc(Cl)cc1Oc1ccccc1, which requires the reactants: COc1ccc(Cl)cc1Oc1ccccc1 (not available, needs to be synthesized). This is step 3 in the retrosynthesis process. To synthesize <mol_start>COc1ccc(Cl)cc1Oc1ccccc1<mol_end>, follow these procedures: A mixture of 2-bromo-4-chloro-1-methoxy-benzene (8.0 g, 36.1 mmol), phenol (6.80 g, 72.2 mmol), cesium carbonate (23.54 g, 72.2 mmol), copper (I) chloride (1.79 g, 18.1 mmol) and 2,2,6,6-tetramethyl-3,5-heptanedione (1.66 g, 9.00 mmol) in dry 1-methyl-2-pyrrolidinone (80 mL) is heated to 120\u00b0 C. for 20 hours under N2. The reaction is cooled, filtered and the filtrate quenched with 1 N HCl (50 mL). The filtrate is diluted with Et2O and extracted with water. The organic layer is dried (Na2SO4), and the solvent is removed in vacuo to afford crude product that is absorbed on silica gel and purified by flash chromatography using 9/1 hexanes/EtOAc to afford 7.42 g (88%) of the title compound. Rf=0.37 (4/1 hexanes/EtOAc). The applied reaction is: <retro_start><retro_end>COc1ccc(Cl)cc1Oc1ccccc1>>CC(C)(C)C(=O)CC(=O)C(C)(C)C.COc1ccc(Cl)cc1Br.O=C(O)O.[Cs+].[Cs+].[O-]c1ccccc1, which requires the reactants: CC(C)(C)C(=O)CC(=O)C(C)(C)C (available), COc1ccc(Cl)cc1Br (available), O=C(O)O (available), [Cs+] (available), [Cs+] (available), [O-]c1ccccc1 (available).",
6
+ "property": {
7
+ "HIV": 1.0,
8
+ "BACE": 1.0,
9
+ "SC": 3.1,
10
+ "SA": 2.46
11
+ },
12
+ "retro": [
13
+ null,
14
+ null,
15
+ null
16
+ ]
17
+ },
18
+ {
19
+ "instruction": "What is a possible molecular structure that can penetrate the blood-brain barrier and inhibit Human Immunodeficiency Virus, given a molecular weight of 209.249 and specific ring and bond characteristics?",
20
+ "input": "",
21
+ "output": "To satisfy the requirements: The molecule has an aromatic ring and a common functional group of carboxyl, which suggests a certain level of structural complexity, likely influencing its physical and chemical properties, and indicating a moderate synthetic complexity score. Therefore, the designed molecule is: <design_start><design_end><mol_start>CC(C)(C)OC(=O)Nc1ccc(N)cn1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CC(C)(C)OC(=O)Nc1ccc(N)cn1<mol_end>, follow these procedures: To a suspension of (tert-butoxy)-N-(5-nitro(2-pyridyl))carboxamide (0.27 g, 1.13 mmol) in methanol (2 mL), ethyl acetate (4 mL) and TEA (0.16 mL) was added 10% Pd/C (60 mg, 0.056 mmol) under argon. The reaction mixture was hydrogenated under 1 atm H2 for 20 hr, filtered through Celite and concentrated in vacuo to give N-(5-amino(2-pyridyl))(tert-butoxy)carboxamide (0.226 g, 97%). 1H-NMR (DMSO-d6): \u03b4 1.40 (s, 9H), 4.92 (br s, 2H), 6.89-6.91 (dd, 1H), 7.35-7.37 (d, 1H), 7.58 (d, 1H), 9.06 (s, 1H). The applied reaction is: <retro_start><retro_end>CC(C)(C)OC(=O)Nc1ccc(N)cn1>>CC(C)(C)OC(=O)Nc1ccc([N+](=O)[O-])cn1 with the template [#7;a:4]:[c:3]:[c:2]-[NH2;D1;+0:1]>>O=[N+;H0;D3:1](-[O-])-[c:2]:[c:3]:[#7;a:4], which requires the reactants: CC(C)(C)OC(=O)Nc1ccc([N+](=O)[O-])cn1 (available).",
22
+ "property": {
23
+ "BBBP": 1.0,
24
+ "HIV": 1.0,
25
+ "SC": 2.03,
26
+ "SA": 2.08
27
+ },
28
+ "retro": [
29
+ 2965
30
+ ]
31
+ },
32
+ {
33
+ "instruction": "What molecular structure can be designed to exhibit Blood-Brain Barrier Penetration and Human Immunodeficiency Virus Inhibition while featuring a specific combination of molecular weight, ring count, and functional groups?",
34
+ "input": "",
35
+ "output": "To satisfy the requirements: The molecule features a complex structure with two aromatic rings and a high number of rotatable bonds, indicating potential difficulties in synthesis. Its chemical properties are influenced by the presence of functional groups such as carboxyl and amino groups. The predicted properties suggest high substrate specificity for CYP2C9 and moderate substrate specificity for CYP3A4, while also indicating a low risk of carcinogenicity and high clearance in microsomes. The molecule's hydration free energy and skin reaction predictions are also noteworthy. Overall, the molecule's unique combination of structural features and predicted chemical properties makes it an interesting candidate for further investigation. Therefore, the designed molecule is: <design_start><design_end><mol_start>CC(=O)c1sc(NS(=O)(=O)c2cc(Br)c(Cl)s2)nc1C<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CC(=O)c1sc(NS(=O)(=O)c2cc(Br)c(Cl)s2)nc1C<mol_end>, follow these procedures: The title compound was prepared from 5-acety1-2-amino-4-methylthiazole (42 mg) and 4-bromo-5-chlorothiophene-2-sulfonyl chloride (80 mg) as described in the synthetic METHOD B to give a white solid (11.7 mg) with purity >90%: MS (pos) m/z 415.3, 417.3. The applied reaction is: <retro_start><retro_end>CC(=O)c1sc(NS(=O)(=O)c2cc(Br)c(Cl)s2)nc1C>>CC(=O)c1sc(N)nc1C.O=S(=O)(Cl)c1cc(Br)c(Cl)s1 with the template [#16;a:5]:[c:4]-[S;H0;D4;+0:1](=[O;D1;H0:2])(=[O;D1;H0:3])-[NH;D2;+0:9]-[c:7](:[#16;a:6]):[#7;a:8]>>Cl-[S;H0;D4;+0:1](=[O;D1;H0:2])(=[O;D1;H0:3])-[c:4]:[#16;a:5].[#16;a:6]:[c:7](:[#7;a:8])-[NH2;D1;+0:9], which requires the reactants: CC(=O)c1sc(N)nc1C (available), O=S(=O)(Cl)c1cc(Br)c(Cl)s1 (available).",
36
+ "property": {
37
+ "BBBP": 1.0,
38
+ "HIV": 1.0,
39
+ "SC": 2.83,
40
+ "SA": 2.74
41
+ },
42
+ "retro": [
43
+ 145329
44
+ ]
45
+ },
46
+ {
47
+ "instruction": "Can you design a molecule that penetrates the blood-brain barrier, inhibits Human Immunodeficiency Virus, and doesn't inhibit Beta-Secretase 1, while having a specific set of structural properties?",
48
+ "input": "",
49
+ "output": "To satisfy the requirements: The molecule exhibits a moderate level of structural complexity, featuring three aromatic rings and an aliphatic chain, with a molecular weight of 216.27. Its synthetic complexity score is 2.88, indicating it may be moderately challenging to synthesize. This structure is predicted to have a relatively low likelihood of mutagenicity, with an AMES score of 0.608. It also shows moderate to low likelihood of blood-brain barrier penetration, with a BBB_Martins score of 0.576. The molecule's oral bioavailability is predicted to be moderate, with a Bioavailability_Ma score of 0.797. Additionally, it is predicted to have a low to moderate interaction with various cytochrome P450 enzymes. Overall, the molecule's drug utility is influenced by its structural features, bioactivity, and pharmacokinetic properties, which suggest it may have potential as a therapeutic agent. Therefore, the designed molecule is: <design_start><design_end><mol_start>c1csc(Nc2nccn3ccnc23)c1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>c1csc(Nc2nccn3ccnc23)c1<mol_end>, follow these procedures: To a solution of (3-bromo-imidazo[1,2-a]pyrazin-8-yl)-methyl-amine (50 mg, 0.2 mmol) in tetrahydrofuran (2 ml) under nitrogen was added 2-thiopheneboronic acid (41 mg, 0.3 mmol), K2CO3 (1.1 ml of a 1 M solution in water) and 16 mg (0.1 mmol) of PdCl2(Dppf)CH2Cl2 (16 mg, 0.1 eq). The mixture was heated at 70\u00b0 C. in a sealed tube overnight. The product was precipitated by adding methanol. The filtrate was evaporated and purified by preparative thin layer chromatography on silica gel to give additional title compound for a total of 45 mg (100% yield). The applied reaction is: <retro_start><retro_end>c1csc(Nc2nccn3ccnc23)c1>>CNc1nccn2c(Br)cnc12.O=C(O)O.OB(O)c1cccs1.[K+].[K+] with the template [#7;a:4]:[c:3]:[cH;D2;+0:1]:[#7;a:2].[#7:5]-[c;H0;D3;+0:6]1:[cH;D2;+0:9]:[c:8]:[cH;D2;+0:7]:[s;H0;D2;+0:10]:1>>Br-[c;H0;D3;+0:1](:[#7;a:2]):[c:3]:[#7;a:4].[#7:5]-[CH3;D1;+0:6].O-B(-O)-[c;H0;D3;+0:7]1:[c:8]:[cH;D2;+0:9]:c:[s;H0;D2;+0:10]:1, which requires the reactants: CNc1nccn2c(Br)cnc12 (available), O=C(O)O (available), OB(O)c1cccs1 (available), [K+] (available), [K+] (available).",
50
+ "property": {
51
+ "BBBP": 1.0,
52
+ "HIV": 1.0,
53
+ "BACE": 0.0,
54
+ "SC": 2.88,
55
+ "SA": 2.75
56
+ },
57
+ "retro": [
58
+ 155241
59
+ ]
60
+ },
61
+ {
62
+ "instruction": "What molecule can be designed and synthesized to penetrate the blood-brain barrier and inhibit Human Immunodeficiency Virus, given the constraints of a molecular weight around 209.245 and a single aromatic ring?",
63
+ "input": "",
64
+ "output": "To satisfy the requirements: The molecule featuring an aromatic ring and an aliphatic chain, with functional groups including hydroxyl and carboxyl, exhibits moderate synthetic complexity and moderate lipophilicity, which may influence its physical and chemical properties. Therefore, the designed molecule is: <design_start><design_end><mol_start>CC(C)(C)OC(=O)Nc1ccccc1O<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CC(C)(C)OC(=O)Nc1ccccc1O<mol_end>, follow these procedures: To a solution of 2-aminophenol (2.84 g) in dichloromethane (120 ml) was added di-tert-butyl dicarbonate (6.55 g). The mixture was stirred at room temperature for 18 hours. The mixture was partitioned between water and dichloromethane. The organic extracts were dried (MgSO4) and evaporated to give a solid which was purified by column chromatography on silica eluting with a mixture of ethyl acetate and hexane (20:80) to give 2-tert-butyloxycarbonylaminophenol (1.80 g);. The applied reaction is: <retro_start><retro_end>CC(C)(C)OC(=O)Nc1ccccc1O>>CC(C)(C)OC(=O)OC(=O)OC(C)(C)C.Nc1ccccc1[O-] with the template [C;D1;H3:1]-[C:2](-[C;D1;H3:3])(-[C;D1;H3:4])-[#8:5]-[C;H0;D3;+0:6](=[O;D1;H0:7])-[NH;D2;+0:8]-[c:9]>>[C;D1;H3:1]-[C:2](-[C;D1;H3:3])(-[C;D1;H3:4])-[#8:5]-[C;H0;D3;+0:6](=[O;D1;H0:7])-O-[C;H0;D3;+0:6](=[O;D1;H0:7])-[#8:5]-[C:2](-[C;D1;H3:1])(-[C;D1;H3:3])-[C;D1;H3:4].[NH2;D1;+0:8]-[c:9], which requires the reactants: CC(C)(C)OC(=O)OC(=O)OC(C)(C)C (available), Nc1ccccc1[O-] (available).",
65
+ "property": {
66
+ "BBBP": 1.0,
67
+ "HIV": 1.0,
68
+ "SC": 1.7,
69
+ "SA": 1.73
70
+ },
71
+ "retro": [
72
+ 38065
73
+ ]
74
+ }
75
+ ]
data/molqa_material_examples.json ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "What is the optimal molecular design and synthesis route for a polymer with high CO2 gas permeability and low permeability to N2 and O2, featuring an aromatic ring and specific functional groups?",
4
+ "input": "",
5
+ "output": "To satisfy the requirements: The polymer, featuring an aromatic ring and functional groups such as carboxyl and hydroxyl, exhibits a moderate synthetic complexity score and high gas permeability, indicating potential for use in gas separation membranes with good efficiency. Therefore, the designed molecule is: <design_start><design_end><mol_start>*c1cc(C(=O)OCC)c(*)[nH]1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CCOC(=O)c1cc[nH]c1<mol_end>, follow these procedures: Ethanol (3.2 mL, 54 mmol) and 4-dimethylamino pyridine (DMAP) (54 mg, 0.45 mmol) were added to a mixture of pyrrole-3-carboxylic acid (500 mg, 4.5 mmol) and dicyclohexylcarbodiimide (1.11 g, 5.4 mmol) (Aldrich, Milwaukee, Wis.) in tetrahydrofuran (THF) (15 mL). After heating at 60\u00b0 C. for 10 hours, the reaction was cooled. The precipitate was filtered off, washed with ethyl acetate, the combined filtrate was concentrated and purified on a silica gel column to give 500 mg (81%) of 1H-pyrrole-3-carboxylic acid ethyl ester as a colorless oil. The applied reaction is: <retro_start><retro_end>CCOC(=O)c1cc[nH]c1>>C(=NC1CCCCC1)=NC1CCCCC1.CCO.O=C(O)c1cc[nH]c1 with the template [#7;a:5]:[c:4]:[c:3]-[C;H0;D3;+0:1](=[O;D1;H0:2])-[O;H0;D2;+0:7]-[C:6]>>O-[C;H0;D3;+0:1](=[O;D1;H0:2])-[c:3]:[c:4]:[#7;a:5].[C:6]-[OH;D1;+0:7], which requires the reactants: C(=NC1CCCCC1)=NC1CCCCC1 (available), CCO (available), O=C(O)c1cc[nH]c1 (available).",
6
+ "property": {
7
+ "CO2": 0.94,
8
+ "N2": 0.0,
9
+ "O2": 0.0,
10
+ "FFV": 0.381,
11
+ "SC": 2.28,
12
+ "SA": 4.21
13
+ },
14
+ "retro": [
15
+ 36155
16
+ ]
17
+ },
18
+ {
19
+ "instruction": "Can a polymer with a molecular weight of 184.22, one aromatic ring, and specific gas permeability, thermal conductivity, and electrical properties be designed and synthesized for gas separation applications?",
20
+ "input": "",
21
+ "output": "To satisfy the requirements: The polymer has an aromatic ring incorporated within its structure, which can contribute to its overall rigidity and mechanical strength. The presence of the sulfonamide group can also enhance the polymer's chemical stability and resistance to degradation. The predicted thermal properties, such as melting point and thermal conductivity, suggest that the polymer may exhibit suitable thermal behavior for various applications. Additionally, the predicted dielectric constant and electrical conductivity indicate that the polymer may possess suitable electrical properties. The gas permeability coefficient and gas diffusion coefficient suggest that the polymer may be suitable for gas separation applications. Overall, the polymer's material utility seems promising for creating polymeric materials with desirable properties. Therefore, the designed molecule is: <design_start><design_end><mol_start>*Nc1cc(*)ccc1S(=O)(=O)NC<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CNS(=O)(=O)c1ccccc1N<mol_end>, follow these procedures: N-Methyl-2-nitrobenzenesulfonamide (1.0 g, 4.6 mmol), 10% palladium on carbon (100 mg) and absolute ethanol (30 mL) were placed in a Parr bottle, the bottle placed on a Parr apparatus under a hydrogen atmosphere (45 psi), and shaken for 1 hour, The Parr bottle was evacuated and flushed with nitrogen, then the mixture was filtered through a disposable syringe filter. The filtrate was concentrated in vacuo to yield 0.81 g of 2-amino-N-methylbenzenesulfonamide. The applied reaction is: <retro_start><retro_end>CNS(=O)(=O)c1ccccc1N>>CNS(=O)(=O)c1ccccc1[N+](=O)[O-] with the template [NH2;D1;+0:1]-[c:2]>>O=[N+;H0;D3:1](-[O-])-[c:2], which requires the reactants: CNS(=O)(=O)c1ccccc1[N+](=O)[O-] (available).",
22
+ "property": {
23
+ "CO2": 1.746,
24
+ "O2": 1.247,
25
+ "FFV": 0.37,
26
+ "TC": 0.192,
27
+ "SC": 2.8,
28
+ "SA": 4.17
29
+ },
30
+ "retro": [
31
+ 3
32
+ ]
33
+ },
34
+ {
35
+ "instruction": "Can a molecule be designed and synthesized to exhibit high CO2 permeability, moderate fractional free volume, and specific thermal conductivity, while incorporating aromatic and carboxyl functional groups?",
36
+ "input": "",
37
+ "output": "To satisfy the requirements: The polymer, with its aromatic ring and carboxyl functional group, exhibits a moderate synthetic complexity score, suggesting a relatively accessible molecular structure for synthesis. Its molecular structure and functional groups are expected to impart specific physical and chemical properties, influencing its material utility for polymeric materials. Therefore, the designed molecule is: <design_start><design_end><mol_start>*CC(*)OC(=O)c1cc(C)cc(Cl)c1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CCOC(=O)c1cc(C)cc(Cl)c1<mol_end>, follow these procedures: The title compound was synthesized from 3-chloro-5-methyl-benzoic acid under the same conditions as for Compound b1. The applied reaction is: <retro_start><retro_end>CCOC(=O)c1cc(C)cc(Cl)c1>>CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1.Cc1cc(Cl)cc(C(=O)O)c1, which requires the reactants: CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1 (not available, needs to be synthesized), Cc1cc(Cl)cc(C(=O)O)c1 (available). This is step 2 in the retrosynthesis process. To synthesize <mol_start>CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1<mol_end>, follow these procedures: Potassium carbonate (1.5 g, 11.2 mmol) and ethyl iodide (1.2 g, 7.4 mmol) were added to a solution of 4-bromo-3-trifluoromethyl-benzoic acid (1.0 g, 3.7 mmol) in DMF (5 ml), and the mixture was stirred at room temperature for 24 hours. The reaction mixture was diluted with ethyl acetate, and the organic layer was washed with water and saturated-saline, and dried over anhydrous sodium sulfate. The drying agent was removed by filtration. After concentration under reduced pressure, the resulting residue was purified by silica gel column chromatography (ethyl acetate/hexane) to yield the title compound (1.03 g, 94%) as a brown oily substance. The applied reaction is: <retro_start><retro_end>CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1>>CCI.O=C(O)c1ccc(Br)c(C(F)(F)F)c1.O=C([O-])[O-].[K+].[K+], which requires the reactants: CCI (available), O=C(O)c1ccc(Br)c(C(F)(F)F)c1 (available), O=C([O-])[O-] (available), [K+] (available), [K+] (available).",
38
+ "property": {
39
+ "CO2": 0.94,
40
+ "N2": 0.0,
41
+ "O2": 0.0,
42
+ "FFV": 0.375,
43
+ "TC": 0.223,
44
+ "SC": 2.58,
45
+ "SA": 4.01
46
+ },
47
+ "retro": [
48
+ null,
49
+ null
50
+ ]
51
+ },
52
+ {
53
+ "instruction": "What molecular structure can be designed and synthesized to achieve CO2 permeability of 1.743 Barrer, N2 impermeability, and specific thermal conductivity and fractional free volume properties, while incorporating an aromatic ring and limited rotatable bonds?",
54
+ "input": "",
55
+ "output": "To satisfy the requirements: The polymer is a complex molecular structure incorporating an aromatic ring and a bromine functional group, which may provide enhanced thermal and mechanical properties, potentially useful for gas separation applications. Therefore, the designed molecule is: <design_start><design_end><mol_start>*c1ccc(*)c(CBr)c1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>BrCc1ccccc1<mol_end>, follow these procedures: 4.65 g (10 mM) of compound 1a were treated, while stirring, with 40 ml of 2 N HBr in glacial acetic acid for 45 min. at 20\u00b0 in the absence of moisture. The amino acid derivative dissolved with CO2 evolution. The reaction solution was added dropwise with vigorous stirring to 250 ml of absolute ether which resulted in the precipitation of 2HBr.H-Arg-pNA. The ethereal phase was sucked off, whereupon the solid phase was washed 4 times with portions of 100 ml of abs. ether in order to substantially remove benzyl bromide which had formed as a by-product as well as excess HBr and AcOH. The residue was dissolved in 50 ml of MeOH, the pH was adjusted to 4.5 by the addition of Et3N, and the solution was concentrated to dryness in vacuo at 30\u00b0. The resulting product was dissolved in 75 ml of MeOH and passed through a column of \"Sephadex\" LH-20 (cross-linked dextran gel) equilibrated with MeOH. From a fraction of the eluate there were obtained 4.18 g (91.6% of the theory) of amorphous compound 1b which was homogeneous in the SS as shown by TLC. Elementary analysis and calculation from the empirical formula C12H20N6O3Br2 gave the following values: C=31.15% (31.60%), H=4.35% (4.42%), N=18.84% (18.43%) and Br=34.81% (35.03%). The applied reaction is: <retro_start><retro_end>BrCc1ccccc1>>CC(C)(C#N)/N=N/C(C)(C)C#N.Cc1ccc(S(=O)(=O)Cl)cc1.O=C1CCC(=O)N1Br, which requires the reactants: CC(C)(C#N)/N=N/C(C)(C)C#N (available), Cc1ccc(S(=O)(=O)Cl)cc1 (available), O=C1CCC(=O)N1Br (available).",
56
+ "property": {
57
+ "CO2": 1.743,
58
+ "N2": 0.0,
59
+ "O2": 4.692,
60
+ "FFV": 0.411,
61
+ "TC": 0.256,
62
+ "SC": 2.25,
63
+ "SA": 5.14
64
+ },
65
+ "retro": [
66
+ null
67
+ ]
68
+ },
69
+ {
70
+ "instruction": "Can you design a molecule with moderate synthetic complexity, featuring a carboxyl functional group and an aromatic ring, that exhibits CO2 gas permeability of 9.176 Barrer, O2 gas permeability of 1.494 Barrer, fractional free volume of 0.374, and thermal conductivity of 0.228 W m^{-1} K^{-1}?",
71
+ "input": "",
72
+ "output": "To satisfy the requirements: The polymer exhibits a moderate synthetic complexity score, indicating that its synthesis is relatively feasible. The presence of an aromatic ring and a carboxyl functional group in its scaffold contributes to its moderate complexity. Therefore, the designed molecule is: <design_start><design_end><mol_start>*OC(=O)c1ccc(*)s1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>O=C([O-])c1cccs1<mol_end>, follow these procedures: To a 15-mL polypropylene centrifuge tube was added 3.86 mL of 50 mM potassium phosphate buffer (pH 7.0), 1.0 mL of a suspension of 22.1 mg dry cell weight E. coli SW132 wet cells (prepared as described in Example 10) in 50 mM potassium phosphate buffer (pH 7.0), and 0.1691 g of 2-thiophenecarbonitrile. The final concentration of 2-thiophencarbonitrile was 0.307 M. The reaction mixture was mixed on a rotating platform at 27\u00b0 C. After 30 min, 7.50 mL of 95:5 acetonitrile/water containing 0.30 M N,N-dimethylbenzamide (HPLC external standard) was added to the reaction, the resulting mixture centrifuged, and a 0.100 mL of the supernatant mixed with 0.900 mL of acetonitrile and analyzed by HPLC. The conversion of 2-thiophenecarbonitrile was 99.5%, and the yields of 2-thiophenecarboxamide and 2-thiophenecarboxylic acid were 98% and 0%, respectively. The applied reaction is: <retro_start><retro_end>O=C([O-])c1cccs1>>CN(C)C(=O)c1ccccc1.N#Cc1cccs1.O=P([O-])([O-])O.[2H]C([2H])([2H])C#N.[K+].[K+].[K+].[OH-], which requires the reactants: CN(C)C(=O)c1ccccc1 (available), N#Cc1cccs1 (available), O=P([O-])([O-])O (available), [2H]C([2H])([2H])C#N (available), [K+] (available), [K+] (available), [K+] (available), [OH-] (available).",
73
+ "property": {
74
+ "CO2": 9.176,
75
+ "O2": 1.494,
76
+ "FFV": 0.374,
77
+ "TC": 0.228,
78
+ "SC": 2.23,
79
+ "SA": 4.78
80
+ },
81
+ "retro": [
82
+ null
83
+ ]
84
+ }
85
+ ]
data/molqa_train_examples.json ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [
2
+ {
3
+ "instruction": "What is a feasible molecular design that meets the criteria of a molecular weight of 243.306, two rings including one aromatic and one aliphatic, two rotatable bonds, and two hydrogen bond acceptors, with a balanced hydrophobic and hydrophilic profile?",
4
+ "input": "",
5
+ "output": "To satisfy the requirements: The molecule has a scaffold consisting of a central aromatic ring linked to an amino-functionalized aliphatic chain, featuring a carboxyl group. Its structural complexity is moderate, with a balance of hydrophobic and hydrophilic regions, which could influence its physical and chemical properties. Therefore, the designed molecule is: <design_start><design_end><mol_start>C#CCC1CCN(C(=O)Oc2ccccc2)CC1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>C#CCC1CCN(C(=O)Oc2ccccc2)CC1<mol_end>, follow these procedures: Phenyl chloroformate (6.2 g, 40.2 mmol) was added to a solution of tert-butyl 4-(prop-2-ynyl)piperidine-1-carboxylate (1.65 g, 13.4 mmol) according to general procedure 1. Yield=0.600 g, 34%. m/z MH+=244.08. HPLC rt=10.3 min. The applied reaction is: <retro_start><retro_end>C#CCC1CCN(C(=O)Oc2ccccc2)CC1>>C#CCC1CCN(C(=O)OC(C)(C)C)CC1.O=C(Cl)Oc1ccccc1 with the template [#8:5]-[C;H0;D3;+0:4](=[O;D1;H0:6])-[N;H0;D3;+0:1](-[C:2])-[C:3]>>C-C(-C)(-C)-O-C(=O)-[N;H0;D3;+0:1](-[C:2])-[C:3].Cl-[C;H0;D3;+0:4](-[#8:5])=[O;D1;H0:6], which requires the reactants: C#CCC1CCN(C(=O)OC(C)(C)C)CC1 (available), O=C(Cl)Oc1ccccc1 (available).",
6
+ "property": {
7
+ "SC": 2.42,
8
+ "SA": 2.22
9
+ },
10
+ "retro": [
11
+ 180575
12
+ ]
13
+ },
14
+ {
15
+ "instruction": "Can you design and synthesize a molecule with a molecular weight of around 206, containing 2 rings including 1 aromatic and 1 aliphatic, and 3 rotatable bonds, with no hydrogen bond donors and 2 acceptors?",
16
+ "input": "",
17
+ "output": "To satisfy the requirements: The molecule has a relatively low synthetic complexity score and a moderate synthetic accessibility score, indicating that it is relatively easy to synthesize. Its molecular structure is composed of an aromatic ring and an aliphatic chain, which may influence its physical and chemical properties. Therefore, the designed molecule is: <design_start><design_end><mol_start>Cc1ccc(C)n1CCN1CCCC1=O<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>Cc1ccc(C)n1CCN1CCCC1=O<mol_end>, follow these procedures: 5.7 g (0.05 mmol) of acetonylacetone and 6.4 g (0.05 mmol) of 1-(2-aminoethyl)-2-pyrrolidinone are heated under reflux in 70 ml of methanol for 2 h. The reaction mixture is then concentrated and, after the concentration, the crude reaction product is then precipitated by addition of petroleum ether. The crude product is recrystallized from diethyl ether. Yield: 3.2 g (31% of theory), Melting point: 66\u00b0-68\u00b0 C. Elemental analysis: C12H18N2O (206.29) calculated: C 69.9 H 8.8 N 13.6 0 7.8 found: C 69.4 H 8.7 N 13.7 0 8.11. The applied reaction is: <retro_start><retro_end>Cc1ccc(C)n1CCN1CCCC1=O>>CC(=O)CCC(C)=O.NCCN1CCCC1=O with the template [C:7]-[n;H0;D3;+0:8]1:[c;H0;D3;+0:1](-[C;D1;H3:2]):[cH;D2;+0:3]:[cH;D2;+0:4]:[c;H0;D3;+0:5]:1-[C;D1;H3:6]>>O=[C;H0;D3;+0:1](-[C;D1;H3:2])-[CH2;D2;+0:3]-[CH2;D2;+0:4]-[C;H0;D3;+0:5](=O)-[C;D1;H3:6].[C:7]-[NH2;D1;+0:8], which requires the reactants: CC(=O)CCC(C)=O (available), NCCN1CCCC1=O (available).",
18
+ "property": {
19
+ "SC": 2.69,
20
+ "SA": 2.34
21
+ },
22
+ "retro": [
23
+ 18636
24
+ ]
25
+ },
26
+ {
27
+ "instruction": "Can you design and synthesize a molecule with a molecular weight around 377, incorporating four rings including two aromatic and two aliphatic rings, and six rotatable bonds, with no hydrogen bond donors and four acceptors?",
28
+ "input": "",
29
+ "output": "To satisfy the requirements: The molecule has a moderately complex structure built around a core scaffold that incorporates aromatic rings and aliphatic chains, with the presence of functional groups and a moderate number of rotatable bonds, contributing to its moderate synthetic complexity score. Therefore, the designed molecule is: <design_start><design_end><mol_start>Cc1ccccc1N1CCN(CCCCN2C(=O)c3ccccc3C2=O)CC1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>Cc1ccccc1N1CCN(CCCCN2C(=O)c3ccccc3C2=O)CC1<mol_end>, follow these procedures: The title compound was synthesized applying the general procedure III step A using 1-(2-methylphenyl)piperazine (1.99 mmol, 350 mg), N-(4-bromobutyl)phthalimide (1.99 mmol, 560 mg), and K2CO3 (4.97 mmol, 686 mg) in 6 mL of acetonitrile. White solid 750 mg (99%). 1H NMR (400 MHz, Chloroform-d) \u03b4 7.92-7.79 (m, 2H), 7.72 (dt, J=5.5, 3.1 Hz, 2H), 7.16 (t, J=7.6 Hz, 2H), 7.08-6.88 (m, 2H), 3.74 (t, J=7.1 Hz, 2H), 2.93 (t, J=4.8 Hz, 4H), 2.60 (s, 4H), 2.52-2.39 (m, 2H), 2.30 (s, 3H), 1.75 (p, J=7.3 Hz, 2H), 1.59 (tt, J=9.7, 5.9 Hz, 2H). The applied reaction is: <retro_start><retro_end>Cc1ccccc1N1CCN(CCCCN2C(=O)c3ccccc3C2=O)CC1>>Cc1ccccc1[NH+]1CC[NH2+]CC1.O=C(O)O.O=C1c2ccccc2C(=O)N1CCCCBr.[K+].[K+], which requires the reactants: Cc1ccccc1[NH+]1CC[NH2+]CC1 (available), O=C(O)O (available), O=C1c2ccccc2C(=O)N1CCCCBr (available), [K+] (available), [K+] (available).",
30
+ "property": {
31
+ "SC": 2.93,
32
+ "SA": 1.97
33
+ },
34
+ "retro": [
35
+ null
36
+ ]
37
+ },
38
+ {
39
+ "instruction": "Can you design a molecule with a molecular weight around 242, one aromatic ring, three rotatable bonds, and three hydrogen bond acceptors, and describe its synthesis process?",
40
+ "input": "",
41
+ "output": "To satisfy the requirements: The molecule has an aromatic ring and an aliphatic chain, with the presence of carboxyl, amino, and halogen groups, which contribute to its moderate synthetic complexity score, indicating a relatively challenging synthesis process. Therefore, the designed molecule is: <design_start><design_end><mol_start>COC(=O)c1ccccc1N(C)C(=O)CCl<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>COC(=O)c1ccccc1N(C)C(=O)CCl<mol_end>, follow these procedures: 25 Grams of methyl o-methylaminobenzoate was dissolved in 125 ml of acetone, to this solution was added a solution prepared by dissolving 20.92 g of potassium carbonate in 60 ml of water. Under an ice-cooled condition, 17.95 g of chloroacetyl chloride was added dropwise thereto, then the reaction mixture was stired at a room temperature for 1 hour. Acetone was removed by distillation, the residue obtained was extracted with chloroform, then the chloroform layer was washed with water, dried and the solvent was removed by distillation. The residue thus obtained was purified by a silica gel column chromatography to obtain 16.80 g of N-methyl-o-methoxycarbonyl-\u03b1-chloroacetanilide in the form of colorless oily substance. The applied reaction is: <retro_start><retro_end>COC(=O)c1ccccc1N(C)C(=O)CCl>>CNc1ccccc1C(=O)OC.O=C(Cl)CCl.O=C(O)O.[K+].[K+] with the template [C:2]-[C;H0;D3;+0:1](=[O;D1;H0:3])-[N;H0;D3;+0:5](-[C;D1;H3:4])-[c:6]>>Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3].[C;D1;H3:4]-[NH;D2;+0:5]-[c:6], which requires the reactants: CNc1ccccc1C(=O)OC (available), O=C(Cl)CCl (available), O=C(O)O (available), [K+] (available), [K+] (available).",
42
+ "property": {
43
+ "SC": 2.13,
44
+ "SA": 2.07
45
+ },
46
+ "retro": [
47
+ 39003
48
+ ]
49
+ },
50
+ {
51
+ "instruction": "Can you design a molecule with a molecular weight of 317.186, three aromatic rings, and two rotatable bonds, and outline a synthesis pathway for it?",
52
+ "input": "",
53
+ "output": "To satisfy the requirements: The molecule features a complex aromatic scaffold with multiple rings, exhibiting moderate synthetic complexity and potential bioavailability, while also showing possible interactions with nuclear receptors and stress response pathways. Therefore, the designed molecule is: <design_start><design_end><mol_start>COc1ccccc1-c1cn2cc(Br)cc(C)c2n1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>COc1ccccc1-c1cn2cc(Br)cc(C)c2n1<mol_end>, follow these procedures: A mixture of 2-bromo-2\u2032-methoxyacetophenone (0.500 g, 2.183 mmol) and 2-amino-5-bromo-3-methylpyridine (0.408 g, 2.183 mmol) in ethanol (8 mL) was heated to 80\u00b0 C. for 16 hours in a pressure vessel. A yellow thick precipitate formed. Water was added and the solid was extracted with ethyl acetate (some DCM and MeOH were added to help in the dissolution). The organic extracts were then washed with brine, dried over Na2SO4, filtered and concentrated and the residue was purified by silica gel chromatography (Biotage SNAP 100 g, 10 to 20% ethyl acetate in hexanes) to give 220 (0.375 g, 1.182 mmol, 54.2% yield) as a white solid. LRMS (ESI): calc. 316.02 found 317.1 (MH)+. The applied reaction is: <retro_start><retro_end>COc1ccccc1-c1cn2cc(Br)cc(C)c2n1>>COc1ccccc1C(=O)CBr.Cc1cc(Br)cnc1N.[OH-] with the template [c:4]:[c;H0;D3;+0:3](:[c:5])-[c;H0;D3;+0:2]1:[cH;D2;+0:1]:[n;H0;D3;+0:8](:[c:9]):[c:7]:[n;H0;D2;+0:6]:1>>Br-[CH2;D2;+0:1]-[C;H0;D3;+0:2](=O)-[c;H0;D3;+0:3](:[c:4]):[c:5].[NH2;D1;+0:6]-[c:7]:[n;H0;D2;+0:8]:[c:9], which requires the reactants: COc1ccccc1C(=O)CBr (available), Cc1cc(Br)cnc1N (available), [OH-] (available).",
54
+ "property": {
55
+ "SC": 2.94,
56
+ "SA": 2.19
57
+ },
58
+ "retro": [
59
+ 10919
60
+ ]
61
+ }
62
+ ]
data/property_ranges.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "BBBP": {
3
+ "min": 0.0,
4
+ "max": 1.0
5
+ },
6
+ "HIV": {
7
+ "min": 0.0,
8
+ "max": 1.0
9
+ },
10
+ "BACE": {
11
+ "min": 0.0,
12
+ "max": 1.0
13
+ },
14
+ "CO2": {
15
+ "min": 0.94,
16
+ "max": 1019.265
17
+ },
18
+ "N2": {
19
+ "min": 0.0,
20
+ "max": 73.417
21
+ },
22
+ "O2": {
23
+ "min": 0.0,
24
+ "max": 122.94
25
+ },
26
+ "FFV": {
27
+ "min": 0.324,
28
+ "max": 0.434
29
+ },
30
+ "TC": {
31
+ "min": 0.117,
32
+ "max": 0.38
33
+ },
34
+ "SC": {
35
+ "min": 1.0,
36
+ "max": 5.0
37
+ },
38
+ "SA": {
39
+ "min": 1.0,
40
+ "max": 8.48
41
+ }
42
+ }
install_environment.sh ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+
3
+ # Set non-interactive frontend
4
+ export DEBIAN_FRONTEND=noninteractive
5
+
6
+ # Activate the llama_factory environment
7
+ # Uncomment the following line if you need to create the environment
8
+ # conda create --name llamole python=3.11 -y
9
+ # conda activate llamole
10
+
11
+ # Function to get the current conda environment name
12
+ get_current_conda_env() {
13
+ echo "current: $CONDA_DEFAULT_ENV"
14
+ }
15
+
16
+ # Get the current conda environment
17
+ current_env=$(basename "$CONDA_PREFIX")
18
+
19
+ # Check if the current environment is 'llamole'
20
+ if [ "$current_env" != "llamole" ]; then
21
+ echo "current: $CONDA_DEFAULT_ENV"
22
+ echo "Current conda environment is neither 'llama_factory' nor 'llamole'."
23
+ echo "Please activate one of these environments before running this script."
24
+ echo "You can activate an environment using one of these commands:"
25
+ echo "conda activate llama_factory"
26
+ echo "conda activate llamole"
27
+ exit 1
28
+ fi
29
+
30
+ echo "Running in conda environment: $current_env"
31
+
32
+ # "pandas>=2.0.0" \
33
+ # Install packages using pip
34
+ pip install --no-cache-dir \
35
+ pyarrow \
36
+ "pandas>=1.5.3" \
37
+ "rdkit==2023.9.6" \
38
+ pyyaml \
39
+ ipykernel \
40
+ packaging \
41
+ gdown \
42
+ "fcd_torch==1.0.7" \
43
+ "omegaconf==2.3.0" \
44
+ "imageio==2.26.0" \
45
+ wandb \
46
+ pandarallel \
47
+ scipy \
48
+ einops \
49
+ sentencepiece \
50
+ tiktoken \
51
+ protobuf \
52
+ uvicorn \
53
+ pydantic \
54
+ fastapi \
55
+ sse-starlette \
56
+ "matplotlib>=3.7.0" \
57
+ fire \
58
+ "numpy<2.0.0" \
59
+ gradio
60
+
61
+ pip install --no-cache-dir hydra-core --upgrade
62
+
63
+ # Install PyTorch
64
+ pip install --no-cache-dir torch
65
+
66
+ # Install PyTorch Geometric and related packages
67
+ pip install --no-cache-dir torch_geometric
68
+
69
+ # for retro reaction
70
+ pip install rdchiral
71
+ pip install nltk
72
+
73
+ # Install transformers and related packages
74
+ pip install --no-cache-dir \
75
+ "transformers>=4.41.3" \
76
+ "datasets>=2.16.0" \
77
+ "accelerate>=0.30.1" \
78
+ "peft>=0.11.1" \
79
+ "trl>=0.8.6" \
80
+ "gradio>=4.0.0"
81
+
82
+ # Install mini-moses from GitHub
83
+ pip install --no-cache-dir git+https://github.com/igor-krawczuk/mini-moses
84
+
85
+ echo "Installation complete!"
launch.py ADDED
@@ -0,0 +1,238 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Llamole Team
2
+ #
3
+ # This code is inspired by the HuggingFace's transformers library.
4
+ # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import json
19
+ import yaml
20
+ import numpy as np
21
+ import gradio as gr
22
+ import random
23
+ from rdkit import Chem
24
+ from rdkit.Chem import Draw
25
+ from rdkit.Chem import AllChem
26
+
27
+ from src.webui.workflow import load_model_and_tokenizer, process_input, generate
28
+ from src.webui.elements import create_input_components
29
+
30
+ # Load candidates
31
+ with open('data/molqa_material_examples.json', 'r') as f:
32
+ material_examples = json.load(f)
33
+
34
+ with open('data/molqa_drug_examples.json', 'r') as f:
35
+ drug_examples = json.load(f)
36
+
37
+ # Add type to each example
38
+ for example in material_examples:
39
+ example['type'] = 'Material'
40
+
41
+ for example in drug_examples:
42
+ example['type'] = 'Drug'
43
+
44
+ # Function to process property values
45
+ def process_property(value):
46
+ return 1e-8 if value == 0 else value
47
+
48
+ # Add type to each example and process property values
49
+ for example in material_examples:
50
+ example['type'] = 'Material'
51
+ for prop in ['CO2', 'N2', 'O2', 'FFV']:
52
+ if prop in example['property']:
53
+ example['property'][prop] = process_property(example['property'][prop])
54
+
55
+ # Combine examples
56
+ all_examples = material_examples + drug_examples
57
+
58
+ # Get default values from the first material example
59
+ default_values = drug_examples[0]
60
+
61
+ # Load property ranges and arguments
62
+ with open('data/property_ranges.json', 'r') as f:
63
+ property_ranges = json.load(f)
64
+
65
+ # with open('config/generate/qwen_material.yaml', 'r') as file:
66
+ with open('config/generate/llama_material.yaml', 'r') as file:
67
+ args_dict = yaml.safe_load(file)
68
+
69
+ # Load model and tokenizer outside the function
70
+ model, tokenizer, generating_args = load_model_and_tokenizer(args_dict)
71
+
72
+ def format_example(example):
73
+ formatted = [example['instruction']]
74
+
75
+ # Determine if it's a drug or material example based on properties
76
+ is_drug = any(prop in example.get('property', {}) for prop in ["HIV", "BBBP", "BACE"])
77
+ formatted.append("Drug" if is_drug else "Material")
78
+
79
+ # Handle drug properties
80
+ for prop in ["HIV", "BBBP", "BACE"]:
81
+ value = example.get('property', {}).get(prop, float('nan'))
82
+ formatted.append(value if not np.isnan(value) else "NAN")
83
+
84
+ # Handle material properties
85
+ for prop in ["CO2", "N2", "O2", "FFV", "TC"]:
86
+ value = example.get('property', {}).get(prop, float('nan'))
87
+ formatted.append(value if not np.isnan(value) else 0) # 0 represents NAN for material properties
88
+
89
+ # Handle synthetic properties
90
+ for prop in ["SC", "SA"]:
91
+ value = example.get('property', {}).get(prop, float('nan'))
92
+ formatted.append(value if not np.isnan(value) else float('nan'))
93
+
94
+ return formatted
95
+
96
+ # Prepare examples
97
+ formatted_examples = [format_example(example) for example in all_examples]
98
+
99
+ def random_example(examples):
100
+ example = random.choice(examples)
101
+ property_type = example['type']
102
+
103
+ outputs = [example['instruction'], property_type]
104
+
105
+ for prop in ["HIV", "BBBP", "BACE"]:
106
+ outputs.append(example['property'].get(prop, "NAN"))
107
+
108
+ for prop in ["CO2", "N2", "O2", "FFV", "TC"]:
109
+ outputs.append(example['property'].get(prop, 0))
110
+
111
+ for prop in ["SC", "SA"]:
112
+ outputs.append(example['property'].get(prop, float('nan')))
113
+
114
+ return outputs
115
+
116
+ def generate_and_visualize(instruction, property_type, HIV, BBBP, BACE, CO2, N2, O2, FFV, TC, SC, SA):
117
+ properties = {
118
+ "HIV": float('nan') if HIV == "NAN" else HIV,
119
+ "BBBP": float('nan') if BBBP == "NAN" else BBBP,
120
+ "BACE": float('nan') if BACE == "NAN" else BACE,
121
+ "CO2": float('nan') if CO2 == 0 else CO2,
122
+ "N2": float('nan') if N2 == 0 else N2,
123
+ "O2": float('nan') if O2 == 0 else O2,
124
+ "FFV": float('nan') if FFV == 0 else FFV,
125
+ "TC": float('nan') if TC == 0 else TC,
126
+ "SC": SC,
127
+ "SA": SA
128
+ }
129
+
130
+ # Filter out NaN values
131
+ properties = {k: v for k, v in properties.items() if not np.isnan(v)}
132
+
133
+ print('instruction', instruction)
134
+ print('properties', properties)
135
+ results = run_molqa(instruction, **properties)
136
+
137
+ llm_response = results.get('llm_response', 'No response generated')
138
+ llm_smiles = results.get('llm_smiles')
139
+ llm_reactions = results['llm_reactions']
140
+
141
+ molecule_img = visualize_molecule(llm_smiles) if llm_smiles else None
142
+
143
+ reaction_steps = []
144
+ reaction_imgs = []
145
+ if llm_reactions:
146
+ for i, reaction_dict in enumerate(llm_reactions):
147
+ reaction = reaction_dict.get('reaction')
148
+ if reaction:
149
+ reaction_steps.append(f"Step {i+1}: {reaction}")
150
+ reaction_imgs.append(visualize_reaction(reaction))
151
+
152
+ return (
153
+ llm_response,
154
+ llm_smiles if llm_smiles else "No SMILES generated",
155
+ molecule_img,
156
+ gr.JSON(value=reaction_steps, visible=bool(reaction_steps)),
157
+ gr.Gallery(value=reaction_imgs, visible=bool(reaction_imgs))
158
+ )
159
+
160
+ def run_molqa(instruction: str, **properties) -> dict:
161
+ # Filter out properties with NaN values
162
+ filtered_properties = {k: v for k, v in properties.items() if not np.isnan(v)}
163
+
164
+ input_data = {
165
+ "instruction": instruction,
166
+ "input": "",
167
+ "property": filtered_properties
168
+ }
169
+
170
+ dataloader, gen_kwargs = process_input(input_data, model, tokenizer, generating_args)
171
+ generated_results = generate(model, dataloader, gen_kwargs)
172
+
173
+ return generated_results
174
+
175
+ def visualize_molecule(smiles: str) -> np.ndarray:
176
+ mol = Chem.MolFromSmiles(smiles)
177
+ if mol is not None:
178
+ img = Draw.MolToImage(mol)
179
+ return np.array(img)
180
+ return np.zeros((300, 300, 3), dtype=np.uint8)
181
+
182
+ def visualize_reaction(reaction: str) -> np.ndarray:
183
+ rxn = AllChem.ReactionFromSmarts(reaction, useSmiles=True)
184
+ if rxn is not None:
185
+ img = Draw.ReactionToImage(rxn)
186
+ return np.array(img)
187
+ return np.zeros((300, 300, 3), dtype=np.uint8)
188
+
189
+ # Define property names and their full descriptions
190
+ property_names = {
191
+ "HIV": "HIV virus replication inhibition",
192
+ "BBBP": "Blood-brain barrier permeability",
193
+ "BACE": "Human β-secretase 1 inhibition",
194
+ "CO2": "CO2 Perm",
195
+ "N2": "N2 Perm",
196
+ "O2": "O2 Perm",
197
+ "FFV": "Fractional free volume",
198
+ "TC": "Thermal conductivity",
199
+ "SC": "Heuristic Synthetic Scores (SCScore)",
200
+ "SA": "Synthetic Synthetic Scores (SAScore)"
201
+ }
202
+
203
+ # Define outputs
204
+ outputs = [
205
+ gr.Textbox(label="Overall LLM Response"),
206
+ gr.Textbox(label="Generated SMILES"),
207
+ gr.Image(label="Generated Molecule"),
208
+ gr.JSON(label="Reaction Steps"),
209
+ gr.Gallery(label="Reaction Visualizations")
210
+ ]
211
+
212
+ with gr.Blocks() as iface:
213
+ gr.Markdown("# Llamole Demo Interface")
214
+ gr.Markdown("Enter an instruction and property values to generate a molecule design.")
215
+
216
+ interface, instruction, property_type, drug_properties, material_properties, synthetic_properties = create_input_components(default_values, property_names, property_ranges)
217
+
218
+ random_btn = gr.Button("Random Example")
219
+ generate_btn = gr.Button("Generate")
220
+
221
+ for output in outputs:
222
+ output.render()
223
+
224
+ # Update the inputs for the generate button
225
+ all_inputs = [instruction, property_type]
226
+ all_inputs.extend(drug_properties.values())
227
+ all_inputs.extend(material_properties.values())
228
+ all_inputs.extend(synthetic_properties.values())
229
+
230
+ generate_btn.click(generate_and_visualize, inputs=all_inputs, outputs=outputs)
231
+ random_btn.click(
232
+ random_example,
233
+ inputs=gr.State(all_examples),
234
+ outputs=all_inputs
235
+ )
236
+
237
+ if __name__ == "__main__":
238
+ iface.launch(share=True)
main.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Llamole Team
2
+ #
3
+ # This code is inspired by the HuggingFace's transformers library.
4
+ # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import sys
19
+ import os
20
+ from src.train.tuner import run_train, merge_adapter
21
+ from src.eval.workflow import run_eval
22
+ from tqdm import tqdm
23
+
24
+ from huggingface_hub import hf_hub_download
25
+
26
+ def download_data():
27
+ repo_id = "liuganghuggingface/Llamole-MolQA"
28
+ files_to_download = [
29
+ "molqa_drug.json",
30
+ "molqa_material.json",
31
+ "molqa_train.json"
32
+ ]
33
+ local_dir = "data"
34
+
35
+ # Create the data directory if it doesn't exist
36
+ os.makedirs(local_dir, exist_ok=True)
37
+
38
+ print(f"Downloading files from {repo_id} to {local_dir}/")
39
+ for file in tqdm(files_to_download, desc="Downloading files"):
40
+ try:
41
+ hf_hub_download(
42
+ repo_id=repo_id,
43
+ filename=file,
44
+ repo_type="dataset",
45
+ local_dir=local_dir,
46
+ local_dir_use_symlinks=False
47
+ )
48
+ print(f"Successfully downloaded: {file}")
49
+ except Exception as e:
50
+ print(f"Error downloading {file}: {e}")
51
+
52
+ print("Download complete!")
53
+
54
+ if __name__ == "__main__":
55
+ command = sys.argv.pop(1) if len(sys.argv) != 1 else 'train'
56
+ if command == 'train':
57
+ run_train()
58
+ elif command == 'export':
59
+ merge_adapter()
60
+ elif command == 'eval':
61
+ run_eval()
62
+ elif command == 'download_data':
63
+ download_data()
64
+ else:
65
+ print(f"Invalid command: {command}. Please use 'train', 'export', 'eval', or 'download_data'.")
66
+ sys.exit(1)
requirements.txt ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ accelerate==0.33.0
2
+ badam==1.2.2
3
+ bitsandbytes==0.44.1
4
+ datasets==2.21.0
5
+ deepspeed==0.15.2
6
+ galore_torch==1.0
7
+ gradio==5.0.1
8
+ huggingface_hub==0.24.5
9
+ jieba==0.42.1
10
+ matplotlib==3.9.2
11
+ MoD==0.3.0
12
+ modelscope==1.18.1
13
+ nltk==3.9.1
14
+ numpy==2.1.2
15
+ optuna==3.6.1
16
+ packaging==24.1
17
+ pandas==1.5.3
18
+ peft==0.12.0
19
+ PyYAML==6.0.2
20
+ PyYAML==6.0.2
21
+ rdchiral==1.1.0
22
+ rdkit==2023.9.6
23
+ rouge_chinese==1.0.3
24
+ safetensors==0.4.5
25
+ torch==2.4.0
26
+ torch_geometric==2.6.1
27
+ transformers==4.44.0
28
+ trl==0.9.6
29
+ typing_extensions==4.12.2
30
+ vllm==0.6.2
src/__init__.py ADDED
File without changes
src/data/__init__.py ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ # from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
16
+ from .data_utils import Role, split_dataset
17
+ from .loader import get_dataset
18
+ from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
19
+ from .collator import DataCollatorForSeqGraph
20
+
21
+ __all__ = [
22
+ "Role",
23
+ "split_dataset",
24
+ "get_dataset",
25
+ "TEMPLATES",
26
+ "Template",
27
+ "get_template_and_fix_tokenizer",
28
+ 'DataCollatorForSeqGraph'
29
+ ]
src/data/aligner.py ADDED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ import re
17
+ from functools import partial
18
+ from typing import TYPE_CHECKING, Any, Dict, List, Union, Tuple
19
+
20
+ from datasets import Features
21
+
22
+ from ..extras.logging import get_logger
23
+ from .data_utils import Role
24
+
25
+ if TYPE_CHECKING:
26
+ from datasets import Dataset, IterableDataset
27
+ from transformers import Seq2SeqTrainingArguments
28
+
29
+ from ..hparams import DataArguments
30
+ from .parser import DatasetAttr
31
+
32
+
33
+ logger = get_logger(__name__)
34
+
35
+ def extract_all_smiles(text):
36
+ pattern = r'<mol_start>(.*?)<mol_end>'
37
+ return re.findall(pattern, text)
38
+
39
+ def replace_all_smiles(text):
40
+ pattern = r'<mol_start>.*?<mol_end>'
41
+ return re.sub(pattern, '<molecule>', text)
42
+
43
+ def replace_smiles_with_callback(text):
44
+ def replace_mol(match):
45
+ design_end = match.group(1)
46
+ smiles = match.group(2)
47
+ # return f'{design_end}<molecule><callback_start>{smiles}<callback_end>'
48
+ return f'{design_end}<molecule><rollback_start>{smiles}<rollback_end>'
49
+
50
+ pattern = r'(<design_start><design_end>)<mol_start>(.*?)<mol_end>'
51
+ text = re.sub(pattern, replace_mol, text)
52
+
53
+ # Replace remaining molecules that are not immediately after <design_start><design_end>
54
+ remaining_pattern = r'<mol_start>.*?<mol_end>'
55
+ text = re.sub(remaining_pattern, '<molecule>', text)
56
+
57
+ return text
58
+
59
+ def dict_to_list(data_dict, mol_properties):
60
+ return [data_dict.get(prop, None) for prop in mol_properties]
61
+
62
+ def insert_bodies(text, num_insertions, retro_labels):
63
+ design_pattern = r'<design_start>(.*?)<design_end>'
64
+ retro_pattern = r'(This is step \d+ in the retrosynthesis process\..*?<retro_start>.*?<retro_end>)(.*?)(?=This is step \d+|$)'
65
+
66
+ def replace_design(match):
67
+ return f'<design_start>' + ''.join(['<design_body>'] * num_insertions) + f'<design_end>'
68
+
69
+ def replace_retro(match, label):
70
+ step_content = match.group(1)
71
+ remaining_text = match.group(2)
72
+ retro_match = re.search(r'<retro_start>(.*?)<retro_end>', step_content)
73
+ if retro_match and label is not None:
74
+ modified_content = f'<retro_start>' + ''.join(['<retro_body>'] * num_insertions) + f'<retro_end>'
75
+ return re.sub(r'<retro_start>.*?<retro_end>', modified_content, step_content)
76
+ return step_content + remaining_text
77
+
78
+ text = re.sub(design_pattern, replace_design, text)
79
+
80
+ steps = re.finditer(retro_pattern, text)
81
+ modified_text = ""
82
+ last_end = 0
83
+
84
+ for i, step in enumerate(steps):
85
+ label = retro_labels[i] if i < len(retro_labels) else None
86
+ modified_text += text[last_end:step.start()] + replace_retro(step, label)
87
+ last_end = step.end()
88
+
89
+ modified_text += text[last_end:]
90
+ return modified_text
91
+
92
+ def extract_retro_products(text):
93
+ pattern = r'<retro_end>(.*?)>>'
94
+ matches = re.findall(pattern, text)
95
+ return [match.strip() for match in matches]
96
+
97
+ def convert_molqa(
98
+ examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
99
+ ) -> Dict[str, List[Any]]:
100
+ r"""
101
+ Converts alpaca format dataset to the standard format.
102
+ """
103
+ outputs = {"prompt": [], "response": [], "system": [], "molecules": [], "property": [], "retro_labels": [], "retro_products": []}
104
+
105
+ mol_properties = ['BBBP', 'HIV', 'BACE', 'CO2', 'N2', 'O2', 'FFV', 'TC', 'SC', 'SA']
106
+ for i in range(len(examples[dataset_attr.prompt])):
107
+ prompt = []
108
+ if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
109
+ for old_prompt, old_response in examples[dataset_attr.history][i]:
110
+ prompt.append({"role": Role.USER.value, "content": old_prompt})
111
+ prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
112
+
113
+ content = []
114
+ if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
115
+ content.append(examples[dataset_attr.prompt][i])
116
+
117
+ if dataset_attr.query and examples[dataset_attr.query][i]:
118
+ content.append(examples[dataset_attr.query][i])
119
+
120
+ prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
121
+
122
+ if dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
123
+ current_response = examples[dataset_attr.response][i]
124
+ smiles_list = extract_all_smiles(current_response)
125
+ modified_response = replace_smiles_with_callback(current_response)
126
+ retro_labels = examples[dataset_attr.retro][i] if dataset_attr.retro else []
127
+ retro_products = extract_retro_products(current_response)
128
+ modified_response = insert_bodies(modified_response, data_args.learned_query_size, retro_labels)
129
+ # modified_response = insert_bodies(modified_response, dataset_attr.learned_query_size, retro_labels)
130
+ response = [{"role": Role.ASSISTANT.value, "content": modified_response}]
131
+ else: # unsupervised
132
+ response = []
133
+
134
+ outputs["prompt"].append(prompt)
135
+ outputs["response"].append(response)
136
+ outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
137
+ outputs["molecules"].append(smiles_list)
138
+ outputs["property"].append(dict_to_list(examples[dataset_attr.property][i], mol_properties))
139
+ outputs["retro_labels"].append(retro_labels)
140
+ outputs["retro_products"].append(retro_products)
141
+
142
+ return outputs
143
+
144
+ def map_smiles_to_id(example, smiles_to_id):
145
+ example['molecules'] = [smiles_to_id[smiles] for smiles in example['molecules']]
146
+ return example
147
+
148
+ def align_dataset(
149
+ dataset: Union["Dataset", "IterableDataset"],
150
+ dataset_attr: "DatasetAttr",
151
+ data_args: "DataArguments",
152
+ training_args: "Seq2SeqTrainingArguments",
153
+ ) -> Tuple[Union["Dataset", "IterableDataset"], Dict[int, str]]:
154
+ r"""
155
+ Aligns the dataset and maps unique SMILES strings to molecule IDs.
156
+
157
+ This function performs the following operations:
158
+ 1. Converts the dataset to the required format (molqa).
159
+ 2. Extracts all unique SMILES strings from the dataset.
160
+ 3. Maps each unique SMILES string to a unique integer ID (0, 1, 2, ...).
161
+ 4. Update 'molecules' field to each example, containing the mapped IDs.
162
+
163
+ The aligned dataset contains the following fields:
164
+ prompt: [{"role": "user", "content": "..."}] * (2T - 1)
165
+ response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
166
+ system: "..."
167
+ molecules: [List of SMILES string]
168
+ property: [List of float values]
169
+ retro_labels: [List of int values]
170
+ retro_products: [List of SMILES string]
171
+
172
+ Args:
173
+ dataset (Union["Dataset", "IterableDataset"]): The input dataset.
174
+ dataset_attr (DatasetAttr): Attributes of the dataset.
175
+ data_args (DataArguments): Arguments for data processing.
176
+ training_args (Seq2SeqTrainingArguments): Arguments for training.
177
+
178
+ Returns:
179
+ Tuple[Union["Dataset", "IterableDataset"], Dict[int, str]]:
180
+ - The aligned and converted dataset with molecule IDs.
181
+ - A dictionary mapping molecule IDs to their SMILES strings.
182
+ """
183
+ assert dataset_attr.formatting == "molqa"
184
+
185
+ features = Features.from_dict(
186
+ {
187
+ "prompt": [
188
+ {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
189
+ ],
190
+ "response": [
191
+ {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
192
+ ],
193
+ "system": {"dtype": "string", "_type": "Value"},
194
+ "molecules": [{'dtype': "string", "_type": "Value"}],
195
+ "property": [{"dtype": "float", "_type": "Value"}],
196
+ "retro_labels": [{"dtype": "int32", "_type": "Value"}],
197
+ "retro_products": [{'dtype': "string", "_type": "Value"}],
198
+ }
199
+ )
200
+
201
+ convert_func = partial(convert_molqa, dataset_attr=dataset_attr, data_args=data_args)
202
+ aligned = dataset.map(
203
+ convert_func,
204
+ batched=True,
205
+ remove_columns=['instruction', 'input', 'output', 'property', 'retro'],
206
+ features=features,
207
+ num_proc=data_args.preprocessing_num_workers,
208
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
209
+ desc="Converting molqa format of dataset"
210
+ )
211
+
212
+ # Extract all unique SMILES strings and map them to molecule IDs
213
+ all_smiles = set()
214
+ for item in aligned:
215
+ all_smiles.update(item['molecules'])
216
+ all_smiles.update(item['retro_products'])
217
+
218
+ smiles_to_id = {smiles: idx for idx, smiles in enumerate(sorted(all_smiles))}
219
+ id_to_smiles = {idx: smiles for smiles, idx in smiles_to_id.items()}
220
+
221
+ def map_smiles_to_id(example, smiles_to_id):
222
+ example['molecules'] = [smiles_to_id[smiles] for smiles in example['molecules']]
223
+ example['retro_products'] = [smiles_to_id[smiles] for smiles in example['retro_products']]
224
+ return example
225
+
226
+ smiles_convert_func = partial(map_smiles_to_id, smiles_to_id=smiles_to_id)
227
+
228
+ aligned = aligned.map(
229
+ smiles_convert_func,
230
+ desc="Mapping SMILES to molecule IDs",
231
+ )
232
+
233
+ return aligned, id_to_smiles
src/data/collator.py ADDED
@@ -0,0 +1,165 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+ from dataclasses import dataclass
4
+ from typing import Any, Dict, List, Optional, Union
5
+
6
+ from torch_geometric.data import Batch as PyGBatch
7
+ from transformers.tokenization_utils_base import PreTrainedTokenizerBase
8
+ from transformers.utils import PaddingStrategy
9
+
10
+ def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
11
+ """
12
+ Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
13
+ """
14
+ # To avoid errors when using Feature extractors
15
+ if not hasattr(tokenizer, "deprecation_warnings"):
16
+ return tokenizer.pad(*pad_args, **pad_kwargs)
17
+
18
+ # Save the state of the warning, then disable it
19
+ warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
20
+ tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
21
+
22
+ try:
23
+ padded = tokenizer.pad(*pad_args, **pad_kwargs)
24
+ finally:
25
+ # Restore the state of the warning.
26
+ tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
27
+
28
+ return padded
29
+
30
+ @dataclass
31
+ class DataCollatorForSeqGraph:
32
+ """
33
+ Data collator that will dynamically pad the inputs received, as well as the labels.
34
+ """
35
+ tokenizer: PreTrainedTokenizerBase
36
+ mol_id_to_pyg: Dict[str, Any]
37
+ model: Optional[Any] = None
38
+ padding: Union[bool, str, PaddingStrategy] = True
39
+ max_length: Optional[int] = None
40
+ pad_to_multiple_of: Optional[int] = None
41
+ label_pad_token_id: int = -100
42
+ return_tensors: str = "pt"
43
+
44
+ def __call__(self, features, return_tensors=None):
45
+ if return_tensors is None:
46
+ return_tensors = self.return_tensors
47
+
48
+ label_name = "label" if "label" in features[0].keys() else "labels"
49
+ labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
50
+ if labels is not None and all(label is None for label in labels):
51
+ labels = None
52
+
53
+ # Store molecule_ids, retro_labels, and retro_product_ids separately and remove from non_labels_features
54
+ molecule_ids_list = []
55
+ retro_labels_list = []
56
+ retro_products_list = []
57
+ non_labels_features = []
58
+ for feature in features:
59
+ new_feature = {k: v for k, v in feature.items() if k != label_name}
60
+ if 'molecule_ids' in new_feature:
61
+ molecule_ids_list.append(new_feature['molecule_ids'])
62
+ del new_feature['molecule_ids']
63
+ else:
64
+ molecule_ids_list.append(None)
65
+ if 'retro_labels' in new_feature:
66
+ retro_labels_list.append(new_feature['retro_labels'])
67
+ del new_feature['retro_labels']
68
+ else:
69
+ retro_labels_list.append(None)
70
+ if 'retro_product_ids' in new_feature:
71
+ retro_products_list.append(new_feature['retro_product_ids'])
72
+ del new_feature['retro_product_ids']
73
+ else:
74
+ retro_products_list.append(None)
75
+ non_labels_features.append(new_feature)
76
+
77
+ # Convert molecule IDs to PyG Data objects
78
+ molecule_graphs_list = []
79
+ design_graphs_list = []
80
+ for seq_idx, molecule_ids in enumerate(molecule_ids_list):
81
+ if molecule_ids is not None and len(molecule_ids) > 0:
82
+ for pos, mol_id in enumerate(molecule_ids):
83
+ if pos == 0:
84
+ design_graphs_list.append(self.mol_id_to_pyg[mol_id])
85
+ if mol_id != self.label_pad_token_id and mol_id in self.mol_id_to_pyg:
86
+ molecule_graphs_list.append(self.mol_id_to_pyg[mol_id])
87
+
88
+ # Convert retro_product_ids to PyG Data objects
89
+ retro_product_graphs_list = []
90
+ for seq_idx, retro_product_ids in enumerate(retro_products_list):
91
+ if retro_product_ids is not None and len(retro_product_ids) > 0:
92
+ for pos, mol_id in enumerate(retro_product_ids):
93
+ if mol_id != self.label_pad_token_id and mol_id in self.mol_id_to_pyg:
94
+ retro_product_graphs_list.append(self.mol_id_to_pyg[mol_id])
95
+
96
+ # Batch the PyG Data objects
97
+ if molecule_graphs_list:
98
+ batched_graphs = PyGBatch.from_data_list(molecule_graphs_list)
99
+ else:
100
+ batched_graphs = None
101
+
102
+ if design_graphs_list:
103
+ batched_design_graphs = PyGBatch.from_data_list(design_graphs_list)
104
+ else:
105
+ batched_design_graphs = None
106
+
107
+ if retro_product_graphs_list:
108
+ batched_retro_products = PyGBatch.from_data_list(retro_product_graphs_list)
109
+ else:
110
+ batched_retro_products = None
111
+
112
+ # Pad retro_labels
113
+ if retro_labels_list and any(retro_labels is not None for retro_labels in retro_labels_list):
114
+ max_retro_length = max(len(retro_labels) for retro_labels in retro_labels_list if retro_labels is not None)
115
+ padded_retro_labels = [
116
+ retro_labels + [self.label_pad_token_id] * (max_retro_length - len(retro_labels)) if retro_labels is not None else [self.label_pad_token_id] * max_retro_length
117
+ for retro_labels in retro_labels_list
118
+ ]
119
+ else:
120
+ padded_retro_labels = None
121
+
122
+ # Pad other features
123
+ batch = pad_without_fast_tokenizer_warning(
124
+ self.tokenizer,
125
+ non_labels_features,
126
+ padding=self.padding,
127
+ max_length=self.max_length,
128
+ pad_to_multiple_of=self.pad_to_multiple_of,
129
+ return_tensors=return_tensors,
130
+ )
131
+
132
+ batch["molecule_graphs"] = batched_graphs
133
+ batch["design_graphs"] = batched_design_graphs
134
+ batch["retro_product_graphs"] = batched_retro_products
135
+ batch["retro_labels"] = torch.tensor(padded_retro_labels, dtype=torch.int64)
136
+
137
+ # Pad labels
138
+ if labels is not None:
139
+ max_label_length = max(len(l) for l in labels)
140
+ if self.pad_to_multiple_of is not None:
141
+ max_label_length = (
142
+ (max_label_length + self.pad_to_multiple_of - 1)
143
+ // self.pad_to_multiple_of
144
+ * self.pad_to_multiple_of
145
+ )
146
+
147
+ padding_side = self.tokenizer.padding_side
148
+ padded_labels = [
149
+ label + [self.label_pad_token_id] * (max_label_length - len(label))
150
+ if padding_side == "right"
151
+ else [self.label_pad_token_id] * (max_label_length - len(label)) + label
152
+ for label in labels
153
+ ]
154
+ batch["labels"] = torch.tensor(padded_labels, dtype=torch.int64)
155
+
156
+ # Prepare decoder_input_ids
157
+ if (
158
+ labels is not None
159
+ and self.model is not None
160
+ and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
161
+ ):
162
+ decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
163
+ batch["decoder_input_ids"] = decoder_input_ids
164
+
165
+ return batch
src/data/data_utils.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from enum import Enum, unique
16
+ from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
17
+
18
+ from datasets import concatenate_datasets, interleave_datasets
19
+
20
+ from ..extras.logging import get_logger
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from datasets import Dataset, IterableDataset
25
+ from transformers import Seq2SeqTrainingArguments
26
+
27
+ from ..hparams import DataArguments
28
+
29
+
30
+ logger = get_logger(__name__)
31
+
32
+
33
+ SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
34
+
35
+
36
+ @unique
37
+ class Role(str, Enum):
38
+ USER = "user"
39
+ ASSISTANT = "assistant"
40
+ SYSTEM = "system"
41
+ FUNCTION = "function"
42
+ OBSERVATION = "observation"
43
+
44
+
45
+ def merge_dataset(
46
+ all_datasets: List[Union["Dataset", "IterableDataset"]],
47
+ data_args: "DataArguments",
48
+ training_args: "Seq2SeqTrainingArguments",
49
+ ) -> Union["Dataset", "IterableDataset"]:
50
+ if len(all_datasets) == 1:
51
+ return all_datasets[0]
52
+ elif data_args.mix_strategy == "concat":
53
+ if data_args.streaming:
54
+ logger.warning("The samples between different datasets will not be mixed in streaming mode.")
55
+ return concatenate_datasets(all_datasets)
56
+ elif data_args.mix_strategy.startswith("interleave"):
57
+ if not data_args.streaming:
58
+ logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
59
+ return interleave_datasets(
60
+ datasets=all_datasets,
61
+ probabilities=data_args.interleave_probs,
62
+ seed=training_args.seed,
63
+ stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
64
+ )
65
+ else:
66
+ raise ValueError("Unknown mixing strategy.")
67
+
68
+
69
+ def split_dataset(
70
+ dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
71
+ ) -> Dict[str, "Dataset"]:
72
+ if training_args.do_train:
73
+ if data_args.val_size > 1e-6: # Split the dataset
74
+ val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
75
+ dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
76
+ return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
77
+ else:
78
+ if data_args.streaming:
79
+ dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
80
+ return {"train_dataset": dataset}
81
+ else: # do_eval or do_predict
82
+ return {"eval_dataset": dataset}
src/data/formatter.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import re
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass, field
19
+ from typing import List, Literal, Optional, Tuple, Union
20
+
21
+ from .data_utils import SLOTS
22
+ from .tool_utils import DefaultToolUtils, GLM4ToolUtils
23
+
24
+
25
+ @dataclass
26
+ class Formatter(ABC):
27
+ slots: SLOTS = field(default_factory=list)
28
+ tool_format: Optional[Literal["default", "glm4"]] = None
29
+
30
+ @abstractmethod
31
+ def apply(self, **kwargs) -> SLOTS: ...
32
+
33
+ def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
34
+ raise NotImplementedError
35
+
36
+
37
+ @dataclass
38
+ class EmptyFormatter(Formatter):
39
+ def __post_init__(self):
40
+ has_placeholder = False
41
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
42
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
43
+ has_placeholder = True
44
+
45
+ if has_placeholder:
46
+ raise ValueError("Empty formatter should not contain any placeholder.")
47
+
48
+ def apply(self, **kwargs) -> SLOTS:
49
+ return self.slots
50
+
51
+
52
+ @dataclass
53
+ class StringFormatter(Formatter):
54
+ def __post_init__(self):
55
+ has_placeholder = False
56
+ for slot in filter(lambda s: isinstance(s, str), self.slots):
57
+ if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
58
+ has_placeholder = True
59
+
60
+ if not has_placeholder:
61
+ raise ValueError("A placeholder is required in the string formatter.")
62
+
63
+ def apply(self, **kwargs) -> SLOTS:
64
+ elements = []
65
+ for slot in self.slots:
66
+ if isinstance(slot, str):
67
+ for name, value in kwargs.items():
68
+ if not isinstance(value, str):
69
+ raise RuntimeError("Expected a string, got {}".format(value))
70
+
71
+ slot = slot.replace("{{" + name + "}}", value, 1)
72
+ elements.append(slot)
73
+ elif isinstance(slot, (dict, set)):
74
+ elements.append(slot)
75
+ else:
76
+ raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
77
+
78
+ return elements
79
+
80
+ @dataclass
81
+ class FunctionFormatter(Formatter):
82
+ def __post_init__(self):
83
+ if self.tool_format == "default":
84
+ self.slots = DefaultToolUtils.get_function_slots() + self.slots
85
+ elif self.tool_format == "glm4":
86
+ self.slots = GLM4ToolUtils.get_function_slots() + self.slots
87
+ else:
88
+ raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
89
+
90
+ def apply(self, **kwargs) -> SLOTS:
91
+ content = kwargs.pop("content")
92
+ functions: List[Tuple[str, str]] = []
93
+ try:
94
+ tool_calls = json.loads(content)
95
+ if not isinstance(tool_calls, list): # parallel function call
96
+ tool_calls = [tool_calls]
97
+
98
+ for tool_call in tool_calls:
99
+ functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
100
+
101
+ except json.JSONDecodeError:
102
+ functions = []
103
+
104
+ elements = []
105
+ for name, arguments in functions:
106
+ for slot in self.slots:
107
+ if isinstance(slot, str):
108
+ slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
109
+ elements.append(slot)
110
+ elif isinstance(slot, (dict, set)):
111
+ elements.append(slot)
112
+ else:
113
+ raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
114
+
115
+ return elements
116
+
117
+
118
+ @dataclass
119
+ class ToolFormatter(Formatter):
120
+ def __post_init__(self):
121
+ if self.tool_format == "default":
122
+ self._tool_formatter = DefaultToolUtils.tool_formatter
123
+ self._tool_extractor = DefaultToolUtils.tool_extractor
124
+ elif self.tool_format == "glm4":
125
+ self._tool_formatter = GLM4ToolUtils.tool_formatter
126
+ self._tool_extractor = GLM4ToolUtils.tool_extractor
127
+ else:
128
+ raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
129
+
130
+ def apply(self, **kwargs) -> SLOTS:
131
+ content = kwargs.pop("content")
132
+ try:
133
+ tools = json.loads(content)
134
+ return [self._tool_formatter(tools) if len(tools) != 0 else ""]
135
+ except json.JSONDecodeError:
136
+ return [""]
137
+
138
+ def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
139
+ return self._tool_extractor(content)
src/data/loader.py ADDED
@@ -0,0 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ import os
17
+ import sys
18
+ from typing import TYPE_CHECKING, Literal, Optional, Union
19
+ from functools import partial
20
+
21
+ import numpy as np
22
+ from datasets import load_dataset, load_from_disk
23
+
24
+ # from ..extras.constants import FILEEXT2TYPE
25
+ from ..extras.logging import get_logger
26
+ from ..extras.misc import has_tokenized_data
27
+ from .aligner import align_dataset
28
+ from .data_utils import merge_dataset
29
+ from .parser import get_dataset_attr
30
+ # from .preprocess import get_preprocess_and_print_func
31
+ from .template import get_template_and_fix_tokenizer
32
+
33
+ from .processors.mmsupervised import (
34
+ preprocess_mmsupervised_dataset,
35
+ print_supervised_dataset_example,
36
+ encode_graph_pyg
37
+ )
38
+
39
+ if TYPE_CHECKING:
40
+ from datasets import Dataset, IterableDataset
41
+ from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
42
+
43
+ from ..hparams import DataArguments, ModelArguments
44
+ from .parser import DatasetAttr
45
+
46
+
47
+ logger = get_logger(__name__)
48
+
49
+
50
+ def load_single_dataset(
51
+ dataset_attr: "DatasetAttr",
52
+ model_args: "ModelArguments",
53
+ data_args: "DataArguments",
54
+ training_args: "Seq2SeqTrainingArguments",
55
+ ) -> Union["Dataset", "IterableDataset"]:
56
+ logger.info("Loading dataset {}...".format(dataset_attr))
57
+
58
+ data_files = []
59
+ assert dataset_attr.load_from == "file"
60
+
61
+ data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
62
+ data_files.append(data_path)
63
+ data_path = data_path.split(".")[-1]
64
+
65
+ if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
66
+ kwargs = {"trust_remote_code": True}
67
+ else:
68
+ kwargs = {}
69
+
70
+ dataset = load_dataset(
71
+ path=data_path,
72
+ name=None,
73
+ data_dir=None,
74
+ data_files=data_files,
75
+ split=data_args.split,
76
+ cache_dir=model_args.cache_dir,
77
+ token=model_args.hf_hub_token,
78
+ streaming=False,
79
+ **kwargs,
80
+ )
81
+
82
+ converted_dataset, mol_id_to_smiles = align_dataset(dataset, dataset_attr, data_args, training_args)
83
+ return converted_dataset, mol_id_to_smiles
84
+
85
+ def get_dataset(
86
+ model_args: "ModelArguments",
87
+ data_args: "DataArguments",
88
+ training_args: "Seq2SeqTrainingArguments",
89
+ tokenizer: "PreTrainedTokenizer",
90
+ ) -> Union["Dataset", "IterableDataset"]:
91
+
92
+ template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
93
+ if data_args.train_on_prompt and template.efficient_eos:
94
+ raise ValueError("Current template does not support `train_on_prompt`.")
95
+ print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
96
+
97
+ # Load tokenized dataset
98
+ if data_args.tokenized_path is not None:
99
+ if has_tokenized_data(data_args.tokenized_path):
100
+ mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path)
101
+ logger.warning("Loading dataset from disk will ignore other data arguments.")
102
+ dataset = load_from_disk(data_args.tokenized_path)
103
+ logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
104
+ # print_function(next(iter(dataset)))
105
+ data_iter = iter(dataset)
106
+ print_function(next(data_iter))
107
+ return mol_id_to_pyg, dataset
108
+
109
+ # Load tokenized dataset
110
+ with training_args.main_process_first(desc="load dataset"):
111
+ # current only support one dataset
112
+ dataset_attr = get_dataset_attr(data_args)
113
+ dataset, mol_id_to_smiles = load_single_dataset(dataset_attr, model_args, data_args, training_args)
114
+
115
+ with training_args.main_process_first(desc="pre-process dataset"):
116
+ preprocess_func = partial(
117
+ preprocess_mmsupervised_dataset,
118
+ template=template,
119
+ tokenizer=tokenizer,
120
+ data_args=data_args,
121
+ )
122
+
123
+ column_names = list(next(iter(dataset)).keys())
124
+ kwargs = {}
125
+ kwargs = dict(
126
+ num_proc=data_args.preprocessing_num_workers,
127
+ load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
128
+ desc="Running tokenizer on dataset",
129
+ )
130
+
131
+ dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
132
+
133
+ if data_args.tokenized_path is not None:
134
+ if training_args.should_save:
135
+ dataset.save_to_disk(data_args.tokenized_path)
136
+ mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path, mol_id_to_smiles=mol_id_to_smiles)
137
+ logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
138
+ logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
139
+ sys.exit(0)
140
+ else:
141
+ mol_id_to_pyg = encode_graph_pyg(mol_id_to_smiles=mol_id_to_smiles)
142
+
143
+ if training_args.should_log:
144
+ try:
145
+ print_function(next(iter(dataset)))
146
+ except StopIteration:
147
+ raise RuntimeError("Cannot find valid samples.")
148
+
149
+ return mol_id_to_pyg, dataset
src/data/parser.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import os
17
+ from dataclasses import dataclass
18
+ from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
19
+
20
+ from ..extras.constants import DATA_CONFIG
21
+ from ..extras.misc import use_modelscope
22
+
23
+
24
+ if TYPE_CHECKING:
25
+ from ..hparams import DataArguments
26
+
27
+
28
+ @dataclass
29
+ class DatasetAttr:
30
+ r"""
31
+ Dataset attributes.
32
+ """
33
+
34
+ # basic configs
35
+ load_from: Literal["hf_hub", "ms_hub", "script", "file"]
36
+ dataset_name: str
37
+ formatting: Literal["alpaca", "sharegpt", "molqa"] = "molqa"
38
+ ranking: bool = False
39
+ # extra configs
40
+ subset: Optional[str] = None
41
+ folder: Optional[str] = None
42
+ num_samples: Optional[int] = None
43
+ # common columns
44
+ system: Optional[str] = None
45
+ tools: Optional[str] = None
46
+ images: Optional[str] = None
47
+ # rlhf columns
48
+ chosen: Optional[str] = None
49
+ rejected: Optional[str] = None
50
+ kto_tag: Optional[str] = None
51
+ # alpaca columns
52
+ prompt: Optional[str] = "instruction"
53
+ query: Optional[str] = "input"
54
+ response: Optional[str] = "output"
55
+ history: Optional[str] = None
56
+ # sharegpt columns
57
+ messages: Optional[str] = "conversations"
58
+ # sharegpt tags
59
+ role_tag: Optional[str] = "from"
60
+ content_tag: Optional[str] = "value"
61
+ user_tag: Optional[str] = "human"
62
+ assistant_tag: Optional[str] = "gpt"
63
+ observation_tag: Optional[str] = "observation"
64
+ function_tag: Optional[str] = "function_call"
65
+ system_tag: Optional[str] = "system"
66
+ # molqa columns
67
+ property: Optional[str] = 'property'
68
+ retro: Optional[str] = 'retro'
69
+ # learned_query_size: Optional[int] = None
70
+
71
+ def __repr__(self) -> str:
72
+ return self.dataset_name
73
+
74
+ def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
75
+ setattr(self, key, obj.get(key, default))
76
+
77
+ def get_dataset_attr(data_args: "DataArguments") -> List["DatasetAttr"]:
78
+ if data_args.dataset is not None:
79
+ dataset_name = data_args.dataset.strip()
80
+ else:
81
+ raise ValueError("Please specify the dataset name.")
82
+
83
+ try:
84
+ with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
85
+ dataset_info = json.load(f)
86
+ except Exception as err:
87
+ raise ValueError(
88
+ "Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
89
+ )
90
+ dataset_info = None
91
+
92
+ if dataset_name not in dataset_info:
93
+ raise ValueError("Undefined dataset {} in {}.".format(dataset_name, DATA_CONFIG))
94
+
95
+ dataset_attr = DatasetAttr("file", dataset_name=dataset_info[dataset_name]["file_name"])
96
+
97
+ print('dataset_info', dataset_info)
98
+
99
+ dataset_attr.set_attr("formatting", dataset_info[dataset_name], default="molqa")
100
+ dataset_attr.set_attr("ranking", dataset_info[dataset_name], default=False)
101
+ dataset_attr.set_attr("subset", dataset_info[dataset_name])
102
+ dataset_attr.set_attr("folder", dataset_info[dataset_name])
103
+ dataset_attr.set_attr("num_samples", dataset_info[dataset_name])
104
+
105
+ if "columns" in dataset_info[dataset_name]:
106
+ column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
107
+ assert dataset_attr.formatting == "molqa"
108
+ column_names.extend(["prompt", "query", "response", "history", "property", "retro"])
109
+
110
+ for column_name in column_names:
111
+ dataset_attr.set_attr(column_name, dataset_info[dataset_name]["columns"])
112
+
113
+ return dataset_attr
src/data/processors/__init__.py ADDED
File without changes
src/data/processors/mmsupervised.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Llamole Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import defaultdict
16
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
17
+
18
+ from ...extras.constants import IGNORE_INDEX, BOND_INDEX, NO_LABEL_INDEX
19
+ from ...extras.logging import get_logger
20
+
21
+ if TYPE_CHECKING:
22
+ from transformers import PreTrainedTokenizer, ProcessorMixin
23
+
24
+ from ...hparams import DataArguments
25
+ from ..template import Template
26
+
27
+ import os
28
+ from rdkit import Chem
29
+ import torch
30
+ from torch_geometric.data import Data, Batch
31
+ import pickle
32
+
33
+ logger = get_logger(__name__)
34
+
35
+ import os
36
+ import torch
37
+ from typing import Dict
38
+ from torch_geometric.data import Data
39
+ from rdkit import Chem
40
+ import pickle
41
+
42
+
43
+ def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
44
+ if target_len * 2 < cutoff_len: # truncate source
45
+ max_target_len = cutoff_len
46
+ elif source_len * 2 < cutoff_len: # truncate target
47
+ max_target_len = cutoff_len - source_len
48
+ else: # truncate both
49
+ max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
50
+
51
+ new_target_len = min(max_target_len, target_len)
52
+ new_source_len = max(cutoff_len - new_target_len, 0)
53
+ return new_source_len, new_target_len
54
+
55
+ def encode_graph_pyg(
56
+ data_path: Optional[str] = None, mol_id_to_smiles: Optional[Dict[str, str]] = None
57
+ ) -> Dict[str, Data]:
58
+ """
59
+ Converts molecule data to a dictionary of PyTorch Geometric Data objects, with caching functionality.
60
+ Uses a sparse representation for efficiency.
61
+
62
+ Args:
63
+ data_path (Optional[str]): Path to the Hugging Face dataset folder.
64
+ mol_id_to_smiles (Optional[Dict[str, str]]): Dictionary where keys are molecule IDs
65
+ and values are SMILES strings.
66
+
67
+ Returns:
68
+ Dict[str, Data]: Dictionary where keys are molecule IDs and values are
69
+ PyTorch Geometric Data objects.
70
+
71
+ Raises:
72
+ ValueError: If both data_path and mol_id_to_smiles are None, or if data_path is provided but loading fails.
73
+ """
74
+ print(f"Current execution directory: {os.getcwd()}")
75
+
76
+ if data_path is None and mol_id_to_smiles is None:
77
+ raise ValueError("Either data_path or mol_id_to_smiles must be provided.")
78
+
79
+ if data_path is not None:
80
+ cache_file = os.path.join(data_path, "pyg_molecule.pickle")
81
+
82
+ # Try to load cached data
83
+ if os.path.exists(cache_file):
84
+ try:
85
+ with open(cache_file, "rb") as f:
86
+ return pickle.load(f)
87
+ except Exception as e:
88
+ print(f"Failed to load cached data: {e}")
89
+
90
+ mol_id_to_pyg = {}
91
+
92
+ for mol_id, smiles in mol_id_to_smiles.items():
93
+ mol = Chem.MolFromSmiles(smiles)
94
+ if mol is None:
95
+ raise ValueError(f"Invalid SMILES string for molecule {mol_id}: {smiles}")
96
+
97
+ type_idx = []
98
+ heavy_atom_indices = []
99
+ for atom in mol.GetAtoms():
100
+ if atom.GetAtomicNum() != 1: # Exclude hydrogen atoms
101
+ type_idx.append(
102
+ 119 - 2 if atom.GetSymbol() == "*" else atom.GetAtomicNum() - 2
103
+ )
104
+ heavy_atom_indices.append(atom.GetIdx())
105
+
106
+ x = torch.LongTensor(type_idx)
107
+
108
+ edge_index = []
109
+ edge_attr = []
110
+ for bond in mol.GetBonds():
111
+ start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
112
+ if start in heavy_atom_indices and end in heavy_atom_indices:
113
+ start_new, end_new = heavy_atom_indices.index(
114
+ start
115
+ ), heavy_atom_indices.index(end)
116
+ edge_index.extend([[start_new, end_new], [end_new, start_new]])
117
+ bond_type = BOND_INDEX[bond.GetBondType()]
118
+ edge_attr.extend([bond_type, bond_type])
119
+
120
+ edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
121
+ edge_attr = torch.tensor(edge_attr, dtype=torch.long)
122
+
123
+ # Create PyG Data object
124
+ data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
125
+
126
+ mol_id_to_pyg[mol_id] = data
127
+
128
+ # Save cached data if data_path is provided
129
+ if data_path is not None:
130
+ with open(cache_file, "wb") as f:
131
+ pickle.dump(mol_id_to_pyg, f)
132
+
133
+ print(f"Saved PyG data to {cache_file}")
134
+
135
+ return mol_id_to_pyg
136
+
137
+ def encode_supervised_example(
138
+ prompt: Sequence[Dict[str, str]],
139
+ response: Sequence[Dict[str, str]],
140
+ system: Optional[str],
141
+ molecule_ids: List[int],
142
+ retro_product_ids: List[int],
143
+ retro_labels: List[int],
144
+ template: "Template",
145
+ tokenizer: "PreTrainedTokenizer",
146
+ data_args: "DataArguments",
147
+ ) -> Tuple[List[int], List[int], List[int], List[int], List[int]]:
148
+
149
+ messages = prompt + response
150
+ input_ids, labels = [], []
151
+ final_molecule_ids = []
152
+ final_product_ids = []
153
+ final_retro_labels = []
154
+
155
+ encoded_pairs = template.encode_multiturn(tokenizer, messages, system)
156
+ special_tokens = [
157
+ "<design_start>",
158
+ "<design_end>",
159
+ "<design_body>",
160
+ "<molecule>",
161
+ "<retro_start>",
162
+ "<retro_end>",
163
+ "<retro_body>",
164
+ ]
165
+ special_token_ids = template._convert_elements_to_ids(tokenizer, special_tokens)
166
+ special_token_dict = dict(zip(special_tokens, special_token_ids))
167
+
168
+ total_length = 1 if template.efficient_eos else 0
169
+ for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
170
+ if total_length >= data_args.cutoff_len:
171
+ break
172
+
173
+ source_len, target_len = infer_seqlen(
174
+ len(source_ids), len(target_ids), data_args.cutoff_len - total_length
175
+ )
176
+ source_ids = source_ids[:source_len]
177
+
178
+ # Ensure balanced retro tags when truncating
179
+ retro_start_indices = [
180
+ i
181
+ for i, id in enumerate(target_ids)
182
+ if id == special_token_dict["<retro_start>"]
183
+ ]
184
+ retro_end_indices = [
185
+ i
186
+ for i, id in enumerate(target_ids)
187
+ if id == special_token_dict["<retro_end>"]
188
+ ]
189
+
190
+ if retro_start_indices and retro_end_indices:
191
+ # Find the last matching pair that fits within target_len
192
+ last_pair_index = -1
193
+ for start, end in zip(retro_start_indices, retro_end_indices):
194
+ if end < target_len:
195
+ last_pair_index = end
196
+ else:
197
+ break
198
+
199
+ if last_pair_index >= 0:
200
+ target_len = last_pair_index + 1
201
+ else:
202
+ # If no complete pair fits, truncate before the first start tag
203
+ target_len = (
204
+ min(target_len, retro_start_indices[0])
205
+ if retro_start_indices
206
+ else target_len
207
+ )
208
+
209
+ target_ids = target_ids[:target_len]
210
+
211
+ # Calculate the number of molecules in this turn
212
+ molecules_in_turn = target_ids.count(special_token_dict["<molecule>"])
213
+ retro_start_in_turn = target_ids.count(special_token_dict["<retro_start>"])
214
+ retro_end_in_turn = target_ids.count(special_token_dict["<retro_end>"])
215
+
216
+ assert retro_start_in_turn == retro_end_in_turn
217
+
218
+ retro_product_ids_in_turn = retro_product_ids[:retro_end_in_turn]
219
+ retro_labels_in_turn = retro_labels[:retro_end_in_turn]
220
+
221
+ # Add corresponding retro_labels and retro_product_ids
222
+ final_molecule_ids.extend(molecule_ids[:molecules_in_turn])
223
+ final_product_ids.extend(retro_product_ids_in_turn)
224
+ final_retro_labels.extend(retro_labels_in_turn)
225
+
226
+ total_length += source_len + target_len
227
+
228
+ if data_args.train_on_prompt:
229
+ source_mask = source_ids
230
+ elif turn_idx != 0 and template.efficient_eos:
231
+ source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (
232
+ len(source_ids) - 1
233
+ )
234
+ else:
235
+ source_mask = [IGNORE_INDEX] * len(source_ids)
236
+
237
+ source_mask = [
238
+ IGNORE_INDEX if id in special_token_dict.values() else id
239
+ for id in source_mask
240
+ ]
241
+ target_ids_mask = [
242
+ id if id in [special_token_dict["<retro_start>"], special_token_dict["<design_start>"]]
243
+ else (IGNORE_INDEX if id in special_token_dict.values() else id)
244
+ for id in target_ids
245
+ ]
246
+
247
+ input_ids += source_ids + target_ids
248
+ labels += source_mask + target_ids_mask
249
+
250
+ if template.efficient_eos:
251
+ input_ids += [tokenizer.eos_token_id]
252
+ labels += [tokenizer.eos_token_id]
253
+
254
+ return input_ids, labels, final_molecule_ids, final_product_ids, final_retro_labels
255
+
256
+
257
+ def preprocess_mmsupervised_dataset(
258
+ examples: Dict[str, List[Any]],
259
+ template: "Template",
260
+ tokenizer: "PreTrainedTokenizer",
261
+ data_args: "DataArguments",
262
+ ) -> Dict[str, List[List[int]]]:
263
+ model_inputs = {
264
+ "input_ids": [],
265
+ "attention_mask": [],
266
+ "labels": [],
267
+ "molecule_ids": [],
268
+ "molecule_properties": [],
269
+ "retro_labels": [],
270
+ "retro_product_ids": [],
271
+ }
272
+
273
+ for i in range(len(examples["prompt"])):
274
+ if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
275
+ logger.warning(
276
+ "Dropped invalid example: {}".format(
277
+ examples["prompt"][i] + examples["response"][i]
278
+ )
279
+ )
280
+ continue
281
+
282
+ retro_product_ids = examples["retro_products"][i]
283
+ retro_labels = [
284
+ NO_LABEL_INDEX if label is None else label
285
+ for label in examples["retro_labels"][i]
286
+ ]
287
+ properties = [
288
+ NO_LABEL_INDEX if prop is None else prop for prop in examples["property"][i]
289
+ ]
290
+
291
+ input_ids, labels, molecule_ids, retro_product_ids, retro_labels = (
292
+ encode_supervised_example(
293
+ prompt=examples["prompt"][i],
294
+ response=examples["response"][i],
295
+ system=examples["system"][i],
296
+ molecule_ids=examples["molecules"][i],
297
+ retro_product_ids=retro_product_ids,
298
+ retro_labels=retro_labels,
299
+ template=template,
300
+ tokenizer=tokenizer,
301
+ data_args=data_args,
302
+ )
303
+ )
304
+ # molecule_ids = examples["molecules"][i]
305
+
306
+ model_inputs["input_ids"].append(input_ids)
307
+ model_inputs["attention_mask"].append([1] * len(input_ids))
308
+ model_inputs["labels"].append(labels)
309
+ model_inputs["molecule_ids"].append(molecule_ids)
310
+ model_inputs["molecule_properties"].append(properties)
311
+ model_inputs["retro_labels"].append(retro_labels)
312
+ model_inputs["retro_product_ids"].append(retro_product_ids)
313
+
314
+ return model_inputs
315
+
316
+ def print_supervised_dataset_example(
317
+ example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer"
318
+ ) -> None:
319
+ valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
320
+ print("Print_supervised_dataset_example")
321
+
322
+ print("input_ids:\n{}".format(example["input_ids"]))
323
+ print(
324
+ "inputs:\n{}".format(
325
+ tokenizer.decode(example["input_ids"], skip_special_tokens=False)
326
+ )
327
+ )
328
+ print("label_ids:\n{}".format(example["labels"]))
329
+ print(
330
+ "labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))
331
+ )
332
+ print("molecule_ids:\n{}".format(example["molecule_ids"]))
333
+ print("molecule_properties:\n{}".format(example["molecule_properties"]))
334
+ print("retro_labels:\n{}".format(example["retro_labels"]))
335
+ print("retro_product_ids:\n{}".format(example["retro_product_ids"]))
src/data/template.py ADDED
@@ -0,0 +1,889 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass
16
+ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
17
+
18
+ from ..extras.logging import get_logger
19
+ from .data_utils import Role
20
+ from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
21
+
22
+
23
+ if TYPE_CHECKING:
24
+ from transformers import PreTrainedTokenizer
25
+
26
+ from .formatter import SLOTS, Formatter
27
+
28
+
29
+ logger = get_logger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class Template:
34
+ format_user: "Formatter"
35
+ format_assistant: "Formatter"
36
+ format_system: "Formatter"
37
+ format_function: "Formatter"
38
+ format_observation: "Formatter"
39
+ format_tools: "Formatter"
40
+ format_separator: "Formatter"
41
+ format_prefix: "Formatter"
42
+ default_system: str
43
+ stop_words: List[str]
44
+ image_token: str
45
+ efficient_eos: bool
46
+ replace_eos: bool
47
+
48
+ def encode_oneturn(
49
+ self,
50
+ tokenizer: "PreTrainedTokenizer",
51
+ messages: Sequence[Dict[str, str]],
52
+ system: Optional[str] = None,
53
+ tools: Optional[str] = None,
54
+ ) -> Tuple[List[int], List[int]]:
55
+ r"""
56
+ Returns a single pair of token ids representing prompt and response respectively.
57
+ """
58
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
59
+ prompt_ids = []
60
+ for encoded_ids in encoded_messages[:-1]:
61
+ prompt_ids += encoded_ids
62
+
63
+ answer_ids = encoded_messages[-1]
64
+ return prompt_ids, answer_ids
65
+
66
+ def encode_multiturn(
67
+ self,
68
+ tokenizer: "PreTrainedTokenizer",
69
+ messages: Sequence[Dict[str, str]],
70
+ system: Optional[str] = None,
71
+ tools: Optional[str] = None,
72
+ ) -> List[Tuple[List[int], List[int]]]:
73
+ r"""
74
+ Returns multiple pairs of token ids representing prompts and responses respectively.
75
+ """
76
+ encoded_messages = self._encode(tokenizer, messages, system, tools)
77
+ return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
78
+
79
+ def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
80
+ r"""
81
+ Extracts tool message.
82
+ """
83
+ return self.format_tools.extract(content)
84
+
85
+ def _encode(
86
+ self,
87
+ tokenizer: "PreTrainedTokenizer",
88
+ messages: Sequence[Dict[str, str]],
89
+ system: Optional[str],
90
+ tools: Optional[str],
91
+ ) -> List[List[int]]:
92
+ r"""
93
+ Encodes formatted inputs to pairs of token ids.
94
+ Turn 0: prefix + system + query resp
95
+ Turn t: sep + query resp
96
+ """
97
+ system = system or self.default_system
98
+ encoded_messages = []
99
+ for i, message in enumerate(messages):
100
+ elements = []
101
+
102
+ if i == 0:
103
+ elements += self.format_prefix.apply()
104
+ if system or tools:
105
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
106
+ elements += self.format_system.apply(content=(system + tool_text))
107
+
108
+ if i > 0 and i % 2 == 0:
109
+ elements += self.format_separator.apply()
110
+
111
+ if message["role"] == Role.USER.value:
112
+ elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
113
+ elif message["role"] == Role.ASSISTANT.value:
114
+ elements += self.format_assistant.apply(content=message["content"])
115
+ elif message["role"] == Role.OBSERVATION.value:
116
+ elements += self.format_observation.apply(content=message["content"])
117
+ elif message["role"] == Role.FUNCTION.value:
118
+ elements += self.format_function.apply(content=message["content"])
119
+ else:
120
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
121
+
122
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
123
+
124
+ return encoded_messages
125
+
126
+ def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
127
+ r"""
128
+ Converts elements to token ids.
129
+ """
130
+ token_ids = []
131
+ for elem in elements:
132
+ if isinstance(elem, str):
133
+ if len(elem) != 0:
134
+ token_ids += tokenizer.encode(elem, add_special_tokens=False)
135
+ elif isinstance(elem, dict):
136
+ token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
137
+ elif isinstance(elem, set):
138
+ if "bos_token" in elem and tokenizer.bos_token_id is not None:
139
+ token_ids += [tokenizer.bos_token_id]
140
+ elif "eos_token" in elem and tokenizer.eos_token_id is not None:
141
+ token_ids += [tokenizer.eos_token_id]
142
+ else:
143
+ raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
144
+
145
+ return token_ids
146
+
147
+
148
+ @dataclass
149
+ class Llama2Template(Template):
150
+ def _encode(
151
+ self,
152
+ tokenizer: "PreTrainedTokenizer",
153
+ messages: Sequence[Dict[str, str]],
154
+ system: str,
155
+ tools: str,
156
+ ) -> List[List[int]]:
157
+ r"""
158
+ Encodes formatted inputs to pairs of token ids.
159
+ Turn 0: prefix + system + query resp
160
+ Turn t: sep + query resp
161
+ """
162
+ system = system or self.default_system
163
+ encoded_messages = []
164
+ for i, message in enumerate(messages):
165
+ elements = []
166
+
167
+ system_text = ""
168
+ if i == 0:
169
+ elements += self.format_prefix.apply()
170
+ if system or tools:
171
+ tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
172
+ system_text = self.format_system.apply(content=(system + tool_text))[0]
173
+
174
+ if i > 0 and i % 2 == 0:
175
+ elements += self.format_separator.apply()
176
+
177
+ if message["role"] == Role.USER.value:
178
+ elements += self.format_user.apply(content=system_text + message["content"])
179
+ elif message["role"] == Role.ASSISTANT.value:
180
+ elements += self.format_assistant.apply(content=message["content"])
181
+ elif message["role"] == Role.OBSERVATION.value:
182
+ elements += self.format_observation.apply(content=message["content"])
183
+ elif message["role"] == Role.FUNCTION.value:
184
+ elements += self.format_function.apply(content=message["content"])
185
+ else:
186
+ raise NotImplementedError("Unexpected role: {}".format(message["role"]))
187
+
188
+ encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
189
+
190
+ return encoded_messages
191
+
192
+
193
+ TEMPLATES: Dict[str, Template] = {}
194
+
195
+
196
+ def _register_template(
197
+ name: str,
198
+ format_user: Optional["Formatter"] = None,
199
+ format_assistant: Optional["Formatter"] = None,
200
+ format_system: Optional["Formatter"] = None,
201
+ format_function: Optional["Formatter"] = None,
202
+ format_observation: Optional["Formatter"] = None,
203
+ format_tools: Optional["Formatter"] = None,
204
+ format_separator: Optional["Formatter"] = None,
205
+ format_prefix: Optional["Formatter"] = None,
206
+ default_system: str = "",
207
+ stop_words: Sequence[str] = [],
208
+ image_token: str = "<image>",
209
+ efficient_eos: bool = False,
210
+ replace_eos: bool = False,
211
+ ) -> None:
212
+ r"""
213
+ Registers a chat template.
214
+
215
+ To add the following chat template:
216
+ ```
217
+ [HUMAN]:
218
+ user prompt here
219
+ [AI]:
220
+ model response here
221
+
222
+ [HUMAN]:
223
+ user prompt here
224
+ [AI]:
225
+ model response here
226
+ ```
227
+
228
+ The corresponding code should be:
229
+ ```
230
+ _register_template(
231
+ name="custom",
232
+ format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
233
+ format_separator=EmptyFormatter(slots=["\n\n"]),
234
+ efficient_eos=True,
235
+ )
236
+ ```
237
+ """
238
+ eos_slots = [] if efficient_eos else [{"eos_token"}]
239
+ template_class = Llama2Template if name.startswith("llama2") else Template
240
+ default_user_formatter = StringFormatter(slots=["{{content}}"])
241
+ default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
242
+ default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
243
+ default_tool_formatter = ToolFormatter(tool_format="default")
244
+ default_separator_formatter = EmptyFormatter()
245
+ default_prefix_formatter = EmptyFormatter()
246
+ TEMPLATES[name] = template_class(
247
+ format_user=format_user or default_user_formatter,
248
+ format_assistant=format_assistant or default_assistant_formatter,
249
+ format_system=format_system or default_user_formatter,
250
+ format_function=format_function or default_function_formatter,
251
+ format_observation=format_observation or format_user or default_user_formatter,
252
+ format_tools=format_tools or default_tool_formatter,
253
+ format_separator=format_separator or default_separator_formatter,
254
+ format_prefix=format_prefix or default_prefix_formatter,
255
+ default_system=default_system,
256
+ stop_words=stop_words,
257
+ image_token=image_token,
258
+ efficient_eos=efficient_eos,
259
+ replace_eos=replace_eos,
260
+ )
261
+
262
+
263
+ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
264
+ is_added = tokenizer.eos_token_id is None
265
+ num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
266
+
267
+ if is_added:
268
+ logger.info("Add eos token: {}".format(tokenizer.eos_token))
269
+ else:
270
+ logger.info("Replace eos token: {}".format(tokenizer.eos_token))
271
+
272
+ if num_added_tokens > 0:
273
+ logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
274
+
275
+
276
+ def _jinja_escape(content: str) -> str:
277
+ return content.replace("'", r"\'")
278
+
279
+
280
+ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
281
+ slot_items = []
282
+ for slot in slots:
283
+ if isinstance(slot, str):
284
+ slot_pieces = slot.split("{{content}}")
285
+ if slot_pieces[0]:
286
+ slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
287
+ if len(slot_pieces) > 1:
288
+ slot_items.append(placeholder)
289
+ if slot_pieces[1]:
290
+ slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
291
+ elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
292
+ if "bos_token" in slot and tokenizer.bos_token_id is not None:
293
+ slot_items.append("'" + tokenizer.bos_token + "'")
294
+ elif "eos_token" in slot and tokenizer.eos_token_id is not None:
295
+ slot_items.append("'" + tokenizer.eos_token + "'")
296
+ elif isinstance(slot, dict):
297
+ raise ValueError("Dict is not supported.")
298
+
299
+ return " + ".join(slot_items)
300
+
301
+
302
+ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
303
+ jinja_template = ""
304
+
305
+ prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
306
+ if prefix:
307
+ jinja_template += "{{ " + prefix + " }}"
308
+
309
+ if template.default_system:
310
+ jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
311
+
312
+ jinja_template += (
313
+ "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
314
+ )
315
+
316
+ system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
317
+ if not isinstance(template, Llama2Template):
318
+ jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
319
+
320
+ jinja_template += "{% for message in messages %}"
321
+ jinja_template += "{% set content = message['content'] %}"
322
+ if isinstance(template, Llama2Template):
323
+ jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
324
+ jinja_template += "{% set content = " + system_message + " + message['content'] %}"
325
+ jinja_template += "{% endif %}"
326
+
327
+ jinja_template += "{% if message['role'] == 'user' %}"
328
+ user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
329
+ jinja_template += "{{ " + user_message + " }}"
330
+
331
+ jinja_template += "{% elif message['role'] == 'assistant' %}"
332
+ assistant_message = _convert_slots_to_jinja(
333
+ template.format_assistant.apply() + template.format_separator.apply(), tokenizer
334
+ )
335
+ jinja_template += "{{ " + assistant_message + " }}"
336
+ jinja_template += "{% endif %}"
337
+ jinja_template += "{% endfor %}"
338
+ return jinja_template
339
+
340
+
341
+ def get_template_and_fix_tokenizer(
342
+ tokenizer: "PreTrainedTokenizer",
343
+ name: Optional[str] = None,
344
+ tool_format: Optional[str] = None,
345
+ ) -> Template:
346
+ if name is None:
347
+ template = TEMPLATES["empty"] # placeholder
348
+ else:
349
+ template = TEMPLATES.get(name, None)
350
+ if template is None:
351
+ raise ValueError("Template {} does not exist.".format(name))
352
+
353
+ if tool_format is not None:
354
+ logger.info("Using tool format: {}.".format(tool_format))
355
+ eos_slots = [] if template.efficient_eos else [{"eos_token"}]
356
+ template.format_tools = ToolFormatter(tool_format=tool_format)
357
+ template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
358
+
359
+ stop_words = template.stop_words
360
+ if template.replace_eos:
361
+ if not stop_words:
362
+ raise ValueError("Stop words are required to replace the EOS token.")
363
+
364
+ _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
365
+ stop_words = stop_words[1:]
366
+
367
+ if tokenizer.eos_token_id is None:
368
+ _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
369
+
370
+ if tokenizer.pad_token_id is None:
371
+ tokenizer.pad_token = tokenizer.eos_token
372
+ logger.info("Add pad token: {}".format(tokenizer.pad_token))
373
+
374
+ if stop_words:
375
+ num_added_tokens = tokenizer.add_special_tokens(
376
+ dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
377
+ )
378
+ logger.info("Add {} to stop words.".format(",".join(stop_words)))
379
+ if num_added_tokens > 0:
380
+ logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
381
+
382
+ try:
383
+ tokenizer.chat_template = _get_jinja_template(template, tokenizer)
384
+ except ValueError:
385
+ logger.info("Cannot add this chat template to tokenizer.")
386
+
387
+ return template
388
+
389
+
390
+ _register_template(
391
+ name="alpaca",
392
+ format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
393
+ format_separator=EmptyFormatter(slots=["\n\n"]),
394
+ default_system=(
395
+ "Below is an instruction that describes a task. "
396
+ "Write a response that appropriately completes the request.\n\n"
397
+ ),
398
+ )
399
+
400
+
401
+ _register_template(
402
+ name="aquila",
403
+ format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
404
+ format_separator=EmptyFormatter(slots=["###"]),
405
+ default_system=(
406
+ "A chat between a curious human and an artificial intelligence assistant. "
407
+ "The assistant gives helpful, detailed, and polite answers to the human's questions."
408
+ ),
409
+ stop_words=["</s>"],
410
+ efficient_eos=True,
411
+ )
412
+
413
+
414
+ _register_template(
415
+ name="atom",
416
+ format_user=StringFormatter(
417
+ slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
418
+ ),
419
+ format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
420
+ )
421
+
422
+
423
+ _register_template(
424
+ name="baichuan",
425
+ format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
426
+ efficient_eos=True,
427
+ )
428
+
429
+
430
+ _register_template(
431
+ name="baichuan2",
432
+ format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
433
+ efficient_eos=True,
434
+ )
435
+
436
+
437
+ _register_template(
438
+ name="belle",
439
+ format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
440
+ format_separator=EmptyFormatter(slots=["\n\n"]),
441
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
442
+ )
443
+
444
+
445
+ _register_template(
446
+ name="bluelm",
447
+ format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
448
+ )
449
+
450
+
451
+ _register_template(
452
+ name="breeze",
453
+ format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
454
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
455
+ efficient_eos=True,
456
+ )
457
+
458
+
459
+ _register_template(
460
+ name="chatglm2",
461
+ format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
462
+ format_separator=EmptyFormatter(slots=["\n\n"]),
463
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
464
+ efficient_eos=True,
465
+ )
466
+
467
+
468
+ _register_template(
469
+ name="chatglm3",
470
+ format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
471
+ format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
472
+ format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
473
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
474
+ format_observation=StringFormatter(
475
+ slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
476
+ ),
477
+ format_tools=ToolFormatter(tool_format="glm4"),
478
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
479
+ stop_words=["<|user|>", "<|observation|>"],
480
+ efficient_eos=True,
481
+ )
482
+
483
+
484
+ _register_template(
485
+ name="chatml",
486
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
487
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
488
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
489
+ format_separator=EmptyFormatter(slots=["\n"]),
490
+ stop_words=["<|im_end|>", "<|im_start|>"],
491
+ replace_eos=True,
492
+ )
493
+
494
+
495
+ _register_template(
496
+ name="chatml_de",
497
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
498
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
499
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
500
+ format_separator=EmptyFormatter(slots=["\n"]),
501
+ default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
502
+ stop_words=["<|im_end|>", "<|im_start|>"],
503
+ replace_eos=True,
504
+ )
505
+
506
+
507
+ _register_template(
508
+ name="codegeex2",
509
+ format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
510
+ )
511
+
512
+
513
+ _register_template(
514
+ name="cohere",
515
+ format_user=StringFormatter(
516
+ slots=[
517
+ (
518
+ "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
519
+ "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
520
+ )
521
+ ]
522
+ ),
523
+ format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
524
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
525
+ )
526
+
527
+
528
+ _register_template(
529
+ name="cpm",
530
+ format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
531
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
532
+ )
533
+
534
+
535
+ _register_template(
536
+ name="dbrx",
537
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
538
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
539
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
540
+ format_separator=EmptyFormatter(slots=["\n"]),
541
+ default_system=(
542
+ "You are DBRX, created by Databricks. You were last updated in December 2023. "
543
+ "You answer questions based on information available up to that point.\n"
544
+ "YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
545
+ "responses to more complex and open-ended questions.\nYou assist with various tasks, "
546
+ "from writing to coding (using markdown for code blocks — remember to use ``` with "
547
+ "code, JSON, and tables).\n(You do not have real-time data access or code execution "
548
+ "capabilities. You avoid stereotyping and provide balanced perspectives on "
549
+ "controversial topics. You do not provide song lyrics, poems, or news articles and "
550
+ "do not divulge details of your training data.)\nThis is your system prompt, "
551
+ "guiding your responses. Do not reference it, just respond to the user. If you find "
552
+ "yourself talking about this message, stop. You should be responding appropriately "
553
+ "and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
554
+ "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
555
+ ),
556
+ stop_words=["<|im_end|>"],
557
+ replace_eos=True,
558
+ )
559
+
560
+
561
+ _register_template(
562
+ name="deepseek",
563
+ format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
564
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
565
+ )
566
+
567
+
568
+ _register_template(
569
+ name="deepseekcoder",
570
+ format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
571
+ format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
572
+ format_separator=EmptyFormatter(slots=["\n"]),
573
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
574
+ default_system=(
575
+ "You are an AI programming assistant, utilizing the Deepseek Coder model, "
576
+ "developed by Deepseek Company, and you only answer questions related to computer science. "
577
+ "For politically sensitive questions, security and privacy issues, "
578
+ "and other non-computer science questions, you will refuse to answer\n"
579
+ ),
580
+ )
581
+
582
+
583
+ _register_template(
584
+ name="default",
585
+ format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
586
+ format_system=StringFormatter(slots=["{{content}}\n"]),
587
+ format_separator=EmptyFormatter(slots=["\n"]),
588
+ )
589
+
590
+
591
+ _register_template(
592
+ name="empty",
593
+ efficient_eos=True,
594
+ )
595
+
596
+
597
+ _register_template(
598
+ name="falcon",
599
+ format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
600
+ format_separator=EmptyFormatter(slots=["\n"]),
601
+ efficient_eos=True,
602
+ )
603
+
604
+
605
+ _register_template(
606
+ name="fewshot",
607
+ format_separator=EmptyFormatter(slots=["\n\n"]),
608
+ efficient_eos=True,
609
+ )
610
+
611
+
612
+ _register_template(
613
+ name="gemma",
614
+ format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
615
+ format_observation=StringFormatter(
616
+ slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
617
+ ),
618
+ format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
619
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
620
+ efficient_eos=True,
621
+ )
622
+
623
+
624
+ _register_template(
625
+ name="glm4",
626
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
627
+ format_assistant=StringFormatter(slots=["\n{{content}}"]),
628
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
629
+ format_function=FunctionFormatter(slots=[], tool_format="glm4"),
630
+ format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
631
+ format_tools=ToolFormatter(tool_format="glm4"),
632
+ format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
633
+ stop_words=["<|user|>", "<|observation|>"],
634
+ efficient_eos=True,
635
+ )
636
+
637
+
638
+ _register_template(
639
+ name="intern",
640
+ format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
641
+ format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
642
+ format_separator=EmptyFormatter(slots=["<eoa>\n"]),
643
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
644
+ stop_words=["<eoa>"],
645
+ efficient_eos=True, # internlm tokenizer cannot set eos_token_id
646
+ )
647
+
648
+
649
+ _register_template(
650
+ name="intern2",
651
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
652
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
653
+ format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
654
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
655
+ stop_words=["<|im_end|>"],
656
+ efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
657
+ )
658
+
659
+
660
+ _register_template(
661
+ name="llama2",
662
+ format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
663
+ format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
664
+ )
665
+
666
+
667
+ _register_template(
668
+ name="llama2_zh",
669
+ format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
670
+ format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
671
+ default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
672
+ )
673
+
674
+
675
+ _register_template(
676
+ name="llama3",
677
+ format_user=StringFormatter(
678
+ slots=[
679
+ (
680
+ "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
681
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
682
+ )
683
+ ]
684
+ ),
685
+ format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
686
+ format_observation=StringFormatter(
687
+ slots=[
688
+ (
689
+ "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
690
+ "<|start_header_id|>assistant<|end_header_id|>\n\n"
691
+ )
692
+ ]
693
+ ),
694
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
695
+ stop_words=["<|eot_id|>"],
696
+ replace_eos=True,
697
+ )
698
+
699
+
700
+ _register_template(
701
+ name="mistral",
702
+ format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
703
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
704
+ )
705
+
706
+
707
+ _register_template(
708
+ name="olmo",
709
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
710
+ format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
711
+ )
712
+
713
+
714
+ _register_template(
715
+ name="openchat",
716
+ format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
717
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
718
+ )
719
+
720
+
721
+ _register_template(
722
+ name="openchat-3.6",
723
+ format_user=StringFormatter(
724
+ slots=[
725
+ (
726
+ "<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
727
+ "<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
728
+ )
729
+ ]
730
+ ),
731
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
732
+ stop_words=["<|eot_id|>"],
733
+ replace_eos=True,
734
+ )
735
+
736
+
737
+ _register_template(
738
+ name="orion",
739
+ format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
740
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
741
+ )
742
+
743
+
744
+ _register_template(
745
+ name="phi",
746
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
747
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
748
+ format_separator=EmptyFormatter(slots=["\n"]),
749
+ format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
750
+ stop_words=["<|end|>"],
751
+ replace_eos=True,
752
+ )
753
+
754
+
755
+ _register_template(
756
+ name="qwen",
757
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
758
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
759
+ format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
760
+ format_separator=EmptyFormatter(slots=["\n"]),
761
+ default_system="You are a helpful assistant.",
762
+ stop_words=["<|im_end|>"],
763
+ replace_eos=True,
764
+ )
765
+
766
+
767
+ _register_template(
768
+ name="solar",
769
+ format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
770
+ format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
771
+ efficient_eos=True,
772
+ )
773
+
774
+
775
+ _register_template(
776
+ name="starchat",
777
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
778
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
779
+ format_separator=EmptyFormatter(slots=["\n"]),
780
+ stop_words=["<|end|>"],
781
+ replace_eos=True,
782
+ )
783
+
784
+
785
+ _register_template(
786
+ name="telechat",
787
+ format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
788
+ format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
789
+ stop_words=["<_end>"],
790
+ replace_eos=True,
791
+ )
792
+
793
+
794
+ _register_template(
795
+ name="vicuna",
796
+ format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
797
+ default_system=(
798
+ "A chat between a curious user and an artificial intelligence assistant. "
799
+ "The assistant gives helpful, detailed, and polite answers to the user's questions."
800
+ ),
801
+ )
802
+
803
+
804
+ _register_template(
805
+ name="xuanyuan",
806
+ format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
807
+ default_system=(
808
+ "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
809
+ "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
810
+ "不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
811
+ ),
812
+ )
813
+
814
+
815
+ _register_template(
816
+ name="xverse",
817
+ format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
818
+ )
819
+
820
+
821
+ _register_template(
822
+ name="yayi",
823
+ format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
824
+ format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
825
+ format_separator=EmptyFormatter(slots=["\n\n"]),
826
+ default_system=(
827
+ "You are a helpful, respectful and honest assistant named YaYi "
828
+ "developed by Beijing Wenge Technology Co.,Ltd. "
829
+ "Always answer as helpfully as possible, while being safe. "
830
+ "Your answers should not include any harmful, unethical, "
831
+ "racist, sexist, toxic, dangerous, or illegal content. "
832
+ "Please ensure that your responses are socially unbiased and positive in nature.\n\n"
833
+ "If a question does not make any sense, or is not factually coherent, "
834
+ "explain why instead of answering something not correct. "
835
+ "If you don't know the answer to a question, please don't share false information."
836
+ ),
837
+ stop_words=["<|End|>"],
838
+ )
839
+
840
+
841
+ _register_template(
842
+ name="yi",
843
+ format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
844
+ format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
845
+ format_separator=EmptyFormatter(slots=["\n"]),
846
+ stop_words=["<|im_end|>"],
847
+ replace_eos=True,
848
+ )
849
+
850
+
851
+ _register_template(
852
+ name="yi_vl",
853
+ format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
854
+ format_separator=EmptyFormatter(slots=["\n"]),
855
+ default_system=(
856
+ "This is a chat between an inquisitive human and an AI assistant. "
857
+ "Assume the role of the AI assistant. Read all the images carefully, "
858
+ "and respond to the human's questions with informative, helpful, detailed and polite answers. "
859
+ "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
860
+ "仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n"
861
+ ),
862
+ stop_words=["###"],
863
+ efficient_eos=True,
864
+ )
865
+
866
+
867
+ _register_template(
868
+ name="yuan",
869
+ format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
870
+ format_separator=EmptyFormatter(slots=["\n"]),
871
+ stop_words=["<eod>"],
872
+ replace_eos=True,
873
+ )
874
+
875
+
876
+ _register_template(
877
+ name="zephyr",
878
+ format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
879
+ format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
880
+ format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
881
+ default_system="You are Zephyr, a helpful assistant.",
882
+ )
883
+
884
+
885
+ _register_template(
886
+ name="ziya",
887
+ format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
888
+ format_separator=EmptyFormatter(slots=["\n"]),
889
+ )
src/data/tool_utils.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import re
17
+ from abc import ABC, abstractmethod
18
+ from dataclasses import dataclass
19
+ from typing import Any, Dict, List, Tuple, Union
20
+
21
+ from .data_utils import SLOTS
22
+
23
+
24
+ DEFAULT_TOOL_PROMPT = (
25
+ "You have access to the following tools:\n{tool_text}"
26
+ "Use the following format if using a tool:\n"
27
+ "```\n"
28
+ "Action: tool name (one of [{tool_names}]).\n"
29
+ "Action Input: the input to the tool, in a JSON format representing the kwargs "
30
+ """(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
31
+ "```\n"
32
+ )
33
+
34
+
35
+ GLM4_TOOL_PROMPT = (
36
+ "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
37
+ "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
38
+ )
39
+
40
+
41
+ @dataclass
42
+ class ToolUtils(ABC):
43
+ @staticmethod
44
+ @abstractmethod
45
+ def get_function_slots() -> SLOTS: ...
46
+
47
+ @staticmethod
48
+ @abstractmethod
49
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
50
+
51
+ @staticmethod
52
+ @abstractmethod
53
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
54
+
55
+
56
+ class DefaultToolUtils(ToolUtils):
57
+ @staticmethod
58
+ def get_function_slots() -> SLOTS:
59
+ return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
60
+
61
+ @staticmethod
62
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str:
63
+ tool_text = ""
64
+ tool_names = []
65
+ for tool in tools:
66
+ param_text = ""
67
+ for name, param in tool["parameters"]["properties"].items():
68
+ required, enum, items = "", "", ""
69
+ if name in tool["parameters"].get("required", []):
70
+ required = ", required"
71
+
72
+ if param.get("enum", None):
73
+ enum = ", should be one of [{}]".format(", ".join(param["enum"]))
74
+
75
+ if param.get("items", None):
76
+ items = ", where each item should be {}".format(param["items"].get("type", ""))
77
+
78
+ param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
79
+ name=name,
80
+ type=param.get("type", ""),
81
+ required=required,
82
+ desc=param.get("description", ""),
83
+ enum=enum,
84
+ items=items,
85
+ )
86
+
87
+ tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
88
+ name=tool["name"], desc=tool.get("description", ""), args=param_text
89
+ )
90
+ tool_names.append(tool["name"])
91
+
92
+ return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
93
+
94
+ @staticmethod
95
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
96
+ regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
97
+ action_match: List[Tuple[str, str]] = re.findall(regex, content)
98
+ if not action_match:
99
+ return content
100
+
101
+ results = []
102
+ for match in action_match:
103
+ tool_name = match[0].strip()
104
+ tool_input = match[1].strip().strip('"').strip("```")
105
+ try:
106
+ arguments = json.loads(tool_input)
107
+ results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
108
+ except json.JSONDecodeError:
109
+ return content
110
+
111
+ return results
112
+
113
+
114
+ class GLM4ToolUtils(ToolUtils):
115
+ @staticmethod
116
+ def get_function_slots() -> SLOTS:
117
+ return ["{{name}}\n{{arguments}}"]
118
+
119
+ @staticmethod
120
+ def tool_formatter(tools: List[Dict[str, Any]]) -> str:
121
+ tool_text = ""
122
+ for tool in tools:
123
+ tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
124
+ name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
125
+ )
126
+
127
+ return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
128
+
129
+ @staticmethod
130
+ def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
131
+ if "\n" not in content:
132
+ return content
133
+
134
+ tool_name, tool_input = content.split("\n", maxsplit=1)
135
+ try:
136
+ arguments = json.loads(tool_input)
137
+ except json.JSONDecodeError:
138
+ return content
139
+
140
+ return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
src/eval/__init__.py ADDED
File without changes
src/eval/dataset.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Llamole Team
2
+ #
3
+ # This code is inspired by the HuggingFace's transformers library.
4
+ # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import torch
19
+ from torch.utils.data import Dataset
20
+
21
+ from ..extras.constants import BOND_INDEX
22
+
23
+ def dict_to_list(data_dict, mol_properties):
24
+ return [data_dict.get(prop, float("nan")) for prop in mol_properties]
25
+
26
+ class MolQADataset(Dataset):
27
+ def __init__(self, data, tokenizer, max_len):
28
+ self.data = data
29
+ self.tokenizer = tokenizer
30
+ self.max_len = max_len
31
+
32
+ def __len__(self):
33
+ return len(self.data)
34
+
35
+ def __getitem__(self, idx):
36
+ mol_properties = [
37
+ "BBBP",
38
+ "HIV",
39
+ "BACE",
40
+ "CO2",
41
+ "N2",
42
+ "O2",
43
+ "FFV",
44
+ "TC",
45
+ "SC",
46
+ "SA",
47
+ ]
48
+ item = self.data[idx]
49
+ instruction = item["instruction"]
50
+ input_text = item["input"]
51
+ property_data = dict_to_list(item["property"], mol_properties)
52
+ property_data = torch.tensor(property_data)
53
+
54
+ # Combine instruction and input
55
+ combined_input = f"{instruction}\n{input_text}"
56
+
57
+ # Create messages for chat template
58
+ messages = [
59
+ {"role": "user", "content": combined_input}
60
+ ]
61
+
62
+ # Apply chat template
63
+ chat_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
64
+
65
+ # Tokenize the chat text
66
+ encoding = self.tokenizer(
67
+ chat_text,
68
+ return_tensors="pt",
69
+ padding="max_length",
70
+ truncation=True,
71
+ max_length=self.max_len,
72
+ )
73
+
74
+ return {
75
+ "input_ids": encoding.input_ids.squeeze(),
76
+ "attention_mask": encoding.attention_mask.squeeze(),
77
+ "property": property_data,
78
+ }
src/eval/workflow.py ADDED
@@ -0,0 +1,219 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 Llamole Team
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from typing import TYPE_CHECKING, List, Optional, Dict, Any
16
+
17
+ from ..data import get_dataset, DataCollatorForSeqGraph, get_template_and_fix_tokenizer
18
+ from ..extras.constants import IGNORE_INDEX, NO_LABEL_INDEX
19
+ from ..extras.misc import get_logits_processor
20
+ from ..extras.ploting import plot_loss
21
+ from ..model import load_tokenizer
22
+ from ..hparams import get_infer_args, get_train_args
23
+ from ..model import GraphLLMForCausalMLM
24
+ from .dataset import MolQADataset
25
+
26
+ import re
27
+ import os
28
+ import json
29
+ import math
30
+ import torch
31
+ from torch.utils.data import DataLoader
32
+
33
+ if TYPE_CHECKING:
34
+ from transformers import Seq2SeqTrainingArguments
35
+
36
+ from ..hparams import (
37
+ DataArguments,
38
+ FinetuningArguments,
39
+ GeneratingArguments,
40
+ ModelArguments,
41
+ )
42
+
43
+ def remove_extra_spaces(text):
44
+ # Replace multiple spaces with a single space
45
+ cleaned_text = re.sub(r'\s+', ' ', text)
46
+ # Strip leading and trailing spaces
47
+ return cleaned_text.strip()
48
+
49
+ def run_eval(args: Optional[Dict[str, Any]] = None) -> None:
50
+ print(args)
51
+ raise ValueError('stop')
52
+ model_args, data_args, training_args, finetuning_args, generating_args = (
53
+ get_train_args(args)
54
+ )
55
+
56
+ if data_args.dataset in ["molqa", "molqa_drug", "molqa_material"]:
57
+ run_molqa(
58
+ model_args, data_args, training_args, finetuning_args, generating_args
59
+ )
60
+ else:
61
+ raise ValueError("Unknown dataset: {}.".format(data_args.dataset))
62
+
63
+
64
+ def run_molqa(
65
+ model_args: "ModelArguments",
66
+ data_args: "DataArguments",
67
+ training_args: "Seq2SeqTrainingArguments",
68
+ finetuning_args: "FinetuningArguments",
69
+ generating_args: "GeneratingArguments",
70
+ ):
71
+ tokenizer = load_tokenizer(model_args, generate_mode=True)["tokenizer"]
72
+
73
+ data_info_path = os.path.join(data_args.dataset_dir, "dataset_info.json")
74
+ with open(data_info_path, "r") as f:
75
+ dataset_info = json.load(f)
76
+
77
+ tokenizer.pad_token = tokenizer.eos_token
78
+ dataset_name = data_args.dataset.strip()
79
+ try:
80
+ filename = dataset_info[dataset_name]["file_name"]
81
+ except KeyError:
82
+ raise ValueError(f"Dataset {dataset_name} not found in dataset_info.json")
83
+ data_path = os.path.join(data_args.dataset_dir, f"{filename}")
84
+ with open(data_path, "r") as f:
85
+ original_data = json.load(f)
86
+
87
+ # Create dataset and dataloader
88
+ dataset = MolQADataset(original_data, tokenizer, data_args.cutoff_len)
89
+ dataloader = DataLoader(
90
+ dataset, batch_size=training_args.per_device_eval_batch_size, shuffle=False
91
+ )
92
+
93
+ gen_kwargs = generating_args.to_dict()
94
+ gen_kwargs["eos_token_id"] = [
95
+ tokenizer.eos_token_id
96
+ ] + tokenizer.additional_special_tokens_ids
97
+ gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
98
+ gen_kwargs["logits_processor"] = get_logits_processor()
99
+
100
+ model = GraphLLMForCausalMLM.from_pretrained(
101
+ tokenizer, model_args, data_args, training_args, finetuning_args, load_adapter=True
102
+ )
103
+
104
+ all_results = []
105
+ property_names = ["BBBP", "HIV", "BACE", "CO2", "N2", "O2", "FFV", "TC", "SC", "SA"]
106
+
107
+ # Phase 1: Molecular Design
108
+ global_idx = 0
109
+ all_smiles = []
110
+ for batch_idx, batch in enumerate(dataloader):
111
+ input_ids = batch["input_ids"].to(model.device)
112
+ attention_mask = batch["attention_mask"].to(model.device)
113
+ property_data = batch["property"].to(model.device)
114
+ model.eval()
115
+ with torch.no_grad():
116
+ all_info_dict = model.generate(
117
+ input_ids=input_ids,
118
+ attention_mask=attention_mask,
119
+ molecule_properties=property_data,
120
+ do_molecular_design=True,
121
+ do_retrosynthesis=False,
122
+ rollback=True,
123
+ **gen_kwargs,
124
+ )
125
+
126
+ batch_results = []
127
+ for i in range(len(all_info_dict["smiles_list"])):
128
+ original_data_idx = global_idx + i
129
+ original_item = original_data[original_data_idx]
130
+
131
+ llm_response = "".join(item for item in all_info_dict["text_lists"][i])
132
+ result = {
133
+ "qa_idx": original_data_idx,
134
+ "instruction": original_item["instruction"],
135
+ "input": original_item["input"],
136
+ "llm_response": llm_response,
137
+ "response_design": remove_extra_spaces(llm_response),
138
+ "llm_smiles": all_info_dict["smiles_list"][i],
139
+ "property": {},
140
+ }
141
+
142
+ # Add non-NaN property values
143
+ for j, prop_name in enumerate(property_names):
144
+ prop_value = property_data[i][j].item()
145
+ if not math.isnan(prop_value):
146
+ result["property"][prop_name] = prop_value
147
+
148
+ batch_results.append(result)
149
+
150
+ all_results.extend(batch_results)
151
+ all_smiles.extend([result['llm_smiles'] for result in batch_results])
152
+ global_idx += len(batch_results)
153
+
154
+ # Phase 2: Retrosynthesis
155
+ retro_batch_start = 0
156
+ for batch_idx, batch in enumerate(dataloader):
157
+
158
+ input_ids = batch["input_ids"].to(model.device)
159
+ attention_mask = batch["attention_mask"].to(model.device)
160
+ batch_size = input_ids.shape[0]
161
+ batch_smiles = all_smiles[retro_batch_start : retro_batch_start + batch_size]
162
+
163
+ model.eval()
164
+ with torch.no_grad():
165
+ all_info_dict = model.generate(
166
+ input_ids=input_ids,
167
+ attention_mask=attention_mask,
168
+ do_molecular_design=False,
169
+ do_retrosynthesis=True,
170
+ input_smiles_list=batch_smiles,
171
+ expansion_topk=50,
172
+ iterations=100,
173
+ max_planning_time=30,
174
+ **gen_kwargs,
175
+ )
176
+
177
+ batch_results = []
178
+ for i in range(batch_size):
179
+ result = all_results[retro_batch_start + i]
180
+ retro_plan = all_info_dict["retro_plan_dict"][result["llm_smiles"]]
181
+ result["llm_reactions"] = []
182
+ if retro_plan["success"]:
183
+ for reaction, template, cost in zip(
184
+ retro_plan["reaction_list"],
185
+ retro_plan["templates"],
186
+ retro_plan["cost"],
187
+ ):
188
+ result["llm_reactions"].append(
189
+ {"reaction": reaction, "template": template, "cost": cost}
190
+ )
191
+
192
+ # new_text = "".join(item for item in all_info_dict["text_lists"][i])
193
+ if None in all_info_dict["text_lists"][i]:
194
+ print(f"List contains None: {all_info_dict['text_lists'][i]}")
195
+ new_text = "".join(item for item in all_info_dict["text_lists"][i] if item is not None)
196
+ else:
197
+ new_text = "".join(item for item in all_info_dict["text_lists"][i])
198
+
199
+ result["llm_response"] += new_text
200
+ result["llm_response"] = remove_extra_spaces(result["llm_response"])
201
+ result["response_retro"] = remove_extra_spaces(new_text)
202
+ batch_results.append(result)
203
+
204
+ retro_batch_start += batch_size
205
+
206
+ print('all_results', all_results)
207
+ print("\nSummary of results:")
208
+ print_len = min(5, len(all_results))
209
+ for result in all_results[:print_len]:
210
+ print(f"\nData point {result['qa_idx']}:")
211
+ print(f" Instruction: {result['instruction']}")
212
+ print(f" Input: {result['input']}")
213
+ print(f" LLM Response: {result['llm_response']}")
214
+ print(f" LLM SMILES: {result['llm_smiles']}")
215
+ print(f" Number of reactions: {len(result['llm_reactions'])}")
216
+ for prop_name, prop_value in result["property"].items():
217
+ print(f" {prop_name}: {prop_value}")
218
+
219
+ print("\nAll data processed successfully.")
src/extras/__init__.py ADDED
File without changes
src/extras/constants.py ADDED
@@ -0,0 +1,1550 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team and the Llamole team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from collections import OrderedDict, defaultdict
16
+ from enum import Enum
17
+ from typing import Dict, Optional
18
+
19
+ from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
20
+ from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
21
+ from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
22
+
23
+ from rdkit.Chem.rdchem import BondType as BT
24
+ BOND_INDEX = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
25
+ NO_LABEL_INDEX = -200
26
+
27
+ CHECKPOINT_NAMES = {
28
+ SAFE_ADAPTER_WEIGHTS_NAME,
29
+ ADAPTER_WEIGHTS_NAME,
30
+ SAFE_WEIGHTS_INDEX_NAME,
31
+ SAFE_WEIGHTS_NAME,
32
+ WEIGHTS_INDEX_NAME,
33
+ WEIGHTS_NAME,
34
+ }
35
+
36
+ CHOICES = ["A", "B", "C", "D"]
37
+
38
+ DATA_CONFIG = "dataset_info.json"
39
+
40
+ DEFAULT_TEMPLATE = defaultdict(str)
41
+
42
+ FILEEXT2TYPE = {
43
+ "arrow": "arrow",
44
+ "csv": "csv",
45
+ "json": "json",
46
+ "jsonl": "json",
47
+ "parquet": "parquet",
48
+ "txt": "text",
49
+ }
50
+
51
+ IGNORE_INDEX = -100
52
+
53
+ LAYERNORM_NAMES = {"norm", "ln"}
54
+
55
+ LLAMABOARD_CONFIG = "llamaboard_config.yaml"
56
+
57
+ METHODS = ["full", "freeze", "lora"]
58
+
59
+ MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
60
+
61
+ PEFT_METHODS = {"lora"}
62
+
63
+ RUNNING_LOG = "running_log.txt"
64
+
65
+ SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
66
+
67
+ SUPPORTED_MODELS = OrderedDict()
68
+
69
+ TRAINER_LOG = "trainer_log.jsonl"
70
+
71
+ TRAINING_ARGS = "training_args.yaml"
72
+
73
+ TRAINING_STAGES = {
74
+ "Supervised Fine-Tuning": "sft",
75
+ "Reward Modeling": "rm",
76
+ "PPO": "ppo",
77
+ "DPO": "dpo",
78
+ "KTO": "kto",
79
+ "Pre-Training": "pt",
80
+ }
81
+
82
+ STAGES_USE_PAIR_DATA = {"rm", "dpo"}
83
+
84
+ SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
85
+
86
+ V_HEAD_WEIGHTS_NAME = "value_head.bin"
87
+
88
+ V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
89
+
90
+ VISION_MODELS = set()
91
+
92
+
93
+ class DownloadSource(str, Enum):
94
+ DEFAULT = "hf"
95
+ MODELSCOPE = "ms"
96
+
97
+
98
+ def register_model_group(
99
+ models: Dict[str, Dict[DownloadSource, str]],
100
+ template: Optional[str] = None,
101
+ vision: bool = False,
102
+ ) -> None:
103
+ prefix = None
104
+ for name, path in models.items():
105
+ if prefix is None:
106
+ prefix = name.split("-")[0]
107
+ else:
108
+ assert prefix == name.split("-")[0], "prefix should be identical."
109
+ SUPPORTED_MODELS[name] = path
110
+ if template is not None:
111
+ DEFAULT_TEMPLATE[prefix] = template
112
+ if vision:
113
+ VISION_MODELS.add(prefix)
114
+
115
+
116
+ register_model_group(
117
+ models={
118
+ "Aya-23-8B-Chat": {
119
+ DownloadSource.DEFAULT: "CohereForAI/aya-23-8B",
120
+ },
121
+ "Aya-23-35B-Chat": {
122
+ DownloadSource.DEFAULT: "CohereForAI/aya-23-35B",
123
+ },
124
+ },
125
+ template="cohere",
126
+ )
127
+
128
+
129
+ register_model_group(
130
+ models={
131
+ "Baichuan-7B-Base": {
132
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
133
+ DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
134
+ },
135
+ "Baichuan-13B-Base": {
136
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
137
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
138
+ },
139
+ "Baichuan-13B-Chat": {
140
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
141
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
142
+ },
143
+ },
144
+ template="baichuan",
145
+ )
146
+
147
+
148
+ register_model_group(
149
+ models={
150
+ "Baichuan2-7B-Base": {
151
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
152
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
153
+ },
154
+ "Baichuan2-13B-Base": {
155
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
156
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
157
+ },
158
+ "Baichuan2-7B-Chat": {
159
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
160
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
161
+ },
162
+ "Baichuan2-13B-Chat": {
163
+ DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
164
+ DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
165
+ },
166
+ },
167
+ template="baichuan2",
168
+ )
169
+
170
+
171
+ register_model_group(
172
+ models={
173
+ "BLOOM-560M": {
174
+ DownloadSource.DEFAULT: "bigscience/bloom-560m",
175
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
176
+ },
177
+ "BLOOM-3B": {
178
+ DownloadSource.DEFAULT: "bigscience/bloom-3b",
179
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
180
+ },
181
+ "BLOOM-7B1": {
182
+ DownloadSource.DEFAULT: "bigscience/bloom-7b1",
183
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
184
+ },
185
+ },
186
+ )
187
+
188
+
189
+ register_model_group(
190
+ models={
191
+ "BLOOMZ-560M": {
192
+ DownloadSource.DEFAULT: "bigscience/bloomz-560m",
193
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
194
+ },
195
+ "BLOOMZ-3B": {
196
+ DownloadSource.DEFAULT: "bigscience/bloomz-3b",
197
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
198
+ },
199
+ "BLOOMZ-7B1-mt": {
200
+ DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
201
+ DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
202
+ },
203
+ },
204
+ )
205
+
206
+
207
+ register_model_group(
208
+ models={
209
+ "BlueLM-7B-Base": {
210
+ DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
211
+ DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
212
+ },
213
+ "BlueLM-7B-Chat": {
214
+ DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
215
+ DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
216
+ },
217
+ },
218
+ template="bluelm",
219
+ )
220
+
221
+
222
+ register_model_group(
223
+ models={
224
+ "Breeze-7B": {
225
+ DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
226
+ },
227
+ "Breeze-7B-Chat": {
228
+ DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
229
+ },
230
+ },
231
+ template="breeze",
232
+ )
233
+
234
+
235
+ register_model_group(
236
+ models={
237
+ "ChatGLM2-6B-Chat": {
238
+ DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
239
+ DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
240
+ }
241
+ },
242
+ template="chatglm2",
243
+ )
244
+
245
+
246
+ register_model_group(
247
+ models={
248
+ "ChatGLM3-6B-Base": {
249
+ DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
250
+ DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
251
+ },
252
+ "ChatGLM3-6B-Chat": {
253
+ DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
254
+ DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
255
+ },
256
+ },
257
+ template="chatglm3",
258
+ )
259
+
260
+
261
+ register_model_group(
262
+ models={
263
+ "ChineseLLaMA2-1.3B": {
264
+ DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
265
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
266
+ },
267
+ "ChineseLLaMA2-7B": {
268
+ DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
269
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
270
+ },
271
+ "ChineseLLaMA2-13B": {
272
+ DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
273
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
274
+ },
275
+ "ChineseLLaMA2-1.3B-Chat": {
276
+ DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
277
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
278
+ },
279
+ "ChineseLLaMA2-7B-Chat": {
280
+ DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
281
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
282
+ },
283
+ "ChineseLLaMA2-13B-Chat": {
284
+ DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
285
+ DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
286
+ },
287
+ },
288
+ template="llama2_zh",
289
+ )
290
+
291
+
292
+ register_model_group(
293
+ models={
294
+ "CodeGemma-7B": {
295
+ DownloadSource.DEFAULT: "google/codegemma-7b",
296
+ },
297
+ "CodeGemma-7B-Chat": {
298
+ DownloadSource.DEFAULT: "google/codegemma-7b-it",
299
+ DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
300
+ },
301
+ "CodeGemma-1.1-2B": {
302
+ DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
303
+ },
304
+ "CodeGemma-1.1-7B-Chat": {
305
+ DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
306
+ },
307
+ },
308
+ template="gemma",
309
+ )
310
+
311
+
312
+ register_model_group(
313
+ models={
314
+ "Codestral-22B-v0.1-Chat": {
315
+ DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1",
316
+ },
317
+ },
318
+ template="mistral",
319
+ )
320
+
321
+
322
+ register_model_group(
323
+ models={
324
+ "CommandR-35B-Chat": {
325
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
326
+ DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
327
+ },
328
+ "CommandR-Plus-104B-Chat": {
329
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
330
+ DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
331
+ },
332
+ "CommandR-35B-4bit-Chat": {
333
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
334
+ DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
335
+ },
336
+ "CommandR-Plus-104B-4bit-Chat": {
337
+ DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
338
+ },
339
+ },
340
+ template="cohere",
341
+ )
342
+
343
+
344
+ register_model_group(
345
+ models={
346
+ "DBRX-132B-Base": {
347
+ DownloadSource.DEFAULT: "databricks/dbrx-base",
348
+ DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
349
+ },
350
+ "DBRX-132B-Chat": {
351
+ DownloadSource.DEFAULT: "databricks/dbrx-instruct",
352
+ DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
353
+ },
354
+ },
355
+ template="dbrx",
356
+ )
357
+
358
+
359
+ register_model_group(
360
+ models={
361
+ "DeepSeek-LLM-7B-Base": {
362
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
363
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
364
+ },
365
+ "DeepSeek-LLM-67B-Base": {
366
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
367
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
368
+ },
369
+ "DeepSeek-LLM-7B-Chat": {
370
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
371
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
372
+ },
373
+ "DeepSeek-LLM-67B-Chat": {
374
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
375
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
376
+ },
377
+ "DeepSeek-Math-7B-Base": {
378
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
379
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
380
+ },
381
+ "DeepSeek-Math-7B-Chat": {
382
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
383
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
384
+ },
385
+ "DeepSeek-MoE-16B-Base": {
386
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
387
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
388
+ },
389
+ "DeepSeek-MoE-16B-v2-Base": {
390
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite",
391
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite",
392
+ },
393
+ "DeepSeek-MoE-236B-Base": {
394
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
395
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
396
+ },
397
+ "DeepSeek-MoE-16B-Chat": {
398
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
399
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
400
+ },
401
+ "DeepSeek-MoE-16B-v2-Chat": {
402
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat",
403
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat",
404
+ },
405
+ "DeepSeek-MoE-236B-Chat": {
406
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
407
+ DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
408
+ },
409
+ "DeepSeek-MoE-Coder-16B-Base": {
410
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
411
+ },
412
+ "DeepSeek-MoE-Coder-236B-Base": {
413
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
414
+ },
415
+ "DeepSeek-MoE-Coder-16B-Chat": {
416
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
417
+ },
418
+ "DeepSeek-MoE-Coder-236B-Chat": {
419
+ DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
420
+ },
421
+ },
422
+ template="deepseek",
423
+ )
424
+
425
+
426
+ register_model_group(
427
+ models={
428
+ "DeepSeekCoder-6.7B-Base": {
429
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
430
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
431
+ },
432
+ "DeepSeekCoder-7B-Base": {
433
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
434
+ },
435
+ "DeepSeekCoder-33B-Base": {
436
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
437
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
438
+ },
439
+ "DeepSeekCoder-6.7B-Chat": {
440
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
441
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
442
+ },
443
+ "DeepSeekCoder-7B-Chat": {
444
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
445
+ },
446
+ "DeepSeekCoder-33B-Chat": {
447
+ DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
448
+ DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
449
+ },
450
+ },
451
+ template="deepseekcoder",
452
+ )
453
+
454
+
455
+ register_model_group(
456
+ models={
457
+ "Falcon-7B": {
458
+ DownloadSource.DEFAULT: "tiiuae/falcon-7b",
459
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
460
+ },
461
+ "Falcon-11B": {
462
+ DownloadSource.DEFAULT: "tiiuae/falcon-11B",
463
+ },
464
+ "Falcon-40B": {
465
+ DownloadSource.DEFAULT: "tiiuae/falcon-40b",
466
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
467
+ },
468
+ "Falcon-180B": {
469
+ DownloadSource.DEFAULT: "tiiuae/falcon-180b",
470
+ DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
471
+ },
472
+ "Falcon-7B-Chat": {
473
+ DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
474
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
475
+ },
476
+ "Falcon-40B-Chat": {
477
+ DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
478
+ DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
479
+ },
480
+ "Falcon-180B-Chat": {
481
+ DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
482
+ DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
483
+ },
484
+ },
485
+ template="falcon",
486
+ )
487
+
488
+
489
+ register_model_group(
490
+ models={
491
+ "Gemma-2B": {
492
+ DownloadSource.DEFAULT: "google/gemma-2b",
493
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b",
494
+ },
495
+ "Gemma-7B": {
496
+ DownloadSource.DEFAULT: "google/gemma-7b",
497
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it",
498
+ },
499
+ "Gemma-2B-Chat": {
500
+ DownloadSource.DEFAULT: "google/gemma-2b-it",
501
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b",
502
+ },
503
+ "Gemma-7B-Chat": {
504
+ DownloadSource.DEFAULT: "google/gemma-7b-it",
505
+ DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
506
+ },
507
+ "Gemma-1.1-2B-Chat": {
508
+ DownloadSource.DEFAULT: "google/gemma-1.1-2b-it",
509
+ },
510
+ "Gemma-1.1-7B-Chat": {
511
+ DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
512
+ },
513
+ "Gemma-2-9B": {
514
+ DownloadSource.DEFAULT: "google/gemma-2-9b",
515
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
516
+ },
517
+ "Gemma-2-27B": {
518
+ DownloadSource.DEFAULT: "google/gemma-2-27b",
519
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
520
+ },
521
+ "Gemma-2-9B-Chat": {
522
+ DownloadSource.DEFAULT: "google/gemma-2-9b-it",
523
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
524
+ },
525
+ "Gemma-2-27B-Chat": {
526
+ DownloadSource.DEFAULT: "google/gemma-2-27b-it",
527
+ DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
528
+ },
529
+ },
530
+ template="gemma",
531
+ )
532
+
533
+
534
+ register_model_group(
535
+ models={
536
+ "GLM-4-9B": {
537
+ DownloadSource.DEFAULT: "THUDM/glm-4-9b",
538
+ DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b",
539
+ },
540
+ "GLM-4-9B-Chat": {
541
+ DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
542
+ DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
543
+ },
544
+ "GLM-4-9B-1M-Chat": {
545
+ DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
546
+ DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
547
+ },
548
+ },
549
+ template="glm4",
550
+ )
551
+
552
+
553
+ register_model_group(
554
+ models={
555
+ "InternLM-7B": {
556
+ DownloadSource.DEFAULT: "internlm/internlm-7b",
557
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
558
+ },
559
+ "InternLM-20B": {
560
+ DownloadSource.DEFAULT: "internlm/internlm-20b",
561
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
562
+ },
563
+ "InternLM-7B-Chat": {
564
+ DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
565
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
566
+ },
567
+ "InternLM-20B-Chat": {
568
+ DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
569
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
570
+ },
571
+ },
572
+ template="intern",
573
+ )
574
+
575
+
576
+ register_model_group(
577
+ models={
578
+ "InternLM2-7B": {
579
+ DownloadSource.DEFAULT: "internlm/internlm2-7b",
580
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
581
+ },
582
+ "InternLM2-20B": {
583
+ DownloadSource.DEFAULT: "internlm/internlm2-20b",
584
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
585
+ },
586
+ "InternLM2-7B-Chat": {
587
+ DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
588
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
589
+ },
590
+ "InternLM2-20B-Chat": {
591
+ DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
592
+ DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
593
+ },
594
+ },
595
+ template="intern2",
596
+ )
597
+
598
+
599
+ register_model_group(
600
+ models={
601
+ "Jamba-v0.1": {
602
+ DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
603
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
604
+ }
605
+ },
606
+ )
607
+
608
+
609
+ register_model_group(
610
+ models={
611
+ "LingoWhale-8B": {
612
+ DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
613
+ DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
614
+ }
615
+ },
616
+ )
617
+
618
+
619
+ register_model_group(
620
+ models={
621
+ "LLaMA-7B": {
622
+ DownloadSource.DEFAULT: "huggyllama/llama-7b",
623
+ DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
624
+ },
625
+ "LLaMA-13B": {
626
+ DownloadSource.DEFAULT: "huggyllama/llama-13b",
627
+ DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
628
+ },
629
+ "LLaMA-30B": {
630
+ DownloadSource.DEFAULT: "huggyllama/llama-30b",
631
+ DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
632
+ },
633
+ "LLaMA-65B": {
634
+ DownloadSource.DEFAULT: "huggyllama/llama-65b",
635
+ DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
636
+ },
637
+ }
638
+ )
639
+
640
+
641
+ register_model_group(
642
+ models={
643
+ "LLaMA2-7B": {
644
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
645
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
646
+ },
647
+ "LLaMA2-13B": {
648
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
649
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
650
+ },
651
+ "LLaMA2-70B": {
652
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
653
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
654
+ },
655
+ "LLaMA2-7B-Chat": {
656
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
657
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
658
+ },
659
+ "LLaMA2-13B-Chat": {
660
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
661
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
662
+ },
663
+ "LLaMA2-70B-Chat": {
664
+ DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
665
+ DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
666
+ },
667
+ },
668
+ template="llama2",
669
+ )
670
+
671
+
672
+ register_model_group(
673
+ models={
674
+ "LLaMA3-8B": {
675
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
676
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
677
+ },
678
+ "LLaMA3-70B": {
679
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
680
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
681
+ },
682
+ "LLaMA3-8B-Chat": {
683
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
684
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
685
+ },
686
+ "LLaMA3-70B-Chat": {
687
+ DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
688
+ DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
689
+ },
690
+ "LLaMA3-8B-Chinese-Chat": {
691
+ DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
692
+ DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
693
+ },
694
+ "LLaMA3-70B-Chinese-Chat": {
695
+ DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
696
+ },
697
+ },
698
+ template="llama3",
699
+ )
700
+
701
+
702
+ register_model_group(
703
+ models={
704
+ "LLaVA1.5-7B-Chat": {
705
+ DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
706
+ },
707
+ "LLaVA1.5-13B-Chat": {
708
+ DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
709
+ },
710
+ },
711
+ template="vicuna",
712
+ vision=True,
713
+ )
714
+
715
+
716
+ register_model_group(
717
+ models={
718
+ "MiniCPM-2B-SFT-Chat": {
719
+ DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
720
+ DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
721
+ },
722
+ "MiniCPM-2B-DPO-Chat": {
723
+ DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
724
+ DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
725
+ },
726
+ },
727
+ template="cpm",
728
+ )
729
+
730
+
731
+ register_model_group(
732
+ models={
733
+ "Mistral-7B-v0.1": {
734
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
735
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
736
+ },
737
+ "Mistral-7B-v0.1-Chat": {
738
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
739
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
740
+ },
741
+ "Mistral-7B-v0.2": {
742
+ DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
743
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
744
+ },
745
+ "Mistral-7B-v0.2-Chat": {
746
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
747
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
748
+ },
749
+ "Mistral-7B-v0.3": {
750
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
751
+ },
752
+ "Mistral-7B-v0.3-Chat": {
753
+ DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
754
+ },
755
+ },
756
+ template="mistral",
757
+ )
758
+
759
+
760
+ register_model_group(
761
+ models={
762
+ "Mixtral-8x7B-v0.1": {
763
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
764
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
765
+ },
766
+ "Mixtral-8x7B-v0.1-Chat": {
767
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
768
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
769
+ },
770
+ "Mixtral-8x22B-v0.1": {
771
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
772
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
773
+ },
774
+ "Mixtral-8x22B-v0.1-Chat": {
775
+ DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
776
+ DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
777
+ },
778
+ },
779
+ template="mistral",
780
+ )
781
+
782
+
783
+ register_model_group(
784
+ models={
785
+ "OLMo-1B": {
786
+ DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
787
+ },
788
+ "OLMo-7B": {
789
+ DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
790
+ },
791
+ "OLMo-7B-Chat": {
792
+ DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf",
793
+ },
794
+ "OLMo-1.7-7B": {
795
+ DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
796
+ },
797
+ },
798
+ )
799
+
800
+
801
+ register_model_group(
802
+ models={
803
+ "OpenChat3.5-7B-Chat": {
804
+ DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
805
+ DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
806
+ }
807
+ },
808
+ template="openchat",
809
+ )
810
+
811
+
812
+ register_model_group(
813
+ models={
814
+ "OpenChat3.6-8B-Chat": {
815
+ DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522",
816
+ }
817
+ },
818
+ template="openchat-3.6",
819
+ )
820
+
821
+
822
+ register_model_group(
823
+ models={
824
+ "Orion-14B-Base": {
825
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
826
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
827
+ },
828
+ "Orion-14B-Chat": {
829
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
830
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
831
+ },
832
+ "Orion-14B-Long-Chat": {
833
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
834
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
835
+ },
836
+ "Orion-14B-RAG-Chat": {
837
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
838
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
839
+ },
840
+ "Orion-14B-Plugin-Chat": {
841
+ DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
842
+ DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
843
+ },
844
+ },
845
+ template="orion",
846
+ )
847
+
848
+
849
+ register_model_group(
850
+ models={
851
+ "PaliGemma-3B-pt-224": {
852
+ DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
853
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
854
+ },
855
+ "PaliGemma-3B-pt-448": {
856
+ DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
857
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
858
+ },
859
+ "PaliGemma-3B-pt-896": {
860
+ DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
861
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
862
+ },
863
+ "PaliGemma-3B-mix-224": {
864
+ DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
865
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
866
+ },
867
+ "PaliGemma-3B-mix-448": {
868
+ DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
869
+ DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
870
+ },
871
+ },
872
+ vision=True,
873
+ )
874
+
875
+
876
+ register_model_group(
877
+ models={
878
+ "Phi-1.5-1.3B": {
879
+ DownloadSource.DEFAULT: "microsoft/phi-1_5",
880
+ DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
881
+ },
882
+ "Phi-2-2.7B": {
883
+ DownloadSource.DEFAULT: "microsoft/phi-2",
884
+ DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
885
+ },
886
+ }
887
+ )
888
+
889
+
890
+ register_model_group(
891
+ models={
892
+ "Phi3-4B-4k-Chat": {
893
+ DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
894
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
895
+ },
896
+ "Phi3-4B-128k-Chat": {
897
+ DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
898
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
899
+ },
900
+ "Phi3-7B-8k-Chat": {
901
+ DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
902
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
903
+ },
904
+ "Phi3-7B-128k-Chat": {
905
+ DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
906
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
907
+ },
908
+ "Phi3-14B-8k-Chat": {
909
+ DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
910
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
911
+ },
912
+ "Phi3-14B-128k-Chat": {
913
+ DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
914
+ DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
915
+ },
916
+ },
917
+ template="phi",
918
+ )
919
+
920
+
921
+ register_model_group(
922
+ models={
923
+ "Qwen-1.8B": {
924
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
925
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B",
926
+ },
927
+ "Qwen-7B": {
928
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B",
929
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B",
930
+ },
931
+ "Qwen-14B": {
932
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B",
933
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B",
934
+ },
935
+ "Qwen-72B": {
936
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B",
937
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B",
938
+ },
939
+ "Qwen-1.8B-Chat": {
940
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
941
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
942
+ },
943
+ "Qwen-7B-Chat": {
944
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
945
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat",
946
+ },
947
+ "Qwen-14B-Chat": {
948
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
949
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
950
+ },
951
+ "Qwen-72B-Chat": {
952
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
953
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
954
+ },
955
+ "Qwen-1.8B-int8-Chat": {
956
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
957
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
958
+ },
959
+ "Qwen-1.8B-int4-Chat": {
960
+ DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
961
+ DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
962
+ },
963
+ "Qwen-7B-int8-Chat": {
964
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
965
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
966
+ },
967
+ "Qwen-7B-int4-Chat": {
968
+ DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
969
+ DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
970
+ },
971
+ "Qwen-14B-int8-Chat": {
972
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
973
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
974
+ },
975
+ "Qwen-14B-int4-Chat": {
976
+ DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
977
+ DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
978
+ },
979
+ "Qwen-72B-int8-Chat": {
980
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
981
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
982
+ },
983
+ "Qwen-72B-int4-Chat": {
984
+ DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
985
+ DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
986
+ },
987
+ },
988
+ template="qwen",
989
+ )
990
+
991
+
992
+ register_model_group(
993
+ models={
994
+ "Qwen1.5-0.5B": {
995
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
996
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B",
997
+ },
998
+ "Qwen1.5-1.8B": {
999
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
1000
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B",
1001
+ },
1002
+ "Qwen1.5-4B": {
1003
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
1004
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B",
1005
+ },
1006
+ "Qwen1.5-7B": {
1007
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
1008
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B",
1009
+ },
1010
+ "Qwen1.5-14B": {
1011
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
1012
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
1013
+ },
1014
+ "Qwen1.5-32B": {
1015
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
1016
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B",
1017
+ },
1018
+ "Qwen1.5-72B": {
1019
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
1020
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
1021
+ },
1022
+ "Qwen1.5-110B": {
1023
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
1024
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
1025
+ },
1026
+ "Qwen1.5-MoE-A2.7B": {
1027
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
1028
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
1029
+ },
1030
+ "Qwen1.5-Code-7B": {
1031
+ DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
1032
+ DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
1033
+ },
1034
+ "Qwen1.5-0.5B-Chat": {
1035
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
1036
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
1037
+ },
1038
+ "Qwen1.5-1.8B-Chat": {
1039
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
1040
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat",
1041
+ },
1042
+ "Qwen1.5-4B-Chat": {
1043
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
1044
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat",
1045
+ },
1046
+ "Qwen1.5-7B-Chat": {
1047
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
1048
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat",
1049
+ },
1050
+ "Qwen1.5-14B-Chat": {
1051
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
1052
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
1053
+ },
1054
+ "Qwen1.5-32B-Chat": {
1055
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
1056
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat",
1057
+ },
1058
+ "Qwen1.5-72B-Chat": {
1059
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
1060
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
1061
+ },
1062
+ "Qwen1.5-110B-Chat": {
1063
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
1064
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
1065
+ },
1066
+ "Qwen1.5-MoE-A2.7B-Chat": {
1067
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
1068
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
1069
+ },
1070
+ "Qwen1.5-Code-7B-Chat": {
1071
+ DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
1072
+ DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
1073
+ },
1074
+ "Qwen1.5-0.5B-int8-Chat": {
1075
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
1076
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
1077
+ },
1078
+ "Qwen1.5-0.5B-int4-Chat": {
1079
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
1080
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
1081
+ },
1082
+ "Qwen1.5-1.8B-int8-Chat": {
1083
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
1084
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
1085
+ },
1086
+ "Qwen1.5-1.8B-int4-Chat": {
1087
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
1088
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
1089
+ },
1090
+ "Qwen1.5-4B-int8-Chat": {
1091
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
1092
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
1093
+ },
1094
+ "Qwen1.5-4B-int4-Chat": {
1095
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
1096
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
1097
+ },
1098
+ "Qwen1.5-7B-int8-Chat": {
1099
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
1100
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
1101
+ },
1102
+ "Qwen1.5-7B-int4-Chat": {
1103
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
1104
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
1105
+ },
1106
+ "Qwen1.5-14B-int8-Chat": {
1107
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
1108
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
1109
+ },
1110
+ "Qwen1.5-14B-int4-Chat": {
1111
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
1112
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
1113
+ },
1114
+ "Qwen1.5-32B-int4-Chat": {
1115
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
1116
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
1117
+ },
1118
+ "Qwen1.5-72B-int8-Chat": {
1119
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
1120
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
1121
+ },
1122
+ "Qwen1.5-72B-int4-Chat": {
1123
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
1124
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
1125
+ },
1126
+ "Qwen1.5-110B-int4-Chat": {
1127
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
1128
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
1129
+ },
1130
+ "Qwen1.5-MoE-A2.7B-int4-Chat": {
1131
+ DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
1132
+ DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
1133
+ },
1134
+ "Qwen1.5-Code-7B-int4-Chat": {
1135
+ DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
1136
+ DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
1137
+ },
1138
+ },
1139
+ template="qwen",
1140
+ )
1141
+
1142
+
1143
+ register_model_group(
1144
+ models={
1145
+ "Qwen2-0.5B": {
1146
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B",
1147
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B",
1148
+ },
1149
+ "Qwen2-1.5B": {
1150
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B",
1151
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B",
1152
+ },
1153
+ "Qwen2-7B": {
1154
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B",
1155
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B",
1156
+ },
1157
+ "Qwen2-72B": {
1158
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B",
1159
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B",
1160
+ },
1161
+ "Qwen2-MoE-57B": {
1162
+ DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
1163
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
1164
+ },
1165
+ "Qwen2-0.5B-Chat": {
1166
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
1167
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
1168
+ },
1169
+ "Qwen2-1.5B-Chat": {
1170
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
1171
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
1172
+ },
1173
+ "Qwen2-7B-Chat": {
1174
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
1175
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
1176
+ },
1177
+ "Qwen2-72B-Chat": {
1178
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
1179
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct",
1180
+ },
1181
+ "Qwen2-MoE-57B-Chat": {
1182
+ DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
1183
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
1184
+ },
1185
+ "Qwen2-0.5B-int8-Chat": {
1186
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
1187
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
1188
+ },
1189
+ "Qwen2-0.5B-int4-Chat": {
1190
+ DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ",
1191
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ",
1192
+ },
1193
+ "Qwen2-1.5B-int8-Chat": {
1194
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
1195
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
1196
+ },
1197
+ "Qwen2-1.5B-int4-Chat": {
1198
+ DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ",
1199
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ",
1200
+ },
1201
+ "Qwen2-7B-int8-Chat": {
1202
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
1203
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8",
1204
+ },
1205
+ "Qwen2-7B-int4-Chat": {
1206
+ DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ",
1207
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ",
1208
+ },
1209
+ "Qwen2-72B-int8-Chat": {
1210
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
1211
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8",
1212
+ },
1213
+ "Qwen2-72B-int4-Chat": {
1214
+ DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ",
1215
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ",
1216
+ },
1217
+ "Qwen2-MoE-57B-int4-Chat": {
1218
+ DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
1219
+ DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
1220
+ },
1221
+ },
1222
+ template="qwen",
1223
+ )
1224
+
1225
+
1226
+ register_model_group(
1227
+ models={
1228
+ "SOLAR-10.7B": {
1229
+ DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
1230
+ },
1231
+ "SOLAR-10.7B-Chat": {
1232
+ DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
1233
+ DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
1234
+ },
1235
+ },
1236
+ template="solar",
1237
+ )
1238
+
1239
+
1240
+ register_model_group(
1241
+ models={
1242
+ "Skywork-13B-Base": {
1243
+ DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
1244
+ DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
1245
+ }
1246
+ }
1247
+ )
1248
+
1249
+
1250
+ register_model_group(
1251
+ models={
1252
+ "StarCoder2-3B": {
1253
+ DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
1254
+ DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
1255
+ },
1256
+ "StarCoder2-7B": {
1257
+ DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
1258
+ DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
1259
+ },
1260
+ "StarCoder2-15B": {
1261
+ DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
1262
+ DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
1263
+ },
1264
+ }
1265
+ )
1266
+
1267
+
1268
+ register_model_group(
1269
+ models={
1270
+ "TeleChat-1B-Chat": {
1271
+ DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B",
1272
+ DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B",
1273
+ },
1274
+ "TeleChat-7B-Chat": {
1275
+ DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
1276
+ DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
1277
+ },
1278
+ "TeleChat-12B-Chat": {
1279
+ DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
1280
+ DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
1281
+ },
1282
+ "TeleChat-12B-v2-Chat": {
1283
+ DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
1284
+ DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
1285
+ },
1286
+ },
1287
+ template="telechat",
1288
+ )
1289
+
1290
+
1291
+ register_model_group(
1292
+ models={
1293
+ "Vicuna1.5-7B-Chat": {
1294
+ DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
1295
+ DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
1296
+ },
1297
+ "Vicuna1.5-13B-Chat": {
1298
+ DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
1299
+ DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
1300
+ },
1301
+ },
1302
+ template="vicuna",
1303
+ )
1304
+
1305
+
1306
+ register_model_group(
1307
+ models={
1308
+ "XuanYuan-6B": {
1309
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
1310
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
1311
+ },
1312
+ "XuanYuan-70B": {
1313
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
1314
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
1315
+ },
1316
+ "XuanYuan-2-70B": {
1317
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
1318
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
1319
+ },
1320
+ "XuanYuan-6B-Chat": {
1321
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
1322
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
1323
+ },
1324
+ "XuanYuan-70B-Chat": {
1325
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
1326
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
1327
+ },
1328
+ "XuanYuan-2-70B-Chat": {
1329
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
1330
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
1331
+ },
1332
+ "XuanYuan-6B-int8-Chat": {
1333
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
1334
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
1335
+ },
1336
+ "XuanYuan-6B-int4-Chat": {
1337
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
1338
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
1339
+ },
1340
+ "XuanYuan-70B-int8-Chat": {
1341
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
1342
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
1343
+ },
1344
+ "XuanYuan-70B-int4-Chat": {
1345
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
1346
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
1347
+ },
1348
+ "XuanYuan-2-70B-int8-Chat": {
1349
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
1350
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
1351
+ },
1352
+ "XuanYuan-2-70B-int4-Chat": {
1353
+ DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
1354
+ DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
1355
+ },
1356
+ },
1357
+ template="xuanyuan",
1358
+ )
1359
+
1360
+
1361
+ register_model_group(
1362
+ models={
1363
+ "XVERSE-7B": {
1364
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B",
1365
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
1366
+ },
1367
+ "XVERSE-13B": {
1368
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B",
1369
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
1370
+ },
1371
+ "XVERSE-65B": {
1372
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B",
1373
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
1374
+ },
1375
+ "XVERSE-65B-2": {
1376
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
1377
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
1378
+ },
1379
+ "XVERSE-7B-Chat": {
1380
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
1381
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
1382
+ },
1383
+ "XVERSE-13B-Chat": {
1384
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
1385
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
1386
+ },
1387
+ "XVERSE-65B-Chat": {
1388
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
1389
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
1390
+ },
1391
+ "XVERSE-MoE-A4.2B": {
1392
+ DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
1393
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
1394
+ },
1395
+ "XVERSE-7B-int8-Chat": {
1396
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
1397
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
1398
+ },
1399
+ "XVERSE-7B-int4-Chat": {
1400
+ DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
1401
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
1402
+ },
1403
+ "XVERSE-13B-int8-Chat": {
1404
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
1405
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
1406
+ },
1407
+ "XVERSE-13B-int4-Chat": {
1408
+ DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
1409
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
1410
+ },
1411
+ "XVERSE-65B-int4-Chat": {
1412
+ DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
1413
+ DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
1414
+ },
1415
+ },
1416
+ template="xverse",
1417
+ )
1418
+
1419
+
1420
+ register_model_group(
1421
+ models={
1422
+ "Yayi-7B": {
1423
+ DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
1424
+ DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
1425
+ },
1426
+ "Yayi-13B": {
1427
+ DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
1428
+ DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
1429
+ },
1430
+ },
1431
+ template="yayi",
1432
+ )
1433
+
1434
+
1435
+ register_model_group(
1436
+ models={
1437
+ "Yi-6B": {
1438
+ DownloadSource.DEFAULT: "01-ai/Yi-6B",
1439
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B",
1440
+ },
1441
+ "Yi-9B": {
1442
+ DownloadSource.DEFAULT: "01-ai/Yi-9B",
1443
+ DownloadSource.MODELSCOPE: "01ai/Yi-9B",
1444
+ },
1445
+ "Yi-34B": {
1446
+ DownloadSource.DEFAULT: "01-ai/Yi-34B",
1447
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B",
1448
+ },
1449
+ "Yi-6B-Chat": {
1450
+ DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
1451
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat",
1452
+ },
1453
+ "Yi-34B-Chat": {
1454
+ DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
1455
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
1456
+ },
1457
+ "Yi-6B-int8-Chat": {
1458
+ DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
1459
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
1460
+ },
1461
+ "Yi-6B-int4-Chat": {
1462
+ DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
1463
+ DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
1464
+ },
1465
+ "Yi-34B-int8-Chat": {
1466
+ DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
1467
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
1468
+ },
1469
+ "Yi-34B-int4-Chat": {
1470
+ DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
1471
+ DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
1472
+ },
1473
+ "Yi-1.5-6B": {
1474
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B",
1475
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B",
1476
+ },
1477
+ "Yi-1.5-9B": {
1478
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B",
1479
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B",
1480
+ },
1481
+ "Yi-1.5-34B": {
1482
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B",
1483
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B",
1484
+ },
1485
+ "Yi-1.5-6B-Chat": {
1486
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
1487
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
1488
+ },
1489
+ "Yi-1.5-9B-Chat": {
1490
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
1491
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat",
1492
+ },
1493
+ "Yi-1.5-34B-Chat": {
1494
+ DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
1495
+ DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
1496
+ },
1497
+ },
1498
+ template="yi",
1499
+ )
1500
+
1501
+
1502
+ register_model_group(
1503
+ models={
1504
+ "YiVL-6B-Chat": {
1505
+ DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
1506
+ },
1507
+ "YiVL-34B-Chat": {
1508
+ DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
1509
+ },
1510
+ },
1511
+ template="yi_vl",
1512
+ vision=True,
1513
+ )
1514
+
1515
+
1516
+ register_model_group(
1517
+ models={
1518
+ "Yuan2-2B-Chat": {
1519
+ DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
1520
+ DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
1521
+ },
1522
+ "Yuan2-51B-Chat": {
1523
+ DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
1524
+ DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
1525
+ },
1526
+ "Yuan2-102B-Chat": {
1527
+ DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
1528
+ DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
1529
+ },
1530
+ },
1531
+ template="yuan",
1532
+ )
1533
+
1534
+
1535
+ register_model_group(
1536
+ models={
1537
+ "Zephyr-7B-Alpha-Chat": {
1538
+ DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
1539
+ DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
1540
+ },
1541
+ "Zephyr-7B-Beta-Chat": {
1542
+ DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
1543
+ DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
1544
+ },
1545
+ "Zephyr-141B-ORPO-Chat": {
1546
+ DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
1547
+ },
1548
+ },
1549
+ template="zephyr",
1550
+ )
src/extras/env.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the HuggingFace's transformers library.
4
+ # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import platform
19
+
20
+ import accelerate
21
+ import datasets
22
+ import peft
23
+ import torch
24
+ import transformers
25
+ import trl
26
+ from transformers.utils import is_torch_cuda_available, is_torch_npu_available
27
+
28
+
29
+ VERSION = "0.8.3.dev0"
30
+
31
+
32
+ def print_env() -> None:
33
+ info = {
34
+ "`llamafactory` version": VERSION,
35
+ "Platform": platform.platform(),
36
+ "Python version": platform.python_version(),
37
+ "PyTorch version": torch.__version__,
38
+ "Transformers version": transformers.__version__,
39
+ "Datasets version": datasets.__version__,
40
+ "Accelerate version": accelerate.__version__,
41
+ "PEFT version": peft.__version__,
42
+ "TRL version": trl.__version__,
43
+ }
44
+
45
+ if is_torch_cuda_available():
46
+ info["PyTorch version"] += " (GPU)"
47
+ info["GPU type"] = torch.cuda.get_device_name()
48
+
49
+ if is_torch_npu_available():
50
+ info["PyTorch version"] += " (NPU)"
51
+ info["NPU type"] = torch.npu.get_device_name()
52
+ info["CANN version"] = torch.version.cann
53
+
54
+ try:
55
+ import deepspeed # type: ignore
56
+
57
+ info["DeepSpeed version"] = deepspeed.__version__
58
+ except Exception:
59
+ pass
60
+
61
+ try:
62
+ import bitsandbytes
63
+
64
+ info["Bitsandbytes version"] = bitsandbytes.__version__
65
+ except Exception:
66
+ pass
67
+
68
+ try:
69
+ import vllm
70
+
71
+ info["vLLM version"] = vllm.__version__
72
+ except Exception:
73
+ pass
74
+
75
+ print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
src/extras/logging.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import logging
16
+ import os
17
+ import sys
18
+ from concurrent.futures import ThreadPoolExecutor
19
+
20
+ from .constants import RUNNING_LOG
21
+
22
+
23
+ class LoggerHandler(logging.Handler):
24
+ r"""
25
+ Logger handler used in Web UI.
26
+ """
27
+
28
+ def __init__(self, output_dir: str) -> None:
29
+ super().__init__()
30
+ formatter = logging.Formatter(
31
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
32
+ )
33
+ self.setLevel(logging.INFO)
34
+ self.setFormatter(formatter)
35
+
36
+ os.makedirs(output_dir, exist_ok=True)
37
+ self.running_log = os.path.join(output_dir, RUNNING_LOG)
38
+ if os.path.exists(self.running_log):
39
+ os.remove(self.running_log)
40
+
41
+ self.thread_pool = ThreadPoolExecutor(max_workers=1)
42
+
43
+ def _write_log(self, log_entry: str) -> None:
44
+ with open(self.running_log, "a", encoding="utf-8") as f:
45
+ f.write(log_entry + "\n\n")
46
+
47
+ def emit(self, record) -> None:
48
+ if record.name == "httpx":
49
+ return
50
+
51
+ log_entry = self.format(record)
52
+ self.thread_pool.submit(self._write_log, log_entry)
53
+
54
+ def close(self) -> None:
55
+ self.thread_pool.shutdown(wait=True)
56
+ return super().close()
57
+
58
+
59
+ def get_logger(name: str) -> logging.Logger:
60
+ r"""
61
+ Gets a standard logger with a stream hander to stdout.
62
+ """
63
+ formatter = logging.Formatter(
64
+ fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
65
+ )
66
+ handler = logging.StreamHandler(sys.stdout)
67
+ handler.setFormatter(formatter)
68
+
69
+ logger = logging.getLogger(name)
70
+ logger.setLevel(logging.INFO)
71
+ logger.addHandler(handler)
72
+
73
+ return logger
74
+
75
+
76
+ def reset_logging() -> None:
77
+ r"""
78
+ Removes basic config of root logger. (unused in script)
79
+ """
80
+ root = logging.getLogger()
81
+ list(map(root.removeHandler, root.handlers))
82
+ list(map(root.removeFilter, root.filters))
src/extras/misc.py ADDED
@@ -0,0 +1,215 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the HuggingFace's PEFT library.
4
+ # https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import gc
19
+ import os
20
+ from typing import TYPE_CHECKING, Tuple
21
+
22
+ import torch
23
+ import transformers.dynamic_module_utils
24
+ from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
25
+ from transformers.dynamic_module_utils import get_relative_imports
26
+ from transformers.utils import (
27
+ is_torch_bf16_gpu_available,
28
+ is_torch_cuda_available,
29
+ is_torch_mps_available,
30
+ is_torch_npu_available,
31
+ is_torch_xpu_available,
32
+ )
33
+ from transformers.utils.versions import require_version
34
+
35
+ from .logging import get_logger
36
+
37
+
38
+ _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
39
+ try:
40
+ _is_bf16_available = is_torch_bf16_gpu_available()
41
+ except Exception:
42
+ _is_bf16_available = False
43
+
44
+
45
+ if TYPE_CHECKING:
46
+ from ..hparams import ModelArguments
47
+
48
+
49
+ logger = get_logger(__name__)
50
+
51
+
52
+ class AverageMeter:
53
+ r"""
54
+ Computes and stores the average and current value.
55
+ """
56
+
57
+ def __init__(self):
58
+ self.reset()
59
+
60
+ def reset(self):
61
+ self.val = 0
62
+ self.avg = 0
63
+ self.sum = 0
64
+ self.count = 0
65
+
66
+ def update(self, val, n=1):
67
+ self.val = val
68
+ self.sum += val * n
69
+ self.count += n
70
+ self.avg = self.sum / self.count
71
+
72
+
73
+ def check_dependencies() -> None:
74
+ r"""
75
+ Checks the version of the required packages.
76
+ """
77
+ if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
78
+ logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
79
+ else:
80
+ require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
81
+ require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0")
82
+ require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1")
83
+ require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1")
84
+ require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
85
+
86
+
87
+ def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
88
+ r"""
89
+ Returns the number of trainable parameters and number of all parameters in the model.
90
+ """
91
+ trainable_params, all_param = 0, 0
92
+ for param in model.parameters():
93
+ num_params = param.numel()
94
+ # if using DS Zero 3 and the weights are initialized empty
95
+ if num_params == 0 and hasattr(param, "ds_numel"):
96
+ num_params = param.ds_numel
97
+
98
+ # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
99
+ if param.__class__.__name__ == "Params4bit":
100
+ if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
101
+ num_bytes = param.quant_storage.itemsize
102
+ elif hasattr(param, "element_size"): # for older pytorch version
103
+ num_bytes = param.element_size()
104
+ else:
105
+ num_bytes = 1
106
+
107
+ num_params = num_params * 2 * num_bytes
108
+
109
+ all_param += num_params
110
+ if param.requires_grad:
111
+ trainable_params += num_params
112
+
113
+ return trainable_params, all_param
114
+
115
+
116
+ def get_current_device() -> "torch.device":
117
+ r"""
118
+ Gets the current available device.
119
+ """
120
+ if is_torch_xpu_available():
121
+ device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
122
+ elif is_torch_npu_available():
123
+ device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
124
+ elif is_torch_mps_available():
125
+ device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
126
+ elif is_torch_cuda_available():
127
+ device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
128
+ else:
129
+ device = "cpu"
130
+
131
+ return torch.device(device)
132
+
133
+
134
+ def get_device_count() -> int:
135
+ r"""
136
+ Gets the number of available GPU or NPU devices.
137
+ """
138
+ if is_torch_npu_available():
139
+ return torch.npu.device_count()
140
+ elif is_torch_cuda_available():
141
+ return torch.cuda.device_count()
142
+ else:
143
+ return 0
144
+
145
+
146
+ def get_logits_processor() -> "LogitsProcessorList":
147
+ r"""
148
+ Gets logits processor that removes NaN and Inf logits.
149
+ """
150
+ logits_processor = LogitsProcessorList()
151
+ logits_processor.append(InfNanRemoveLogitsProcessor())
152
+ return logits_processor
153
+
154
+
155
+ def has_tokenized_data(path: "os.PathLike") -> bool:
156
+ r"""
157
+ Checks if the path has a tokenized dataset.
158
+ """
159
+ return os.path.isdir(path) and len(os.listdir(path)) > 0
160
+
161
+
162
+ def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
163
+ r"""
164
+ Infers the optimal dtype according to the model_dtype and device compatibility.
165
+ """
166
+ if _is_bf16_available and model_dtype == torch.bfloat16:
167
+ return torch.bfloat16
168
+ elif _is_fp16_available:
169
+ return torch.float16
170
+ else:
171
+ return torch.float32
172
+
173
+
174
+ def is_gpu_or_npu_available() -> bool:
175
+ r"""
176
+ Checks if the GPU or NPU is available.
177
+ """
178
+ return is_torch_npu_available() or is_torch_cuda_available()
179
+
180
+
181
+ def skip_check_imports() -> None:
182
+ if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
183
+ transformers.dynamic_module_utils.check_imports = get_relative_imports
184
+
185
+
186
+ def torch_gc() -> None:
187
+ r"""
188
+ Collects GPU or NPU memory.
189
+ """
190
+ gc.collect()
191
+ if is_torch_xpu_available():
192
+ torch.xpu.empty_cache()
193
+ elif is_torch_npu_available():
194
+ torch.npu.empty_cache()
195
+ elif is_torch_mps_available():
196
+ torch.mps.empty_cache()
197
+ elif is_torch_cuda_available():
198
+ torch.cuda.empty_cache()
199
+
200
+
201
+ def try_download_model_from_ms(model_args: "ModelArguments") -> str:
202
+ if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
203
+ return model_args.model_name_or_path
204
+
205
+ try:
206
+ from modelscope import snapshot_download
207
+
208
+ revision = "master" if model_args.model_revision == "main" else model_args.model_revision
209
+ return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
210
+ except ImportError:
211
+ raise ImportError("Please install modelscope via `pip install modelscope -U`")
212
+
213
+
214
+ def use_modelscope() -> bool:
215
+ return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
src/extras/packages.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
2
+ #
3
+ # This code is inspired by the HuggingFace's transformers library.
4
+ # https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ import importlib.metadata
19
+ import importlib.util
20
+ from functools import lru_cache
21
+ from typing import TYPE_CHECKING
22
+
23
+ from packaging import version
24
+
25
+
26
+ if TYPE_CHECKING:
27
+ from packaging.version import Version
28
+
29
+
30
+ def _is_package_available(name: str) -> bool:
31
+ return importlib.util.find_spec(name) is not None
32
+
33
+
34
+ def _get_package_version(name: str) -> "Version":
35
+ try:
36
+ return version.parse(importlib.metadata.version(name))
37
+ except Exception:
38
+ return version.parse("0.0.0")
39
+
40
+
41
+ def is_fastapi_available():
42
+ return _is_package_available("fastapi")
43
+
44
+
45
+ def is_galore_available():
46
+ return _is_package_available("galore_torch")
47
+
48
+
49
+ def is_gradio_available():
50
+ return _is_package_available("gradio")
51
+
52
+
53
+ def is_matplotlib_available():
54
+ return _is_package_available("matplotlib")
55
+
56
+
57
+ def is_pillow_available():
58
+ return _is_package_available("PIL")
59
+
60
+
61
+ def is_requests_available():
62
+ return _is_package_available("requests")
63
+
64
+
65
+ def is_rouge_available():
66
+ return _is_package_available("rouge_chinese")
67
+
68
+
69
+ def is_starlette_available():
70
+ return _is_package_available("sse_starlette")
71
+
72
+
73
+ def is_uvicorn_available():
74
+ return _is_package_available("uvicorn")
75
+
76
+
77
+ def is_vllm_available():
78
+ return _is_package_available("vllm")
79
+
80
+
81
+ @lru_cache
82
+ def is_vllm_version_greater_than_0_5():
83
+ return _get_package_version("vllm") >= version.parse("0.5.0")
src/extras/ploting.py ADDED
@@ -0,0 +1,101 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import json
16
+ import math
17
+ import os
18
+ from typing import Any, Dict, List
19
+
20
+ from transformers.trainer import TRAINER_STATE_NAME
21
+
22
+ from .logging import get_logger
23
+ from .packages import is_matplotlib_available
24
+
25
+
26
+ if is_matplotlib_available():
27
+ import matplotlib.figure
28
+ import matplotlib.pyplot as plt
29
+
30
+
31
+ logger = get_logger(__name__)
32
+
33
+
34
+ def smooth(scalars: List[float]) -> List[float]:
35
+ r"""
36
+ EMA implementation according to TensorBoard.
37
+ """
38
+ if len(scalars) == 0:
39
+ return []
40
+
41
+ last = scalars[0]
42
+ smoothed = []
43
+ weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
44
+ for next_val in scalars:
45
+ smoothed_val = last * weight + (1 - weight) * next_val
46
+ smoothed.append(smoothed_val)
47
+ last = smoothed_val
48
+ return smoothed
49
+
50
+
51
+ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
52
+ r"""
53
+ Plots loss curves in LlamaBoard.
54
+ """
55
+ plt.close("all")
56
+ plt.switch_backend("agg")
57
+ fig = plt.figure()
58
+ ax = fig.add_subplot(111)
59
+ steps, losses = [], []
60
+ for log in trainer_log:
61
+ if log.get("loss", None):
62
+ steps.append(log["current_steps"])
63
+ losses.append(log["loss"])
64
+
65
+ ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
66
+ ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
67
+ ax.legend()
68
+ ax.set_xlabel("step")
69
+ ax.set_ylabel("loss")
70
+ return fig
71
+
72
+
73
+ def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
74
+ r"""
75
+ Plots loss curves and saves the image.
76
+ """
77
+ plt.switch_backend("agg")
78
+ with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
79
+ data = json.load(f)
80
+
81
+ for key in keys:
82
+ steps, metrics = [], []
83
+ for i in range(len(data["log_history"])):
84
+ if key in data["log_history"][i]:
85
+ steps.append(data["log_history"][i]["step"])
86
+ metrics.append(data["log_history"][i][key])
87
+
88
+ if len(metrics) == 0:
89
+ logger.warning(f"No metric {key} to plot.")
90
+ continue
91
+
92
+ plt.figure()
93
+ plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
94
+ plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
95
+ plt.title("training {} of {}".format(key, save_dictionary))
96
+ plt.xlabel("step")
97
+ plt.ylabel(key)
98
+ plt.legend()
99
+ figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
100
+ plt.savefig(figure_path, format="png", dpi=100)
101
+ print("Figure saved at:", figure_path)
src/hparams/__init__.py ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from .data_args import DataArguments
16
+ from .evaluation_args import EvaluationArguments
17
+ from .finetuning_args import FinetuningArguments
18
+ from .generating_args import GeneratingArguments
19
+ from .model_args import ModelArguments
20
+ from .parser import get_infer_args, get_train_args
21
+
22
+
23
+ __all__ = [
24
+ "DataArguments",
25
+ "EvaluationArguments",
26
+ "FinetuningArguments",
27
+ "GeneratingArguments",
28
+ "ModelArguments",
29
+ "get_infer_args",
30
+ "get_train_args",
31
+ ]
src/hparams/data_args.py ADDED
@@ -0,0 +1,110 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 HuggingFace Inc., the LlamaFactory team, and the Llamole team.
2
+ #
3
+ # This code is inspired by the HuggingFace's transformers library.
4
+ # https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
5
+ #
6
+ # Licensed under the Apache License, Version 2.0 (the "License");
7
+ # you may not use this file except in compliance with the License.
8
+ # You may obtain a copy of the License at
9
+ #
10
+ # http://www.apache.org/licenses/LICENSE-2.0
11
+ #
12
+ # Unless required by applicable law or agreed to in writing, software
13
+ # distributed under the License is distributed on an "AS IS" BASIS,
14
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
15
+ # See the License for the specific language governing permissions and
16
+ # limitations under the License.
17
+
18
+ from dataclasses import dataclass, field
19
+ from typing import Literal, Optional
20
+
21
+
22
+ @dataclass
23
+ class DataArguments:
24
+ r"""
25
+ Arguments pertaining to what data we are going to input our model for training and evaluation.
26
+ """
27
+
28
+ template: Optional[str] = field(
29
+ default=None,
30
+ metadata={"help": "Which template to use for constructing prompts in training and inference."},
31
+ )
32
+ dataset: Optional[str] = field(
33
+ default=None,
34
+ metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
35
+ )
36
+ dataset_dir: str = field(
37
+ default="data",
38
+ metadata={"help": "Path to the folder containing the datasets."},
39
+ )
40
+ split: str = field(
41
+ default="train",
42
+ metadata={"help": "Which dataset split to use for training and evaluation."},
43
+ )
44
+ cutoff_len: int = field(
45
+ default=1024,
46
+ metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
47
+ )
48
+ train_on_prompt: bool = field(
49
+ default=False,
50
+ metadata={"help": "Whether to disable the mask on the prompt or not."},
51
+ )
52
+ # number of learnable queries
53
+ learned_query_size: int = field(
54
+ default=8,
55
+ metadata={"help": "Number of queries to use in training."},
56
+ )
57
+
58
+ # used for evaluation with specific starting idx
59
+ start_idx: int = field(
60
+ default=0,
61
+ metadata={"help": "Start index of the evaluation data for drug"},
62
+ )
63
+ buffer_size: int = field(
64
+ default=16384,
65
+ metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
66
+ )
67
+ mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
68
+ default="concat",
69
+ metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
70
+ )
71
+ overwrite_cache: bool = field(
72
+ default=False,
73
+ metadata={"help": "Overwrite the cached training and evaluation sets."},
74
+ )
75
+ preprocessing_num_workers: Optional[int] = field(
76
+ default=None,
77
+ metadata={"help": "The number of processes to use for the pre-processing."},
78
+ )
79
+ max_samples: Optional[int] = field(
80
+ default=None,
81
+ metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
82
+ )
83
+ eval_num_beams: Optional[int] = field(
84
+ default=None,
85
+ metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
86
+ )
87
+ ignore_pad_token_for_loss: bool = field(
88
+ default=True,
89
+ metadata={
90
+ "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
91
+ },
92
+ )
93
+ val_size: float = field(
94
+ default=0.0,
95
+ metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
96
+ )
97
+ packing: Optional[bool] = field(
98
+ default=None,
99
+ metadata={
100
+ "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
101
+ },
102
+ )
103
+ tool_format: Optional[str] = field(
104
+ default=None,
105
+ metadata={"help": "Tool format to use for constructing function calling examples."},
106
+ )
107
+ tokenized_path: Optional[str] = field(
108
+ default=None,
109
+ metadata={"help": "Path to save or load the tokenized datasets."},
110
+ )
src/hparams/evaluation_args.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import os
16
+ from dataclasses import dataclass, field
17
+ from typing import Literal, Optional
18
+
19
+ from datasets import DownloadMode
20
+
21
+
22
+ @dataclass
23
+ class EvaluationArguments:
24
+ r"""
25
+ Arguments pertaining to specify the evaluation parameters.
26
+ """
27
+
28
+ task: str = field(
29
+ metadata={"help": "Name of the evaluation task."},
30
+ )
31
+ task_dir: str = field(
32
+ default="evaluation",
33
+ metadata={"help": "Path to the folder containing the evaluation datasets."},
34
+ )
35
+ batch_size: int = field(
36
+ default=4,
37
+ metadata={"help": "The batch size per GPU for evaluation."},
38
+ )
39
+ seed: int = field(
40
+ default=42,
41
+ metadata={"help": "Random seed to be used with data loaders."},
42
+ )
43
+ lang: Literal["en", "zh"] = field(
44
+ default="en",
45
+ metadata={"help": "Language used at evaluation."},
46
+ )
47
+ n_shot: int = field(
48
+ default=5,
49
+ metadata={"help": "Number of examplars for few-shot learning."},
50
+ )
51
+ save_dir: Optional[str] = field(
52
+ default=None,
53
+ metadata={"help": "Path to save the evaluation results."},
54
+ )
55
+ download_mode: DownloadMode = field(
56
+ default=DownloadMode.REUSE_DATASET_IF_EXISTS,
57
+ metadata={"help": "Download mode used for the evaluation datasets."},
58
+ )
59
+
60
+ def __post_init__(self):
61
+ if self.save_dir is not None and os.path.exists(self.save_dir):
62
+ raise ValueError("`save_dir` already exists, use another one.")
src/hparams/finetuning_args.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2024 the LlamaFactory team.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ from dataclasses import dataclass, field
16
+ from typing import List, Literal, Optional
17
+
18
+
19
+ @dataclass
20
+ class FreezeArguments:
21
+ r"""
22
+ Arguments pertaining to the freeze (partial-parameter) training.
23
+ """
24
+
25
+ freeze_trainable_layers: int = field(
26
+ default=2,
27
+ metadata={
28
+ "help": (
29
+ "The number of trainable layers for freeze (partial-parameter) fine-tuning. "
30
+ "Positive numbers mean the last n layers are set as trainable, "
31
+ "negative numbers mean the first n layers are set as trainable."
32
+ )
33
+ },
34
+ )
35
+ freeze_trainable_modules: str = field(
36
+ default="all",
37
+ metadata={
38
+ "help": (
39
+ "Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
40
+ "Use commas to separate multiple modules. "
41
+ "Use `all` to specify all the available modules."
42
+ )
43
+ },
44
+ )
45
+ freeze_extra_modules: Optional[str] = field(
46
+ default=None,
47
+ metadata={
48
+ "help": (
49
+ "Name(s) of modules apart from hidden layers to be set as trainable "
50
+ "for freeze (partial-parameter) fine-tuning. "
51
+ "Use commas to separate multiple modules."
52
+ )
53
+ },
54
+ )
55
+
56
+
57
+ @dataclass
58
+ class LoraArguments:
59
+ r"""
60
+ Arguments pertaining to the LoRA training.
61
+ """
62
+
63
+ additional_target: Optional[str] = field(
64
+ default=None,
65
+ metadata={
66
+ "help": (
67
+ "Name(s) of modules apart from LoRA layers to be set as trainable "
68
+ "and saved in the final checkpoint. "
69
+ "Use commas to separate multiple modules."
70
+ )
71
+ },
72
+ )
73
+ lora_alpha: Optional[int] = field(
74
+ default=None,
75
+ metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
76
+ )
77
+ lora_dropout: float = field(
78
+ default=0.0,
79
+ metadata={"help": "Dropout rate for the LoRA fine-tuning."},
80
+ )
81
+ lora_rank: int = field(
82
+ default=8,
83
+ metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
84
+ )
85
+ lora_target: str = field(
86
+ default="all",
87
+ metadata={
88
+ "help": (
89
+ "Name(s) of target modules to apply LoRA. "
90
+ "Use commas to separate multiple modules. "
91
+ "Use `all` to specify all the linear modules."
92
+ )
93
+ },
94
+ )
95
+ loraplus_lr_ratio: Optional[float] = field(
96
+ default=None,
97
+ metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
98
+ )
99
+ loraplus_lr_embedding: float = field(
100
+ default=1e-6,
101
+ metadata={"help": "LoRA plus learning rate for lora embedding layers."},
102
+ )
103
+ use_rslora: bool = field(
104
+ default=False,
105
+ metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
106
+ )
107
+ use_dora: bool = field(
108
+ default=False,
109
+ metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
110
+ )
111
+ pissa_init: bool = field(
112
+ default=False,
113
+ metadata={"help": "Whether or not to initialize a PiSSA adapter."},
114
+ )
115
+ pissa_iter: int = field(
116
+ default=16,
117
+ metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
118
+ )
119
+ pissa_convert: bool = field(
120
+ default=False,
121
+ metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
122
+ )
123
+ create_new_adapter: bool = field(
124
+ default=False,
125
+ metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
126
+ )
127
+
128
+
129
+ @dataclass
130
+ class RLHFArguments:
131
+ r"""
132
+ Arguments pertaining to the PPO, DPO and KTO training.
133
+ """
134
+
135
+ pref_beta: float = field(
136
+ default=0.1,
137
+ metadata={"help": "The beta parameter in the preference loss."},
138
+ )
139
+ pref_ftx: float = field(
140
+ default=0.0,
141
+ metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
142
+ )
143
+ pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
144
+ default="sigmoid",
145
+ metadata={"help": "The type of DPO loss to use."},
146
+ )
147
+ dpo_label_smoothing: float = field(
148
+ default=0.0,
149
+ metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
150
+ )
151
+ kto_chosen_weight: float = field(
152
+ default=1.0,
153
+ metadata={"help": "The weight factor of the desirable losses in KTO training."},
154
+ )
155
+ kto_rejected_weight: float = field(
156
+ default=1.0,
157
+ metadata={"help": "The weight factor of the undesirable losses in KTO training."},
158
+ )
159
+ simpo_gamma: float = field(
160
+ default=0.5,
161
+ metadata={"help": "The target reward margin term in SimPO loss."},
162
+ )
163
+ ppo_buffer_size: int = field(
164
+ default=1,
165
+ metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
166
+ )
167
+ ppo_epochs: int = field(
168
+ default=4,
169
+ metadata={"help": "The number of epochs to perform in a PPO optimization step."},
170
+ )
171
+ ppo_score_norm: bool = field(
172
+ default=False,
173
+ metadata={"help": "Use score normalization in PPO training."},
174
+ )
175
+ ppo_target: float = field(
176
+ default=6.0,
177
+ metadata={"help": "Target KL value for adaptive KL control in PPO training."},
178
+ )
179
+ ppo_whiten_rewards: bool = field(
180
+ default=False,
181
+ metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
182
+ )
183
+ ref_model: Optional[str] = field(
184
+ default=None,
185
+ metadata={"help": "Path to the reference model used for the PPO or DPO training."},
186
+ )
187
+ ref_model_adapters: Optional[str] = field(
188
+ default=None,
189
+ metadata={"help": "Path to the adapters of the reference model."},
190
+ )
191
+ ref_model_quantization_bit: Optional[int] = field(
192
+ default=None,
193
+ metadata={"help": "The number of bits to quantize the reference model."},
194
+ )
195
+ reward_model: Optional[str] = field(
196
+ default=None,
197
+ metadata={"help": "Path to the reward model used for the PPO training."},
198
+ )
199
+ reward_model_adapters: Optional[str] = field(
200
+ default=None,
201
+ metadata={"help": "Path to the adapters of the reward model."},
202
+ )
203
+ reward_model_quantization_bit: Optional[int] = field(
204
+ default=None,
205
+ metadata={"help": "The number of bits to quantize the reward model."},
206
+ )
207
+ reward_model_type: Literal["lora", "full", "api"] = field(
208
+ default="lora",
209
+ metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
210
+ )
211
+
212
+
213
+ @dataclass
214
+ class GaloreArguments:
215
+ r"""
216
+ Arguments pertaining to the GaLore algorithm.
217
+ """
218
+
219
+ use_galore: bool = field(
220
+ default=False,
221
+ metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
222
+ )
223
+ galore_target: str = field(
224
+ default="all",
225
+ metadata={
226
+ "help": (
227
+ "Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
228
+ "Use `all` to specify all the linear modules."
229
+ )
230
+ },
231
+ )
232
+ galore_rank: int = field(
233
+ default=16,
234
+ metadata={"help": "The rank of GaLore gradients."},
235
+ )
236
+ galore_update_interval: int = field(
237
+ default=200,
238
+ metadata={"help": "Number of steps to update the GaLore projection."},
239
+ )
240
+ galore_scale: float = field(
241
+ default=0.25,
242
+ metadata={"help": "GaLore scaling coefficient."},
243
+ )
244
+ galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
245
+ default="std",
246
+ metadata={"help": "Type of GaLore projection."},
247
+ )
248
+ galore_layerwise: bool = field(
249
+ default=False,
250
+ metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
251
+ )
252
+
253
+
254
+ @dataclass
255
+ class BAdamArgument:
256
+ r"""
257
+ Arguments pertaining to the BAdam optimizer.
258
+ """
259
+
260
+ use_badam: bool = field(
261
+ default=False,
262
+ metadata={"help": "Whether or not to use the BAdam optimizer."},
263
+ )
264
+ badam_mode: Literal["layer", "ratio"] = field(
265
+ default="layer",
266
+ metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
267
+ )
268
+ badam_start_block: Optional[int] = field(
269
+ default=None,
270
+ metadata={"help": "The starting block index for layer-wise BAdam."},
271
+ )
272
+ badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
273
+ default="ascending",
274
+ metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
275
+ )
276
+ badam_switch_interval: Optional[int] = field(
277
+ default=50,
278
+ metadata={
279
+ "help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
280
+ },
281
+ )
282
+ badam_update_ratio: float = field(
283
+ default=0.05,
284
+ metadata={"help": "The ratio of the update for ratio-wise BAdam."},
285
+ )
286
+ badam_mask_mode: Literal["adjacent", "scatter"] = field(
287
+ default="adjacent",
288
+ metadata={
289
+ "help": (
290
+ "The mode of the mask for BAdam optimizer. "
291
+ "`adjacent` means that the trainable parameters are adjacent to each other, "
292
+ "`scatter` means that trainable parameters are randomly choosed from the weight."
293
+ )
294
+ },
295
+ )
296
+ badam_verbose: int = field(
297
+ default=0,
298
+ metadata={
299
+ "help": (
300
+ "The verbosity level of BAdam optimizer. "
301
+ "0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
302
+ )
303
+ },
304
+ )
305
+
306
+
307
+ @dataclass
308
+ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
309
+ r"""
310
+ Arguments pertaining to which techniques we are going to fine-tuning with.
311
+ """
312
+
313
+ loss_weight_lm: int = field(
314
+ default=1,
315
+ metadata={"help": "Loss weight for the language next token prediction."},
316
+ )
317
+ loss_weight_design: int = field(
318
+ default=1,
319
+ metadata={"help": "Loss weight for molecular design."},
320
+ )
321
+ loss_weight_retro: int = field(
322
+ default=1,
323
+ metadata={"help": "Loss weight for retrosynthesis."},
324
+ )
325
+
326
+ pure_bf16: bool = field(
327
+ default=False,
328
+ metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
329
+ )
330
+ stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
331
+ default="sft",
332
+ metadata={"help": "Which stage will be performed in training."},
333
+ )
334
+ finetuning_type: Literal["lora", "freeze", "full"] = field(
335
+ default="lora",
336
+ metadata={"help": "Which fine-tuning method to use."},
337
+ )
338
+ use_llama_pro: bool = field(
339
+ default=False,
340
+ metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
341
+ )
342
+ freeze_vision_tower: bool = field(
343
+ default=True,
344
+ metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
345
+ )
346
+ train_mm_proj_only: bool = field(
347
+ default=False,
348
+ metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
349
+ )
350
+ plot_loss: bool = field(
351
+ default=False,
352
+ metadata={"help": "Whether or not to save the training loss curves."},
353
+ )
354
+
355
+ def __post_init__(self):
356
+ def split_arg(arg):
357
+ if isinstance(arg, str):
358
+ return [item.strip() for item in arg.split(",")]
359
+ return arg
360
+
361
+ self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
362
+ self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
363
+ self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
364
+ self.lora_target: List[str] = split_arg(self.lora_target)
365
+ self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
366
+ self.galore_target: List[str] = split_arg(self.galore_target)
367
+ self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
368
+ self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
369
+
370
+ assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
371
+ assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
372
+ assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
373
+
374
+ if self.stage == "ppo" and self.reward_model is None:
375
+ raise ValueError("`reward_model` is necessary for PPO training.")
376
+
377
+ if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
378
+ raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
379
+
380
+ if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
381
+ raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
382
+
383
+ if self.use_llama_pro and self.finetuning_type == "full":
384
+ raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
385
+
386
+ if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
387
+ raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
388
+
389
+ if self.use_galore and self.use_badam:
390
+ raise ValueError("Cannot use GaLore with BAdam together.")
391
+
392
+ if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
393
+ raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
394
+
395
+ if self.pissa_init and self.finetuning_type != "lora":
396
+ raise ValueError("`pissa_init` is only valid for LoRA training.")
397
+
398
+ if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
399
+ raise ValueError("Cannot use PiSSA for current training stage.")
400
+
401
+ if self.train_mm_proj_only and self.finetuning_type != "full":
402
+ raise ValueError("`train_mm_proj_only` is only valid for full training.")