Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +4 -35
- .gitignore +172 -0
- LICENSE +201 -0
- README.md +184 -6
- assets/llamole.png +3 -0
- assets/ui_example.png +3 -0
- config/generate/llama_drug.yaml +30 -0
- config/generate/llama_material.yaml +30 -0
- config/generate/mistral_drug.yaml +30 -0
- config/generate/mistral_material.yaml +30 -0
- config/generate/qwen_drug.yaml +30 -0
- config/generate/qwen_material.yaml +30 -0
- config/train/llama_lora.yaml +49 -0
- config/train/mistral_lora.yaml +48 -0
- config/train/qwen_lora.yaml +48 -0
- data/dataset_info.json +20 -0
- data/molqa_drug_examples.json +75 -0
- data/molqa_material_examples.json +85 -0
- data/molqa_train_examples.json +62 -0
- data/property_ranges.json +42 -0
- install_environment.sh +85 -0
- launch.py +238 -0
- main.py +66 -0
- requirements.txt +30 -0
- src/__init__.py +0 -0
- src/data/__init__.py +29 -0
- src/data/aligner.py +233 -0
- src/data/collator.py +165 -0
- src/data/data_utils.py +82 -0
- src/data/formatter.py +139 -0
- src/data/loader.py +149 -0
- src/data/parser.py +113 -0
- src/data/processors/__init__.py +0 -0
- src/data/processors/mmsupervised.py +335 -0
- src/data/template.py +889 -0
- src/data/tool_utils.py +140 -0
- src/eval/__init__.py +0 -0
- src/eval/dataset.py +78 -0
- src/eval/workflow.py +219 -0
- src/extras/__init__.py +0 -0
- src/extras/constants.py +1550 -0
- src/extras/env.py +75 -0
- src/extras/logging.py +82 -0
- src/extras/misc.py +215 -0
- src/extras/packages.py +83 -0
- src/extras/ploting.py +101 -0
- src/hparams/__init__.py +31 -0
- src/hparams/data_args.py +110 -0
- src/hparams/evaluation_args.py +62 -0
- src/hparams/finetuning_args.py +402 -0
.gitattributes
CHANGED
@@ -1,35 +1,4 @@
|
|
1 |
-
|
2 |
-
|
3 |
-
|
4 |
-
|
5 |
-
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
-
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
-
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
-
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
-
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
-
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
-
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
-
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
-
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
-
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
-
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
-
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
-
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
-
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
-
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
-
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
-
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
-
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
-
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
-
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
-
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
-
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
-
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
-
*.tar filter=lfs diff=lfs merge=lfs -text
|
29 |
-
*.tflite filter=lfs diff=lfs merge=lfs -text
|
30 |
-
*.tgz filter=lfs diff=lfs merge=lfs -text
|
31 |
-
*.wasm filter=lfs diff=lfs merge=lfs -text
|
32 |
-
*.xz filter=lfs diff=lfs merge=lfs -text
|
33 |
-
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
-
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
-
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
1 |
+
# Auto detect text files and perform LF normalization
|
2 |
+
* text=auto
|
3 |
+
assets/llamole.png filter=lfs diff=lfs merge=lfs -text
|
4 |
+
assets/ui_example.png filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
.gitignore
ADDED
@@ -0,0 +1,172 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Byte-compiled / optimized / DLL files
|
2 |
+
__pycache__/
|
3 |
+
*.py[cod]
|
4 |
+
*$py.class
|
5 |
+
|
6 |
+
# C extensions
|
7 |
+
*.so
|
8 |
+
|
9 |
+
# Distribution / packaging
|
10 |
+
.Python
|
11 |
+
build/
|
12 |
+
develop-eggs/
|
13 |
+
dist/
|
14 |
+
downloads/
|
15 |
+
eggs/
|
16 |
+
.eggs/
|
17 |
+
lib/
|
18 |
+
lib64/
|
19 |
+
parts/
|
20 |
+
sdist/
|
21 |
+
var/
|
22 |
+
wheels/
|
23 |
+
share/python-wheels/
|
24 |
+
*.egg-info/
|
25 |
+
.installed.cfg
|
26 |
+
*.egg
|
27 |
+
MANIFEST
|
28 |
+
|
29 |
+
# PyInstaller
|
30 |
+
# Usually these files are written by a python script from a template
|
31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
32 |
+
*.manifest
|
33 |
+
*.spec
|
34 |
+
|
35 |
+
# Installer logs
|
36 |
+
pip-log.txt
|
37 |
+
pip-delete-this-directory.txt
|
38 |
+
|
39 |
+
# Unit test / coverage reports
|
40 |
+
htmlcov/
|
41 |
+
.tox/
|
42 |
+
.nox/
|
43 |
+
.coverage
|
44 |
+
.coverage.*
|
45 |
+
.cache
|
46 |
+
nosetests.xml
|
47 |
+
coverage.xml
|
48 |
+
*.cover
|
49 |
+
*.py,cover
|
50 |
+
.hypothesis/
|
51 |
+
.pytest_cache/
|
52 |
+
cover/
|
53 |
+
|
54 |
+
# Translations
|
55 |
+
*.mo
|
56 |
+
*.pot
|
57 |
+
|
58 |
+
# Django stuff:
|
59 |
+
*.log
|
60 |
+
local_settings.py
|
61 |
+
db.sqlite3
|
62 |
+
db.sqlite3-journal
|
63 |
+
|
64 |
+
# Flask stuff:
|
65 |
+
instance/
|
66 |
+
.webassets-cache
|
67 |
+
|
68 |
+
# Scrapy stuff:
|
69 |
+
.scrapy
|
70 |
+
|
71 |
+
# Sphinx documentation
|
72 |
+
docs/_build/
|
73 |
+
|
74 |
+
# PyBuilder
|
75 |
+
.pybuilder/
|
76 |
+
target/
|
77 |
+
|
78 |
+
# Jupyter Notebook
|
79 |
+
.ipynb_checkpoints
|
80 |
+
|
81 |
+
# IPython
|
82 |
+
profile_default/
|
83 |
+
ipython_config.py
|
84 |
+
|
85 |
+
# pyenv
|
86 |
+
# For a library or package, you might want to ignore these files since the code is
|
87 |
+
# intended to run in multiple environments; otherwise, check them in:
|
88 |
+
# .python-version
|
89 |
+
|
90 |
+
# pipenv
|
91 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
92 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
93 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
94 |
+
# install all needed dependencies.
|
95 |
+
#Pipfile.lock
|
96 |
+
|
97 |
+
# poetry
|
98 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
99 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
100 |
+
# commonly ignored for libraries.
|
101 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
102 |
+
#poetry.lock
|
103 |
+
|
104 |
+
# pdm
|
105 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
106 |
+
#pdm.lock
|
107 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
108 |
+
# in version control.
|
109 |
+
# https://pdm.fming.dev/#use-with-ide
|
110 |
+
.pdm.toml
|
111 |
+
|
112 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
113 |
+
__pypackages__/
|
114 |
+
|
115 |
+
# Celery stuff
|
116 |
+
celerybeat-schedule
|
117 |
+
celerybeat.pid
|
118 |
+
|
119 |
+
# SageMath parsed files
|
120 |
+
*.sage.py
|
121 |
+
|
122 |
+
# Environments
|
123 |
+
.env
|
124 |
+
.venv
|
125 |
+
env/
|
126 |
+
venv/
|
127 |
+
ENV/
|
128 |
+
env.bak/
|
129 |
+
venv.bak/
|
130 |
+
|
131 |
+
# Spyder project settings
|
132 |
+
.spyderproject
|
133 |
+
.spyproject
|
134 |
+
|
135 |
+
# Rope project settings
|
136 |
+
.ropeproject
|
137 |
+
|
138 |
+
# mkdocs documentation
|
139 |
+
/site
|
140 |
+
|
141 |
+
# mypy
|
142 |
+
.mypy_cache/
|
143 |
+
.dmypy.json
|
144 |
+
dmypy.json
|
145 |
+
|
146 |
+
# Pyre type checker
|
147 |
+
.pyre/
|
148 |
+
|
149 |
+
# pytype static type analyzer
|
150 |
+
.pytype/
|
151 |
+
|
152 |
+
# Cython debug symbols
|
153 |
+
cython_debug/
|
154 |
+
|
155 |
+
# PyCharm
|
156 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
157 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
158 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
159 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
160 |
+
.idea/
|
161 |
+
|
162 |
+
# custom .gitignore
|
163 |
+
cache/
|
164 |
+
output/
|
165 |
+
wandb/
|
166 |
+
# Ignore everything in data directory
|
167 |
+
run_script*
|
168 |
+
|
169 |
+
saves/*
|
170 |
+
!saves/README.md
|
171 |
+
|
172 |
+
.gradio
|
LICENSE
ADDED
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
Apache License
|
2 |
+
Version 2.0, January 2004
|
3 |
+
http://www.apache.org/licenses/
|
4 |
+
|
5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
6 |
+
|
7 |
+
1. Definitions.
|
8 |
+
|
9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
11 |
+
|
12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
13 |
+
the copyright owner that is granting the License.
|
14 |
+
|
15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
16 |
+
other entities that control, are controlled by, or are under common
|
17 |
+
control with that entity. For the purposes of this definition,
|
18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
19 |
+
direction or management of such entity, whether by contract or
|
20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
22 |
+
|
23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
24 |
+
exercising permissions granted by this License.
|
25 |
+
|
26 |
+
"Source" form shall mean the preferred form for making modifications,
|
27 |
+
including but not limited to software source code, documentation
|
28 |
+
source, and configuration files.
|
29 |
+
|
30 |
+
"Object" form shall mean any form resulting from mechanical
|
31 |
+
transformation or translation of a Source form, including but
|
32 |
+
not limited to compiled object code, generated documentation,
|
33 |
+
and conversions to other media types.
|
34 |
+
|
35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
36 |
+
Object form, made available under the License, as indicated by a
|
37 |
+
copyright notice that is included in or attached to the work
|
38 |
+
(an example is provided in the Appendix below).
|
39 |
+
|
40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
41 |
+
form, that is based on (or derived from) the Work and for which the
|
42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
44 |
+
of this License, Derivative Works shall not include works that remain
|
45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
46 |
+
the Work and Derivative Works thereof.
|
47 |
+
|
48 |
+
"Contribution" shall mean any work of authorship, including
|
49 |
+
the original version of the Work and any modifications or additions
|
50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
54 |
+
means any form of electronic, verbal, or written communication sent
|
55 |
+
to the Licensor or its representatives, including but not limited to
|
56 |
+
communication on electronic mailing lists, source code control systems,
|
57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
59 |
+
excluding communication that is conspicuously marked or otherwise
|
60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
61 |
+
|
62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
64 |
+
subsequently incorporated within the Work.
|
65 |
+
|
66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
71 |
+
Work and such Derivative Works in Source or Object form.
|
72 |
+
|
73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
76 |
+
(except as stated in this section) patent license to make, have made,
|
77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
78 |
+
where such license applies only to those patent claims licensable
|
79 |
+
by such Contributor that are necessarily infringed by their
|
80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
82 |
+
institute patent litigation against any entity (including a
|
83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
84 |
+
or a Contribution incorporated within the Work constitutes direct
|
85 |
+
or contributory patent infringement, then any patent licenses
|
86 |
+
granted to You under this License for that Work shall terminate
|
87 |
+
as of the date such litigation is filed.
|
88 |
+
|
89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
90 |
+
Work or Derivative Works thereof in any medium, with or without
|
91 |
+
modifications, and in Source or Object form, provided that You
|
92 |
+
meet the following conditions:
|
93 |
+
|
94 |
+
(a) You must give any other recipients of the Work or
|
95 |
+
Derivative Works a copy of this License; and
|
96 |
+
|
97 |
+
(b) You must cause any modified files to carry prominent notices
|
98 |
+
stating that You changed the files; and
|
99 |
+
|
100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
101 |
+
that You distribute, all copyright, patent, trademark, and
|
102 |
+
attribution notices from the Source form of the Work,
|
103 |
+
excluding those notices that do not pertain to any part of
|
104 |
+
the Derivative Works; and
|
105 |
+
|
106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
107 |
+
distribution, then any Derivative Works that You distribute must
|
108 |
+
include a readable copy of the attribution notices contained
|
109 |
+
within such NOTICE file, excluding those notices that do not
|
110 |
+
pertain to any part of the Derivative Works, in at least one
|
111 |
+
of the following places: within a NOTICE text file distributed
|
112 |
+
as part of the Derivative Works; within the Source form or
|
113 |
+
documentation, if provided along with the Derivative Works; or,
|
114 |
+
within a display generated by the Derivative Works, if and
|
115 |
+
wherever such third-party notices normally appear. The contents
|
116 |
+
of the NOTICE file are for informational purposes only and
|
117 |
+
do not modify the License. You may add Your own attribution
|
118 |
+
notices within Derivative Works that You distribute, alongside
|
119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
120 |
+
that such additional attribution notices cannot be construed
|
121 |
+
as modifying the License.
|
122 |
+
|
123 |
+
You may add Your own copyright statement to Your modifications and
|
124 |
+
may provide additional or different license terms and conditions
|
125 |
+
for use, reproduction, or distribution of Your modifications, or
|
126 |
+
for any such Derivative Works as a whole, provided Your use,
|
127 |
+
reproduction, and distribution of the Work otherwise complies with
|
128 |
+
the conditions stated in this License.
|
129 |
+
|
130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
132 |
+
by You to the Licensor shall be under the terms and conditions of
|
133 |
+
this License, without any additional terms or conditions.
|
134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
135 |
+
the terms of any separate license agreement you may have executed
|
136 |
+
with Licensor regarding such Contributions.
|
137 |
+
|
138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
140 |
+
except as required for reasonable and customary use in describing the
|
141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
142 |
+
|
143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
144 |
+
agreed to in writing, Licensor provides the Work (and each
|
145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
147 |
+
implied, including, without limitation, any warranties or conditions
|
148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
150 |
+
appropriateness of using or redistributing the Work and assume any
|
151 |
+
risks associated with Your exercise of permissions under this License.
|
152 |
+
|
153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
154 |
+
whether in tort (including negligence), contract, or otherwise,
|
155 |
+
unless required by applicable law (such as deliberate and grossly
|
156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
157 |
+
liable to You for damages, including any direct, indirect, special,
|
158 |
+
incidental, or consequential damages of any character arising as a
|
159 |
+
result of this License or out of the use or inability to use the
|
160 |
+
Work (including but not limited to damages for loss of goodwill,
|
161 |
+
work stoppage, computer failure or malfunction, or any and all
|
162 |
+
other commercial damages or losses), even if such Contributor
|
163 |
+
has been advised of the possibility of such damages.
|
164 |
+
|
165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
168 |
+
or other liability obligations and/or rights consistent with this
|
169 |
+
License. However, in accepting such obligations, You may act only
|
170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
171 |
+
of any other Contributor, and only if You agree to indemnify,
|
172 |
+
defend, and hold each Contributor harmless for any liability
|
173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
174 |
+
of your accepting any such warranty or additional liability.
|
175 |
+
|
176 |
+
END OF TERMS AND CONDITIONS
|
177 |
+
|
178 |
+
APPENDIX: How to apply the Apache License to your work.
|
179 |
+
|
180 |
+
To apply the Apache License to your work, attach the following
|
181 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
182 |
+
replaced with your own identifying information. (Don't include
|
183 |
+
the brackets!) The text should be enclosed in the appropriate
|
184 |
+
comment syntax for the file format. We also recommend that a
|
185 |
+
file or class name and description of purpose be included on the
|
186 |
+
same "printed page" as the copyright notice for easier
|
187 |
+
identification within third-party archives.
|
188 |
+
|
189 |
+
Copyright [yyyy] [name of copyright owner]
|
190 |
+
|
191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
192 |
+
you may not use this file except in compliance with the License.
|
193 |
+
You may obtain a copy of the License at
|
194 |
+
|
195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
196 |
+
|
197 |
+
Unless required by applicable law or agreed to in writing, software
|
198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
200 |
+
See the License for the specific language governing permissions and
|
201 |
+
limitations under the License.
|
README.md
CHANGED
@@ -1,12 +1,190 @@
|
|
1 |
---
|
2 |
title: Llamole
|
3 |
-
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: purple
|
6 |
sdk: gradio
|
7 |
sdk_version: 5.15.0
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
|
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
---
|
2 |
title: Llamole
|
3 |
+
app_file: launch.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
sdk_version: 5.15.0
|
|
|
|
|
6 |
---
|
7 |
+
<!-- # 🧬 Llamole: Multimodal Large Language Models for Inverse Molecular Design with Retrosynthetic Planning -->
|
8 |
|
9 |
+
<div align="center">
|
10 |
+
|
11 |
+
# 🔬 🧪 Llamole: Multimodal Large Language Models for Inverse Molecular Design with Retrosynthetic Planning
|
12 |
+
|
13 |
+
[![](https://img.shields.io/badge/Paper-pink?style=plastic&logo=arXiv)](https://arxiv.org/abs/2410.04223)
|
14 |
+
[![](https://img.shields.io/badge/GitHub-blue?style=plastic&logo=github)](https://github.com/liugangcode/Llamole)
|
15 |
+
[![](https://img.shields.io/badge/Huggingface-orange?style=plastic&logo=huggingface)](https://huggingface.co/collections/liuganghuggingface/llamole-collection-67073a2e76b47d5fea909434)
|
16 |
+
|
17 |
+
</div>
|
18 |
+
|
19 |
+
## 🌟 Introduction
|
20 |
+
|
21 |
+
**Llamole** is a multimodal Large Language Model (LLM) that integrates a base LLM with the Graph Diffusion Transformer and Graph Neural Networks for multi-conditional molecular generation and multi-step reaction inference within texts.
|
22 |
+
|
23 |
+
📄 **[Paper](https://arxiv.org/abs/2410.04223)**: *Multimodal Large Language Models for Inverse Molecular Design with Retrosynthetic Planning*
|
24 |
+
|
25 |
+
<div align="center">
|
26 |
+
|
27 |
+
<details>
|
28 |
+
<summary>🔍 Abstract</summary>
|
29 |
+
<p align="left">
|
30 |
+
While large language models (LLMs) have integrated images, adapting them to graphs remains challenging, limiting their applications in materials and drug design. This difficulty stems from the need for coherent autoregressive generation across texts and graphs. To address this, we introduce Llamole, the first multimodal LLM capable of interleaved text and graph generation, enabling molecular inverse design with retrosynthetic planning. Llamole integrates a base LLM with the Graph Diffusion Transformer and Graph Neural Networks for multi-conditional molecular generation and reaction inference within texts, while the LLM, with enhanced molecular understanding, flexibly controls activation among the different graph modules. Additionally, Llamole integrates A* search with LLM-based cost functions for efficient retrosynthetic planning. We create benchmarking datasets and conduct extensive experiments to evaluate Llamole against in-context learning and supervised fine-tuning. Llamole significantly outperforms 14 adapted LLMs across 12 metrics for controllable molecular design and retrosynthetic planning.
|
31 |
+
</p>
|
32 |
+
</details>
|
33 |
+
|
34 |
+
<img src="assets/llamole.png" alt="Llamole Design" width="80%"/>
|
35 |
+
|
36 |
+
</div>
|
37 |
+
|
38 |
+
---
|
39 |
+
|
40 |
+
## 🛠️ Environment Setup
|
41 |
+
|
42 |
+
Initialize the environment by following these steps:
|
43 |
+
|
44 |
+
```bash
|
45 |
+
conda create --name llamole python=3.11 -y
|
46 |
+
conda activate llamole
|
47 |
+
./install_environment.sh
|
48 |
+
```
|
49 |
+
|
50 |
+
Alternatively, you can install all required dependencies using the `requirements.sh` script.
|
51 |
+
|
52 |
+
---
|
53 |
+
|
54 |
+
## 🚀 Model Usage
|
55 |
+
|
56 |
+
### 🔧 Requirements
|
57 |
+
|
58 |
+
- **Hardware**: A single V100 or A6000 GPU for inference.
|
59 |
+
- **Configuration Files**:
|
60 |
+
- `config/train/{model}_lora.yaml`
|
61 |
+
- `config/generate/{model}_{task}.yaml`
|
62 |
+
|
63 |
+
### 📥 Automatic Model Download
|
64 |
+
|
65 |
+
On the first run, the necessary models will be automatically downloaded, including:
|
66 |
+
|
67 |
+
1. **Base LLMs** (Please ensure you have access to the model):
|
68 |
+
- [Llama-3.1-8b-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B)
|
69 |
+
- [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct)
|
70 |
+
- [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3)
|
71 |
+
|
72 |
+
2. **Pretrained Graph Models**:
|
73 |
+
- **Graph Decoder**: [Graph Diffusion Transformer](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphDiT)
|
74 |
+
- **Graph Encoder**: [GIN-based Encoder](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphEncoder)
|
75 |
+
- **Graph Predictor**: [GIN-based Predictor](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GNNPredictor)
|
76 |
+
|
77 |
+
3. **Adapters and Connectors** for integrating the base LLM with pretrained graph models.
|
78 |
+
|
79 |
+
### 🗂️ Manual Model Download
|
80 |
+
|
81 |
+
If you prefer to download the models manually, refer to [![](https://img.shields.io/badge/Huggingface-orange?style=plastic&logo=huggingface)](https://huggingface.co/collections/liuganghuggingface/llamole-collection-67073a2e76b47d5fea909434) and place them in the following directories:
|
82 |
+
|
83 |
+
- `saves/graph_decoder`
|
84 |
+
- `saves/graph_encoder`
|
85 |
+
- `saves/graph_predictor`
|
86 |
+
- `saves/{model_name}-Adapter`
|
87 |
+
|
88 |
+
---
|
89 |
+
|
90 |
+
### 🕸️ Gradio-based Web UI
|
91 |
+
|
92 |
+
Launch the web interface using Gradio:
|
93 |
+
|
94 |
+
```bash
|
95 |
+
python launch.py
|
96 |
+
```
|
97 |
+
|
98 |
+
The default base LLM is **Qwen2-7B-Instruct**. If you wish to change this, please modify the `args_dict` variable accordingly. Upon launch, the web UI will appear as shown below:
|
99 |
+
|
100 |
+
<div align="center">
|
101 |
+
<img src="assets/ui_example.png" alt="Web UI Example" width="60%">
|
102 |
+
</div>
|
103 |
+
|
104 |
+
|
105 |
+
---
|
106 |
+
|
107 |
+
### 💻 Command Line Usage
|
108 |
+
|
109 |
+
For command-line evaluation, specify the path to the configuration file:
|
110 |
+
|
111 |
+
```bash
|
112 |
+
python main.py eval config/generate/qwen_material.yaml
|
113 |
+
```
|
114 |
+
|
115 |
+
You can modify the configuration files to suit your custom datasets.
|
116 |
+
|
117 |
+
**Note**: Examples of training and evaluation datasets are available in the `data` folder. For more details, refer to `data/dataset_info.json`. To test generation on all MolQA questions, first download the dataset by running:
|
118 |
+
|
119 |
+
```bash
|
120 |
+
python main.py download_data
|
121 |
+
```
|
122 |
+
|
123 |
+
Then, update the configuration files to point to the downloaded dataset based on the names from `data/dataset_info.json`.
|
124 |
+
|
125 |
+
---
|
126 |
+
|
127 |
+
## 📚 Supervised Fine-Tuning
|
128 |
+
|
129 |
+
The codebase supports multimodal graph-text supervised fine-tuning. Follow these steps:
|
130 |
+
|
131 |
+
1. **Download MolQA Training Data**:
|
132 |
+
|
133 |
+
```bash
|
134 |
+
python main.py download_data
|
135 |
+
```
|
136 |
+
Then you may need to modify the configuration files in the `config` folder to point to the downloaded training data. Skipping this step and directly using the command from step 2 will result in training only on the example training set.
|
137 |
+
|
138 |
+
2. **Run Fine-Tuning**:
|
139 |
+
|
140 |
+
```bash
|
141 |
+
python main.py train config/train/mistral_lora.yaml
|
142 |
+
```
|
143 |
+
|
144 |
+
During the first run, pretrained graph models will be downloaded in the `saves` folder. Modify the configuration files as needed for your setup. An 80G A100 GPU is recommended for supervised fine-tuning.
|
145 |
+
|
146 |
+
---
|
147 |
+
## 📖 Citation
|
148 |
+
|
149 |
+
If you find this repository useful, please cite our paper:
|
150 |
+
|
151 |
+
```
|
152 |
+
@misc{liu2024llamole,
|
153 |
+
title={Multimodal Large Language Models for Inverse Molecular Design with Retrosynthetic Planning},
|
154 |
+
author={Gang Liu and Michael Sun and Wojciech Matusik and Meng Jiang and Jie Chen},
|
155 |
+
year={2024},
|
156 |
+
eprint={2410.04223},
|
157 |
+
archivePrefix={arXiv},
|
158 |
+
primaryClass={cs.LG},
|
159 |
+
url={https://arxiv.org/abs/2410.04223},
|
160 |
+
}
|
161 |
+
```
|
162 |
+
|
163 |
+
```
|
164 |
+
@article{liu2024graphdit,
|
165 |
+
title={Graph Diffusion Transformers for Multi-Conditional Molecular Generation},
|
166 |
+
author={Liu, Gang and Xu, Jiaxin and Luo, Tengfei and Jiang, Meng},
|
167 |
+
journal={Thirty-Eighth Annual Conference on Neural Information Processing Systems},
|
168 |
+
year={2024}
|
169 |
+
}
|
170 |
+
```
|
171 |
+
|
172 |
+
---
|
173 |
+
|
174 |
+
## 📄 Acknowledgments
|
175 |
+
|
176 |
+
This codebase is built upon **[Llama-Factory](https://github.com/hiyouga/LLaMA-Factory)**. We extend our gratitude for their open-source contributions.
|
177 |
+
|
178 |
+
---
|
179 |
+
|
180 |
+
## 📂 Additional Resources
|
181 |
+
|
182 |
+
🔗 **Huggingface Models**: Llamole is developed with three variants (adapters) and three pretrained graph modules (encoder, decoder, predictor):
|
183 |
+
- **Base LLM Variant 1**: [Llama-3.1-8b-Instruct](https://huggingface.co/liuganghuggingface/Llamole-Llama-3.1-8B-Instruct-Adapter)
|
184 |
+
- **Base LLM Variant 2**: [Qwen2-7B-Instruct](https://huggingface.co/liuganghuggingface/Llamole-Qwen2-7B-Instruct-Adapter)
|
185 |
+
- **Base LLM Variant 3**: [Mistral-7B-Instruct-v0.3](https://huggingface.co/liuganghuggingface/Llamole-Mistral-7B-Instruct-v0.3-Adapter)
|
186 |
+
- **Pretrained Graph Decoder** for multi-conditional molecular generation: [Graph Diffusion Transformer](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphDiT)
|
187 |
+
- **Pretrained Graph Predictor** for one-step reaction prediction: [GNN Predictor](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GNNPredictor)
|
188 |
+
- **Pretrained Graph Encoder** for enhanced molecule understanding: [Graph Encoder](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphEncoder)
|
189 |
+
|
190 |
+
---
|
assets/llamole.png
ADDED
Git LFS Details
|
assets/ui_example.png
ADDED
Git LFS Details
|
config/generate/llama_drug.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### model
|
2 |
+
model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct
|
3 |
+
new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
|
4 |
+
graph_decoder_path: saves/graph_decoder
|
5 |
+
graph_encoder_path: saves/graph_encoder
|
6 |
+
graph_predictor_path: saves/graph_predictor
|
7 |
+
adapter_name_or_path: saves/Llama-3.1-8B-Instruct-Adapter
|
8 |
+
graph_lm_connector_path: saves/Llama-3.1-8B-Instruct-Adapter/connector
|
9 |
+
### generation
|
10 |
+
max_length: 512
|
11 |
+
max_new_tokens: 128
|
12 |
+
temperature: 0.6
|
13 |
+
top_p: 0.9
|
14 |
+
### method
|
15 |
+
finetuning_type: lora
|
16 |
+
do_train: false
|
17 |
+
flash_attn: disabled
|
18 |
+
learned_query_size: 8
|
19 |
+
### dataset
|
20 |
+
dataset: molqa_drug_examples
|
21 |
+
template: llama3
|
22 |
+
cutoff_len: 128
|
23 |
+
overwrite_cache: true
|
24 |
+
preprocessing_num_workers: 16
|
25 |
+
output_dir: null
|
26 |
+
bf16: true
|
27 |
+
pure_bf16: true
|
28 |
+
### eval
|
29 |
+
per_device_eval_batch_size: 6
|
30 |
+
report_to: 'none'
|
config/generate/llama_material.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### model
|
2 |
+
model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct
|
3 |
+
new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
|
4 |
+
graph_decoder_path: saves/graph_decoder
|
5 |
+
graph_encoder_path: saves/graph_encoder
|
6 |
+
graph_predictor_path: saves/graph_predictor
|
7 |
+
adapter_name_or_path: saves/Llama-3.1-8B-Instruct-Adapter
|
8 |
+
graph_lm_connector_path: saves/Llama-3.1-8B-Instruct-Adapter/connector
|
9 |
+
### generation
|
10 |
+
max_length: 512
|
11 |
+
max_new_tokens: 128
|
12 |
+
temperature: 0.6
|
13 |
+
top_p: 0.9
|
14 |
+
### method
|
15 |
+
finetuning_type: lora
|
16 |
+
do_train: false
|
17 |
+
flash_attn: disabled
|
18 |
+
learned_query_size: 8
|
19 |
+
### dataset
|
20 |
+
dataset: molqa_material_examples
|
21 |
+
template: llama3
|
22 |
+
cutoff_len: 128
|
23 |
+
overwrite_cache: true
|
24 |
+
preprocessing_num_workers: 16
|
25 |
+
output_dir: null
|
26 |
+
bf16: true
|
27 |
+
pure_bf16: true
|
28 |
+
### eval
|
29 |
+
per_device_eval_batch_size: 6
|
30 |
+
report_to: 'none'
|
config/generate/mistral_drug.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### model
|
2 |
+
model_name_or_path: mistralai/Mistral-7B-Instruct-v0.3
|
3 |
+
new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
|
4 |
+
graph_decoder_path: saves/graph_decoder
|
5 |
+
graph_encoder_path: saves/graph_encoder
|
6 |
+
graph_predictor_path: saves/graph_predictor
|
7 |
+
adapter_name_or_path: saves/Mistral-7B-Instruct-v0.3-Adapter
|
8 |
+
graph_lm_connector_path: saves/Mistral-7B-Instruct-v0.3-Adapter/connector
|
9 |
+
### generation
|
10 |
+
max_length: 512
|
11 |
+
max_new_tokens: 128
|
12 |
+
temperature: 0.6
|
13 |
+
top_p: 0.9
|
14 |
+
### method
|
15 |
+
finetuning_type: lora
|
16 |
+
do_train: false
|
17 |
+
flash_attn: disabled
|
18 |
+
learned_query_size: 8
|
19 |
+
### dataset
|
20 |
+
dataset: molqa_drug_examples
|
21 |
+
template: mistral
|
22 |
+
cutoff_len: 128
|
23 |
+
overwrite_cache: true
|
24 |
+
preprocessing_num_workers: 16
|
25 |
+
output_dir: null
|
26 |
+
bf16: true
|
27 |
+
pure_bf16: true
|
28 |
+
### eval
|
29 |
+
per_device_eval_batch_size: 6
|
30 |
+
report_to: 'none'
|
config/generate/mistral_material.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### model
|
2 |
+
model_name_or_path: mistralai/Mistral-7B-Instruct-v0.3
|
3 |
+
new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
|
4 |
+
graph_decoder_path: saves/graph_decoder
|
5 |
+
graph_encoder_path: saves/graph_encoder
|
6 |
+
graph_predictor_path: saves/graph_predictor
|
7 |
+
adapter_name_or_path: saves/Mistral-7B-Instruct-v0.3-Adapter
|
8 |
+
graph_lm_connector_path: saves/Mistral-7B-Instruct-v0.3-Adapter/connector
|
9 |
+
### generation
|
10 |
+
max_length: 512
|
11 |
+
max_new_tokens: 128
|
12 |
+
temperature: 0.6
|
13 |
+
top_p: 0.9
|
14 |
+
### method
|
15 |
+
finetuning_type: lora
|
16 |
+
do_train: false
|
17 |
+
flash_attn: disabled
|
18 |
+
learned_query_size: 8
|
19 |
+
### dataset
|
20 |
+
dataset: molqa_material_examples
|
21 |
+
template: mistral
|
22 |
+
cutoff_len: 128
|
23 |
+
overwrite_cache: true
|
24 |
+
preprocessing_num_workers: 16
|
25 |
+
output_dir: null
|
26 |
+
bf16: true
|
27 |
+
pure_bf16: true
|
28 |
+
### eval
|
29 |
+
per_device_eval_batch_size: 6
|
30 |
+
report_to: 'none'
|
config/generate/qwen_drug.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### model
|
2 |
+
model_name_or_path: Qwen/Qwen2-7B-Instruct
|
3 |
+
new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
|
4 |
+
graph_decoder_path: saves/graph_decoder
|
5 |
+
graph_encoder_path: saves/graph_encoder
|
6 |
+
graph_predictor_path: saves/graph_predictor
|
7 |
+
adapter_name_or_path: saves/Qwen2-7B-Instruct-Adapter
|
8 |
+
graph_lm_connector_path: saves/Qwen2-7B-Instruct-Adapter/connector
|
9 |
+
### generation
|
10 |
+
max_length: 512
|
11 |
+
max_new_tokens: 128
|
12 |
+
temperature: 0.6
|
13 |
+
top_p: 0.9
|
14 |
+
### method
|
15 |
+
finetuning_type: lora
|
16 |
+
do_train: false
|
17 |
+
flash_attn: disabled
|
18 |
+
learned_query_size: 8
|
19 |
+
### dataset
|
20 |
+
dataset: molqa_drug_examples
|
21 |
+
template: qwen
|
22 |
+
cutoff_len: 128
|
23 |
+
overwrite_cache: true
|
24 |
+
preprocessing_num_workers: 16
|
25 |
+
output_dir: null
|
26 |
+
bf16: true
|
27 |
+
pure_bf16: true
|
28 |
+
### eval
|
29 |
+
per_device_eval_batch_size: 6
|
30 |
+
report_to: 'none'
|
config/generate/qwen_material.yaml
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### model
|
2 |
+
model_name_or_path: Qwen/Qwen2-7B-Instruct
|
3 |
+
new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
|
4 |
+
graph_decoder_path: saves/graph_decoder
|
5 |
+
graph_encoder_path: saves/graph_encoder
|
6 |
+
graph_predictor_path: saves/graph_predictor
|
7 |
+
adapter_name_or_path: saves/Qwen2-7B-Instruct-Adapter
|
8 |
+
graph_lm_connector_path: saves/Qwen2-7B-Instruct-Adapter/connector
|
9 |
+
### generation
|
10 |
+
max_length: 512
|
11 |
+
max_new_tokens: 128
|
12 |
+
temperature: 0.6
|
13 |
+
top_p: 0.9
|
14 |
+
### method
|
15 |
+
finetuning_type: lora
|
16 |
+
do_train: false
|
17 |
+
flash_attn: disabled
|
18 |
+
learned_query_size: 8
|
19 |
+
### dataset
|
20 |
+
dataset: molqa_material_examples
|
21 |
+
template: qwen
|
22 |
+
cutoff_len: 128
|
23 |
+
overwrite_cache: true
|
24 |
+
preprocessing_num_workers: 16
|
25 |
+
output_dir: null
|
26 |
+
bf16: true
|
27 |
+
pure_bf16: true
|
28 |
+
### eval
|
29 |
+
per_device_eval_batch_size: 6
|
30 |
+
report_to: 'none'
|
config/train/llama_lora.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### model
|
2 |
+
model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct
|
3 |
+
new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
|
4 |
+
graph_decoder_path: saves/graph_decoder
|
5 |
+
graph_encoder_path: saves/graph_encoder
|
6 |
+
graph_predictor_path: saves/graph_predictor
|
7 |
+
graph_lm_connector_path: null #specify when resume
|
8 |
+
### method
|
9 |
+
stage: mmsft
|
10 |
+
do_train: true
|
11 |
+
finetuning_type: lora
|
12 |
+
lora_target: all
|
13 |
+
flash_attn: disabled
|
14 |
+
learned_query_size: 8
|
15 |
+
### dataset
|
16 |
+
dataset: molqa_train_examples
|
17 |
+
template: llama3
|
18 |
+
cutoff_len: 2048
|
19 |
+
overwrite_cache: true
|
20 |
+
preprocessing_num_workers: 16
|
21 |
+
### output
|
22 |
+
output_dir: saves/Llama-3.1-8B-Instruct-Adapter
|
23 |
+
logging_steps: 10
|
24 |
+
save_steps: 2000
|
25 |
+
plot_loss: true
|
26 |
+
overwrite_output_dir: true
|
27 |
+
### to resume
|
28 |
+
# overwrite_output_dir: false
|
29 |
+
### train
|
30 |
+
per_device_train_batch_size: 10
|
31 |
+
gradient_accumulation_steps: 2
|
32 |
+
learning_rate: 1.0e-4
|
33 |
+
num_train_epochs: 4.0
|
34 |
+
lr_scheduler_type: cosine
|
35 |
+
warmup_ratio: 0.1
|
36 |
+
ddp_timeout: 180000000
|
37 |
+
bf16: true
|
38 |
+
pure_bf16: true
|
39 |
+
### train loss
|
40 |
+
loss_weight_retro: 1
|
41 |
+
loss_weight_design: 1
|
42 |
+
loss_weight_lm: 1
|
43 |
+
### eval
|
44 |
+
val_size: 0.1
|
45 |
+
per_device_eval_batch_size: 6
|
46 |
+
eval_strategy: steps
|
47 |
+
eval_steps: 2000
|
48 |
+
### specify if connected to wandb
|
49 |
+
report_to: 'none'
|
config/train/mistral_lora.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### model
|
2 |
+
model_name_or_path: mistralai/Mistral-7B-Instruct-v0.3
|
3 |
+
new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
|
4 |
+
graph_decoder_path: saves/graph_decoder
|
5 |
+
graph_encoder_path: saves/graph_encoder
|
6 |
+
graph_predictor_path: saves/graph_predictor
|
7 |
+
graph_lm_connector_path: null #specify when resume
|
8 |
+
### method
|
9 |
+
stage: mmsft
|
10 |
+
do_train: true
|
11 |
+
finetuning_type: lora
|
12 |
+
lora_target: all
|
13 |
+
flash_attn: disabled
|
14 |
+
learned_query_size: 8
|
15 |
+
### dataset
|
16 |
+
dataset: molqa_train_examples
|
17 |
+
template: mistral
|
18 |
+
cutoff_len: 2048
|
19 |
+
overwrite_cache: true
|
20 |
+
preprocessing_num_workers: 16
|
21 |
+
### output
|
22 |
+
output_dir: saves/Mistral-7B-Instruct-v0.3-Adapter
|
23 |
+
logging_steps: 10
|
24 |
+
save_steps: 2000
|
25 |
+
plot_loss: true
|
26 |
+
overwrite_output_dir: true
|
27 |
+
### to resume
|
28 |
+
# overwrite_output_dir: false
|
29 |
+
### train
|
30 |
+
per_device_train_batch_size: 6
|
31 |
+
gradient_accumulation_steps: 2
|
32 |
+
learning_rate: 1.0e-4
|
33 |
+
num_train_epochs: 4.0
|
34 |
+
lr_scheduler_type: cosine
|
35 |
+
warmup_ratio: 0.1
|
36 |
+
ddp_timeout: 180000000
|
37 |
+
bf16: true
|
38 |
+
pure_bf16: true
|
39 |
+
loss_weight_retro: 1
|
40 |
+
loss_weight_design: 1
|
41 |
+
loss_weight_lm: 1
|
42 |
+
### eval
|
43 |
+
val_size: 0.1
|
44 |
+
per_device_eval_batch_size: 6
|
45 |
+
eval_strategy: steps
|
46 |
+
eval_steps: 2000
|
47 |
+
### specify if connected to wandb
|
48 |
+
report_to: 'none'
|
config/train/qwen_lora.yaml
ADDED
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### model
|
2 |
+
model_name_or_path: Qwen/Qwen2-7B-Instruct
|
3 |
+
new_special_tokens: <design_start>,<design_end>,<design_body>,<molecule>,<retro_start>,<retro_end>,<retro_body>,<rollback_start>,<rollback_end>
|
4 |
+
graph_decoder_path: saves/graph_decoder
|
5 |
+
graph_encoder_path: saves/graph_encoder
|
6 |
+
graph_predictor_path: saves/graph_predictor
|
7 |
+
graph_lm_connector_path: null #specify when resume
|
8 |
+
### method
|
9 |
+
stage: mmsft
|
10 |
+
do_train: true
|
11 |
+
finetuning_type: lora
|
12 |
+
lora_target: all
|
13 |
+
flash_attn: disabled
|
14 |
+
learned_query_size: 8
|
15 |
+
### dataset
|
16 |
+
dataset: molqa_train_examples
|
17 |
+
template: qwen
|
18 |
+
cutoff_len: 2048
|
19 |
+
overwrite_cache: true
|
20 |
+
preprocessing_num_workers: 16
|
21 |
+
### output
|
22 |
+
output_dir: saves/Qwen2-7B-Instruct-Adapter
|
23 |
+
logging_steps: 10
|
24 |
+
save_steps: 2000
|
25 |
+
plot_loss: true
|
26 |
+
overwrite_output_dir: true
|
27 |
+
### to resume
|
28 |
+
# overwrite_output_dir: false
|
29 |
+
### train
|
30 |
+
per_device_train_batch_size: 6
|
31 |
+
gradient_accumulation_steps: 2
|
32 |
+
learning_rate: 1.0e-4
|
33 |
+
num_train_epochs: 4.0
|
34 |
+
lr_scheduler_type: cosine
|
35 |
+
warmup_ratio: 0.1
|
36 |
+
ddp_timeout: 180000000
|
37 |
+
bf16: true
|
38 |
+
pure_bf16: true
|
39 |
+
loss_weight_retro: 1
|
40 |
+
loss_weight_design: 1
|
41 |
+
loss_weight_lm: 1
|
42 |
+
### eval
|
43 |
+
val_size: 0.1
|
44 |
+
per_device_eval_batch_size: 6
|
45 |
+
eval_strategy: steps
|
46 |
+
eval_steps: 2000
|
47 |
+
### specify if connected to wandb
|
48 |
+
report_to: 'none'
|
data/dataset_info.json
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"molqa_train_examples": {
|
3 |
+
"file_name": "molqa_train_examples.json"
|
4 |
+
},
|
5 |
+
"molqa_material_examples": {
|
6 |
+
"file_name": "molqa_material_examples.json"
|
7 |
+
},
|
8 |
+
"molqa_drug_examples": {
|
9 |
+
"file_name": "molqa_drug_examples.json"
|
10 |
+
},
|
11 |
+
"molqa_train": {
|
12 |
+
"file_name": "molqa_train.json"
|
13 |
+
},
|
14 |
+
"molqa_material": {
|
15 |
+
"file_name": "molqa_material.json"
|
16 |
+
},
|
17 |
+
"molqa_drug": {
|
18 |
+
"file_name": "molqa_drug.json"
|
19 |
+
}
|
20 |
+
}
|
data/molqa_drug_examples.json
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"instruction": "Can you design and synthesize a molecule that inhibits both Human Immunodeficiency Virus and Beta-Secretase 1, with a molecular weight of around 335 and two aromatic rings connected to aliphatic chains?",
|
4 |
+
"input": "",
|
5 |
+
"output": "To satisfy the requirements: The molecule has a moderately complex structure, consisting of aromatic rings connected to aliphatic chains, and featuring carboxyl and hydroxyl functional groups. Therefore, the designed molecule is: <design_start><design_end><mol_start>CC(=O)OCCC(C)Oc1ccc(Cl)cc1Oc1ccccc1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CC(=O)OCCC(C)Oc1ccc(Cl)cc1Oc1ccccc1<mol_end>, follow these procedures: A solution of (R)-acetic acid 3-(toluene-4-sulfonyloxy)-butyl ester (5.89 g, 21 mmol) and 4-chloro-2-phenoxy-phenol (5.0 g, 23 mmol) in DMF (50 mL) is treated with cesium carbonate (7.4 g, 23 mmol). The solution is heated to 60\u00b0 C. and stirred overnight. The reaction is cooled and quenched with 1N HCl. The solution is partitioned in EtOAc and water. The organic is separated, washed with brine, and dried over sodium sulfate. The organic is filtered, and the solvent is removed to afford acetic acid 3-(4-chloro-2-phenoxy-phenoxy)-butyl ester, which is then diluted in methanol (100 mL) and treated with potassium carbonate (5.68 g, 40 mmol). The reaction is stirred for 2 hours at rt. The reaction is then partitioned in EtOAc and water. The organic layer is separated, washed with brine, and dried over sodium sulfate. The organic is filtered and the solvent is removed to afford the crude product. The crude is purified by silica gel column chromatography using 1/1 hexanes/EtOAc to elute the pure product. The solvent is removed to afford 4.35 g (72%) of the desired product. 1H NMR (400 MHz, CDCl3); MS (ES+) m/z mass calcd for C16H17ClO3 292, found 293 (M+1, 100%). The applied reaction is: <retro_start><retro_end>CC(=O)OCCC(C)Oc1ccc(Cl)cc1Oc1ccccc1>>CC(=O)OCCC(C)OS(=O)(=O)c1ccc(C)cc1.O=C(O)O.Oc1ccc(Cl)cc1Oc1ccccc1.[Cs+].[Cs+], which requires the reactants: CC(=O)OCCC(C)OS(=O)(=O)c1ccc(C)cc1 (available), O=C(O)O (available), Oc1ccc(Cl)cc1Oc1ccccc1 (not available, needs to be synthesized), [Cs+] (available), [Cs+] (available). This is step 2 in the retrosynthesis process. To synthesize <mol_start>Oc1ccc(Cl)cc1Oc1ccccc1<mol_end>, follow these procedures: A \u221240\u00b0 C. solution of 4-chloro-2-phenoxy-1-methoxy-benzene (7.16 g, 30.5 mmol) in dry CH2Cl2 (70 mL) is treated dropwise with borontribromide (22.9 g, 91.5 mmol) and then warmed to 0\u00b0 C. and stirred for 3 h under N2. The reaction is diluted with Et2O and quenched with water. The organic layer is dried (Na2SO4), and the solvent is removed in vacuo to afford 7.11 g (100%) of the title compound. Rf=0.30 (4/1 hexanes/acetone). 1H NMR (400 MHz, CDCl3); MS (ES\u2212) m/z mass calcd for C12H9O2Cl 220, found 219 (M\u22121, 100%). The applied reaction is: <retro_start><retro_end>Oc1ccc(Cl)cc1Oc1ccccc1>>COc1ccc(Cl)cc1Oc1ccccc1, which requires the reactants: COc1ccc(Cl)cc1Oc1ccccc1 (not available, needs to be synthesized). This is step 3 in the retrosynthesis process. To synthesize <mol_start>COc1ccc(Cl)cc1Oc1ccccc1<mol_end>, follow these procedures: A mixture of 2-bromo-4-chloro-1-methoxy-benzene (8.0 g, 36.1 mmol), phenol (6.80 g, 72.2 mmol), cesium carbonate (23.54 g, 72.2 mmol), copper (I) chloride (1.79 g, 18.1 mmol) and 2,2,6,6-tetramethyl-3,5-heptanedione (1.66 g, 9.00 mmol) in dry 1-methyl-2-pyrrolidinone (80 mL) is heated to 120\u00b0 C. for 20 hours under N2. The reaction is cooled, filtered and the filtrate quenched with 1 N HCl (50 mL). The filtrate is diluted with Et2O and extracted with water. The organic layer is dried (Na2SO4), and the solvent is removed in vacuo to afford crude product that is absorbed on silica gel and purified by flash chromatography using 9/1 hexanes/EtOAc to afford 7.42 g (88%) of the title compound. Rf=0.37 (4/1 hexanes/EtOAc). The applied reaction is: <retro_start><retro_end>COc1ccc(Cl)cc1Oc1ccccc1>>CC(C)(C)C(=O)CC(=O)C(C)(C)C.COc1ccc(Cl)cc1Br.O=C(O)O.[Cs+].[Cs+].[O-]c1ccccc1, which requires the reactants: CC(C)(C)C(=O)CC(=O)C(C)(C)C (available), COc1ccc(Cl)cc1Br (available), O=C(O)O (available), [Cs+] (available), [Cs+] (available), [O-]c1ccccc1 (available).",
|
6 |
+
"property": {
|
7 |
+
"HIV": 1.0,
|
8 |
+
"BACE": 1.0,
|
9 |
+
"SC": 3.1,
|
10 |
+
"SA": 2.46
|
11 |
+
},
|
12 |
+
"retro": [
|
13 |
+
null,
|
14 |
+
null,
|
15 |
+
null
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"instruction": "What is a possible molecular structure that can penetrate the blood-brain barrier and inhibit Human Immunodeficiency Virus, given a molecular weight of 209.249 and specific ring and bond characteristics?",
|
20 |
+
"input": "",
|
21 |
+
"output": "To satisfy the requirements: The molecule has an aromatic ring and a common functional group of carboxyl, which suggests a certain level of structural complexity, likely influencing its physical and chemical properties, and indicating a moderate synthetic complexity score. Therefore, the designed molecule is: <design_start><design_end><mol_start>CC(C)(C)OC(=O)Nc1ccc(N)cn1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CC(C)(C)OC(=O)Nc1ccc(N)cn1<mol_end>, follow these procedures: To a suspension of (tert-butoxy)-N-(5-nitro(2-pyridyl))carboxamide (0.27 g, 1.13 mmol) in methanol (2 mL), ethyl acetate (4 mL) and TEA (0.16 mL) was added 10% Pd/C (60 mg, 0.056 mmol) under argon. The reaction mixture was hydrogenated under 1 atm H2 for 20 hr, filtered through Celite and concentrated in vacuo to give N-(5-amino(2-pyridyl))(tert-butoxy)carboxamide (0.226 g, 97%). 1H-NMR (DMSO-d6): \u03b4 1.40 (s, 9H), 4.92 (br s, 2H), 6.89-6.91 (dd, 1H), 7.35-7.37 (d, 1H), 7.58 (d, 1H), 9.06 (s, 1H). The applied reaction is: <retro_start><retro_end>CC(C)(C)OC(=O)Nc1ccc(N)cn1>>CC(C)(C)OC(=O)Nc1ccc([N+](=O)[O-])cn1 with the template [#7;a:4]:[c:3]:[c:2]-[NH2;D1;+0:1]>>O=[N+;H0;D3:1](-[O-])-[c:2]:[c:3]:[#7;a:4], which requires the reactants: CC(C)(C)OC(=O)Nc1ccc([N+](=O)[O-])cn1 (available).",
|
22 |
+
"property": {
|
23 |
+
"BBBP": 1.0,
|
24 |
+
"HIV": 1.0,
|
25 |
+
"SC": 2.03,
|
26 |
+
"SA": 2.08
|
27 |
+
},
|
28 |
+
"retro": [
|
29 |
+
2965
|
30 |
+
]
|
31 |
+
},
|
32 |
+
{
|
33 |
+
"instruction": "What molecular structure can be designed to exhibit Blood-Brain Barrier Penetration and Human Immunodeficiency Virus Inhibition while featuring a specific combination of molecular weight, ring count, and functional groups?",
|
34 |
+
"input": "",
|
35 |
+
"output": "To satisfy the requirements: The molecule features a complex structure with two aromatic rings and a high number of rotatable bonds, indicating potential difficulties in synthesis. Its chemical properties are influenced by the presence of functional groups such as carboxyl and amino groups. The predicted properties suggest high substrate specificity for CYP2C9 and moderate substrate specificity for CYP3A4, while also indicating a low risk of carcinogenicity and high clearance in microsomes. The molecule's hydration free energy and skin reaction predictions are also noteworthy. Overall, the molecule's unique combination of structural features and predicted chemical properties makes it an interesting candidate for further investigation. Therefore, the designed molecule is: <design_start><design_end><mol_start>CC(=O)c1sc(NS(=O)(=O)c2cc(Br)c(Cl)s2)nc1C<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CC(=O)c1sc(NS(=O)(=O)c2cc(Br)c(Cl)s2)nc1C<mol_end>, follow these procedures: The title compound was prepared from 5-acety1-2-amino-4-methylthiazole (42 mg) and 4-bromo-5-chlorothiophene-2-sulfonyl chloride (80 mg) as described in the synthetic METHOD B to give a white solid (11.7 mg) with purity >90%: MS (pos) m/z 415.3, 417.3. The applied reaction is: <retro_start><retro_end>CC(=O)c1sc(NS(=O)(=O)c2cc(Br)c(Cl)s2)nc1C>>CC(=O)c1sc(N)nc1C.O=S(=O)(Cl)c1cc(Br)c(Cl)s1 with the template [#16;a:5]:[c:4]-[S;H0;D4;+0:1](=[O;D1;H0:2])(=[O;D1;H0:3])-[NH;D2;+0:9]-[c:7](:[#16;a:6]):[#7;a:8]>>Cl-[S;H0;D4;+0:1](=[O;D1;H0:2])(=[O;D1;H0:3])-[c:4]:[#16;a:5].[#16;a:6]:[c:7](:[#7;a:8])-[NH2;D1;+0:9], which requires the reactants: CC(=O)c1sc(N)nc1C (available), O=S(=O)(Cl)c1cc(Br)c(Cl)s1 (available).",
|
36 |
+
"property": {
|
37 |
+
"BBBP": 1.0,
|
38 |
+
"HIV": 1.0,
|
39 |
+
"SC": 2.83,
|
40 |
+
"SA": 2.74
|
41 |
+
},
|
42 |
+
"retro": [
|
43 |
+
145329
|
44 |
+
]
|
45 |
+
},
|
46 |
+
{
|
47 |
+
"instruction": "Can you design a molecule that penetrates the blood-brain barrier, inhibits Human Immunodeficiency Virus, and doesn't inhibit Beta-Secretase 1, while having a specific set of structural properties?",
|
48 |
+
"input": "",
|
49 |
+
"output": "To satisfy the requirements: The molecule exhibits a moderate level of structural complexity, featuring three aromatic rings and an aliphatic chain, with a molecular weight of 216.27. Its synthetic complexity score is 2.88, indicating it may be moderately challenging to synthesize. This structure is predicted to have a relatively low likelihood of mutagenicity, with an AMES score of 0.608. It also shows moderate to low likelihood of blood-brain barrier penetration, with a BBB_Martins score of 0.576. The molecule's oral bioavailability is predicted to be moderate, with a Bioavailability_Ma score of 0.797. Additionally, it is predicted to have a low to moderate interaction with various cytochrome P450 enzymes. Overall, the molecule's drug utility is influenced by its structural features, bioactivity, and pharmacokinetic properties, which suggest it may have potential as a therapeutic agent. Therefore, the designed molecule is: <design_start><design_end><mol_start>c1csc(Nc2nccn3ccnc23)c1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>c1csc(Nc2nccn3ccnc23)c1<mol_end>, follow these procedures: To a solution of (3-bromo-imidazo[1,2-a]pyrazin-8-yl)-methyl-amine (50 mg, 0.2 mmol) in tetrahydrofuran (2 ml) under nitrogen was added 2-thiopheneboronic acid (41 mg, 0.3 mmol), K2CO3 (1.1 ml of a 1 M solution in water) and 16 mg (0.1 mmol) of PdCl2(Dppf)CH2Cl2 (16 mg, 0.1 eq). The mixture was heated at 70\u00b0 C. in a sealed tube overnight. The product was precipitated by adding methanol. The filtrate was evaporated and purified by preparative thin layer chromatography on silica gel to give additional title compound for a total of 45 mg (100% yield). The applied reaction is: <retro_start><retro_end>c1csc(Nc2nccn3ccnc23)c1>>CNc1nccn2c(Br)cnc12.O=C(O)O.OB(O)c1cccs1.[K+].[K+] with the template [#7;a:4]:[c:3]:[cH;D2;+0:1]:[#7;a:2].[#7:5]-[c;H0;D3;+0:6]1:[cH;D2;+0:9]:[c:8]:[cH;D2;+0:7]:[s;H0;D2;+0:10]:1>>Br-[c;H0;D3;+0:1](:[#7;a:2]):[c:3]:[#7;a:4].[#7:5]-[CH3;D1;+0:6].O-B(-O)-[c;H0;D3;+0:7]1:[c:8]:[cH;D2;+0:9]:c:[s;H0;D2;+0:10]:1, which requires the reactants: CNc1nccn2c(Br)cnc12 (available), O=C(O)O (available), OB(O)c1cccs1 (available), [K+] (available), [K+] (available).",
|
50 |
+
"property": {
|
51 |
+
"BBBP": 1.0,
|
52 |
+
"HIV": 1.0,
|
53 |
+
"BACE": 0.0,
|
54 |
+
"SC": 2.88,
|
55 |
+
"SA": 2.75
|
56 |
+
},
|
57 |
+
"retro": [
|
58 |
+
155241
|
59 |
+
]
|
60 |
+
},
|
61 |
+
{
|
62 |
+
"instruction": "What molecule can be designed and synthesized to penetrate the blood-brain barrier and inhibit Human Immunodeficiency Virus, given the constraints of a molecular weight around 209.245 and a single aromatic ring?",
|
63 |
+
"input": "",
|
64 |
+
"output": "To satisfy the requirements: The molecule featuring an aromatic ring and an aliphatic chain, with functional groups including hydroxyl and carboxyl, exhibits moderate synthetic complexity and moderate lipophilicity, which may influence its physical and chemical properties. Therefore, the designed molecule is: <design_start><design_end><mol_start>CC(C)(C)OC(=O)Nc1ccccc1O<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CC(C)(C)OC(=O)Nc1ccccc1O<mol_end>, follow these procedures: To a solution of 2-aminophenol (2.84 g) in dichloromethane (120 ml) was added di-tert-butyl dicarbonate (6.55 g). The mixture was stirred at room temperature for 18 hours. The mixture was partitioned between water and dichloromethane. The organic extracts were dried (MgSO4) and evaporated to give a solid which was purified by column chromatography on silica eluting with a mixture of ethyl acetate and hexane (20:80) to give 2-tert-butyloxycarbonylaminophenol (1.80 g);. The applied reaction is: <retro_start><retro_end>CC(C)(C)OC(=O)Nc1ccccc1O>>CC(C)(C)OC(=O)OC(=O)OC(C)(C)C.Nc1ccccc1[O-] with the template [C;D1;H3:1]-[C:2](-[C;D1;H3:3])(-[C;D1;H3:4])-[#8:5]-[C;H0;D3;+0:6](=[O;D1;H0:7])-[NH;D2;+0:8]-[c:9]>>[C;D1;H3:1]-[C:2](-[C;D1;H3:3])(-[C;D1;H3:4])-[#8:5]-[C;H0;D3;+0:6](=[O;D1;H0:7])-O-[C;H0;D3;+0:6](=[O;D1;H0:7])-[#8:5]-[C:2](-[C;D1;H3:1])(-[C;D1;H3:3])-[C;D1;H3:4].[NH2;D1;+0:8]-[c:9], which requires the reactants: CC(C)(C)OC(=O)OC(=O)OC(C)(C)C (available), Nc1ccccc1[O-] (available).",
|
65 |
+
"property": {
|
66 |
+
"BBBP": 1.0,
|
67 |
+
"HIV": 1.0,
|
68 |
+
"SC": 1.7,
|
69 |
+
"SA": 1.73
|
70 |
+
},
|
71 |
+
"retro": [
|
72 |
+
38065
|
73 |
+
]
|
74 |
+
}
|
75 |
+
]
|
data/molqa_material_examples.json
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"instruction": "What is the optimal molecular design and synthesis route for a polymer with high CO2 gas permeability and low permeability to N2 and O2, featuring an aromatic ring and specific functional groups?",
|
4 |
+
"input": "",
|
5 |
+
"output": "To satisfy the requirements: The polymer, featuring an aromatic ring and functional groups such as carboxyl and hydroxyl, exhibits a moderate synthetic complexity score and high gas permeability, indicating potential for use in gas separation membranes with good efficiency. Therefore, the designed molecule is: <design_start><design_end><mol_start>*c1cc(C(=O)OCC)c(*)[nH]1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CCOC(=O)c1cc[nH]c1<mol_end>, follow these procedures: Ethanol (3.2 mL, 54 mmol) and 4-dimethylamino pyridine (DMAP) (54 mg, 0.45 mmol) were added to a mixture of pyrrole-3-carboxylic acid (500 mg, 4.5 mmol) and dicyclohexylcarbodiimide (1.11 g, 5.4 mmol) (Aldrich, Milwaukee, Wis.) in tetrahydrofuran (THF) (15 mL). After heating at 60\u00b0 C. for 10 hours, the reaction was cooled. The precipitate was filtered off, washed with ethyl acetate, the combined filtrate was concentrated and purified on a silica gel column to give 500 mg (81%) of 1H-pyrrole-3-carboxylic acid ethyl ester as a colorless oil. The applied reaction is: <retro_start><retro_end>CCOC(=O)c1cc[nH]c1>>C(=NC1CCCCC1)=NC1CCCCC1.CCO.O=C(O)c1cc[nH]c1 with the template [#7;a:5]:[c:4]:[c:3]-[C;H0;D3;+0:1](=[O;D1;H0:2])-[O;H0;D2;+0:7]-[C:6]>>O-[C;H0;D3;+0:1](=[O;D1;H0:2])-[c:3]:[c:4]:[#7;a:5].[C:6]-[OH;D1;+0:7], which requires the reactants: C(=NC1CCCCC1)=NC1CCCCC1 (available), CCO (available), O=C(O)c1cc[nH]c1 (available).",
|
6 |
+
"property": {
|
7 |
+
"CO2": 0.94,
|
8 |
+
"N2": 0.0,
|
9 |
+
"O2": 0.0,
|
10 |
+
"FFV": 0.381,
|
11 |
+
"SC": 2.28,
|
12 |
+
"SA": 4.21
|
13 |
+
},
|
14 |
+
"retro": [
|
15 |
+
36155
|
16 |
+
]
|
17 |
+
},
|
18 |
+
{
|
19 |
+
"instruction": "Can a polymer with a molecular weight of 184.22, one aromatic ring, and specific gas permeability, thermal conductivity, and electrical properties be designed and synthesized for gas separation applications?",
|
20 |
+
"input": "",
|
21 |
+
"output": "To satisfy the requirements: The polymer has an aromatic ring incorporated within its structure, which can contribute to its overall rigidity and mechanical strength. The presence of the sulfonamide group can also enhance the polymer's chemical stability and resistance to degradation. The predicted thermal properties, such as melting point and thermal conductivity, suggest that the polymer may exhibit suitable thermal behavior for various applications. Additionally, the predicted dielectric constant and electrical conductivity indicate that the polymer may possess suitable electrical properties. The gas permeability coefficient and gas diffusion coefficient suggest that the polymer may be suitable for gas separation applications. Overall, the polymer's material utility seems promising for creating polymeric materials with desirable properties. Therefore, the designed molecule is: <design_start><design_end><mol_start>*Nc1cc(*)ccc1S(=O)(=O)NC<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CNS(=O)(=O)c1ccccc1N<mol_end>, follow these procedures: N-Methyl-2-nitrobenzenesulfonamide (1.0 g, 4.6 mmol), 10% palladium on carbon (100 mg) and absolute ethanol (30 mL) were placed in a Parr bottle, the bottle placed on a Parr apparatus under a hydrogen atmosphere (45 psi), and shaken for 1 hour, The Parr bottle was evacuated and flushed with nitrogen, then the mixture was filtered through a disposable syringe filter. The filtrate was concentrated in vacuo to yield 0.81 g of 2-amino-N-methylbenzenesulfonamide. The applied reaction is: <retro_start><retro_end>CNS(=O)(=O)c1ccccc1N>>CNS(=O)(=O)c1ccccc1[N+](=O)[O-] with the template [NH2;D1;+0:1]-[c:2]>>O=[N+;H0;D3:1](-[O-])-[c:2], which requires the reactants: CNS(=O)(=O)c1ccccc1[N+](=O)[O-] (available).",
|
22 |
+
"property": {
|
23 |
+
"CO2": 1.746,
|
24 |
+
"O2": 1.247,
|
25 |
+
"FFV": 0.37,
|
26 |
+
"TC": 0.192,
|
27 |
+
"SC": 2.8,
|
28 |
+
"SA": 4.17
|
29 |
+
},
|
30 |
+
"retro": [
|
31 |
+
3
|
32 |
+
]
|
33 |
+
},
|
34 |
+
{
|
35 |
+
"instruction": "Can a molecule be designed and synthesized to exhibit high CO2 permeability, moderate fractional free volume, and specific thermal conductivity, while incorporating aromatic and carboxyl functional groups?",
|
36 |
+
"input": "",
|
37 |
+
"output": "To satisfy the requirements: The polymer, with its aromatic ring and carboxyl functional group, exhibits a moderate synthetic complexity score, suggesting a relatively accessible molecular structure for synthesis. Its molecular structure and functional groups are expected to impart specific physical and chemical properties, influencing its material utility for polymeric materials. Therefore, the designed molecule is: <design_start><design_end><mol_start>*CC(*)OC(=O)c1cc(C)cc(Cl)c1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>CCOC(=O)c1cc(C)cc(Cl)c1<mol_end>, follow these procedures: The title compound was synthesized from 3-chloro-5-methyl-benzoic acid under the same conditions as for Compound b1. The applied reaction is: <retro_start><retro_end>CCOC(=O)c1cc(C)cc(Cl)c1>>CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1.Cc1cc(Cl)cc(C(=O)O)c1, which requires the reactants: CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1 (not available, needs to be synthesized), Cc1cc(Cl)cc(C(=O)O)c1 (available). This is step 2 in the retrosynthesis process. To synthesize <mol_start>CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1<mol_end>, follow these procedures: Potassium carbonate (1.5 g, 11.2 mmol) and ethyl iodide (1.2 g, 7.4 mmol) were added to a solution of 4-bromo-3-trifluoromethyl-benzoic acid (1.0 g, 3.7 mmol) in DMF (5 ml), and the mixture was stirred at room temperature for 24 hours. The reaction mixture was diluted with ethyl acetate, and the organic layer was washed with water and saturated-saline, and dried over anhydrous sodium sulfate. The drying agent was removed by filtration. After concentration under reduced pressure, the resulting residue was purified by silica gel column chromatography (ethyl acetate/hexane) to yield the title compound (1.03 g, 94%) as a brown oily substance. The applied reaction is: <retro_start><retro_end>CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1>>CCI.O=C(O)c1ccc(Br)c(C(F)(F)F)c1.O=C([O-])[O-].[K+].[K+], which requires the reactants: CCI (available), O=C(O)c1ccc(Br)c(C(F)(F)F)c1 (available), O=C([O-])[O-] (available), [K+] (available), [K+] (available).",
|
38 |
+
"property": {
|
39 |
+
"CO2": 0.94,
|
40 |
+
"N2": 0.0,
|
41 |
+
"O2": 0.0,
|
42 |
+
"FFV": 0.375,
|
43 |
+
"TC": 0.223,
|
44 |
+
"SC": 2.58,
|
45 |
+
"SA": 4.01
|
46 |
+
},
|
47 |
+
"retro": [
|
48 |
+
null,
|
49 |
+
null
|
50 |
+
]
|
51 |
+
},
|
52 |
+
{
|
53 |
+
"instruction": "What molecular structure can be designed and synthesized to achieve CO2 permeability of 1.743 Barrer, N2 impermeability, and specific thermal conductivity and fractional free volume properties, while incorporating an aromatic ring and limited rotatable bonds?",
|
54 |
+
"input": "",
|
55 |
+
"output": "To satisfy the requirements: The polymer is a complex molecular structure incorporating an aromatic ring and a bromine functional group, which may provide enhanced thermal and mechanical properties, potentially useful for gas separation applications. Therefore, the designed molecule is: <design_start><design_end><mol_start>*c1ccc(*)c(CBr)c1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>BrCc1ccccc1<mol_end>, follow these procedures: 4.65 g (10 mM) of compound 1a were treated, while stirring, with 40 ml of 2 N HBr in glacial acetic acid for 45 min. at 20\u00b0 in the absence of moisture. The amino acid derivative dissolved with CO2 evolution. The reaction solution was added dropwise with vigorous stirring to 250 ml of absolute ether which resulted in the precipitation of 2HBr.H-Arg-pNA. The ethereal phase was sucked off, whereupon the solid phase was washed 4 times with portions of 100 ml of abs. ether in order to substantially remove benzyl bromide which had formed as a by-product as well as excess HBr and AcOH. The residue was dissolved in 50 ml of MeOH, the pH was adjusted to 4.5 by the addition of Et3N, and the solution was concentrated to dryness in vacuo at 30\u00b0. The resulting product was dissolved in 75 ml of MeOH and passed through a column of \"Sephadex\" LH-20 (cross-linked dextran gel) equilibrated with MeOH. From a fraction of the eluate there were obtained 4.18 g (91.6% of the theory) of amorphous compound 1b which was homogeneous in the SS as shown by TLC. Elementary analysis and calculation from the empirical formula C12H20N6O3Br2 gave the following values: C=31.15% (31.60%), H=4.35% (4.42%), N=18.84% (18.43%) and Br=34.81% (35.03%). The applied reaction is: <retro_start><retro_end>BrCc1ccccc1>>CC(C)(C#N)/N=N/C(C)(C)C#N.Cc1ccc(S(=O)(=O)Cl)cc1.O=C1CCC(=O)N1Br, which requires the reactants: CC(C)(C#N)/N=N/C(C)(C)C#N (available), Cc1ccc(S(=O)(=O)Cl)cc1 (available), O=C1CCC(=O)N1Br (available).",
|
56 |
+
"property": {
|
57 |
+
"CO2": 1.743,
|
58 |
+
"N2": 0.0,
|
59 |
+
"O2": 4.692,
|
60 |
+
"FFV": 0.411,
|
61 |
+
"TC": 0.256,
|
62 |
+
"SC": 2.25,
|
63 |
+
"SA": 5.14
|
64 |
+
},
|
65 |
+
"retro": [
|
66 |
+
null
|
67 |
+
]
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"instruction": "Can you design a molecule with moderate synthetic complexity, featuring a carboxyl functional group and an aromatic ring, that exhibits CO2 gas permeability of 9.176 Barrer, O2 gas permeability of 1.494 Barrer, fractional free volume of 0.374, and thermal conductivity of 0.228 W m^{-1} K^{-1}?",
|
71 |
+
"input": "",
|
72 |
+
"output": "To satisfy the requirements: The polymer exhibits a moderate synthetic complexity score, indicating that its synthesis is relatively feasible. The presence of an aromatic ring and a carboxyl functional group in its scaffold contributes to its moderate complexity. Therefore, the designed molecule is: <design_start><design_end><mol_start>*OC(=O)c1ccc(*)s1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>O=C([O-])c1cccs1<mol_end>, follow these procedures: To a 15-mL polypropylene centrifuge tube was added 3.86 mL of 50 mM potassium phosphate buffer (pH 7.0), 1.0 mL of a suspension of 22.1 mg dry cell weight E. coli SW132 wet cells (prepared as described in Example 10) in 50 mM potassium phosphate buffer (pH 7.0), and 0.1691 g of 2-thiophenecarbonitrile. The final concentration of 2-thiophencarbonitrile was 0.307 M. The reaction mixture was mixed on a rotating platform at 27\u00b0 C. After 30 min, 7.50 mL of 95:5 acetonitrile/water containing 0.30 M N,N-dimethylbenzamide (HPLC external standard) was added to the reaction, the resulting mixture centrifuged, and a 0.100 mL of the supernatant mixed with 0.900 mL of acetonitrile and analyzed by HPLC. The conversion of 2-thiophenecarbonitrile was 99.5%, and the yields of 2-thiophenecarboxamide and 2-thiophenecarboxylic acid were 98% and 0%, respectively. The applied reaction is: <retro_start><retro_end>O=C([O-])c1cccs1>>CN(C)C(=O)c1ccccc1.N#Cc1cccs1.O=P([O-])([O-])O.[2H]C([2H])([2H])C#N.[K+].[K+].[K+].[OH-], which requires the reactants: CN(C)C(=O)c1ccccc1 (available), N#Cc1cccs1 (available), O=P([O-])([O-])O (available), [2H]C([2H])([2H])C#N (available), [K+] (available), [K+] (available), [K+] (available), [OH-] (available).",
|
73 |
+
"property": {
|
74 |
+
"CO2": 9.176,
|
75 |
+
"O2": 1.494,
|
76 |
+
"FFV": 0.374,
|
77 |
+
"TC": 0.228,
|
78 |
+
"SC": 2.23,
|
79 |
+
"SA": 4.78
|
80 |
+
},
|
81 |
+
"retro": [
|
82 |
+
null
|
83 |
+
]
|
84 |
+
}
|
85 |
+
]
|
data/molqa_train_examples.json
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[
|
2 |
+
{
|
3 |
+
"instruction": "What is a feasible molecular design that meets the criteria of a molecular weight of 243.306, two rings including one aromatic and one aliphatic, two rotatable bonds, and two hydrogen bond acceptors, with a balanced hydrophobic and hydrophilic profile?",
|
4 |
+
"input": "",
|
5 |
+
"output": "To satisfy the requirements: The molecule has a scaffold consisting of a central aromatic ring linked to an amino-functionalized aliphatic chain, featuring a carboxyl group. Its structural complexity is moderate, with a balance of hydrophobic and hydrophilic regions, which could influence its physical and chemical properties. Therefore, the designed molecule is: <design_start><design_end><mol_start>C#CCC1CCN(C(=O)Oc2ccccc2)CC1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>C#CCC1CCN(C(=O)Oc2ccccc2)CC1<mol_end>, follow these procedures: Phenyl chloroformate (6.2 g, 40.2 mmol) was added to a solution of tert-butyl 4-(prop-2-ynyl)piperidine-1-carboxylate (1.65 g, 13.4 mmol) according to general procedure 1. Yield=0.600 g, 34%. m/z MH+=244.08. HPLC rt=10.3 min. The applied reaction is: <retro_start><retro_end>C#CCC1CCN(C(=O)Oc2ccccc2)CC1>>C#CCC1CCN(C(=O)OC(C)(C)C)CC1.O=C(Cl)Oc1ccccc1 with the template [#8:5]-[C;H0;D3;+0:4](=[O;D1;H0:6])-[N;H0;D3;+0:1](-[C:2])-[C:3]>>C-C(-C)(-C)-O-C(=O)-[N;H0;D3;+0:1](-[C:2])-[C:3].Cl-[C;H0;D3;+0:4](-[#8:5])=[O;D1;H0:6], which requires the reactants: C#CCC1CCN(C(=O)OC(C)(C)C)CC1 (available), O=C(Cl)Oc1ccccc1 (available).",
|
6 |
+
"property": {
|
7 |
+
"SC": 2.42,
|
8 |
+
"SA": 2.22
|
9 |
+
},
|
10 |
+
"retro": [
|
11 |
+
180575
|
12 |
+
]
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"instruction": "Can you design and synthesize a molecule with a molecular weight of around 206, containing 2 rings including 1 aromatic and 1 aliphatic, and 3 rotatable bonds, with no hydrogen bond donors and 2 acceptors?",
|
16 |
+
"input": "",
|
17 |
+
"output": "To satisfy the requirements: The molecule has a relatively low synthetic complexity score and a moderate synthetic accessibility score, indicating that it is relatively easy to synthesize. Its molecular structure is composed of an aromatic ring and an aliphatic chain, which may influence its physical and chemical properties. Therefore, the designed molecule is: <design_start><design_end><mol_start>Cc1ccc(C)n1CCN1CCCC1=O<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>Cc1ccc(C)n1CCN1CCCC1=O<mol_end>, follow these procedures: 5.7 g (0.05 mmol) of acetonylacetone and 6.4 g (0.05 mmol) of 1-(2-aminoethyl)-2-pyrrolidinone are heated under reflux in 70 ml of methanol for 2 h. The reaction mixture is then concentrated and, after the concentration, the crude reaction product is then precipitated by addition of petroleum ether. The crude product is recrystallized from diethyl ether. Yield: 3.2 g (31% of theory), Melting point: 66\u00b0-68\u00b0 C. Elemental analysis: C12H18N2O (206.29) calculated: C 69.9 H 8.8 N 13.6 0 7.8 found: C 69.4 H 8.7 N 13.7 0 8.11. The applied reaction is: <retro_start><retro_end>Cc1ccc(C)n1CCN1CCCC1=O>>CC(=O)CCC(C)=O.NCCN1CCCC1=O with the template [C:7]-[n;H0;D3;+0:8]1:[c;H0;D3;+0:1](-[C;D1;H3:2]):[cH;D2;+0:3]:[cH;D2;+0:4]:[c;H0;D3;+0:5]:1-[C;D1;H3:6]>>O=[C;H0;D3;+0:1](-[C;D1;H3:2])-[CH2;D2;+0:3]-[CH2;D2;+0:4]-[C;H0;D3;+0:5](=O)-[C;D1;H3:6].[C:7]-[NH2;D1;+0:8], which requires the reactants: CC(=O)CCC(C)=O (available), NCCN1CCCC1=O (available).",
|
18 |
+
"property": {
|
19 |
+
"SC": 2.69,
|
20 |
+
"SA": 2.34
|
21 |
+
},
|
22 |
+
"retro": [
|
23 |
+
18636
|
24 |
+
]
|
25 |
+
},
|
26 |
+
{
|
27 |
+
"instruction": "Can you design and synthesize a molecule with a molecular weight around 377, incorporating four rings including two aromatic and two aliphatic rings, and six rotatable bonds, with no hydrogen bond donors and four acceptors?",
|
28 |
+
"input": "",
|
29 |
+
"output": "To satisfy the requirements: The molecule has a moderately complex structure built around a core scaffold that incorporates aromatic rings and aliphatic chains, with the presence of functional groups and a moderate number of rotatable bonds, contributing to its moderate synthetic complexity score. Therefore, the designed molecule is: <design_start><design_end><mol_start>Cc1ccccc1N1CCN(CCCCN2C(=O)c3ccccc3C2=O)CC1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>Cc1ccccc1N1CCN(CCCCN2C(=O)c3ccccc3C2=O)CC1<mol_end>, follow these procedures: The title compound was synthesized applying the general procedure III step A using 1-(2-methylphenyl)piperazine (1.99 mmol, 350 mg), N-(4-bromobutyl)phthalimide (1.99 mmol, 560 mg), and K2CO3 (4.97 mmol, 686 mg) in 6 mL of acetonitrile. White solid 750 mg (99%). 1H NMR (400 MHz, Chloroform-d) \u03b4 7.92-7.79 (m, 2H), 7.72 (dt, J=5.5, 3.1 Hz, 2H), 7.16 (t, J=7.6 Hz, 2H), 7.08-6.88 (m, 2H), 3.74 (t, J=7.1 Hz, 2H), 2.93 (t, J=4.8 Hz, 4H), 2.60 (s, 4H), 2.52-2.39 (m, 2H), 2.30 (s, 3H), 1.75 (p, J=7.3 Hz, 2H), 1.59 (tt, J=9.7, 5.9 Hz, 2H). The applied reaction is: <retro_start><retro_end>Cc1ccccc1N1CCN(CCCCN2C(=O)c3ccccc3C2=O)CC1>>Cc1ccccc1[NH+]1CC[NH2+]CC1.O=C(O)O.O=C1c2ccccc2C(=O)N1CCCCBr.[K+].[K+], which requires the reactants: Cc1ccccc1[NH+]1CC[NH2+]CC1 (available), O=C(O)O (available), O=C1c2ccccc2C(=O)N1CCCCBr (available), [K+] (available), [K+] (available).",
|
30 |
+
"property": {
|
31 |
+
"SC": 2.93,
|
32 |
+
"SA": 1.97
|
33 |
+
},
|
34 |
+
"retro": [
|
35 |
+
null
|
36 |
+
]
|
37 |
+
},
|
38 |
+
{
|
39 |
+
"instruction": "Can you design a molecule with a molecular weight around 242, one aromatic ring, three rotatable bonds, and three hydrogen bond acceptors, and describe its synthesis process?",
|
40 |
+
"input": "",
|
41 |
+
"output": "To satisfy the requirements: The molecule has an aromatic ring and an aliphatic chain, with the presence of carboxyl, amino, and halogen groups, which contribute to its moderate synthetic complexity score, indicating a relatively challenging synthesis process. Therefore, the designed molecule is: <design_start><design_end><mol_start>COC(=O)c1ccccc1N(C)C(=O)CCl<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>COC(=O)c1ccccc1N(C)C(=O)CCl<mol_end>, follow these procedures: 25 Grams of methyl o-methylaminobenzoate was dissolved in 125 ml of acetone, to this solution was added a solution prepared by dissolving 20.92 g of potassium carbonate in 60 ml of water. Under an ice-cooled condition, 17.95 g of chloroacetyl chloride was added dropwise thereto, then the reaction mixture was stired at a room temperature for 1 hour. Acetone was removed by distillation, the residue obtained was extracted with chloroform, then the chloroform layer was washed with water, dried and the solvent was removed by distillation. The residue thus obtained was purified by a silica gel column chromatography to obtain 16.80 g of N-methyl-o-methoxycarbonyl-\u03b1-chloroacetanilide in the form of colorless oily substance. The applied reaction is: <retro_start><retro_end>COC(=O)c1ccccc1N(C)C(=O)CCl>>CNc1ccccc1C(=O)OC.O=C(Cl)CCl.O=C(O)O.[K+].[K+] with the template [C:2]-[C;H0;D3;+0:1](=[O;D1;H0:3])-[N;H0;D3;+0:5](-[C;D1;H3:4])-[c:6]>>Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3].[C;D1;H3:4]-[NH;D2;+0:5]-[c:6], which requires the reactants: CNc1ccccc1C(=O)OC (available), O=C(Cl)CCl (available), O=C(O)O (available), [K+] (available), [K+] (available).",
|
42 |
+
"property": {
|
43 |
+
"SC": 2.13,
|
44 |
+
"SA": 2.07
|
45 |
+
},
|
46 |
+
"retro": [
|
47 |
+
39003
|
48 |
+
]
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"instruction": "Can you design a molecule with a molecular weight of 317.186, three aromatic rings, and two rotatable bonds, and outline a synthesis pathway for it?",
|
52 |
+
"input": "",
|
53 |
+
"output": "To satisfy the requirements: The molecule features a complex aromatic scaffold with multiple rings, exhibiting moderate synthetic complexity and potential bioavailability, while also showing possible interactions with nuclear receptors and stress response pathways. Therefore, the designed molecule is: <design_start><design_end><mol_start>COc1ccccc1-c1cn2cc(Br)cc(C)c2n1<mol_end>. This is step 1 in the retrosynthesis process. To synthesize <mol_start>COc1ccccc1-c1cn2cc(Br)cc(C)c2n1<mol_end>, follow these procedures: A mixture of 2-bromo-2\u2032-methoxyacetophenone (0.500 g, 2.183 mmol) and 2-amino-5-bromo-3-methylpyridine (0.408 g, 2.183 mmol) in ethanol (8 mL) was heated to 80\u00b0 C. for 16 hours in a pressure vessel. A yellow thick precipitate formed. Water was added and the solid was extracted with ethyl acetate (some DCM and MeOH were added to help in the dissolution). The organic extracts were then washed with brine, dried over Na2SO4, filtered and concentrated and the residue was purified by silica gel chromatography (Biotage SNAP 100 g, 10 to 20% ethyl acetate in hexanes) to give 220 (0.375 g, 1.182 mmol, 54.2% yield) as a white solid. LRMS (ESI): calc. 316.02 found 317.1 (MH)+. The applied reaction is: <retro_start><retro_end>COc1ccccc1-c1cn2cc(Br)cc(C)c2n1>>COc1ccccc1C(=O)CBr.Cc1cc(Br)cnc1N.[OH-] with the template [c:4]:[c;H0;D3;+0:3](:[c:5])-[c;H0;D3;+0:2]1:[cH;D2;+0:1]:[n;H0;D3;+0:8](:[c:9]):[c:7]:[n;H0;D2;+0:6]:1>>Br-[CH2;D2;+0:1]-[C;H0;D3;+0:2](=O)-[c;H0;D3;+0:3](:[c:4]):[c:5].[NH2;D1;+0:6]-[c:7]:[n;H0;D2;+0:8]:[c:9], which requires the reactants: COc1ccccc1C(=O)CBr (available), Cc1cc(Br)cnc1N (available), [OH-] (available).",
|
54 |
+
"property": {
|
55 |
+
"SC": 2.94,
|
56 |
+
"SA": 2.19
|
57 |
+
},
|
58 |
+
"retro": [
|
59 |
+
10919
|
60 |
+
]
|
61 |
+
}
|
62 |
+
]
|
data/property_ranges.json
ADDED
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"BBBP": {
|
3 |
+
"min": 0.0,
|
4 |
+
"max": 1.0
|
5 |
+
},
|
6 |
+
"HIV": {
|
7 |
+
"min": 0.0,
|
8 |
+
"max": 1.0
|
9 |
+
},
|
10 |
+
"BACE": {
|
11 |
+
"min": 0.0,
|
12 |
+
"max": 1.0
|
13 |
+
},
|
14 |
+
"CO2": {
|
15 |
+
"min": 0.94,
|
16 |
+
"max": 1019.265
|
17 |
+
},
|
18 |
+
"N2": {
|
19 |
+
"min": 0.0,
|
20 |
+
"max": 73.417
|
21 |
+
},
|
22 |
+
"O2": {
|
23 |
+
"min": 0.0,
|
24 |
+
"max": 122.94
|
25 |
+
},
|
26 |
+
"FFV": {
|
27 |
+
"min": 0.324,
|
28 |
+
"max": 0.434
|
29 |
+
},
|
30 |
+
"TC": {
|
31 |
+
"min": 0.117,
|
32 |
+
"max": 0.38
|
33 |
+
},
|
34 |
+
"SC": {
|
35 |
+
"min": 1.0,
|
36 |
+
"max": 5.0
|
37 |
+
},
|
38 |
+
"SA": {
|
39 |
+
"min": 1.0,
|
40 |
+
"max": 8.48
|
41 |
+
}
|
42 |
+
}
|
install_environment.sh
ADDED
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/bin/bash
|
2 |
+
|
3 |
+
# Set non-interactive frontend
|
4 |
+
export DEBIAN_FRONTEND=noninteractive
|
5 |
+
|
6 |
+
# Activate the llama_factory environment
|
7 |
+
# Uncomment the following line if you need to create the environment
|
8 |
+
# conda create --name llamole python=3.11 -y
|
9 |
+
# conda activate llamole
|
10 |
+
|
11 |
+
# Function to get the current conda environment name
|
12 |
+
get_current_conda_env() {
|
13 |
+
echo "current: $CONDA_DEFAULT_ENV"
|
14 |
+
}
|
15 |
+
|
16 |
+
# Get the current conda environment
|
17 |
+
current_env=$(basename "$CONDA_PREFIX")
|
18 |
+
|
19 |
+
# Check if the current environment is 'llamole'
|
20 |
+
if [ "$current_env" != "llamole" ]; then
|
21 |
+
echo "current: $CONDA_DEFAULT_ENV"
|
22 |
+
echo "Current conda environment is neither 'llama_factory' nor 'llamole'."
|
23 |
+
echo "Please activate one of these environments before running this script."
|
24 |
+
echo "You can activate an environment using one of these commands:"
|
25 |
+
echo "conda activate llama_factory"
|
26 |
+
echo "conda activate llamole"
|
27 |
+
exit 1
|
28 |
+
fi
|
29 |
+
|
30 |
+
echo "Running in conda environment: $current_env"
|
31 |
+
|
32 |
+
# "pandas>=2.0.0" \
|
33 |
+
# Install packages using pip
|
34 |
+
pip install --no-cache-dir \
|
35 |
+
pyarrow \
|
36 |
+
"pandas>=1.5.3" \
|
37 |
+
"rdkit==2023.9.6" \
|
38 |
+
pyyaml \
|
39 |
+
ipykernel \
|
40 |
+
packaging \
|
41 |
+
gdown \
|
42 |
+
"fcd_torch==1.0.7" \
|
43 |
+
"omegaconf==2.3.0" \
|
44 |
+
"imageio==2.26.0" \
|
45 |
+
wandb \
|
46 |
+
pandarallel \
|
47 |
+
scipy \
|
48 |
+
einops \
|
49 |
+
sentencepiece \
|
50 |
+
tiktoken \
|
51 |
+
protobuf \
|
52 |
+
uvicorn \
|
53 |
+
pydantic \
|
54 |
+
fastapi \
|
55 |
+
sse-starlette \
|
56 |
+
"matplotlib>=3.7.0" \
|
57 |
+
fire \
|
58 |
+
"numpy<2.0.0" \
|
59 |
+
gradio
|
60 |
+
|
61 |
+
pip install --no-cache-dir hydra-core --upgrade
|
62 |
+
|
63 |
+
# Install PyTorch
|
64 |
+
pip install --no-cache-dir torch
|
65 |
+
|
66 |
+
# Install PyTorch Geometric and related packages
|
67 |
+
pip install --no-cache-dir torch_geometric
|
68 |
+
|
69 |
+
# for retro reaction
|
70 |
+
pip install rdchiral
|
71 |
+
pip install nltk
|
72 |
+
|
73 |
+
# Install transformers and related packages
|
74 |
+
pip install --no-cache-dir \
|
75 |
+
"transformers>=4.41.3" \
|
76 |
+
"datasets>=2.16.0" \
|
77 |
+
"accelerate>=0.30.1" \
|
78 |
+
"peft>=0.11.1" \
|
79 |
+
"trl>=0.8.6" \
|
80 |
+
"gradio>=4.0.0"
|
81 |
+
|
82 |
+
# Install mini-moses from GitHub
|
83 |
+
pip install --no-cache-dir git+https://github.com/igor-krawczuk/mini-moses
|
84 |
+
|
85 |
+
echo "Installation complete!"
|
launch.py
ADDED
@@ -0,0 +1,238 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Llamole Team
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's transformers library.
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import json
|
19 |
+
import yaml
|
20 |
+
import numpy as np
|
21 |
+
import gradio as gr
|
22 |
+
import random
|
23 |
+
from rdkit import Chem
|
24 |
+
from rdkit.Chem import Draw
|
25 |
+
from rdkit.Chem import AllChem
|
26 |
+
|
27 |
+
from src.webui.workflow import load_model_and_tokenizer, process_input, generate
|
28 |
+
from src.webui.elements import create_input_components
|
29 |
+
|
30 |
+
# Load candidates
|
31 |
+
with open('data/molqa_material_examples.json', 'r') as f:
|
32 |
+
material_examples = json.load(f)
|
33 |
+
|
34 |
+
with open('data/molqa_drug_examples.json', 'r') as f:
|
35 |
+
drug_examples = json.load(f)
|
36 |
+
|
37 |
+
# Add type to each example
|
38 |
+
for example in material_examples:
|
39 |
+
example['type'] = 'Material'
|
40 |
+
|
41 |
+
for example in drug_examples:
|
42 |
+
example['type'] = 'Drug'
|
43 |
+
|
44 |
+
# Function to process property values
|
45 |
+
def process_property(value):
|
46 |
+
return 1e-8 if value == 0 else value
|
47 |
+
|
48 |
+
# Add type to each example and process property values
|
49 |
+
for example in material_examples:
|
50 |
+
example['type'] = 'Material'
|
51 |
+
for prop in ['CO2', 'N2', 'O2', 'FFV']:
|
52 |
+
if prop in example['property']:
|
53 |
+
example['property'][prop] = process_property(example['property'][prop])
|
54 |
+
|
55 |
+
# Combine examples
|
56 |
+
all_examples = material_examples + drug_examples
|
57 |
+
|
58 |
+
# Get default values from the first material example
|
59 |
+
default_values = drug_examples[0]
|
60 |
+
|
61 |
+
# Load property ranges and arguments
|
62 |
+
with open('data/property_ranges.json', 'r') as f:
|
63 |
+
property_ranges = json.load(f)
|
64 |
+
|
65 |
+
# with open('config/generate/qwen_material.yaml', 'r') as file:
|
66 |
+
with open('config/generate/llama_material.yaml', 'r') as file:
|
67 |
+
args_dict = yaml.safe_load(file)
|
68 |
+
|
69 |
+
# Load model and tokenizer outside the function
|
70 |
+
model, tokenizer, generating_args = load_model_and_tokenizer(args_dict)
|
71 |
+
|
72 |
+
def format_example(example):
|
73 |
+
formatted = [example['instruction']]
|
74 |
+
|
75 |
+
# Determine if it's a drug or material example based on properties
|
76 |
+
is_drug = any(prop in example.get('property', {}) for prop in ["HIV", "BBBP", "BACE"])
|
77 |
+
formatted.append("Drug" if is_drug else "Material")
|
78 |
+
|
79 |
+
# Handle drug properties
|
80 |
+
for prop in ["HIV", "BBBP", "BACE"]:
|
81 |
+
value = example.get('property', {}).get(prop, float('nan'))
|
82 |
+
formatted.append(value if not np.isnan(value) else "NAN")
|
83 |
+
|
84 |
+
# Handle material properties
|
85 |
+
for prop in ["CO2", "N2", "O2", "FFV", "TC"]:
|
86 |
+
value = example.get('property', {}).get(prop, float('nan'))
|
87 |
+
formatted.append(value if not np.isnan(value) else 0) # 0 represents NAN for material properties
|
88 |
+
|
89 |
+
# Handle synthetic properties
|
90 |
+
for prop in ["SC", "SA"]:
|
91 |
+
value = example.get('property', {}).get(prop, float('nan'))
|
92 |
+
formatted.append(value if not np.isnan(value) else float('nan'))
|
93 |
+
|
94 |
+
return formatted
|
95 |
+
|
96 |
+
# Prepare examples
|
97 |
+
formatted_examples = [format_example(example) for example in all_examples]
|
98 |
+
|
99 |
+
def random_example(examples):
|
100 |
+
example = random.choice(examples)
|
101 |
+
property_type = example['type']
|
102 |
+
|
103 |
+
outputs = [example['instruction'], property_type]
|
104 |
+
|
105 |
+
for prop in ["HIV", "BBBP", "BACE"]:
|
106 |
+
outputs.append(example['property'].get(prop, "NAN"))
|
107 |
+
|
108 |
+
for prop in ["CO2", "N2", "O2", "FFV", "TC"]:
|
109 |
+
outputs.append(example['property'].get(prop, 0))
|
110 |
+
|
111 |
+
for prop in ["SC", "SA"]:
|
112 |
+
outputs.append(example['property'].get(prop, float('nan')))
|
113 |
+
|
114 |
+
return outputs
|
115 |
+
|
116 |
+
def generate_and_visualize(instruction, property_type, HIV, BBBP, BACE, CO2, N2, O2, FFV, TC, SC, SA):
|
117 |
+
properties = {
|
118 |
+
"HIV": float('nan') if HIV == "NAN" else HIV,
|
119 |
+
"BBBP": float('nan') if BBBP == "NAN" else BBBP,
|
120 |
+
"BACE": float('nan') if BACE == "NAN" else BACE,
|
121 |
+
"CO2": float('nan') if CO2 == 0 else CO2,
|
122 |
+
"N2": float('nan') if N2 == 0 else N2,
|
123 |
+
"O2": float('nan') if O2 == 0 else O2,
|
124 |
+
"FFV": float('nan') if FFV == 0 else FFV,
|
125 |
+
"TC": float('nan') if TC == 0 else TC,
|
126 |
+
"SC": SC,
|
127 |
+
"SA": SA
|
128 |
+
}
|
129 |
+
|
130 |
+
# Filter out NaN values
|
131 |
+
properties = {k: v for k, v in properties.items() if not np.isnan(v)}
|
132 |
+
|
133 |
+
print('instruction', instruction)
|
134 |
+
print('properties', properties)
|
135 |
+
results = run_molqa(instruction, **properties)
|
136 |
+
|
137 |
+
llm_response = results.get('llm_response', 'No response generated')
|
138 |
+
llm_smiles = results.get('llm_smiles')
|
139 |
+
llm_reactions = results['llm_reactions']
|
140 |
+
|
141 |
+
molecule_img = visualize_molecule(llm_smiles) if llm_smiles else None
|
142 |
+
|
143 |
+
reaction_steps = []
|
144 |
+
reaction_imgs = []
|
145 |
+
if llm_reactions:
|
146 |
+
for i, reaction_dict in enumerate(llm_reactions):
|
147 |
+
reaction = reaction_dict.get('reaction')
|
148 |
+
if reaction:
|
149 |
+
reaction_steps.append(f"Step {i+1}: {reaction}")
|
150 |
+
reaction_imgs.append(visualize_reaction(reaction))
|
151 |
+
|
152 |
+
return (
|
153 |
+
llm_response,
|
154 |
+
llm_smiles if llm_smiles else "No SMILES generated",
|
155 |
+
molecule_img,
|
156 |
+
gr.JSON(value=reaction_steps, visible=bool(reaction_steps)),
|
157 |
+
gr.Gallery(value=reaction_imgs, visible=bool(reaction_imgs))
|
158 |
+
)
|
159 |
+
|
160 |
+
def run_molqa(instruction: str, **properties) -> dict:
|
161 |
+
# Filter out properties with NaN values
|
162 |
+
filtered_properties = {k: v for k, v in properties.items() if not np.isnan(v)}
|
163 |
+
|
164 |
+
input_data = {
|
165 |
+
"instruction": instruction,
|
166 |
+
"input": "",
|
167 |
+
"property": filtered_properties
|
168 |
+
}
|
169 |
+
|
170 |
+
dataloader, gen_kwargs = process_input(input_data, model, tokenizer, generating_args)
|
171 |
+
generated_results = generate(model, dataloader, gen_kwargs)
|
172 |
+
|
173 |
+
return generated_results
|
174 |
+
|
175 |
+
def visualize_molecule(smiles: str) -> np.ndarray:
|
176 |
+
mol = Chem.MolFromSmiles(smiles)
|
177 |
+
if mol is not None:
|
178 |
+
img = Draw.MolToImage(mol)
|
179 |
+
return np.array(img)
|
180 |
+
return np.zeros((300, 300, 3), dtype=np.uint8)
|
181 |
+
|
182 |
+
def visualize_reaction(reaction: str) -> np.ndarray:
|
183 |
+
rxn = AllChem.ReactionFromSmarts(reaction, useSmiles=True)
|
184 |
+
if rxn is not None:
|
185 |
+
img = Draw.ReactionToImage(rxn)
|
186 |
+
return np.array(img)
|
187 |
+
return np.zeros((300, 300, 3), dtype=np.uint8)
|
188 |
+
|
189 |
+
# Define property names and their full descriptions
|
190 |
+
property_names = {
|
191 |
+
"HIV": "HIV virus replication inhibition",
|
192 |
+
"BBBP": "Blood-brain barrier permeability",
|
193 |
+
"BACE": "Human β-secretase 1 inhibition",
|
194 |
+
"CO2": "CO2 Perm",
|
195 |
+
"N2": "N2 Perm",
|
196 |
+
"O2": "O2 Perm",
|
197 |
+
"FFV": "Fractional free volume",
|
198 |
+
"TC": "Thermal conductivity",
|
199 |
+
"SC": "Heuristic Synthetic Scores (SCScore)",
|
200 |
+
"SA": "Synthetic Synthetic Scores (SAScore)"
|
201 |
+
}
|
202 |
+
|
203 |
+
# Define outputs
|
204 |
+
outputs = [
|
205 |
+
gr.Textbox(label="Overall LLM Response"),
|
206 |
+
gr.Textbox(label="Generated SMILES"),
|
207 |
+
gr.Image(label="Generated Molecule"),
|
208 |
+
gr.JSON(label="Reaction Steps"),
|
209 |
+
gr.Gallery(label="Reaction Visualizations")
|
210 |
+
]
|
211 |
+
|
212 |
+
with gr.Blocks() as iface:
|
213 |
+
gr.Markdown("# Llamole Demo Interface")
|
214 |
+
gr.Markdown("Enter an instruction and property values to generate a molecule design.")
|
215 |
+
|
216 |
+
interface, instruction, property_type, drug_properties, material_properties, synthetic_properties = create_input_components(default_values, property_names, property_ranges)
|
217 |
+
|
218 |
+
random_btn = gr.Button("Random Example")
|
219 |
+
generate_btn = gr.Button("Generate")
|
220 |
+
|
221 |
+
for output in outputs:
|
222 |
+
output.render()
|
223 |
+
|
224 |
+
# Update the inputs for the generate button
|
225 |
+
all_inputs = [instruction, property_type]
|
226 |
+
all_inputs.extend(drug_properties.values())
|
227 |
+
all_inputs.extend(material_properties.values())
|
228 |
+
all_inputs.extend(synthetic_properties.values())
|
229 |
+
|
230 |
+
generate_btn.click(generate_and_visualize, inputs=all_inputs, outputs=outputs)
|
231 |
+
random_btn.click(
|
232 |
+
random_example,
|
233 |
+
inputs=gr.State(all_examples),
|
234 |
+
outputs=all_inputs
|
235 |
+
)
|
236 |
+
|
237 |
+
if __name__ == "__main__":
|
238 |
+
iface.launch(share=True)
|
main.py
ADDED
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Llamole Team
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's transformers library.
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import sys
|
19 |
+
import os
|
20 |
+
from src.train.tuner import run_train, merge_adapter
|
21 |
+
from src.eval.workflow import run_eval
|
22 |
+
from tqdm import tqdm
|
23 |
+
|
24 |
+
from huggingface_hub import hf_hub_download
|
25 |
+
|
26 |
+
def download_data():
|
27 |
+
repo_id = "liuganghuggingface/Llamole-MolQA"
|
28 |
+
files_to_download = [
|
29 |
+
"molqa_drug.json",
|
30 |
+
"molqa_material.json",
|
31 |
+
"molqa_train.json"
|
32 |
+
]
|
33 |
+
local_dir = "data"
|
34 |
+
|
35 |
+
# Create the data directory if it doesn't exist
|
36 |
+
os.makedirs(local_dir, exist_ok=True)
|
37 |
+
|
38 |
+
print(f"Downloading files from {repo_id} to {local_dir}/")
|
39 |
+
for file in tqdm(files_to_download, desc="Downloading files"):
|
40 |
+
try:
|
41 |
+
hf_hub_download(
|
42 |
+
repo_id=repo_id,
|
43 |
+
filename=file,
|
44 |
+
repo_type="dataset",
|
45 |
+
local_dir=local_dir,
|
46 |
+
local_dir_use_symlinks=False
|
47 |
+
)
|
48 |
+
print(f"Successfully downloaded: {file}")
|
49 |
+
except Exception as e:
|
50 |
+
print(f"Error downloading {file}: {e}")
|
51 |
+
|
52 |
+
print("Download complete!")
|
53 |
+
|
54 |
+
if __name__ == "__main__":
|
55 |
+
command = sys.argv.pop(1) if len(sys.argv) != 1 else 'train'
|
56 |
+
if command == 'train':
|
57 |
+
run_train()
|
58 |
+
elif command == 'export':
|
59 |
+
merge_adapter()
|
60 |
+
elif command == 'eval':
|
61 |
+
run_eval()
|
62 |
+
elif command == 'download_data':
|
63 |
+
download_data()
|
64 |
+
else:
|
65 |
+
print(f"Invalid command: {command}. Please use 'train', 'export', 'eval', or 'download_data'.")
|
66 |
+
sys.exit(1)
|
requirements.txt
ADDED
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
accelerate==0.33.0
|
2 |
+
badam==1.2.2
|
3 |
+
bitsandbytes==0.44.1
|
4 |
+
datasets==2.21.0
|
5 |
+
deepspeed==0.15.2
|
6 |
+
galore_torch==1.0
|
7 |
+
gradio==5.0.1
|
8 |
+
huggingface_hub==0.24.5
|
9 |
+
jieba==0.42.1
|
10 |
+
matplotlib==3.9.2
|
11 |
+
MoD==0.3.0
|
12 |
+
modelscope==1.18.1
|
13 |
+
nltk==3.9.1
|
14 |
+
numpy==2.1.2
|
15 |
+
optuna==3.6.1
|
16 |
+
packaging==24.1
|
17 |
+
pandas==1.5.3
|
18 |
+
peft==0.12.0
|
19 |
+
PyYAML==6.0.2
|
20 |
+
PyYAML==6.0.2
|
21 |
+
rdchiral==1.1.0
|
22 |
+
rdkit==2023.9.6
|
23 |
+
rouge_chinese==1.0.3
|
24 |
+
safetensors==0.4.5
|
25 |
+
torch==2.4.0
|
26 |
+
torch_geometric==2.6.1
|
27 |
+
transformers==4.44.0
|
28 |
+
trl==0.9.6
|
29 |
+
typing_extensions==4.12.2
|
30 |
+
vllm==0.6.2
|
src/__init__.py
ADDED
File without changes
|
src/data/__init__.py
ADDED
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
# from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding
|
16 |
+
from .data_utils import Role, split_dataset
|
17 |
+
from .loader import get_dataset
|
18 |
+
from .template import TEMPLATES, Template, get_template_and_fix_tokenizer
|
19 |
+
from .collator import DataCollatorForSeqGraph
|
20 |
+
|
21 |
+
__all__ = [
|
22 |
+
"Role",
|
23 |
+
"split_dataset",
|
24 |
+
"get_dataset",
|
25 |
+
"TEMPLATES",
|
26 |
+
"Template",
|
27 |
+
"get_template_and_fix_tokenizer",
|
28 |
+
'DataCollatorForSeqGraph'
|
29 |
+
]
|
src/data/aligner.py
ADDED
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
import re
|
17 |
+
from functools import partial
|
18 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Union, Tuple
|
19 |
+
|
20 |
+
from datasets import Features
|
21 |
+
|
22 |
+
from ..extras.logging import get_logger
|
23 |
+
from .data_utils import Role
|
24 |
+
|
25 |
+
if TYPE_CHECKING:
|
26 |
+
from datasets import Dataset, IterableDataset
|
27 |
+
from transformers import Seq2SeqTrainingArguments
|
28 |
+
|
29 |
+
from ..hparams import DataArguments
|
30 |
+
from .parser import DatasetAttr
|
31 |
+
|
32 |
+
|
33 |
+
logger = get_logger(__name__)
|
34 |
+
|
35 |
+
def extract_all_smiles(text):
|
36 |
+
pattern = r'<mol_start>(.*?)<mol_end>'
|
37 |
+
return re.findall(pattern, text)
|
38 |
+
|
39 |
+
def replace_all_smiles(text):
|
40 |
+
pattern = r'<mol_start>.*?<mol_end>'
|
41 |
+
return re.sub(pattern, '<molecule>', text)
|
42 |
+
|
43 |
+
def replace_smiles_with_callback(text):
|
44 |
+
def replace_mol(match):
|
45 |
+
design_end = match.group(1)
|
46 |
+
smiles = match.group(2)
|
47 |
+
# return f'{design_end}<molecule><callback_start>{smiles}<callback_end>'
|
48 |
+
return f'{design_end}<molecule><rollback_start>{smiles}<rollback_end>'
|
49 |
+
|
50 |
+
pattern = r'(<design_start><design_end>)<mol_start>(.*?)<mol_end>'
|
51 |
+
text = re.sub(pattern, replace_mol, text)
|
52 |
+
|
53 |
+
# Replace remaining molecules that are not immediately after <design_start><design_end>
|
54 |
+
remaining_pattern = r'<mol_start>.*?<mol_end>'
|
55 |
+
text = re.sub(remaining_pattern, '<molecule>', text)
|
56 |
+
|
57 |
+
return text
|
58 |
+
|
59 |
+
def dict_to_list(data_dict, mol_properties):
|
60 |
+
return [data_dict.get(prop, None) for prop in mol_properties]
|
61 |
+
|
62 |
+
def insert_bodies(text, num_insertions, retro_labels):
|
63 |
+
design_pattern = r'<design_start>(.*?)<design_end>'
|
64 |
+
retro_pattern = r'(This is step \d+ in the retrosynthesis process\..*?<retro_start>.*?<retro_end>)(.*?)(?=This is step \d+|$)'
|
65 |
+
|
66 |
+
def replace_design(match):
|
67 |
+
return f'<design_start>' + ''.join(['<design_body>'] * num_insertions) + f'<design_end>'
|
68 |
+
|
69 |
+
def replace_retro(match, label):
|
70 |
+
step_content = match.group(1)
|
71 |
+
remaining_text = match.group(2)
|
72 |
+
retro_match = re.search(r'<retro_start>(.*?)<retro_end>', step_content)
|
73 |
+
if retro_match and label is not None:
|
74 |
+
modified_content = f'<retro_start>' + ''.join(['<retro_body>'] * num_insertions) + f'<retro_end>'
|
75 |
+
return re.sub(r'<retro_start>.*?<retro_end>', modified_content, step_content)
|
76 |
+
return step_content + remaining_text
|
77 |
+
|
78 |
+
text = re.sub(design_pattern, replace_design, text)
|
79 |
+
|
80 |
+
steps = re.finditer(retro_pattern, text)
|
81 |
+
modified_text = ""
|
82 |
+
last_end = 0
|
83 |
+
|
84 |
+
for i, step in enumerate(steps):
|
85 |
+
label = retro_labels[i] if i < len(retro_labels) else None
|
86 |
+
modified_text += text[last_end:step.start()] + replace_retro(step, label)
|
87 |
+
last_end = step.end()
|
88 |
+
|
89 |
+
modified_text += text[last_end:]
|
90 |
+
return modified_text
|
91 |
+
|
92 |
+
def extract_retro_products(text):
|
93 |
+
pattern = r'<retro_end>(.*?)>>'
|
94 |
+
matches = re.findall(pattern, text)
|
95 |
+
return [match.strip() for match in matches]
|
96 |
+
|
97 |
+
def convert_molqa(
|
98 |
+
examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments"
|
99 |
+
) -> Dict[str, List[Any]]:
|
100 |
+
r"""
|
101 |
+
Converts alpaca format dataset to the standard format.
|
102 |
+
"""
|
103 |
+
outputs = {"prompt": [], "response": [], "system": [], "molecules": [], "property": [], "retro_labels": [], "retro_products": []}
|
104 |
+
|
105 |
+
mol_properties = ['BBBP', 'HIV', 'BACE', 'CO2', 'N2', 'O2', 'FFV', 'TC', 'SC', 'SA']
|
106 |
+
for i in range(len(examples[dataset_attr.prompt])):
|
107 |
+
prompt = []
|
108 |
+
if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list):
|
109 |
+
for old_prompt, old_response in examples[dataset_attr.history][i]:
|
110 |
+
prompt.append({"role": Role.USER.value, "content": old_prompt})
|
111 |
+
prompt.append({"role": Role.ASSISTANT.value, "content": old_response})
|
112 |
+
|
113 |
+
content = []
|
114 |
+
if dataset_attr.prompt and examples[dataset_attr.prompt][i]:
|
115 |
+
content.append(examples[dataset_attr.prompt][i])
|
116 |
+
|
117 |
+
if dataset_attr.query and examples[dataset_attr.query][i]:
|
118 |
+
content.append(examples[dataset_attr.query][i])
|
119 |
+
|
120 |
+
prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery"
|
121 |
+
|
122 |
+
if dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example
|
123 |
+
current_response = examples[dataset_attr.response][i]
|
124 |
+
smiles_list = extract_all_smiles(current_response)
|
125 |
+
modified_response = replace_smiles_with_callback(current_response)
|
126 |
+
retro_labels = examples[dataset_attr.retro][i] if dataset_attr.retro else []
|
127 |
+
retro_products = extract_retro_products(current_response)
|
128 |
+
modified_response = insert_bodies(modified_response, data_args.learned_query_size, retro_labels)
|
129 |
+
# modified_response = insert_bodies(modified_response, dataset_attr.learned_query_size, retro_labels)
|
130 |
+
response = [{"role": Role.ASSISTANT.value, "content": modified_response}]
|
131 |
+
else: # unsupervised
|
132 |
+
response = []
|
133 |
+
|
134 |
+
outputs["prompt"].append(prompt)
|
135 |
+
outputs["response"].append(response)
|
136 |
+
outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "")
|
137 |
+
outputs["molecules"].append(smiles_list)
|
138 |
+
outputs["property"].append(dict_to_list(examples[dataset_attr.property][i], mol_properties))
|
139 |
+
outputs["retro_labels"].append(retro_labels)
|
140 |
+
outputs["retro_products"].append(retro_products)
|
141 |
+
|
142 |
+
return outputs
|
143 |
+
|
144 |
+
def map_smiles_to_id(example, smiles_to_id):
|
145 |
+
example['molecules'] = [smiles_to_id[smiles] for smiles in example['molecules']]
|
146 |
+
return example
|
147 |
+
|
148 |
+
def align_dataset(
|
149 |
+
dataset: Union["Dataset", "IterableDataset"],
|
150 |
+
dataset_attr: "DatasetAttr",
|
151 |
+
data_args: "DataArguments",
|
152 |
+
training_args: "Seq2SeqTrainingArguments",
|
153 |
+
) -> Tuple[Union["Dataset", "IterableDataset"], Dict[int, str]]:
|
154 |
+
r"""
|
155 |
+
Aligns the dataset and maps unique SMILES strings to molecule IDs.
|
156 |
+
|
157 |
+
This function performs the following operations:
|
158 |
+
1. Converts the dataset to the required format (molqa).
|
159 |
+
2. Extracts all unique SMILES strings from the dataset.
|
160 |
+
3. Maps each unique SMILES string to a unique integer ID (0, 1, 2, ...).
|
161 |
+
4. Update 'molecules' field to each example, containing the mapped IDs.
|
162 |
+
|
163 |
+
The aligned dataset contains the following fields:
|
164 |
+
prompt: [{"role": "user", "content": "..."}] * (2T - 1)
|
165 |
+
response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset)
|
166 |
+
system: "..."
|
167 |
+
molecules: [List of SMILES string]
|
168 |
+
property: [List of float values]
|
169 |
+
retro_labels: [List of int values]
|
170 |
+
retro_products: [List of SMILES string]
|
171 |
+
|
172 |
+
Args:
|
173 |
+
dataset (Union["Dataset", "IterableDataset"]): The input dataset.
|
174 |
+
dataset_attr (DatasetAttr): Attributes of the dataset.
|
175 |
+
data_args (DataArguments): Arguments for data processing.
|
176 |
+
training_args (Seq2SeqTrainingArguments): Arguments for training.
|
177 |
+
|
178 |
+
Returns:
|
179 |
+
Tuple[Union["Dataset", "IterableDataset"], Dict[int, str]]:
|
180 |
+
- The aligned and converted dataset with molecule IDs.
|
181 |
+
- A dictionary mapping molecule IDs to their SMILES strings.
|
182 |
+
"""
|
183 |
+
assert dataset_attr.formatting == "molqa"
|
184 |
+
|
185 |
+
features = Features.from_dict(
|
186 |
+
{
|
187 |
+
"prompt": [
|
188 |
+
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
189 |
+
],
|
190 |
+
"response": [
|
191 |
+
{"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}}
|
192 |
+
],
|
193 |
+
"system": {"dtype": "string", "_type": "Value"},
|
194 |
+
"molecules": [{'dtype': "string", "_type": "Value"}],
|
195 |
+
"property": [{"dtype": "float", "_type": "Value"}],
|
196 |
+
"retro_labels": [{"dtype": "int32", "_type": "Value"}],
|
197 |
+
"retro_products": [{'dtype': "string", "_type": "Value"}],
|
198 |
+
}
|
199 |
+
)
|
200 |
+
|
201 |
+
convert_func = partial(convert_molqa, dataset_attr=dataset_attr, data_args=data_args)
|
202 |
+
aligned = dataset.map(
|
203 |
+
convert_func,
|
204 |
+
batched=True,
|
205 |
+
remove_columns=['instruction', 'input', 'output', 'property', 'retro'],
|
206 |
+
features=features,
|
207 |
+
num_proc=data_args.preprocessing_num_workers,
|
208 |
+
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
209 |
+
desc="Converting molqa format of dataset"
|
210 |
+
)
|
211 |
+
|
212 |
+
# Extract all unique SMILES strings and map them to molecule IDs
|
213 |
+
all_smiles = set()
|
214 |
+
for item in aligned:
|
215 |
+
all_smiles.update(item['molecules'])
|
216 |
+
all_smiles.update(item['retro_products'])
|
217 |
+
|
218 |
+
smiles_to_id = {smiles: idx for idx, smiles in enumerate(sorted(all_smiles))}
|
219 |
+
id_to_smiles = {idx: smiles for smiles, idx in smiles_to_id.items()}
|
220 |
+
|
221 |
+
def map_smiles_to_id(example, smiles_to_id):
|
222 |
+
example['molecules'] = [smiles_to_id[smiles] for smiles in example['molecules']]
|
223 |
+
example['retro_products'] = [smiles_to_id[smiles] for smiles in example['retro_products']]
|
224 |
+
return example
|
225 |
+
|
226 |
+
smiles_convert_func = partial(map_smiles_to_id, smiles_to_id=smiles_to_id)
|
227 |
+
|
228 |
+
aligned = aligned.map(
|
229 |
+
smiles_convert_func,
|
230 |
+
desc="Mapping SMILES to molecule IDs",
|
231 |
+
)
|
232 |
+
|
233 |
+
return aligned, id_to_smiles
|
src/data/collator.py
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import numpy as np
|
3 |
+
from dataclasses import dataclass
|
4 |
+
from typing import Any, Dict, List, Optional, Union
|
5 |
+
|
6 |
+
from torch_geometric.data import Batch as PyGBatch
|
7 |
+
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
8 |
+
from transformers.utils import PaddingStrategy
|
9 |
+
|
10 |
+
def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs):
|
11 |
+
"""
|
12 |
+
Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer.
|
13 |
+
"""
|
14 |
+
# To avoid errors when using Feature extractors
|
15 |
+
if not hasattr(tokenizer, "deprecation_warnings"):
|
16 |
+
return tokenizer.pad(*pad_args, **pad_kwargs)
|
17 |
+
|
18 |
+
# Save the state of the warning, then disable it
|
19 |
+
warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False)
|
20 |
+
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True
|
21 |
+
|
22 |
+
try:
|
23 |
+
padded = tokenizer.pad(*pad_args, **pad_kwargs)
|
24 |
+
finally:
|
25 |
+
# Restore the state of the warning.
|
26 |
+
tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state
|
27 |
+
|
28 |
+
return padded
|
29 |
+
|
30 |
+
@dataclass
|
31 |
+
class DataCollatorForSeqGraph:
|
32 |
+
"""
|
33 |
+
Data collator that will dynamically pad the inputs received, as well as the labels.
|
34 |
+
"""
|
35 |
+
tokenizer: PreTrainedTokenizerBase
|
36 |
+
mol_id_to_pyg: Dict[str, Any]
|
37 |
+
model: Optional[Any] = None
|
38 |
+
padding: Union[bool, str, PaddingStrategy] = True
|
39 |
+
max_length: Optional[int] = None
|
40 |
+
pad_to_multiple_of: Optional[int] = None
|
41 |
+
label_pad_token_id: int = -100
|
42 |
+
return_tensors: str = "pt"
|
43 |
+
|
44 |
+
def __call__(self, features, return_tensors=None):
|
45 |
+
if return_tensors is None:
|
46 |
+
return_tensors = self.return_tensors
|
47 |
+
|
48 |
+
label_name = "label" if "label" in features[0].keys() else "labels"
|
49 |
+
labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None
|
50 |
+
if labels is not None and all(label is None for label in labels):
|
51 |
+
labels = None
|
52 |
+
|
53 |
+
# Store molecule_ids, retro_labels, and retro_product_ids separately and remove from non_labels_features
|
54 |
+
molecule_ids_list = []
|
55 |
+
retro_labels_list = []
|
56 |
+
retro_products_list = []
|
57 |
+
non_labels_features = []
|
58 |
+
for feature in features:
|
59 |
+
new_feature = {k: v for k, v in feature.items() if k != label_name}
|
60 |
+
if 'molecule_ids' in new_feature:
|
61 |
+
molecule_ids_list.append(new_feature['molecule_ids'])
|
62 |
+
del new_feature['molecule_ids']
|
63 |
+
else:
|
64 |
+
molecule_ids_list.append(None)
|
65 |
+
if 'retro_labels' in new_feature:
|
66 |
+
retro_labels_list.append(new_feature['retro_labels'])
|
67 |
+
del new_feature['retro_labels']
|
68 |
+
else:
|
69 |
+
retro_labels_list.append(None)
|
70 |
+
if 'retro_product_ids' in new_feature:
|
71 |
+
retro_products_list.append(new_feature['retro_product_ids'])
|
72 |
+
del new_feature['retro_product_ids']
|
73 |
+
else:
|
74 |
+
retro_products_list.append(None)
|
75 |
+
non_labels_features.append(new_feature)
|
76 |
+
|
77 |
+
# Convert molecule IDs to PyG Data objects
|
78 |
+
molecule_graphs_list = []
|
79 |
+
design_graphs_list = []
|
80 |
+
for seq_idx, molecule_ids in enumerate(molecule_ids_list):
|
81 |
+
if molecule_ids is not None and len(molecule_ids) > 0:
|
82 |
+
for pos, mol_id in enumerate(molecule_ids):
|
83 |
+
if pos == 0:
|
84 |
+
design_graphs_list.append(self.mol_id_to_pyg[mol_id])
|
85 |
+
if mol_id != self.label_pad_token_id and mol_id in self.mol_id_to_pyg:
|
86 |
+
molecule_graphs_list.append(self.mol_id_to_pyg[mol_id])
|
87 |
+
|
88 |
+
# Convert retro_product_ids to PyG Data objects
|
89 |
+
retro_product_graphs_list = []
|
90 |
+
for seq_idx, retro_product_ids in enumerate(retro_products_list):
|
91 |
+
if retro_product_ids is not None and len(retro_product_ids) > 0:
|
92 |
+
for pos, mol_id in enumerate(retro_product_ids):
|
93 |
+
if mol_id != self.label_pad_token_id and mol_id in self.mol_id_to_pyg:
|
94 |
+
retro_product_graphs_list.append(self.mol_id_to_pyg[mol_id])
|
95 |
+
|
96 |
+
# Batch the PyG Data objects
|
97 |
+
if molecule_graphs_list:
|
98 |
+
batched_graphs = PyGBatch.from_data_list(molecule_graphs_list)
|
99 |
+
else:
|
100 |
+
batched_graphs = None
|
101 |
+
|
102 |
+
if design_graphs_list:
|
103 |
+
batched_design_graphs = PyGBatch.from_data_list(design_graphs_list)
|
104 |
+
else:
|
105 |
+
batched_design_graphs = None
|
106 |
+
|
107 |
+
if retro_product_graphs_list:
|
108 |
+
batched_retro_products = PyGBatch.from_data_list(retro_product_graphs_list)
|
109 |
+
else:
|
110 |
+
batched_retro_products = None
|
111 |
+
|
112 |
+
# Pad retro_labels
|
113 |
+
if retro_labels_list and any(retro_labels is not None for retro_labels in retro_labels_list):
|
114 |
+
max_retro_length = max(len(retro_labels) for retro_labels in retro_labels_list if retro_labels is not None)
|
115 |
+
padded_retro_labels = [
|
116 |
+
retro_labels + [self.label_pad_token_id] * (max_retro_length - len(retro_labels)) if retro_labels is not None else [self.label_pad_token_id] * max_retro_length
|
117 |
+
for retro_labels in retro_labels_list
|
118 |
+
]
|
119 |
+
else:
|
120 |
+
padded_retro_labels = None
|
121 |
+
|
122 |
+
# Pad other features
|
123 |
+
batch = pad_without_fast_tokenizer_warning(
|
124 |
+
self.tokenizer,
|
125 |
+
non_labels_features,
|
126 |
+
padding=self.padding,
|
127 |
+
max_length=self.max_length,
|
128 |
+
pad_to_multiple_of=self.pad_to_multiple_of,
|
129 |
+
return_tensors=return_tensors,
|
130 |
+
)
|
131 |
+
|
132 |
+
batch["molecule_graphs"] = batched_graphs
|
133 |
+
batch["design_graphs"] = batched_design_graphs
|
134 |
+
batch["retro_product_graphs"] = batched_retro_products
|
135 |
+
batch["retro_labels"] = torch.tensor(padded_retro_labels, dtype=torch.int64)
|
136 |
+
|
137 |
+
# Pad labels
|
138 |
+
if labels is not None:
|
139 |
+
max_label_length = max(len(l) for l in labels)
|
140 |
+
if self.pad_to_multiple_of is not None:
|
141 |
+
max_label_length = (
|
142 |
+
(max_label_length + self.pad_to_multiple_of - 1)
|
143 |
+
// self.pad_to_multiple_of
|
144 |
+
* self.pad_to_multiple_of
|
145 |
+
)
|
146 |
+
|
147 |
+
padding_side = self.tokenizer.padding_side
|
148 |
+
padded_labels = [
|
149 |
+
label + [self.label_pad_token_id] * (max_label_length - len(label))
|
150 |
+
if padding_side == "right"
|
151 |
+
else [self.label_pad_token_id] * (max_label_length - len(label)) + label
|
152 |
+
for label in labels
|
153 |
+
]
|
154 |
+
batch["labels"] = torch.tensor(padded_labels, dtype=torch.int64)
|
155 |
+
|
156 |
+
# Prepare decoder_input_ids
|
157 |
+
if (
|
158 |
+
labels is not None
|
159 |
+
and self.model is not None
|
160 |
+
and hasattr(self.model, "prepare_decoder_input_ids_from_labels")
|
161 |
+
):
|
162 |
+
decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"])
|
163 |
+
batch["decoder_input_ids"] = decoder_input_ids
|
164 |
+
|
165 |
+
return batch
|
src/data/data_utils.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from enum import Enum, unique
|
16 |
+
from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union
|
17 |
+
|
18 |
+
from datasets import concatenate_datasets, interleave_datasets
|
19 |
+
|
20 |
+
from ..extras.logging import get_logger
|
21 |
+
|
22 |
+
|
23 |
+
if TYPE_CHECKING:
|
24 |
+
from datasets import Dataset, IterableDataset
|
25 |
+
from transformers import Seq2SeqTrainingArguments
|
26 |
+
|
27 |
+
from ..hparams import DataArguments
|
28 |
+
|
29 |
+
|
30 |
+
logger = get_logger(__name__)
|
31 |
+
|
32 |
+
|
33 |
+
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
34 |
+
|
35 |
+
|
36 |
+
@unique
|
37 |
+
class Role(str, Enum):
|
38 |
+
USER = "user"
|
39 |
+
ASSISTANT = "assistant"
|
40 |
+
SYSTEM = "system"
|
41 |
+
FUNCTION = "function"
|
42 |
+
OBSERVATION = "observation"
|
43 |
+
|
44 |
+
|
45 |
+
def merge_dataset(
|
46 |
+
all_datasets: List[Union["Dataset", "IterableDataset"]],
|
47 |
+
data_args: "DataArguments",
|
48 |
+
training_args: "Seq2SeqTrainingArguments",
|
49 |
+
) -> Union["Dataset", "IterableDataset"]:
|
50 |
+
if len(all_datasets) == 1:
|
51 |
+
return all_datasets[0]
|
52 |
+
elif data_args.mix_strategy == "concat":
|
53 |
+
if data_args.streaming:
|
54 |
+
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
55 |
+
return concatenate_datasets(all_datasets)
|
56 |
+
elif data_args.mix_strategy.startswith("interleave"):
|
57 |
+
if not data_args.streaming:
|
58 |
+
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
59 |
+
return interleave_datasets(
|
60 |
+
datasets=all_datasets,
|
61 |
+
probabilities=data_args.interleave_probs,
|
62 |
+
seed=training_args.seed,
|
63 |
+
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
64 |
+
)
|
65 |
+
else:
|
66 |
+
raise ValueError("Unknown mixing strategy.")
|
67 |
+
|
68 |
+
|
69 |
+
def split_dataset(
|
70 |
+
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments"
|
71 |
+
) -> Dict[str, "Dataset"]:
|
72 |
+
if training_args.do_train:
|
73 |
+
if data_args.val_size > 1e-6: # Split the dataset
|
74 |
+
val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size
|
75 |
+
dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed)
|
76 |
+
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
77 |
+
else:
|
78 |
+
if data_args.streaming:
|
79 |
+
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
80 |
+
return {"train_dataset": dataset}
|
81 |
+
else: # do_eval or do_predict
|
82 |
+
return {"eval_dataset": dataset}
|
src/data/formatter.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import re
|
17 |
+
from abc import ABC, abstractmethod
|
18 |
+
from dataclasses import dataclass, field
|
19 |
+
from typing import List, Literal, Optional, Tuple, Union
|
20 |
+
|
21 |
+
from .data_utils import SLOTS
|
22 |
+
from .tool_utils import DefaultToolUtils, GLM4ToolUtils
|
23 |
+
|
24 |
+
|
25 |
+
@dataclass
|
26 |
+
class Formatter(ABC):
|
27 |
+
slots: SLOTS = field(default_factory=list)
|
28 |
+
tool_format: Optional[Literal["default", "glm4"]] = None
|
29 |
+
|
30 |
+
@abstractmethod
|
31 |
+
def apply(self, **kwargs) -> SLOTS: ...
|
32 |
+
|
33 |
+
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
34 |
+
raise NotImplementedError
|
35 |
+
|
36 |
+
|
37 |
+
@dataclass
|
38 |
+
class EmptyFormatter(Formatter):
|
39 |
+
def __post_init__(self):
|
40 |
+
has_placeholder = False
|
41 |
+
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
42 |
+
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
43 |
+
has_placeholder = True
|
44 |
+
|
45 |
+
if has_placeholder:
|
46 |
+
raise ValueError("Empty formatter should not contain any placeholder.")
|
47 |
+
|
48 |
+
def apply(self, **kwargs) -> SLOTS:
|
49 |
+
return self.slots
|
50 |
+
|
51 |
+
|
52 |
+
@dataclass
|
53 |
+
class StringFormatter(Formatter):
|
54 |
+
def __post_init__(self):
|
55 |
+
has_placeholder = False
|
56 |
+
for slot in filter(lambda s: isinstance(s, str), self.slots):
|
57 |
+
if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot):
|
58 |
+
has_placeholder = True
|
59 |
+
|
60 |
+
if not has_placeholder:
|
61 |
+
raise ValueError("A placeholder is required in the string formatter.")
|
62 |
+
|
63 |
+
def apply(self, **kwargs) -> SLOTS:
|
64 |
+
elements = []
|
65 |
+
for slot in self.slots:
|
66 |
+
if isinstance(slot, str):
|
67 |
+
for name, value in kwargs.items():
|
68 |
+
if not isinstance(value, str):
|
69 |
+
raise RuntimeError("Expected a string, got {}".format(value))
|
70 |
+
|
71 |
+
slot = slot.replace("{{" + name + "}}", value, 1)
|
72 |
+
elements.append(slot)
|
73 |
+
elif isinstance(slot, (dict, set)):
|
74 |
+
elements.append(slot)
|
75 |
+
else:
|
76 |
+
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
77 |
+
|
78 |
+
return elements
|
79 |
+
|
80 |
+
@dataclass
|
81 |
+
class FunctionFormatter(Formatter):
|
82 |
+
def __post_init__(self):
|
83 |
+
if self.tool_format == "default":
|
84 |
+
self.slots = DefaultToolUtils.get_function_slots() + self.slots
|
85 |
+
elif self.tool_format == "glm4":
|
86 |
+
self.slots = GLM4ToolUtils.get_function_slots() + self.slots
|
87 |
+
else:
|
88 |
+
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
89 |
+
|
90 |
+
def apply(self, **kwargs) -> SLOTS:
|
91 |
+
content = kwargs.pop("content")
|
92 |
+
functions: List[Tuple[str, str]] = []
|
93 |
+
try:
|
94 |
+
tool_calls = json.loads(content)
|
95 |
+
if not isinstance(tool_calls, list): # parallel function call
|
96 |
+
tool_calls = [tool_calls]
|
97 |
+
|
98 |
+
for tool_call in tool_calls:
|
99 |
+
functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False)))
|
100 |
+
|
101 |
+
except json.JSONDecodeError:
|
102 |
+
functions = []
|
103 |
+
|
104 |
+
elements = []
|
105 |
+
for name, arguments in functions:
|
106 |
+
for slot in self.slots:
|
107 |
+
if isinstance(slot, str):
|
108 |
+
slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments)
|
109 |
+
elements.append(slot)
|
110 |
+
elif isinstance(slot, (dict, set)):
|
111 |
+
elements.append(slot)
|
112 |
+
else:
|
113 |
+
raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot)))
|
114 |
+
|
115 |
+
return elements
|
116 |
+
|
117 |
+
|
118 |
+
@dataclass
|
119 |
+
class ToolFormatter(Formatter):
|
120 |
+
def __post_init__(self):
|
121 |
+
if self.tool_format == "default":
|
122 |
+
self._tool_formatter = DefaultToolUtils.tool_formatter
|
123 |
+
self._tool_extractor = DefaultToolUtils.tool_extractor
|
124 |
+
elif self.tool_format == "glm4":
|
125 |
+
self._tool_formatter = GLM4ToolUtils.tool_formatter
|
126 |
+
self._tool_extractor = GLM4ToolUtils.tool_extractor
|
127 |
+
else:
|
128 |
+
raise NotImplementedError("Tool format {} was not found.".format(self.tool_format))
|
129 |
+
|
130 |
+
def apply(self, **kwargs) -> SLOTS:
|
131 |
+
content = kwargs.pop("content")
|
132 |
+
try:
|
133 |
+
tools = json.loads(content)
|
134 |
+
return [self._tool_formatter(tools) if len(tools) != 0 else ""]
|
135 |
+
except json.JSONDecodeError:
|
136 |
+
return [""]
|
137 |
+
|
138 |
+
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
139 |
+
return self._tool_extractor(content)
|
src/data/loader.py
ADDED
@@ -0,0 +1,149 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import inspect
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
from typing import TYPE_CHECKING, Literal, Optional, Union
|
19 |
+
from functools import partial
|
20 |
+
|
21 |
+
import numpy as np
|
22 |
+
from datasets import load_dataset, load_from_disk
|
23 |
+
|
24 |
+
# from ..extras.constants import FILEEXT2TYPE
|
25 |
+
from ..extras.logging import get_logger
|
26 |
+
from ..extras.misc import has_tokenized_data
|
27 |
+
from .aligner import align_dataset
|
28 |
+
from .data_utils import merge_dataset
|
29 |
+
from .parser import get_dataset_attr
|
30 |
+
# from .preprocess import get_preprocess_and_print_func
|
31 |
+
from .template import get_template_and_fix_tokenizer
|
32 |
+
|
33 |
+
from .processors.mmsupervised import (
|
34 |
+
preprocess_mmsupervised_dataset,
|
35 |
+
print_supervised_dataset_example,
|
36 |
+
encode_graph_pyg
|
37 |
+
)
|
38 |
+
|
39 |
+
if TYPE_CHECKING:
|
40 |
+
from datasets import Dataset, IterableDataset
|
41 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments
|
42 |
+
|
43 |
+
from ..hparams import DataArguments, ModelArguments
|
44 |
+
from .parser import DatasetAttr
|
45 |
+
|
46 |
+
|
47 |
+
logger = get_logger(__name__)
|
48 |
+
|
49 |
+
|
50 |
+
def load_single_dataset(
|
51 |
+
dataset_attr: "DatasetAttr",
|
52 |
+
model_args: "ModelArguments",
|
53 |
+
data_args: "DataArguments",
|
54 |
+
training_args: "Seq2SeqTrainingArguments",
|
55 |
+
) -> Union["Dataset", "IterableDataset"]:
|
56 |
+
logger.info("Loading dataset {}...".format(dataset_attr))
|
57 |
+
|
58 |
+
data_files = []
|
59 |
+
assert dataset_attr.load_from == "file"
|
60 |
+
|
61 |
+
data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name)
|
62 |
+
data_files.append(data_path)
|
63 |
+
data_path = data_path.split(".")[-1]
|
64 |
+
|
65 |
+
if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0
|
66 |
+
kwargs = {"trust_remote_code": True}
|
67 |
+
else:
|
68 |
+
kwargs = {}
|
69 |
+
|
70 |
+
dataset = load_dataset(
|
71 |
+
path=data_path,
|
72 |
+
name=None,
|
73 |
+
data_dir=None,
|
74 |
+
data_files=data_files,
|
75 |
+
split=data_args.split,
|
76 |
+
cache_dir=model_args.cache_dir,
|
77 |
+
token=model_args.hf_hub_token,
|
78 |
+
streaming=False,
|
79 |
+
**kwargs,
|
80 |
+
)
|
81 |
+
|
82 |
+
converted_dataset, mol_id_to_smiles = align_dataset(dataset, dataset_attr, data_args, training_args)
|
83 |
+
return converted_dataset, mol_id_to_smiles
|
84 |
+
|
85 |
+
def get_dataset(
|
86 |
+
model_args: "ModelArguments",
|
87 |
+
data_args: "DataArguments",
|
88 |
+
training_args: "Seq2SeqTrainingArguments",
|
89 |
+
tokenizer: "PreTrainedTokenizer",
|
90 |
+
) -> Union["Dataset", "IterableDataset"]:
|
91 |
+
|
92 |
+
template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format)
|
93 |
+
if data_args.train_on_prompt and template.efficient_eos:
|
94 |
+
raise ValueError("Current template does not support `train_on_prompt`.")
|
95 |
+
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer)
|
96 |
+
|
97 |
+
# Load tokenized dataset
|
98 |
+
if data_args.tokenized_path is not None:
|
99 |
+
if has_tokenized_data(data_args.tokenized_path):
|
100 |
+
mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path)
|
101 |
+
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
102 |
+
dataset = load_from_disk(data_args.tokenized_path)
|
103 |
+
logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path))
|
104 |
+
# print_function(next(iter(dataset)))
|
105 |
+
data_iter = iter(dataset)
|
106 |
+
print_function(next(data_iter))
|
107 |
+
return mol_id_to_pyg, dataset
|
108 |
+
|
109 |
+
# Load tokenized dataset
|
110 |
+
with training_args.main_process_first(desc="load dataset"):
|
111 |
+
# current only support one dataset
|
112 |
+
dataset_attr = get_dataset_attr(data_args)
|
113 |
+
dataset, mol_id_to_smiles = load_single_dataset(dataset_attr, model_args, data_args, training_args)
|
114 |
+
|
115 |
+
with training_args.main_process_first(desc="pre-process dataset"):
|
116 |
+
preprocess_func = partial(
|
117 |
+
preprocess_mmsupervised_dataset,
|
118 |
+
template=template,
|
119 |
+
tokenizer=tokenizer,
|
120 |
+
data_args=data_args,
|
121 |
+
)
|
122 |
+
|
123 |
+
column_names = list(next(iter(dataset)).keys())
|
124 |
+
kwargs = {}
|
125 |
+
kwargs = dict(
|
126 |
+
num_proc=data_args.preprocessing_num_workers,
|
127 |
+
load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0),
|
128 |
+
desc="Running tokenizer on dataset",
|
129 |
+
)
|
130 |
+
|
131 |
+
dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs)
|
132 |
+
|
133 |
+
if data_args.tokenized_path is not None:
|
134 |
+
if training_args.should_save:
|
135 |
+
dataset.save_to_disk(data_args.tokenized_path)
|
136 |
+
mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path, mol_id_to_smiles=mol_id_to_smiles)
|
137 |
+
logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path))
|
138 |
+
logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path))
|
139 |
+
sys.exit(0)
|
140 |
+
else:
|
141 |
+
mol_id_to_pyg = encode_graph_pyg(mol_id_to_smiles=mol_id_to_smiles)
|
142 |
+
|
143 |
+
if training_args.should_log:
|
144 |
+
try:
|
145 |
+
print_function(next(iter(dataset)))
|
146 |
+
except StopIteration:
|
147 |
+
raise RuntimeError("Cannot find valid samples.")
|
148 |
+
|
149 |
+
return mol_id_to_pyg, dataset
|
src/data/parser.py
ADDED
@@ -0,0 +1,113 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import os
|
17 |
+
from dataclasses import dataclass
|
18 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional
|
19 |
+
|
20 |
+
from ..extras.constants import DATA_CONFIG
|
21 |
+
from ..extras.misc import use_modelscope
|
22 |
+
|
23 |
+
|
24 |
+
if TYPE_CHECKING:
|
25 |
+
from ..hparams import DataArguments
|
26 |
+
|
27 |
+
|
28 |
+
@dataclass
|
29 |
+
class DatasetAttr:
|
30 |
+
r"""
|
31 |
+
Dataset attributes.
|
32 |
+
"""
|
33 |
+
|
34 |
+
# basic configs
|
35 |
+
load_from: Literal["hf_hub", "ms_hub", "script", "file"]
|
36 |
+
dataset_name: str
|
37 |
+
formatting: Literal["alpaca", "sharegpt", "molqa"] = "molqa"
|
38 |
+
ranking: bool = False
|
39 |
+
# extra configs
|
40 |
+
subset: Optional[str] = None
|
41 |
+
folder: Optional[str] = None
|
42 |
+
num_samples: Optional[int] = None
|
43 |
+
# common columns
|
44 |
+
system: Optional[str] = None
|
45 |
+
tools: Optional[str] = None
|
46 |
+
images: Optional[str] = None
|
47 |
+
# rlhf columns
|
48 |
+
chosen: Optional[str] = None
|
49 |
+
rejected: Optional[str] = None
|
50 |
+
kto_tag: Optional[str] = None
|
51 |
+
# alpaca columns
|
52 |
+
prompt: Optional[str] = "instruction"
|
53 |
+
query: Optional[str] = "input"
|
54 |
+
response: Optional[str] = "output"
|
55 |
+
history: Optional[str] = None
|
56 |
+
# sharegpt columns
|
57 |
+
messages: Optional[str] = "conversations"
|
58 |
+
# sharegpt tags
|
59 |
+
role_tag: Optional[str] = "from"
|
60 |
+
content_tag: Optional[str] = "value"
|
61 |
+
user_tag: Optional[str] = "human"
|
62 |
+
assistant_tag: Optional[str] = "gpt"
|
63 |
+
observation_tag: Optional[str] = "observation"
|
64 |
+
function_tag: Optional[str] = "function_call"
|
65 |
+
system_tag: Optional[str] = "system"
|
66 |
+
# molqa columns
|
67 |
+
property: Optional[str] = 'property'
|
68 |
+
retro: Optional[str] = 'retro'
|
69 |
+
# learned_query_size: Optional[int] = None
|
70 |
+
|
71 |
+
def __repr__(self) -> str:
|
72 |
+
return self.dataset_name
|
73 |
+
|
74 |
+
def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None:
|
75 |
+
setattr(self, key, obj.get(key, default))
|
76 |
+
|
77 |
+
def get_dataset_attr(data_args: "DataArguments") -> List["DatasetAttr"]:
|
78 |
+
if data_args.dataset is not None:
|
79 |
+
dataset_name = data_args.dataset.strip()
|
80 |
+
else:
|
81 |
+
raise ValueError("Please specify the dataset name.")
|
82 |
+
|
83 |
+
try:
|
84 |
+
with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f:
|
85 |
+
dataset_info = json.load(f)
|
86 |
+
except Exception as err:
|
87 |
+
raise ValueError(
|
88 |
+
"Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err))
|
89 |
+
)
|
90 |
+
dataset_info = None
|
91 |
+
|
92 |
+
if dataset_name not in dataset_info:
|
93 |
+
raise ValueError("Undefined dataset {} in {}.".format(dataset_name, DATA_CONFIG))
|
94 |
+
|
95 |
+
dataset_attr = DatasetAttr("file", dataset_name=dataset_info[dataset_name]["file_name"])
|
96 |
+
|
97 |
+
print('dataset_info', dataset_info)
|
98 |
+
|
99 |
+
dataset_attr.set_attr("formatting", dataset_info[dataset_name], default="molqa")
|
100 |
+
dataset_attr.set_attr("ranking", dataset_info[dataset_name], default=False)
|
101 |
+
dataset_attr.set_attr("subset", dataset_info[dataset_name])
|
102 |
+
dataset_attr.set_attr("folder", dataset_info[dataset_name])
|
103 |
+
dataset_attr.set_attr("num_samples", dataset_info[dataset_name])
|
104 |
+
|
105 |
+
if "columns" in dataset_info[dataset_name]:
|
106 |
+
column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"]
|
107 |
+
assert dataset_attr.formatting == "molqa"
|
108 |
+
column_names.extend(["prompt", "query", "response", "history", "property", "retro"])
|
109 |
+
|
110 |
+
for column_name in column_names:
|
111 |
+
dataset_attr.set_attr(column_name, dataset_info[dataset_name]["columns"])
|
112 |
+
|
113 |
+
return dataset_attr
|
src/data/processors/__init__.py
ADDED
File without changes
|
src/data/processors/mmsupervised.py
ADDED
@@ -0,0 +1,335 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Llamole Team
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from collections import defaultdict
|
16 |
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
17 |
+
|
18 |
+
from ...extras.constants import IGNORE_INDEX, BOND_INDEX, NO_LABEL_INDEX
|
19 |
+
from ...extras.logging import get_logger
|
20 |
+
|
21 |
+
if TYPE_CHECKING:
|
22 |
+
from transformers import PreTrainedTokenizer, ProcessorMixin
|
23 |
+
|
24 |
+
from ...hparams import DataArguments
|
25 |
+
from ..template import Template
|
26 |
+
|
27 |
+
import os
|
28 |
+
from rdkit import Chem
|
29 |
+
import torch
|
30 |
+
from torch_geometric.data import Data, Batch
|
31 |
+
import pickle
|
32 |
+
|
33 |
+
logger = get_logger(__name__)
|
34 |
+
|
35 |
+
import os
|
36 |
+
import torch
|
37 |
+
from typing import Dict
|
38 |
+
from torch_geometric.data import Data
|
39 |
+
from rdkit import Chem
|
40 |
+
import pickle
|
41 |
+
|
42 |
+
|
43 |
+
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
|
44 |
+
if target_len * 2 < cutoff_len: # truncate source
|
45 |
+
max_target_len = cutoff_len
|
46 |
+
elif source_len * 2 < cutoff_len: # truncate target
|
47 |
+
max_target_len = cutoff_len - source_len
|
48 |
+
else: # truncate both
|
49 |
+
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
|
50 |
+
|
51 |
+
new_target_len = min(max_target_len, target_len)
|
52 |
+
new_source_len = max(cutoff_len - new_target_len, 0)
|
53 |
+
return new_source_len, new_target_len
|
54 |
+
|
55 |
+
def encode_graph_pyg(
|
56 |
+
data_path: Optional[str] = None, mol_id_to_smiles: Optional[Dict[str, str]] = None
|
57 |
+
) -> Dict[str, Data]:
|
58 |
+
"""
|
59 |
+
Converts molecule data to a dictionary of PyTorch Geometric Data objects, with caching functionality.
|
60 |
+
Uses a sparse representation for efficiency.
|
61 |
+
|
62 |
+
Args:
|
63 |
+
data_path (Optional[str]): Path to the Hugging Face dataset folder.
|
64 |
+
mol_id_to_smiles (Optional[Dict[str, str]]): Dictionary where keys are molecule IDs
|
65 |
+
and values are SMILES strings.
|
66 |
+
|
67 |
+
Returns:
|
68 |
+
Dict[str, Data]: Dictionary where keys are molecule IDs and values are
|
69 |
+
PyTorch Geometric Data objects.
|
70 |
+
|
71 |
+
Raises:
|
72 |
+
ValueError: If both data_path and mol_id_to_smiles are None, or if data_path is provided but loading fails.
|
73 |
+
"""
|
74 |
+
print(f"Current execution directory: {os.getcwd()}")
|
75 |
+
|
76 |
+
if data_path is None and mol_id_to_smiles is None:
|
77 |
+
raise ValueError("Either data_path or mol_id_to_smiles must be provided.")
|
78 |
+
|
79 |
+
if data_path is not None:
|
80 |
+
cache_file = os.path.join(data_path, "pyg_molecule.pickle")
|
81 |
+
|
82 |
+
# Try to load cached data
|
83 |
+
if os.path.exists(cache_file):
|
84 |
+
try:
|
85 |
+
with open(cache_file, "rb") as f:
|
86 |
+
return pickle.load(f)
|
87 |
+
except Exception as e:
|
88 |
+
print(f"Failed to load cached data: {e}")
|
89 |
+
|
90 |
+
mol_id_to_pyg = {}
|
91 |
+
|
92 |
+
for mol_id, smiles in mol_id_to_smiles.items():
|
93 |
+
mol = Chem.MolFromSmiles(smiles)
|
94 |
+
if mol is None:
|
95 |
+
raise ValueError(f"Invalid SMILES string for molecule {mol_id}: {smiles}")
|
96 |
+
|
97 |
+
type_idx = []
|
98 |
+
heavy_atom_indices = []
|
99 |
+
for atom in mol.GetAtoms():
|
100 |
+
if atom.GetAtomicNum() != 1: # Exclude hydrogen atoms
|
101 |
+
type_idx.append(
|
102 |
+
119 - 2 if atom.GetSymbol() == "*" else atom.GetAtomicNum() - 2
|
103 |
+
)
|
104 |
+
heavy_atom_indices.append(atom.GetIdx())
|
105 |
+
|
106 |
+
x = torch.LongTensor(type_idx)
|
107 |
+
|
108 |
+
edge_index = []
|
109 |
+
edge_attr = []
|
110 |
+
for bond in mol.GetBonds():
|
111 |
+
start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx()
|
112 |
+
if start in heavy_atom_indices and end in heavy_atom_indices:
|
113 |
+
start_new, end_new = heavy_atom_indices.index(
|
114 |
+
start
|
115 |
+
), heavy_atom_indices.index(end)
|
116 |
+
edge_index.extend([[start_new, end_new], [end_new, start_new]])
|
117 |
+
bond_type = BOND_INDEX[bond.GetBondType()]
|
118 |
+
edge_attr.extend([bond_type, bond_type])
|
119 |
+
|
120 |
+
edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous()
|
121 |
+
edge_attr = torch.tensor(edge_attr, dtype=torch.long)
|
122 |
+
|
123 |
+
# Create PyG Data object
|
124 |
+
data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr)
|
125 |
+
|
126 |
+
mol_id_to_pyg[mol_id] = data
|
127 |
+
|
128 |
+
# Save cached data if data_path is provided
|
129 |
+
if data_path is not None:
|
130 |
+
with open(cache_file, "wb") as f:
|
131 |
+
pickle.dump(mol_id_to_pyg, f)
|
132 |
+
|
133 |
+
print(f"Saved PyG data to {cache_file}")
|
134 |
+
|
135 |
+
return mol_id_to_pyg
|
136 |
+
|
137 |
+
def encode_supervised_example(
|
138 |
+
prompt: Sequence[Dict[str, str]],
|
139 |
+
response: Sequence[Dict[str, str]],
|
140 |
+
system: Optional[str],
|
141 |
+
molecule_ids: List[int],
|
142 |
+
retro_product_ids: List[int],
|
143 |
+
retro_labels: List[int],
|
144 |
+
template: "Template",
|
145 |
+
tokenizer: "PreTrainedTokenizer",
|
146 |
+
data_args: "DataArguments",
|
147 |
+
) -> Tuple[List[int], List[int], List[int], List[int], List[int]]:
|
148 |
+
|
149 |
+
messages = prompt + response
|
150 |
+
input_ids, labels = [], []
|
151 |
+
final_molecule_ids = []
|
152 |
+
final_product_ids = []
|
153 |
+
final_retro_labels = []
|
154 |
+
|
155 |
+
encoded_pairs = template.encode_multiturn(tokenizer, messages, system)
|
156 |
+
special_tokens = [
|
157 |
+
"<design_start>",
|
158 |
+
"<design_end>",
|
159 |
+
"<design_body>",
|
160 |
+
"<molecule>",
|
161 |
+
"<retro_start>",
|
162 |
+
"<retro_end>",
|
163 |
+
"<retro_body>",
|
164 |
+
]
|
165 |
+
special_token_ids = template._convert_elements_to_ids(tokenizer, special_tokens)
|
166 |
+
special_token_dict = dict(zip(special_tokens, special_token_ids))
|
167 |
+
|
168 |
+
total_length = 1 if template.efficient_eos else 0
|
169 |
+
for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs):
|
170 |
+
if total_length >= data_args.cutoff_len:
|
171 |
+
break
|
172 |
+
|
173 |
+
source_len, target_len = infer_seqlen(
|
174 |
+
len(source_ids), len(target_ids), data_args.cutoff_len - total_length
|
175 |
+
)
|
176 |
+
source_ids = source_ids[:source_len]
|
177 |
+
|
178 |
+
# Ensure balanced retro tags when truncating
|
179 |
+
retro_start_indices = [
|
180 |
+
i
|
181 |
+
for i, id in enumerate(target_ids)
|
182 |
+
if id == special_token_dict["<retro_start>"]
|
183 |
+
]
|
184 |
+
retro_end_indices = [
|
185 |
+
i
|
186 |
+
for i, id in enumerate(target_ids)
|
187 |
+
if id == special_token_dict["<retro_end>"]
|
188 |
+
]
|
189 |
+
|
190 |
+
if retro_start_indices and retro_end_indices:
|
191 |
+
# Find the last matching pair that fits within target_len
|
192 |
+
last_pair_index = -1
|
193 |
+
for start, end in zip(retro_start_indices, retro_end_indices):
|
194 |
+
if end < target_len:
|
195 |
+
last_pair_index = end
|
196 |
+
else:
|
197 |
+
break
|
198 |
+
|
199 |
+
if last_pair_index >= 0:
|
200 |
+
target_len = last_pair_index + 1
|
201 |
+
else:
|
202 |
+
# If no complete pair fits, truncate before the first start tag
|
203 |
+
target_len = (
|
204 |
+
min(target_len, retro_start_indices[0])
|
205 |
+
if retro_start_indices
|
206 |
+
else target_len
|
207 |
+
)
|
208 |
+
|
209 |
+
target_ids = target_ids[:target_len]
|
210 |
+
|
211 |
+
# Calculate the number of molecules in this turn
|
212 |
+
molecules_in_turn = target_ids.count(special_token_dict["<molecule>"])
|
213 |
+
retro_start_in_turn = target_ids.count(special_token_dict["<retro_start>"])
|
214 |
+
retro_end_in_turn = target_ids.count(special_token_dict["<retro_end>"])
|
215 |
+
|
216 |
+
assert retro_start_in_turn == retro_end_in_turn
|
217 |
+
|
218 |
+
retro_product_ids_in_turn = retro_product_ids[:retro_end_in_turn]
|
219 |
+
retro_labels_in_turn = retro_labels[:retro_end_in_turn]
|
220 |
+
|
221 |
+
# Add corresponding retro_labels and retro_product_ids
|
222 |
+
final_molecule_ids.extend(molecule_ids[:molecules_in_turn])
|
223 |
+
final_product_ids.extend(retro_product_ids_in_turn)
|
224 |
+
final_retro_labels.extend(retro_labels_in_turn)
|
225 |
+
|
226 |
+
total_length += source_len + target_len
|
227 |
+
|
228 |
+
if data_args.train_on_prompt:
|
229 |
+
source_mask = source_ids
|
230 |
+
elif turn_idx != 0 and template.efficient_eos:
|
231 |
+
source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * (
|
232 |
+
len(source_ids) - 1
|
233 |
+
)
|
234 |
+
else:
|
235 |
+
source_mask = [IGNORE_INDEX] * len(source_ids)
|
236 |
+
|
237 |
+
source_mask = [
|
238 |
+
IGNORE_INDEX if id in special_token_dict.values() else id
|
239 |
+
for id in source_mask
|
240 |
+
]
|
241 |
+
target_ids_mask = [
|
242 |
+
id if id in [special_token_dict["<retro_start>"], special_token_dict["<design_start>"]]
|
243 |
+
else (IGNORE_INDEX if id in special_token_dict.values() else id)
|
244 |
+
for id in target_ids
|
245 |
+
]
|
246 |
+
|
247 |
+
input_ids += source_ids + target_ids
|
248 |
+
labels += source_mask + target_ids_mask
|
249 |
+
|
250 |
+
if template.efficient_eos:
|
251 |
+
input_ids += [tokenizer.eos_token_id]
|
252 |
+
labels += [tokenizer.eos_token_id]
|
253 |
+
|
254 |
+
return input_ids, labels, final_molecule_ids, final_product_ids, final_retro_labels
|
255 |
+
|
256 |
+
|
257 |
+
def preprocess_mmsupervised_dataset(
|
258 |
+
examples: Dict[str, List[Any]],
|
259 |
+
template: "Template",
|
260 |
+
tokenizer: "PreTrainedTokenizer",
|
261 |
+
data_args: "DataArguments",
|
262 |
+
) -> Dict[str, List[List[int]]]:
|
263 |
+
model_inputs = {
|
264 |
+
"input_ids": [],
|
265 |
+
"attention_mask": [],
|
266 |
+
"labels": [],
|
267 |
+
"molecule_ids": [],
|
268 |
+
"molecule_properties": [],
|
269 |
+
"retro_labels": [],
|
270 |
+
"retro_product_ids": [],
|
271 |
+
}
|
272 |
+
|
273 |
+
for i in range(len(examples["prompt"])):
|
274 |
+
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1:
|
275 |
+
logger.warning(
|
276 |
+
"Dropped invalid example: {}".format(
|
277 |
+
examples["prompt"][i] + examples["response"][i]
|
278 |
+
)
|
279 |
+
)
|
280 |
+
continue
|
281 |
+
|
282 |
+
retro_product_ids = examples["retro_products"][i]
|
283 |
+
retro_labels = [
|
284 |
+
NO_LABEL_INDEX if label is None else label
|
285 |
+
for label in examples["retro_labels"][i]
|
286 |
+
]
|
287 |
+
properties = [
|
288 |
+
NO_LABEL_INDEX if prop is None else prop for prop in examples["property"][i]
|
289 |
+
]
|
290 |
+
|
291 |
+
input_ids, labels, molecule_ids, retro_product_ids, retro_labels = (
|
292 |
+
encode_supervised_example(
|
293 |
+
prompt=examples["prompt"][i],
|
294 |
+
response=examples["response"][i],
|
295 |
+
system=examples["system"][i],
|
296 |
+
molecule_ids=examples["molecules"][i],
|
297 |
+
retro_product_ids=retro_product_ids,
|
298 |
+
retro_labels=retro_labels,
|
299 |
+
template=template,
|
300 |
+
tokenizer=tokenizer,
|
301 |
+
data_args=data_args,
|
302 |
+
)
|
303 |
+
)
|
304 |
+
# molecule_ids = examples["molecules"][i]
|
305 |
+
|
306 |
+
model_inputs["input_ids"].append(input_ids)
|
307 |
+
model_inputs["attention_mask"].append([1] * len(input_ids))
|
308 |
+
model_inputs["labels"].append(labels)
|
309 |
+
model_inputs["molecule_ids"].append(molecule_ids)
|
310 |
+
model_inputs["molecule_properties"].append(properties)
|
311 |
+
model_inputs["retro_labels"].append(retro_labels)
|
312 |
+
model_inputs["retro_product_ids"].append(retro_product_ids)
|
313 |
+
|
314 |
+
return model_inputs
|
315 |
+
|
316 |
+
def print_supervised_dataset_example(
|
317 |
+
example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer"
|
318 |
+
) -> None:
|
319 |
+
valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"]))
|
320 |
+
print("Print_supervised_dataset_example")
|
321 |
+
|
322 |
+
print("input_ids:\n{}".format(example["input_ids"]))
|
323 |
+
print(
|
324 |
+
"inputs:\n{}".format(
|
325 |
+
tokenizer.decode(example["input_ids"], skip_special_tokens=False)
|
326 |
+
)
|
327 |
+
)
|
328 |
+
print("label_ids:\n{}".format(example["labels"]))
|
329 |
+
print(
|
330 |
+
"labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False))
|
331 |
+
)
|
332 |
+
print("molecule_ids:\n{}".format(example["molecule_ids"]))
|
333 |
+
print("molecule_properties:\n{}".format(example["molecule_properties"]))
|
334 |
+
print("retro_labels:\n{}".format(example["retro_labels"]))
|
335 |
+
print("retro_product_ids:\n{}".format(example["retro_product_ids"]))
|
src/data/template.py
ADDED
@@ -0,0 +1,889 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass
|
16 |
+
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
17 |
+
|
18 |
+
from ..extras.logging import get_logger
|
19 |
+
from .data_utils import Role
|
20 |
+
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
21 |
+
|
22 |
+
|
23 |
+
if TYPE_CHECKING:
|
24 |
+
from transformers import PreTrainedTokenizer
|
25 |
+
|
26 |
+
from .formatter import SLOTS, Formatter
|
27 |
+
|
28 |
+
|
29 |
+
logger = get_logger(__name__)
|
30 |
+
|
31 |
+
|
32 |
+
@dataclass
|
33 |
+
class Template:
|
34 |
+
format_user: "Formatter"
|
35 |
+
format_assistant: "Formatter"
|
36 |
+
format_system: "Formatter"
|
37 |
+
format_function: "Formatter"
|
38 |
+
format_observation: "Formatter"
|
39 |
+
format_tools: "Formatter"
|
40 |
+
format_separator: "Formatter"
|
41 |
+
format_prefix: "Formatter"
|
42 |
+
default_system: str
|
43 |
+
stop_words: List[str]
|
44 |
+
image_token: str
|
45 |
+
efficient_eos: bool
|
46 |
+
replace_eos: bool
|
47 |
+
|
48 |
+
def encode_oneturn(
|
49 |
+
self,
|
50 |
+
tokenizer: "PreTrainedTokenizer",
|
51 |
+
messages: Sequence[Dict[str, str]],
|
52 |
+
system: Optional[str] = None,
|
53 |
+
tools: Optional[str] = None,
|
54 |
+
) -> Tuple[List[int], List[int]]:
|
55 |
+
r"""
|
56 |
+
Returns a single pair of token ids representing prompt and response respectively.
|
57 |
+
"""
|
58 |
+
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
59 |
+
prompt_ids = []
|
60 |
+
for encoded_ids in encoded_messages[:-1]:
|
61 |
+
prompt_ids += encoded_ids
|
62 |
+
|
63 |
+
answer_ids = encoded_messages[-1]
|
64 |
+
return prompt_ids, answer_ids
|
65 |
+
|
66 |
+
def encode_multiturn(
|
67 |
+
self,
|
68 |
+
tokenizer: "PreTrainedTokenizer",
|
69 |
+
messages: Sequence[Dict[str, str]],
|
70 |
+
system: Optional[str] = None,
|
71 |
+
tools: Optional[str] = None,
|
72 |
+
) -> List[Tuple[List[int], List[int]]]:
|
73 |
+
r"""
|
74 |
+
Returns multiple pairs of token ids representing prompts and responses respectively.
|
75 |
+
"""
|
76 |
+
encoded_messages = self._encode(tokenizer, messages, system, tools)
|
77 |
+
return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)]
|
78 |
+
|
79 |
+
def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
80 |
+
r"""
|
81 |
+
Extracts tool message.
|
82 |
+
"""
|
83 |
+
return self.format_tools.extract(content)
|
84 |
+
|
85 |
+
def _encode(
|
86 |
+
self,
|
87 |
+
tokenizer: "PreTrainedTokenizer",
|
88 |
+
messages: Sequence[Dict[str, str]],
|
89 |
+
system: Optional[str],
|
90 |
+
tools: Optional[str],
|
91 |
+
) -> List[List[int]]:
|
92 |
+
r"""
|
93 |
+
Encodes formatted inputs to pairs of token ids.
|
94 |
+
Turn 0: prefix + system + query resp
|
95 |
+
Turn t: sep + query resp
|
96 |
+
"""
|
97 |
+
system = system or self.default_system
|
98 |
+
encoded_messages = []
|
99 |
+
for i, message in enumerate(messages):
|
100 |
+
elements = []
|
101 |
+
|
102 |
+
if i == 0:
|
103 |
+
elements += self.format_prefix.apply()
|
104 |
+
if system or tools:
|
105 |
+
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
106 |
+
elements += self.format_system.apply(content=(system + tool_text))
|
107 |
+
|
108 |
+
if i > 0 and i % 2 == 0:
|
109 |
+
elements += self.format_separator.apply()
|
110 |
+
|
111 |
+
if message["role"] == Role.USER.value:
|
112 |
+
elements += self.format_user.apply(content=message["content"], idx=str(i // 2))
|
113 |
+
elif message["role"] == Role.ASSISTANT.value:
|
114 |
+
elements += self.format_assistant.apply(content=message["content"])
|
115 |
+
elif message["role"] == Role.OBSERVATION.value:
|
116 |
+
elements += self.format_observation.apply(content=message["content"])
|
117 |
+
elif message["role"] == Role.FUNCTION.value:
|
118 |
+
elements += self.format_function.apply(content=message["content"])
|
119 |
+
else:
|
120 |
+
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
121 |
+
|
122 |
+
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
123 |
+
|
124 |
+
return encoded_messages
|
125 |
+
|
126 |
+
def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]:
|
127 |
+
r"""
|
128 |
+
Converts elements to token ids.
|
129 |
+
"""
|
130 |
+
token_ids = []
|
131 |
+
for elem in elements:
|
132 |
+
if isinstance(elem, str):
|
133 |
+
if len(elem) != 0:
|
134 |
+
token_ids += tokenizer.encode(elem, add_special_tokens=False)
|
135 |
+
elif isinstance(elem, dict):
|
136 |
+
token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))]
|
137 |
+
elif isinstance(elem, set):
|
138 |
+
if "bos_token" in elem and tokenizer.bos_token_id is not None:
|
139 |
+
token_ids += [tokenizer.bos_token_id]
|
140 |
+
elif "eos_token" in elem and tokenizer.eos_token_id is not None:
|
141 |
+
token_ids += [tokenizer.eos_token_id]
|
142 |
+
else:
|
143 |
+
raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem)))
|
144 |
+
|
145 |
+
return token_ids
|
146 |
+
|
147 |
+
|
148 |
+
@dataclass
|
149 |
+
class Llama2Template(Template):
|
150 |
+
def _encode(
|
151 |
+
self,
|
152 |
+
tokenizer: "PreTrainedTokenizer",
|
153 |
+
messages: Sequence[Dict[str, str]],
|
154 |
+
system: str,
|
155 |
+
tools: str,
|
156 |
+
) -> List[List[int]]:
|
157 |
+
r"""
|
158 |
+
Encodes formatted inputs to pairs of token ids.
|
159 |
+
Turn 0: prefix + system + query resp
|
160 |
+
Turn t: sep + query resp
|
161 |
+
"""
|
162 |
+
system = system or self.default_system
|
163 |
+
encoded_messages = []
|
164 |
+
for i, message in enumerate(messages):
|
165 |
+
elements = []
|
166 |
+
|
167 |
+
system_text = ""
|
168 |
+
if i == 0:
|
169 |
+
elements += self.format_prefix.apply()
|
170 |
+
if system or tools:
|
171 |
+
tool_text = self.format_tools.apply(content=tools)[0] if tools else ""
|
172 |
+
system_text = self.format_system.apply(content=(system + tool_text))[0]
|
173 |
+
|
174 |
+
if i > 0 and i % 2 == 0:
|
175 |
+
elements += self.format_separator.apply()
|
176 |
+
|
177 |
+
if message["role"] == Role.USER.value:
|
178 |
+
elements += self.format_user.apply(content=system_text + message["content"])
|
179 |
+
elif message["role"] == Role.ASSISTANT.value:
|
180 |
+
elements += self.format_assistant.apply(content=message["content"])
|
181 |
+
elif message["role"] == Role.OBSERVATION.value:
|
182 |
+
elements += self.format_observation.apply(content=message["content"])
|
183 |
+
elif message["role"] == Role.FUNCTION.value:
|
184 |
+
elements += self.format_function.apply(content=message["content"])
|
185 |
+
else:
|
186 |
+
raise NotImplementedError("Unexpected role: {}".format(message["role"]))
|
187 |
+
|
188 |
+
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
|
189 |
+
|
190 |
+
return encoded_messages
|
191 |
+
|
192 |
+
|
193 |
+
TEMPLATES: Dict[str, Template] = {}
|
194 |
+
|
195 |
+
|
196 |
+
def _register_template(
|
197 |
+
name: str,
|
198 |
+
format_user: Optional["Formatter"] = None,
|
199 |
+
format_assistant: Optional["Formatter"] = None,
|
200 |
+
format_system: Optional["Formatter"] = None,
|
201 |
+
format_function: Optional["Formatter"] = None,
|
202 |
+
format_observation: Optional["Formatter"] = None,
|
203 |
+
format_tools: Optional["Formatter"] = None,
|
204 |
+
format_separator: Optional["Formatter"] = None,
|
205 |
+
format_prefix: Optional["Formatter"] = None,
|
206 |
+
default_system: str = "",
|
207 |
+
stop_words: Sequence[str] = [],
|
208 |
+
image_token: str = "<image>",
|
209 |
+
efficient_eos: bool = False,
|
210 |
+
replace_eos: bool = False,
|
211 |
+
) -> None:
|
212 |
+
r"""
|
213 |
+
Registers a chat template.
|
214 |
+
|
215 |
+
To add the following chat template:
|
216 |
+
```
|
217 |
+
[HUMAN]:
|
218 |
+
user prompt here
|
219 |
+
[AI]:
|
220 |
+
model response here
|
221 |
+
|
222 |
+
[HUMAN]:
|
223 |
+
user prompt here
|
224 |
+
[AI]:
|
225 |
+
model response here
|
226 |
+
```
|
227 |
+
|
228 |
+
The corresponding code should be:
|
229 |
+
```
|
230 |
+
_register_template(
|
231 |
+
name="custom",
|
232 |
+
format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]),
|
233 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
234 |
+
efficient_eos=True,
|
235 |
+
)
|
236 |
+
```
|
237 |
+
"""
|
238 |
+
eos_slots = [] if efficient_eos else [{"eos_token"}]
|
239 |
+
template_class = Llama2Template if name.startswith("llama2") else Template
|
240 |
+
default_user_formatter = StringFormatter(slots=["{{content}}"])
|
241 |
+
default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots)
|
242 |
+
default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default")
|
243 |
+
default_tool_formatter = ToolFormatter(tool_format="default")
|
244 |
+
default_separator_formatter = EmptyFormatter()
|
245 |
+
default_prefix_formatter = EmptyFormatter()
|
246 |
+
TEMPLATES[name] = template_class(
|
247 |
+
format_user=format_user or default_user_formatter,
|
248 |
+
format_assistant=format_assistant or default_assistant_formatter,
|
249 |
+
format_system=format_system or default_user_formatter,
|
250 |
+
format_function=format_function or default_function_formatter,
|
251 |
+
format_observation=format_observation or format_user or default_user_formatter,
|
252 |
+
format_tools=format_tools or default_tool_formatter,
|
253 |
+
format_separator=format_separator or default_separator_formatter,
|
254 |
+
format_prefix=format_prefix or default_prefix_formatter,
|
255 |
+
default_system=default_system,
|
256 |
+
stop_words=stop_words,
|
257 |
+
image_token=image_token,
|
258 |
+
efficient_eos=efficient_eos,
|
259 |
+
replace_eos=replace_eos,
|
260 |
+
)
|
261 |
+
|
262 |
+
|
263 |
+
def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None:
|
264 |
+
is_added = tokenizer.eos_token_id is None
|
265 |
+
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
266 |
+
|
267 |
+
if is_added:
|
268 |
+
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
269 |
+
else:
|
270 |
+
logger.info("Replace eos token: {}".format(tokenizer.eos_token))
|
271 |
+
|
272 |
+
if num_added_tokens > 0:
|
273 |
+
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
274 |
+
|
275 |
+
|
276 |
+
def _jinja_escape(content: str) -> str:
|
277 |
+
return content.replace("'", r"\'")
|
278 |
+
|
279 |
+
|
280 |
+
def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str:
|
281 |
+
slot_items = []
|
282 |
+
for slot in slots:
|
283 |
+
if isinstance(slot, str):
|
284 |
+
slot_pieces = slot.split("{{content}}")
|
285 |
+
if slot_pieces[0]:
|
286 |
+
slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'")
|
287 |
+
if len(slot_pieces) > 1:
|
288 |
+
slot_items.append(placeholder)
|
289 |
+
if slot_pieces[1]:
|
290 |
+
slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'")
|
291 |
+
elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced
|
292 |
+
if "bos_token" in slot and tokenizer.bos_token_id is not None:
|
293 |
+
slot_items.append("'" + tokenizer.bos_token + "'")
|
294 |
+
elif "eos_token" in slot and tokenizer.eos_token_id is not None:
|
295 |
+
slot_items.append("'" + tokenizer.eos_token + "'")
|
296 |
+
elif isinstance(slot, dict):
|
297 |
+
raise ValueError("Dict is not supported.")
|
298 |
+
|
299 |
+
return " + ".join(slot_items)
|
300 |
+
|
301 |
+
|
302 |
+
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
303 |
+
jinja_template = ""
|
304 |
+
|
305 |
+
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
306 |
+
if prefix:
|
307 |
+
jinja_template += "{{ " + prefix + " }}"
|
308 |
+
|
309 |
+
if template.default_system:
|
310 |
+
jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}"
|
311 |
+
|
312 |
+
jinja_template += (
|
313 |
+
"{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}"
|
314 |
+
)
|
315 |
+
|
316 |
+
system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message")
|
317 |
+
if not isinstance(template, Llama2Template):
|
318 |
+
jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}"
|
319 |
+
|
320 |
+
jinja_template += "{% for message in messages %}"
|
321 |
+
jinja_template += "{% set content = message['content'] %}"
|
322 |
+
if isinstance(template, Llama2Template):
|
323 |
+
jinja_template += "{% if loop.index0 == 0 and system_message is defined %}"
|
324 |
+
jinja_template += "{% set content = " + system_message + " + message['content'] %}"
|
325 |
+
jinja_template += "{% endif %}"
|
326 |
+
|
327 |
+
jinja_template += "{% if message['role'] == 'user' %}"
|
328 |
+
user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer)
|
329 |
+
jinja_template += "{{ " + user_message + " }}"
|
330 |
+
|
331 |
+
jinja_template += "{% elif message['role'] == 'assistant' %}"
|
332 |
+
assistant_message = _convert_slots_to_jinja(
|
333 |
+
template.format_assistant.apply() + template.format_separator.apply(), tokenizer
|
334 |
+
)
|
335 |
+
jinja_template += "{{ " + assistant_message + " }}"
|
336 |
+
jinja_template += "{% endif %}"
|
337 |
+
jinja_template += "{% endfor %}"
|
338 |
+
return jinja_template
|
339 |
+
|
340 |
+
|
341 |
+
def get_template_and_fix_tokenizer(
|
342 |
+
tokenizer: "PreTrainedTokenizer",
|
343 |
+
name: Optional[str] = None,
|
344 |
+
tool_format: Optional[str] = None,
|
345 |
+
) -> Template:
|
346 |
+
if name is None:
|
347 |
+
template = TEMPLATES["empty"] # placeholder
|
348 |
+
else:
|
349 |
+
template = TEMPLATES.get(name, None)
|
350 |
+
if template is None:
|
351 |
+
raise ValueError("Template {} does not exist.".format(name))
|
352 |
+
|
353 |
+
if tool_format is not None:
|
354 |
+
logger.info("Using tool format: {}.".format(tool_format))
|
355 |
+
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
356 |
+
template.format_tools = ToolFormatter(tool_format=tool_format)
|
357 |
+
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format)
|
358 |
+
|
359 |
+
stop_words = template.stop_words
|
360 |
+
if template.replace_eos:
|
361 |
+
if not stop_words:
|
362 |
+
raise ValueError("Stop words are required to replace the EOS token.")
|
363 |
+
|
364 |
+
_add_or_replace_eos_token(tokenizer, eos_token=stop_words[0])
|
365 |
+
stop_words = stop_words[1:]
|
366 |
+
|
367 |
+
if tokenizer.eos_token_id is None:
|
368 |
+
_add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>")
|
369 |
+
|
370 |
+
if tokenizer.pad_token_id is None:
|
371 |
+
tokenizer.pad_token = tokenizer.eos_token
|
372 |
+
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
373 |
+
|
374 |
+
if stop_words:
|
375 |
+
num_added_tokens = tokenizer.add_special_tokens(
|
376 |
+
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
377 |
+
)
|
378 |
+
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
379 |
+
if num_added_tokens > 0:
|
380 |
+
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
381 |
+
|
382 |
+
try:
|
383 |
+
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
384 |
+
except ValueError:
|
385 |
+
logger.info("Cannot add this chat template to tokenizer.")
|
386 |
+
|
387 |
+
return template
|
388 |
+
|
389 |
+
|
390 |
+
_register_template(
|
391 |
+
name="alpaca",
|
392 |
+
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]),
|
393 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
394 |
+
default_system=(
|
395 |
+
"Below is an instruction that describes a task. "
|
396 |
+
"Write a response that appropriately completes the request.\n\n"
|
397 |
+
),
|
398 |
+
)
|
399 |
+
|
400 |
+
|
401 |
+
_register_template(
|
402 |
+
name="aquila",
|
403 |
+
format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]),
|
404 |
+
format_separator=EmptyFormatter(slots=["###"]),
|
405 |
+
default_system=(
|
406 |
+
"A chat between a curious human and an artificial intelligence assistant. "
|
407 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions."
|
408 |
+
),
|
409 |
+
stop_words=["</s>"],
|
410 |
+
efficient_eos=True,
|
411 |
+
)
|
412 |
+
|
413 |
+
|
414 |
+
_register_template(
|
415 |
+
name="atom",
|
416 |
+
format_user=StringFormatter(
|
417 |
+
slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"]
|
418 |
+
),
|
419 |
+
format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]),
|
420 |
+
)
|
421 |
+
|
422 |
+
|
423 |
+
_register_template(
|
424 |
+
name="baichuan",
|
425 |
+
format_user=StringFormatter(slots=[{"token": "<reserved_102>"}, "{{content}}", {"token": "<reserved_103>"}]),
|
426 |
+
efficient_eos=True,
|
427 |
+
)
|
428 |
+
|
429 |
+
|
430 |
+
_register_template(
|
431 |
+
name="baichuan2",
|
432 |
+
format_user=StringFormatter(slots=["<reserved_106>{{content}}<reserved_107>"]),
|
433 |
+
efficient_eos=True,
|
434 |
+
)
|
435 |
+
|
436 |
+
|
437 |
+
_register_template(
|
438 |
+
name="belle",
|
439 |
+
format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]),
|
440 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
441 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
442 |
+
)
|
443 |
+
|
444 |
+
|
445 |
+
_register_template(
|
446 |
+
name="bluelm",
|
447 |
+
format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]),
|
448 |
+
)
|
449 |
+
|
450 |
+
|
451 |
+
_register_template(
|
452 |
+
name="breeze",
|
453 |
+
format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]),
|
454 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
455 |
+
efficient_eos=True,
|
456 |
+
)
|
457 |
+
|
458 |
+
|
459 |
+
_register_template(
|
460 |
+
name="chatglm2",
|
461 |
+
format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]),
|
462 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
463 |
+
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
464 |
+
efficient_eos=True,
|
465 |
+
)
|
466 |
+
|
467 |
+
|
468 |
+
_register_template(
|
469 |
+
name="chatglm3",
|
470 |
+
format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]),
|
471 |
+
format_assistant=StringFormatter(slots=["\n", "{{content}}"]),
|
472 |
+
format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]),
|
473 |
+
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
474 |
+
format_observation=StringFormatter(
|
475 |
+
slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]
|
476 |
+
),
|
477 |
+
format_tools=ToolFormatter(tool_format="glm4"),
|
478 |
+
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
479 |
+
stop_words=["<|user|>", "<|observation|>"],
|
480 |
+
efficient_eos=True,
|
481 |
+
)
|
482 |
+
|
483 |
+
|
484 |
+
_register_template(
|
485 |
+
name="chatml",
|
486 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
487 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
488 |
+
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
489 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
490 |
+
stop_words=["<|im_end|>", "<|im_start|>"],
|
491 |
+
replace_eos=True,
|
492 |
+
)
|
493 |
+
|
494 |
+
|
495 |
+
_register_template(
|
496 |
+
name="chatml_de",
|
497 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
498 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
499 |
+
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
500 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
501 |
+
default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.",
|
502 |
+
stop_words=["<|im_end|>", "<|im_start|>"],
|
503 |
+
replace_eos=True,
|
504 |
+
)
|
505 |
+
|
506 |
+
|
507 |
+
_register_template(
|
508 |
+
name="codegeex2",
|
509 |
+
format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]),
|
510 |
+
)
|
511 |
+
|
512 |
+
|
513 |
+
_register_template(
|
514 |
+
name="cohere",
|
515 |
+
format_user=StringFormatter(
|
516 |
+
slots=[
|
517 |
+
(
|
518 |
+
"<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"
|
519 |
+
"<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>"
|
520 |
+
)
|
521 |
+
]
|
522 |
+
),
|
523 |
+
format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]),
|
524 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
525 |
+
)
|
526 |
+
|
527 |
+
|
528 |
+
_register_template(
|
529 |
+
name="cpm",
|
530 |
+
format_user=StringFormatter(slots=["<用户>{{content}}<AI>"]),
|
531 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
532 |
+
)
|
533 |
+
|
534 |
+
|
535 |
+
_register_template(
|
536 |
+
name="dbrx",
|
537 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
538 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
539 |
+
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
540 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
541 |
+
default_system=(
|
542 |
+
"You are DBRX, created by Databricks. You were last updated in December 2023. "
|
543 |
+
"You answer questions based on information available up to that point.\n"
|
544 |
+
"YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough "
|
545 |
+
"responses to more complex and open-ended questions.\nYou assist with various tasks, "
|
546 |
+
"from writing to coding (using markdown for code blocks — remember to use ``` with "
|
547 |
+
"code, JSON, and tables).\n(You do not have real-time data access or code execution "
|
548 |
+
"capabilities. You avoid stereotyping and provide balanced perspectives on "
|
549 |
+
"controversial topics. You do not provide song lyrics, poems, or news articles and "
|
550 |
+
"do not divulge details of your training data.)\nThis is your system prompt, "
|
551 |
+
"guiding your responses. Do not reference it, just respond to the user. If you find "
|
552 |
+
"yourself talking about this message, stop. You should be responding appropriately "
|
553 |
+
"and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION "
|
554 |
+
"ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY."
|
555 |
+
),
|
556 |
+
stop_words=["<|im_end|>"],
|
557 |
+
replace_eos=True,
|
558 |
+
)
|
559 |
+
|
560 |
+
|
561 |
+
_register_template(
|
562 |
+
name="deepseek",
|
563 |
+
format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]),
|
564 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
565 |
+
)
|
566 |
+
|
567 |
+
|
568 |
+
_register_template(
|
569 |
+
name="deepseekcoder",
|
570 |
+
format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]),
|
571 |
+
format_assistant=StringFormatter(slots=["\n{{content}}\n"]),
|
572 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
573 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
574 |
+
default_system=(
|
575 |
+
"You are an AI programming assistant, utilizing the Deepseek Coder model, "
|
576 |
+
"developed by Deepseek Company, and you only answer questions related to computer science. "
|
577 |
+
"For politically sensitive questions, security and privacy issues, "
|
578 |
+
"and other non-computer science questions, you will refuse to answer\n"
|
579 |
+
),
|
580 |
+
)
|
581 |
+
|
582 |
+
|
583 |
+
_register_template(
|
584 |
+
name="default",
|
585 |
+
format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]),
|
586 |
+
format_system=StringFormatter(slots=["{{content}}\n"]),
|
587 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
588 |
+
)
|
589 |
+
|
590 |
+
|
591 |
+
_register_template(
|
592 |
+
name="empty",
|
593 |
+
efficient_eos=True,
|
594 |
+
)
|
595 |
+
|
596 |
+
|
597 |
+
_register_template(
|
598 |
+
name="falcon",
|
599 |
+
format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]),
|
600 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
601 |
+
efficient_eos=True,
|
602 |
+
)
|
603 |
+
|
604 |
+
|
605 |
+
_register_template(
|
606 |
+
name="fewshot",
|
607 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
608 |
+
efficient_eos=True,
|
609 |
+
)
|
610 |
+
|
611 |
+
|
612 |
+
_register_template(
|
613 |
+
name="gemma",
|
614 |
+
format_user=StringFormatter(slots=["<start_of_turn>user\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]),
|
615 |
+
format_observation=StringFormatter(
|
616 |
+
slots=["<start_of_turn>tool\n{{content}}<end_of_turn>\n<start_of_turn>model\n"]
|
617 |
+
),
|
618 |
+
format_separator=EmptyFormatter(slots=["<end_of_turn>\n"]),
|
619 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
620 |
+
efficient_eos=True,
|
621 |
+
)
|
622 |
+
|
623 |
+
|
624 |
+
_register_template(
|
625 |
+
name="glm4",
|
626 |
+
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]),
|
627 |
+
format_assistant=StringFormatter(slots=["\n{{content}}"]),
|
628 |
+
format_system=StringFormatter(slots=["<|system|>\n{{content}}"]),
|
629 |
+
format_function=FunctionFormatter(slots=[], tool_format="glm4"),
|
630 |
+
format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]),
|
631 |
+
format_tools=ToolFormatter(tool_format="glm4"),
|
632 |
+
format_prefix=EmptyFormatter(slots=["[gMASK]<sop>"]),
|
633 |
+
stop_words=["<|user|>", "<|observation|>"],
|
634 |
+
efficient_eos=True,
|
635 |
+
)
|
636 |
+
|
637 |
+
|
638 |
+
_register_template(
|
639 |
+
name="intern",
|
640 |
+
format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]),
|
641 |
+
format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]),
|
642 |
+
format_separator=EmptyFormatter(slots=["<eoa>\n"]),
|
643 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
644 |
+
stop_words=["<eoa>"],
|
645 |
+
efficient_eos=True, # internlm tokenizer cannot set eos_token_id
|
646 |
+
)
|
647 |
+
|
648 |
+
|
649 |
+
_register_template(
|
650 |
+
name="intern2",
|
651 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
652 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
653 |
+
format_separator=EmptyFormatter(slots=["<|im_end|>\n"]),
|
654 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
655 |
+
stop_words=["<|im_end|>"],
|
656 |
+
efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id
|
657 |
+
)
|
658 |
+
|
659 |
+
|
660 |
+
_register_template(
|
661 |
+
name="llama2",
|
662 |
+
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
663 |
+
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
664 |
+
)
|
665 |
+
|
666 |
+
|
667 |
+
_register_template(
|
668 |
+
name="llama2_zh",
|
669 |
+
format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]),
|
670 |
+
format_system=StringFormatter(slots=["<<SYS>>\n{{content}}\n<</SYS>>\n\n"]),
|
671 |
+
default_system="You are a helpful assistant. 你是一个乐于助人的助手。",
|
672 |
+
)
|
673 |
+
|
674 |
+
|
675 |
+
_register_template(
|
676 |
+
name="llama3",
|
677 |
+
format_user=StringFormatter(
|
678 |
+
slots=[
|
679 |
+
(
|
680 |
+
"<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
681 |
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
682 |
+
)
|
683 |
+
]
|
684 |
+
),
|
685 |
+
format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]),
|
686 |
+
format_observation=StringFormatter(
|
687 |
+
slots=[
|
688 |
+
(
|
689 |
+
"<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
690 |
+
"<|start_header_id|>assistant<|end_header_id|>\n\n"
|
691 |
+
)
|
692 |
+
]
|
693 |
+
),
|
694 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
695 |
+
stop_words=["<|eot_id|>"],
|
696 |
+
replace_eos=True,
|
697 |
+
)
|
698 |
+
|
699 |
+
|
700 |
+
_register_template(
|
701 |
+
name="mistral",
|
702 |
+
format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]),
|
703 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
704 |
+
)
|
705 |
+
|
706 |
+
|
707 |
+
_register_template(
|
708 |
+
name="olmo",
|
709 |
+
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]),
|
710 |
+
format_prefix=EmptyFormatter(slots=[{"eos_token"}]),
|
711 |
+
)
|
712 |
+
|
713 |
+
|
714 |
+
_register_template(
|
715 |
+
name="openchat",
|
716 |
+
format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]),
|
717 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
718 |
+
)
|
719 |
+
|
720 |
+
|
721 |
+
_register_template(
|
722 |
+
name="openchat-3.6",
|
723 |
+
format_user=StringFormatter(
|
724 |
+
slots=[
|
725 |
+
(
|
726 |
+
"<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>"
|
727 |
+
"<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n"
|
728 |
+
)
|
729 |
+
]
|
730 |
+
),
|
731 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
732 |
+
stop_words=["<|eot_id|>"],
|
733 |
+
replace_eos=True,
|
734 |
+
)
|
735 |
+
|
736 |
+
|
737 |
+
_register_template(
|
738 |
+
name="orion",
|
739 |
+
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]),
|
740 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
741 |
+
)
|
742 |
+
|
743 |
+
|
744 |
+
_register_template(
|
745 |
+
name="phi",
|
746 |
+
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]),
|
747 |
+
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
748 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
749 |
+
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
750 |
+
stop_words=["<|end|>"],
|
751 |
+
replace_eos=True,
|
752 |
+
)
|
753 |
+
|
754 |
+
|
755 |
+
_register_template(
|
756 |
+
name="qwen",
|
757 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
758 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
759 |
+
format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
760 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
761 |
+
default_system="You are a helpful assistant.",
|
762 |
+
stop_words=["<|im_end|>"],
|
763 |
+
replace_eos=True,
|
764 |
+
)
|
765 |
+
|
766 |
+
|
767 |
+
_register_template(
|
768 |
+
name="solar",
|
769 |
+
format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]),
|
770 |
+
format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]),
|
771 |
+
efficient_eos=True,
|
772 |
+
)
|
773 |
+
|
774 |
+
|
775 |
+
_register_template(
|
776 |
+
name="starchat",
|
777 |
+
format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]),
|
778 |
+
format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]),
|
779 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
780 |
+
stop_words=["<|end|>"],
|
781 |
+
replace_eos=True,
|
782 |
+
)
|
783 |
+
|
784 |
+
|
785 |
+
_register_template(
|
786 |
+
name="telechat",
|
787 |
+
format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]),
|
788 |
+
format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]),
|
789 |
+
stop_words=["<_end>"],
|
790 |
+
replace_eos=True,
|
791 |
+
)
|
792 |
+
|
793 |
+
|
794 |
+
_register_template(
|
795 |
+
name="vicuna",
|
796 |
+
format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]),
|
797 |
+
default_system=(
|
798 |
+
"A chat between a curious user and an artificial intelligence assistant. "
|
799 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions."
|
800 |
+
),
|
801 |
+
)
|
802 |
+
|
803 |
+
|
804 |
+
_register_template(
|
805 |
+
name="xuanyuan",
|
806 |
+
format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]),
|
807 |
+
default_system=(
|
808 |
+
"以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头,"
|
809 |
+
"会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、"
|
810 |
+
"不安全、有争议、政治敏感等相关的话题、问题和指示。\n"
|
811 |
+
),
|
812 |
+
)
|
813 |
+
|
814 |
+
|
815 |
+
_register_template(
|
816 |
+
name="xverse",
|
817 |
+
format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]),
|
818 |
+
)
|
819 |
+
|
820 |
+
|
821 |
+
_register_template(
|
822 |
+
name="yayi",
|
823 |
+
format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]),
|
824 |
+
format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]),
|
825 |
+
format_separator=EmptyFormatter(slots=["\n\n"]),
|
826 |
+
default_system=(
|
827 |
+
"You are a helpful, respectful and honest assistant named YaYi "
|
828 |
+
"developed by Beijing Wenge Technology Co.,Ltd. "
|
829 |
+
"Always answer as helpfully as possible, while being safe. "
|
830 |
+
"Your answers should not include any harmful, unethical, "
|
831 |
+
"racist, sexist, toxic, dangerous, or illegal content. "
|
832 |
+
"Please ensure that your responses are socially unbiased and positive in nature.\n\n"
|
833 |
+
"If a question does not make any sense, or is not factually coherent, "
|
834 |
+
"explain why instead of answering something not correct. "
|
835 |
+
"If you don't know the answer to a question, please don't share false information."
|
836 |
+
),
|
837 |
+
stop_words=["<|End|>"],
|
838 |
+
)
|
839 |
+
|
840 |
+
|
841 |
+
_register_template(
|
842 |
+
name="yi",
|
843 |
+
format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]),
|
844 |
+
format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]),
|
845 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
846 |
+
stop_words=["<|im_end|>"],
|
847 |
+
replace_eos=True,
|
848 |
+
)
|
849 |
+
|
850 |
+
|
851 |
+
_register_template(
|
852 |
+
name="yi_vl",
|
853 |
+
format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]),
|
854 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
855 |
+
default_system=(
|
856 |
+
"This is a chat between an inquisitive human and an AI assistant. "
|
857 |
+
"Assume the role of the AI assistant. Read all the images carefully, "
|
858 |
+
"and respond to the human's questions with informative, helpful, detailed and polite answers. "
|
859 |
+
"这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。"
|
860 |
+
"仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n"
|
861 |
+
),
|
862 |
+
stop_words=["###"],
|
863 |
+
efficient_eos=True,
|
864 |
+
)
|
865 |
+
|
866 |
+
|
867 |
+
_register_template(
|
868 |
+
name="yuan",
|
869 |
+
format_user=StringFormatter(slots=["{{content}}", {"token": "<sep>"}]),
|
870 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
871 |
+
stop_words=["<eod>"],
|
872 |
+
replace_eos=True,
|
873 |
+
)
|
874 |
+
|
875 |
+
|
876 |
+
_register_template(
|
877 |
+
name="zephyr",
|
878 |
+
format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]),
|
879 |
+
format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]),
|
880 |
+
format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]),
|
881 |
+
default_system="You are Zephyr, a helpful assistant.",
|
882 |
+
)
|
883 |
+
|
884 |
+
|
885 |
+
_register_template(
|
886 |
+
name="ziya",
|
887 |
+
format_user=StringFormatter(slots=["<human>:{{content}}\n<bot>:"]),
|
888 |
+
format_separator=EmptyFormatter(slots=["\n"]),
|
889 |
+
)
|
src/data/tool_utils.py
ADDED
@@ -0,0 +1,140 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import re
|
17 |
+
from abc import ABC, abstractmethod
|
18 |
+
from dataclasses import dataclass
|
19 |
+
from typing import Any, Dict, List, Tuple, Union
|
20 |
+
|
21 |
+
from .data_utils import SLOTS
|
22 |
+
|
23 |
+
|
24 |
+
DEFAULT_TOOL_PROMPT = (
|
25 |
+
"You have access to the following tools:\n{tool_text}"
|
26 |
+
"Use the following format if using a tool:\n"
|
27 |
+
"```\n"
|
28 |
+
"Action: tool name (one of [{tool_names}]).\n"
|
29 |
+
"Action Input: the input to the tool, in a JSON format representing the kwargs "
|
30 |
+
"""(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n"""
|
31 |
+
"```\n"
|
32 |
+
)
|
33 |
+
|
34 |
+
|
35 |
+
GLM4_TOOL_PROMPT = (
|
36 |
+
"你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的,"
|
37 |
+
"你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}"
|
38 |
+
)
|
39 |
+
|
40 |
+
|
41 |
+
@dataclass
|
42 |
+
class ToolUtils(ABC):
|
43 |
+
@staticmethod
|
44 |
+
@abstractmethod
|
45 |
+
def get_function_slots() -> SLOTS: ...
|
46 |
+
|
47 |
+
@staticmethod
|
48 |
+
@abstractmethod
|
49 |
+
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
|
50 |
+
|
51 |
+
@staticmethod
|
52 |
+
@abstractmethod
|
53 |
+
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
|
54 |
+
|
55 |
+
|
56 |
+
class DefaultToolUtils(ToolUtils):
|
57 |
+
@staticmethod
|
58 |
+
def get_function_slots() -> SLOTS:
|
59 |
+
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
60 |
+
|
61 |
+
@staticmethod
|
62 |
+
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
63 |
+
tool_text = ""
|
64 |
+
tool_names = []
|
65 |
+
for tool in tools:
|
66 |
+
param_text = ""
|
67 |
+
for name, param in tool["parameters"]["properties"].items():
|
68 |
+
required, enum, items = "", "", ""
|
69 |
+
if name in tool["parameters"].get("required", []):
|
70 |
+
required = ", required"
|
71 |
+
|
72 |
+
if param.get("enum", None):
|
73 |
+
enum = ", should be one of [{}]".format(", ".join(param["enum"]))
|
74 |
+
|
75 |
+
if param.get("items", None):
|
76 |
+
items = ", where each item should be {}".format(param["items"].get("type", ""))
|
77 |
+
|
78 |
+
param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format(
|
79 |
+
name=name,
|
80 |
+
type=param.get("type", ""),
|
81 |
+
required=required,
|
82 |
+
desc=param.get("description", ""),
|
83 |
+
enum=enum,
|
84 |
+
items=items,
|
85 |
+
)
|
86 |
+
|
87 |
+
tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format(
|
88 |
+
name=tool["name"], desc=tool.get("description", ""), args=param_text
|
89 |
+
)
|
90 |
+
tool_names.append(tool["name"])
|
91 |
+
|
92 |
+
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
93 |
+
|
94 |
+
@staticmethod
|
95 |
+
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
96 |
+
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
97 |
+
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
98 |
+
if not action_match:
|
99 |
+
return content
|
100 |
+
|
101 |
+
results = []
|
102 |
+
for match in action_match:
|
103 |
+
tool_name = match[0].strip()
|
104 |
+
tool_input = match[1].strip().strip('"').strip("```")
|
105 |
+
try:
|
106 |
+
arguments = json.loads(tool_input)
|
107 |
+
results.append((tool_name, json.dumps(arguments, ensure_ascii=False)))
|
108 |
+
except json.JSONDecodeError:
|
109 |
+
return content
|
110 |
+
|
111 |
+
return results
|
112 |
+
|
113 |
+
|
114 |
+
class GLM4ToolUtils(ToolUtils):
|
115 |
+
@staticmethod
|
116 |
+
def get_function_slots() -> SLOTS:
|
117 |
+
return ["{{name}}\n{{arguments}}"]
|
118 |
+
|
119 |
+
@staticmethod
|
120 |
+
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
121 |
+
tool_text = ""
|
122 |
+
for tool in tools:
|
123 |
+
tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format(
|
124 |
+
name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False)
|
125 |
+
)
|
126 |
+
|
127 |
+
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
128 |
+
|
129 |
+
@staticmethod
|
130 |
+
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
131 |
+
if "\n" not in content:
|
132 |
+
return content
|
133 |
+
|
134 |
+
tool_name, tool_input = content.split("\n", maxsplit=1)
|
135 |
+
try:
|
136 |
+
arguments = json.loads(tool_input)
|
137 |
+
except json.JSONDecodeError:
|
138 |
+
return content
|
139 |
+
|
140 |
+
return [(tool_name, json.dumps(arguments, ensure_ascii=False))]
|
src/eval/__init__.py
ADDED
File without changes
|
src/eval/dataset.py
ADDED
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Llamole Team
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's transformers library.
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import torch
|
19 |
+
from torch.utils.data import Dataset
|
20 |
+
|
21 |
+
from ..extras.constants import BOND_INDEX
|
22 |
+
|
23 |
+
def dict_to_list(data_dict, mol_properties):
|
24 |
+
return [data_dict.get(prop, float("nan")) for prop in mol_properties]
|
25 |
+
|
26 |
+
class MolQADataset(Dataset):
|
27 |
+
def __init__(self, data, tokenizer, max_len):
|
28 |
+
self.data = data
|
29 |
+
self.tokenizer = tokenizer
|
30 |
+
self.max_len = max_len
|
31 |
+
|
32 |
+
def __len__(self):
|
33 |
+
return len(self.data)
|
34 |
+
|
35 |
+
def __getitem__(self, idx):
|
36 |
+
mol_properties = [
|
37 |
+
"BBBP",
|
38 |
+
"HIV",
|
39 |
+
"BACE",
|
40 |
+
"CO2",
|
41 |
+
"N2",
|
42 |
+
"O2",
|
43 |
+
"FFV",
|
44 |
+
"TC",
|
45 |
+
"SC",
|
46 |
+
"SA",
|
47 |
+
]
|
48 |
+
item = self.data[idx]
|
49 |
+
instruction = item["instruction"]
|
50 |
+
input_text = item["input"]
|
51 |
+
property_data = dict_to_list(item["property"], mol_properties)
|
52 |
+
property_data = torch.tensor(property_data)
|
53 |
+
|
54 |
+
# Combine instruction and input
|
55 |
+
combined_input = f"{instruction}\n{input_text}"
|
56 |
+
|
57 |
+
# Create messages for chat template
|
58 |
+
messages = [
|
59 |
+
{"role": "user", "content": combined_input}
|
60 |
+
]
|
61 |
+
|
62 |
+
# Apply chat template
|
63 |
+
chat_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
|
64 |
+
|
65 |
+
# Tokenize the chat text
|
66 |
+
encoding = self.tokenizer(
|
67 |
+
chat_text,
|
68 |
+
return_tensors="pt",
|
69 |
+
padding="max_length",
|
70 |
+
truncation=True,
|
71 |
+
max_length=self.max_len,
|
72 |
+
)
|
73 |
+
|
74 |
+
return {
|
75 |
+
"input_ids": encoding.input_ids.squeeze(),
|
76 |
+
"attention_mask": encoding.attention_mask.squeeze(),
|
77 |
+
"property": property_data,
|
78 |
+
}
|
src/eval/workflow.py
ADDED
@@ -0,0 +1,219 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Llamole Team
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from typing import TYPE_CHECKING, List, Optional, Dict, Any
|
16 |
+
|
17 |
+
from ..data import get_dataset, DataCollatorForSeqGraph, get_template_and_fix_tokenizer
|
18 |
+
from ..extras.constants import IGNORE_INDEX, NO_LABEL_INDEX
|
19 |
+
from ..extras.misc import get_logits_processor
|
20 |
+
from ..extras.ploting import plot_loss
|
21 |
+
from ..model import load_tokenizer
|
22 |
+
from ..hparams import get_infer_args, get_train_args
|
23 |
+
from ..model import GraphLLMForCausalMLM
|
24 |
+
from .dataset import MolQADataset
|
25 |
+
|
26 |
+
import re
|
27 |
+
import os
|
28 |
+
import json
|
29 |
+
import math
|
30 |
+
import torch
|
31 |
+
from torch.utils.data import DataLoader
|
32 |
+
|
33 |
+
if TYPE_CHECKING:
|
34 |
+
from transformers import Seq2SeqTrainingArguments
|
35 |
+
|
36 |
+
from ..hparams import (
|
37 |
+
DataArguments,
|
38 |
+
FinetuningArguments,
|
39 |
+
GeneratingArguments,
|
40 |
+
ModelArguments,
|
41 |
+
)
|
42 |
+
|
43 |
+
def remove_extra_spaces(text):
|
44 |
+
# Replace multiple spaces with a single space
|
45 |
+
cleaned_text = re.sub(r'\s+', ' ', text)
|
46 |
+
# Strip leading and trailing spaces
|
47 |
+
return cleaned_text.strip()
|
48 |
+
|
49 |
+
def run_eval(args: Optional[Dict[str, Any]] = None) -> None:
|
50 |
+
print(args)
|
51 |
+
raise ValueError('stop')
|
52 |
+
model_args, data_args, training_args, finetuning_args, generating_args = (
|
53 |
+
get_train_args(args)
|
54 |
+
)
|
55 |
+
|
56 |
+
if data_args.dataset in ["molqa", "molqa_drug", "molqa_material"]:
|
57 |
+
run_molqa(
|
58 |
+
model_args, data_args, training_args, finetuning_args, generating_args
|
59 |
+
)
|
60 |
+
else:
|
61 |
+
raise ValueError("Unknown dataset: {}.".format(data_args.dataset))
|
62 |
+
|
63 |
+
|
64 |
+
def run_molqa(
|
65 |
+
model_args: "ModelArguments",
|
66 |
+
data_args: "DataArguments",
|
67 |
+
training_args: "Seq2SeqTrainingArguments",
|
68 |
+
finetuning_args: "FinetuningArguments",
|
69 |
+
generating_args: "GeneratingArguments",
|
70 |
+
):
|
71 |
+
tokenizer = load_tokenizer(model_args, generate_mode=True)["tokenizer"]
|
72 |
+
|
73 |
+
data_info_path = os.path.join(data_args.dataset_dir, "dataset_info.json")
|
74 |
+
with open(data_info_path, "r") as f:
|
75 |
+
dataset_info = json.load(f)
|
76 |
+
|
77 |
+
tokenizer.pad_token = tokenizer.eos_token
|
78 |
+
dataset_name = data_args.dataset.strip()
|
79 |
+
try:
|
80 |
+
filename = dataset_info[dataset_name]["file_name"]
|
81 |
+
except KeyError:
|
82 |
+
raise ValueError(f"Dataset {dataset_name} not found in dataset_info.json")
|
83 |
+
data_path = os.path.join(data_args.dataset_dir, f"{filename}")
|
84 |
+
with open(data_path, "r") as f:
|
85 |
+
original_data = json.load(f)
|
86 |
+
|
87 |
+
# Create dataset and dataloader
|
88 |
+
dataset = MolQADataset(original_data, tokenizer, data_args.cutoff_len)
|
89 |
+
dataloader = DataLoader(
|
90 |
+
dataset, batch_size=training_args.per_device_eval_batch_size, shuffle=False
|
91 |
+
)
|
92 |
+
|
93 |
+
gen_kwargs = generating_args.to_dict()
|
94 |
+
gen_kwargs["eos_token_id"] = [
|
95 |
+
tokenizer.eos_token_id
|
96 |
+
] + tokenizer.additional_special_tokens_ids
|
97 |
+
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
|
98 |
+
gen_kwargs["logits_processor"] = get_logits_processor()
|
99 |
+
|
100 |
+
model = GraphLLMForCausalMLM.from_pretrained(
|
101 |
+
tokenizer, model_args, data_args, training_args, finetuning_args, load_adapter=True
|
102 |
+
)
|
103 |
+
|
104 |
+
all_results = []
|
105 |
+
property_names = ["BBBP", "HIV", "BACE", "CO2", "N2", "O2", "FFV", "TC", "SC", "SA"]
|
106 |
+
|
107 |
+
# Phase 1: Molecular Design
|
108 |
+
global_idx = 0
|
109 |
+
all_smiles = []
|
110 |
+
for batch_idx, batch in enumerate(dataloader):
|
111 |
+
input_ids = batch["input_ids"].to(model.device)
|
112 |
+
attention_mask = batch["attention_mask"].to(model.device)
|
113 |
+
property_data = batch["property"].to(model.device)
|
114 |
+
model.eval()
|
115 |
+
with torch.no_grad():
|
116 |
+
all_info_dict = model.generate(
|
117 |
+
input_ids=input_ids,
|
118 |
+
attention_mask=attention_mask,
|
119 |
+
molecule_properties=property_data,
|
120 |
+
do_molecular_design=True,
|
121 |
+
do_retrosynthesis=False,
|
122 |
+
rollback=True,
|
123 |
+
**gen_kwargs,
|
124 |
+
)
|
125 |
+
|
126 |
+
batch_results = []
|
127 |
+
for i in range(len(all_info_dict["smiles_list"])):
|
128 |
+
original_data_idx = global_idx + i
|
129 |
+
original_item = original_data[original_data_idx]
|
130 |
+
|
131 |
+
llm_response = "".join(item for item in all_info_dict["text_lists"][i])
|
132 |
+
result = {
|
133 |
+
"qa_idx": original_data_idx,
|
134 |
+
"instruction": original_item["instruction"],
|
135 |
+
"input": original_item["input"],
|
136 |
+
"llm_response": llm_response,
|
137 |
+
"response_design": remove_extra_spaces(llm_response),
|
138 |
+
"llm_smiles": all_info_dict["smiles_list"][i],
|
139 |
+
"property": {},
|
140 |
+
}
|
141 |
+
|
142 |
+
# Add non-NaN property values
|
143 |
+
for j, prop_name in enumerate(property_names):
|
144 |
+
prop_value = property_data[i][j].item()
|
145 |
+
if not math.isnan(prop_value):
|
146 |
+
result["property"][prop_name] = prop_value
|
147 |
+
|
148 |
+
batch_results.append(result)
|
149 |
+
|
150 |
+
all_results.extend(batch_results)
|
151 |
+
all_smiles.extend([result['llm_smiles'] for result in batch_results])
|
152 |
+
global_idx += len(batch_results)
|
153 |
+
|
154 |
+
# Phase 2: Retrosynthesis
|
155 |
+
retro_batch_start = 0
|
156 |
+
for batch_idx, batch in enumerate(dataloader):
|
157 |
+
|
158 |
+
input_ids = batch["input_ids"].to(model.device)
|
159 |
+
attention_mask = batch["attention_mask"].to(model.device)
|
160 |
+
batch_size = input_ids.shape[0]
|
161 |
+
batch_smiles = all_smiles[retro_batch_start : retro_batch_start + batch_size]
|
162 |
+
|
163 |
+
model.eval()
|
164 |
+
with torch.no_grad():
|
165 |
+
all_info_dict = model.generate(
|
166 |
+
input_ids=input_ids,
|
167 |
+
attention_mask=attention_mask,
|
168 |
+
do_molecular_design=False,
|
169 |
+
do_retrosynthesis=True,
|
170 |
+
input_smiles_list=batch_smiles,
|
171 |
+
expansion_topk=50,
|
172 |
+
iterations=100,
|
173 |
+
max_planning_time=30,
|
174 |
+
**gen_kwargs,
|
175 |
+
)
|
176 |
+
|
177 |
+
batch_results = []
|
178 |
+
for i in range(batch_size):
|
179 |
+
result = all_results[retro_batch_start + i]
|
180 |
+
retro_plan = all_info_dict["retro_plan_dict"][result["llm_smiles"]]
|
181 |
+
result["llm_reactions"] = []
|
182 |
+
if retro_plan["success"]:
|
183 |
+
for reaction, template, cost in zip(
|
184 |
+
retro_plan["reaction_list"],
|
185 |
+
retro_plan["templates"],
|
186 |
+
retro_plan["cost"],
|
187 |
+
):
|
188 |
+
result["llm_reactions"].append(
|
189 |
+
{"reaction": reaction, "template": template, "cost": cost}
|
190 |
+
)
|
191 |
+
|
192 |
+
# new_text = "".join(item for item in all_info_dict["text_lists"][i])
|
193 |
+
if None in all_info_dict["text_lists"][i]:
|
194 |
+
print(f"List contains None: {all_info_dict['text_lists'][i]}")
|
195 |
+
new_text = "".join(item for item in all_info_dict["text_lists"][i] if item is not None)
|
196 |
+
else:
|
197 |
+
new_text = "".join(item for item in all_info_dict["text_lists"][i])
|
198 |
+
|
199 |
+
result["llm_response"] += new_text
|
200 |
+
result["llm_response"] = remove_extra_spaces(result["llm_response"])
|
201 |
+
result["response_retro"] = remove_extra_spaces(new_text)
|
202 |
+
batch_results.append(result)
|
203 |
+
|
204 |
+
retro_batch_start += batch_size
|
205 |
+
|
206 |
+
print('all_results', all_results)
|
207 |
+
print("\nSummary of results:")
|
208 |
+
print_len = min(5, len(all_results))
|
209 |
+
for result in all_results[:print_len]:
|
210 |
+
print(f"\nData point {result['qa_idx']}:")
|
211 |
+
print(f" Instruction: {result['instruction']}")
|
212 |
+
print(f" Input: {result['input']}")
|
213 |
+
print(f" LLM Response: {result['llm_response']}")
|
214 |
+
print(f" LLM SMILES: {result['llm_smiles']}")
|
215 |
+
print(f" Number of reactions: {len(result['llm_reactions'])}")
|
216 |
+
for prop_name, prop_value in result["property"].items():
|
217 |
+
print(f" {prop_name}: {prop_value}")
|
218 |
+
|
219 |
+
print("\nAll data processed successfully.")
|
src/extras/__init__.py
ADDED
File without changes
|
src/extras/constants.py
ADDED
@@ -0,0 +1,1550 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team and the Llamole team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from collections import OrderedDict, defaultdict
|
16 |
+
from enum import Enum
|
17 |
+
from typing import Dict, Optional
|
18 |
+
|
19 |
+
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
20 |
+
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
21 |
+
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME
|
22 |
+
|
23 |
+
from rdkit.Chem.rdchem import BondType as BT
|
24 |
+
BOND_INDEX = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4}
|
25 |
+
NO_LABEL_INDEX = -200
|
26 |
+
|
27 |
+
CHECKPOINT_NAMES = {
|
28 |
+
SAFE_ADAPTER_WEIGHTS_NAME,
|
29 |
+
ADAPTER_WEIGHTS_NAME,
|
30 |
+
SAFE_WEIGHTS_INDEX_NAME,
|
31 |
+
SAFE_WEIGHTS_NAME,
|
32 |
+
WEIGHTS_INDEX_NAME,
|
33 |
+
WEIGHTS_NAME,
|
34 |
+
}
|
35 |
+
|
36 |
+
CHOICES = ["A", "B", "C", "D"]
|
37 |
+
|
38 |
+
DATA_CONFIG = "dataset_info.json"
|
39 |
+
|
40 |
+
DEFAULT_TEMPLATE = defaultdict(str)
|
41 |
+
|
42 |
+
FILEEXT2TYPE = {
|
43 |
+
"arrow": "arrow",
|
44 |
+
"csv": "csv",
|
45 |
+
"json": "json",
|
46 |
+
"jsonl": "json",
|
47 |
+
"parquet": "parquet",
|
48 |
+
"txt": "text",
|
49 |
+
}
|
50 |
+
|
51 |
+
IGNORE_INDEX = -100
|
52 |
+
|
53 |
+
LAYERNORM_NAMES = {"norm", "ln"}
|
54 |
+
|
55 |
+
LLAMABOARD_CONFIG = "llamaboard_config.yaml"
|
56 |
+
|
57 |
+
METHODS = ["full", "freeze", "lora"]
|
58 |
+
|
59 |
+
MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"}
|
60 |
+
|
61 |
+
PEFT_METHODS = {"lora"}
|
62 |
+
|
63 |
+
RUNNING_LOG = "running_log.txt"
|
64 |
+
|
65 |
+
SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"]
|
66 |
+
|
67 |
+
SUPPORTED_MODELS = OrderedDict()
|
68 |
+
|
69 |
+
TRAINER_LOG = "trainer_log.jsonl"
|
70 |
+
|
71 |
+
TRAINING_ARGS = "training_args.yaml"
|
72 |
+
|
73 |
+
TRAINING_STAGES = {
|
74 |
+
"Supervised Fine-Tuning": "sft",
|
75 |
+
"Reward Modeling": "rm",
|
76 |
+
"PPO": "ppo",
|
77 |
+
"DPO": "dpo",
|
78 |
+
"KTO": "kto",
|
79 |
+
"Pre-Training": "pt",
|
80 |
+
}
|
81 |
+
|
82 |
+
STAGES_USE_PAIR_DATA = {"rm", "dpo"}
|
83 |
+
|
84 |
+
SUPPORTED_CLASS_FOR_S2ATTN = {"llama"}
|
85 |
+
|
86 |
+
V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
87 |
+
|
88 |
+
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
89 |
+
|
90 |
+
VISION_MODELS = set()
|
91 |
+
|
92 |
+
|
93 |
+
class DownloadSource(str, Enum):
|
94 |
+
DEFAULT = "hf"
|
95 |
+
MODELSCOPE = "ms"
|
96 |
+
|
97 |
+
|
98 |
+
def register_model_group(
|
99 |
+
models: Dict[str, Dict[DownloadSource, str]],
|
100 |
+
template: Optional[str] = None,
|
101 |
+
vision: bool = False,
|
102 |
+
) -> None:
|
103 |
+
prefix = None
|
104 |
+
for name, path in models.items():
|
105 |
+
if prefix is None:
|
106 |
+
prefix = name.split("-")[0]
|
107 |
+
else:
|
108 |
+
assert prefix == name.split("-")[0], "prefix should be identical."
|
109 |
+
SUPPORTED_MODELS[name] = path
|
110 |
+
if template is not None:
|
111 |
+
DEFAULT_TEMPLATE[prefix] = template
|
112 |
+
if vision:
|
113 |
+
VISION_MODELS.add(prefix)
|
114 |
+
|
115 |
+
|
116 |
+
register_model_group(
|
117 |
+
models={
|
118 |
+
"Aya-23-8B-Chat": {
|
119 |
+
DownloadSource.DEFAULT: "CohereForAI/aya-23-8B",
|
120 |
+
},
|
121 |
+
"Aya-23-35B-Chat": {
|
122 |
+
DownloadSource.DEFAULT: "CohereForAI/aya-23-35B",
|
123 |
+
},
|
124 |
+
},
|
125 |
+
template="cohere",
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
register_model_group(
|
130 |
+
models={
|
131 |
+
"Baichuan-7B-Base": {
|
132 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B",
|
133 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B",
|
134 |
+
},
|
135 |
+
"Baichuan-13B-Base": {
|
136 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base",
|
137 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base",
|
138 |
+
},
|
139 |
+
"Baichuan-13B-Chat": {
|
140 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat",
|
141 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat",
|
142 |
+
},
|
143 |
+
},
|
144 |
+
template="baichuan",
|
145 |
+
)
|
146 |
+
|
147 |
+
|
148 |
+
register_model_group(
|
149 |
+
models={
|
150 |
+
"Baichuan2-7B-Base": {
|
151 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base",
|
152 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base",
|
153 |
+
},
|
154 |
+
"Baichuan2-13B-Base": {
|
155 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base",
|
156 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base",
|
157 |
+
},
|
158 |
+
"Baichuan2-7B-Chat": {
|
159 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat",
|
160 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat",
|
161 |
+
},
|
162 |
+
"Baichuan2-13B-Chat": {
|
163 |
+
DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat",
|
164 |
+
DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat",
|
165 |
+
},
|
166 |
+
},
|
167 |
+
template="baichuan2",
|
168 |
+
)
|
169 |
+
|
170 |
+
|
171 |
+
register_model_group(
|
172 |
+
models={
|
173 |
+
"BLOOM-560M": {
|
174 |
+
DownloadSource.DEFAULT: "bigscience/bloom-560m",
|
175 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m",
|
176 |
+
},
|
177 |
+
"BLOOM-3B": {
|
178 |
+
DownloadSource.DEFAULT: "bigscience/bloom-3b",
|
179 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b",
|
180 |
+
},
|
181 |
+
"BLOOM-7B1": {
|
182 |
+
DownloadSource.DEFAULT: "bigscience/bloom-7b1",
|
183 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1",
|
184 |
+
},
|
185 |
+
},
|
186 |
+
)
|
187 |
+
|
188 |
+
|
189 |
+
register_model_group(
|
190 |
+
models={
|
191 |
+
"BLOOMZ-560M": {
|
192 |
+
DownloadSource.DEFAULT: "bigscience/bloomz-560m",
|
193 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m",
|
194 |
+
},
|
195 |
+
"BLOOMZ-3B": {
|
196 |
+
DownloadSource.DEFAULT: "bigscience/bloomz-3b",
|
197 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b",
|
198 |
+
},
|
199 |
+
"BLOOMZ-7B1-mt": {
|
200 |
+
DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt",
|
201 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt",
|
202 |
+
},
|
203 |
+
},
|
204 |
+
)
|
205 |
+
|
206 |
+
|
207 |
+
register_model_group(
|
208 |
+
models={
|
209 |
+
"BlueLM-7B-Base": {
|
210 |
+
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base",
|
211 |
+
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base",
|
212 |
+
},
|
213 |
+
"BlueLM-7B-Chat": {
|
214 |
+
DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat",
|
215 |
+
DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat",
|
216 |
+
},
|
217 |
+
},
|
218 |
+
template="bluelm",
|
219 |
+
)
|
220 |
+
|
221 |
+
|
222 |
+
register_model_group(
|
223 |
+
models={
|
224 |
+
"Breeze-7B": {
|
225 |
+
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0",
|
226 |
+
},
|
227 |
+
"Breeze-7B-Chat": {
|
228 |
+
DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0",
|
229 |
+
},
|
230 |
+
},
|
231 |
+
template="breeze",
|
232 |
+
)
|
233 |
+
|
234 |
+
|
235 |
+
register_model_group(
|
236 |
+
models={
|
237 |
+
"ChatGLM2-6B-Chat": {
|
238 |
+
DownloadSource.DEFAULT: "THUDM/chatglm2-6b",
|
239 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b",
|
240 |
+
}
|
241 |
+
},
|
242 |
+
template="chatglm2",
|
243 |
+
)
|
244 |
+
|
245 |
+
|
246 |
+
register_model_group(
|
247 |
+
models={
|
248 |
+
"ChatGLM3-6B-Base": {
|
249 |
+
DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base",
|
250 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base",
|
251 |
+
},
|
252 |
+
"ChatGLM3-6B-Chat": {
|
253 |
+
DownloadSource.DEFAULT: "THUDM/chatglm3-6b",
|
254 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b",
|
255 |
+
},
|
256 |
+
},
|
257 |
+
template="chatglm3",
|
258 |
+
)
|
259 |
+
|
260 |
+
|
261 |
+
register_model_group(
|
262 |
+
models={
|
263 |
+
"ChineseLLaMA2-1.3B": {
|
264 |
+
DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b",
|
265 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b",
|
266 |
+
},
|
267 |
+
"ChineseLLaMA2-7B": {
|
268 |
+
DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b",
|
269 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b",
|
270 |
+
},
|
271 |
+
"ChineseLLaMA2-13B": {
|
272 |
+
DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b",
|
273 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b",
|
274 |
+
},
|
275 |
+
"ChineseLLaMA2-1.3B-Chat": {
|
276 |
+
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b",
|
277 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b",
|
278 |
+
},
|
279 |
+
"ChineseLLaMA2-7B-Chat": {
|
280 |
+
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b",
|
281 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b",
|
282 |
+
},
|
283 |
+
"ChineseLLaMA2-13B-Chat": {
|
284 |
+
DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b",
|
285 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b",
|
286 |
+
},
|
287 |
+
},
|
288 |
+
template="llama2_zh",
|
289 |
+
)
|
290 |
+
|
291 |
+
|
292 |
+
register_model_group(
|
293 |
+
models={
|
294 |
+
"CodeGemma-7B": {
|
295 |
+
DownloadSource.DEFAULT: "google/codegemma-7b",
|
296 |
+
},
|
297 |
+
"CodeGemma-7B-Chat": {
|
298 |
+
DownloadSource.DEFAULT: "google/codegemma-7b-it",
|
299 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it",
|
300 |
+
},
|
301 |
+
"CodeGemma-1.1-2B": {
|
302 |
+
DownloadSource.DEFAULT: "google/codegemma-1.1-2b",
|
303 |
+
},
|
304 |
+
"CodeGemma-1.1-7B-Chat": {
|
305 |
+
DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it",
|
306 |
+
},
|
307 |
+
},
|
308 |
+
template="gemma",
|
309 |
+
)
|
310 |
+
|
311 |
+
|
312 |
+
register_model_group(
|
313 |
+
models={
|
314 |
+
"Codestral-22B-v0.1-Chat": {
|
315 |
+
DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1",
|
316 |
+
},
|
317 |
+
},
|
318 |
+
template="mistral",
|
319 |
+
)
|
320 |
+
|
321 |
+
|
322 |
+
register_model_group(
|
323 |
+
models={
|
324 |
+
"CommandR-35B-Chat": {
|
325 |
+
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01",
|
326 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01",
|
327 |
+
},
|
328 |
+
"CommandR-Plus-104B-Chat": {
|
329 |
+
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus",
|
330 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus",
|
331 |
+
},
|
332 |
+
"CommandR-35B-4bit-Chat": {
|
333 |
+
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit",
|
334 |
+
DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit",
|
335 |
+
},
|
336 |
+
"CommandR-Plus-104B-4bit-Chat": {
|
337 |
+
DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit",
|
338 |
+
},
|
339 |
+
},
|
340 |
+
template="cohere",
|
341 |
+
)
|
342 |
+
|
343 |
+
|
344 |
+
register_model_group(
|
345 |
+
models={
|
346 |
+
"DBRX-132B-Base": {
|
347 |
+
DownloadSource.DEFAULT: "databricks/dbrx-base",
|
348 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base",
|
349 |
+
},
|
350 |
+
"DBRX-132B-Chat": {
|
351 |
+
DownloadSource.DEFAULT: "databricks/dbrx-instruct",
|
352 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct",
|
353 |
+
},
|
354 |
+
},
|
355 |
+
template="dbrx",
|
356 |
+
)
|
357 |
+
|
358 |
+
|
359 |
+
register_model_group(
|
360 |
+
models={
|
361 |
+
"DeepSeek-LLM-7B-Base": {
|
362 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base",
|
363 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base",
|
364 |
+
},
|
365 |
+
"DeepSeek-LLM-67B-Base": {
|
366 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base",
|
367 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base",
|
368 |
+
},
|
369 |
+
"DeepSeek-LLM-7B-Chat": {
|
370 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat",
|
371 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat",
|
372 |
+
},
|
373 |
+
"DeepSeek-LLM-67B-Chat": {
|
374 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat",
|
375 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat",
|
376 |
+
},
|
377 |
+
"DeepSeek-Math-7B-Base": {
|
378 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base",
|
379 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base",
|
380 |
+
},
|
381 |
+
"DeepSeek-Math-7B-Chat": {
|
382 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct",
|
383 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct",
|
384 |
+
},
|
385 |
+
"DeepSeek-MoE-16B-Base": {
|
386 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base",
|
387 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base",
|
388 |
+
},
|
389 |
+
"DeepSeek-MoE-16B-v2-Base": {
|
390 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite",
|
391 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite",
|
392 |
+
},
|
393 |
+
"DeepSeek-MoE-236B-Base": {
|
394 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2",
|
395 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2",
|
396 |
+
},
|
397 |
+
"DeepSeek-MoE-16B-Chat": {
|
398 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat",
|
399 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat",
|
400 |
+
},
|
401 |
+
"DeepSeek-MoE-16B-v2-Chat": {
|
402 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
403 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat",
|
404 |
+
},
|
405 |
+
"DeepSeek-MoE-236B-Chat": {
|
406 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat",
|
407 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat",
|
408 |
+
},
|
409 |
+
"DeepSeek-MoE-Coder-16B-Base": {
|
410 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base",
|
411 |
+
},
|
412 |
+
"DeepSeek-MoE-Coder-236B-Base": {
|
413 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base",
|
414 |
+
},
|
415 |
+
"DeepSeek-MoE-Coder-16B-Chat": {
|
416 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct",
|
417 |
+
},
|
418 |
+
"DeepSeek-MoE-Coder-236B-Chat": {
|
419 |
+
DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct",
|
420 |
+
},
|
421 |
+
},
|
422 |
+
template="deepseek",
|
423 |
+
)
|
424 |
+
|
425 |
+
|
426 |
+
register_model_group(
|
427 |
+
models={
|
428 |
+
"DeepSeekCoder-6.7B-Base": {
|
429 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base",
|
430 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base",
|
431 |
+
},
|
432 |
+
"DeepSeekCoder-7B-Base": {
|
433 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5",
|
434 |
+
},
|
435 |
+
"DeepSeekCoder-33B-Base": {
|
436 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base",
|
437 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base",
|
438 |
+
},
|
439 |
+
"DeepSeekCoder-6.7B-Chat": {
|
440 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
441 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct",
|
442 |
+
},
|
443 |
+
"DeepSeekCoder-7B-Chat": {
|
444 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5",
|
445 |
+
},
|
446 |
+
"DeepSeekCoder-33B-Chat": {
|
447 |
+
DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct",
|
448 |
+
DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct",
|
449 |
+
},
|
450 |
+
},
|
451 |
+
template="deepseekcoder",
|
452 |
+
)
|
453 |
+
|
454 |
+
|
455 |
+
register_model_group(
|
456 |
+
models={
|
457 |
+
"Falcon-7B": {
|
458 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-7b",
|
459 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b",
|
460 |
+
},
|
461 |
+
"Falcon-11B": {
|
462 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-11B",
|
463 |
+
},
|
464 |
+
"Falcon-40B": {
|
465 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-40b",
|
466 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b",
|
467 |
+
},
|
468 |
+
"Falcon-180B": {
|
469 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-180b",
|
470 |
+
DownloadSource.MODELSCOPE: "modelscope/falcon-180B",
|
471 |
+
},
|
472 |
+
"Falcon-7B-Chat": {
|
473 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct",
|
474 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct",
|
475 |
+
},
|
476 |
+
"Falcon-40B-Chat": {
|
477 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct",
|
478 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct",
|
479 |
+
},
|
480 |
+
"Falcon-180B-Chat": {
|
481 |
+
DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat",
|
482 |
+
DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat",
|
483 |
+
},
|
484 |
+
},
|
485 |
+
template="falcon",
|
486 |
+
)
|
487 |
+
|
488 |
+
|
489 |
+
register_model_group(
|
490 |
+
models={
|
491 |
+
"Gemma-2B": {
|
492 |
+
DownloadSource.DEFAULT: "google/gemma-2b",
|
493 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b",
|
494 |
+
},
|
495 |
+
"Gemma-7B": {
|
496 |
+
DownloadSource.DEFAULT: "google/gemma-7b",
|
497 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it",
|
498 |
+
},
|
499 |
+
"Gemma-2B-Chat": {
|
500 |
+
DownloadSource.DEFAULT: "google/gemma-2b-it",
|
501 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b",
|
502 |
+
},
|
503 |
+
"Gemma-7B-Chat": {
|
504 |
+
DownloadSource.DEFAULT: "google/gemma-7b-it",
|
505 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it",
|
506 |
+
},
|
507 |
+
"Gemma-1.1-2B-Chat": {
|
508 |
+
DownloadSource.DEFAULT: "google/gemma-1.1-2b-it",
|
509 |
+
},
|
510 |
+
"Gemma-1.1-7B-Chat": {
|
511 |
+
DownloadSource.DEFAULT: "google/gemma-1.1-7b-it",
|
512 |
+
},
|
513 |
+
"Gemma-2-9B": {
|
514 |
+
DownloadSource.DEFAULT: "google/gemma-2-9b",
|
515 |
+
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b",
|
516 |
+
},
|
517 |
+
"Gemma-2-27B": {
|
518 |
+
DownloadSource.DEFAULT: "google/gemma-2-27b",
|
519 |
+
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b",
|
520 |
+
},
|
521 |
+
"Gemma-2-9B-Chat": {
|
522 |
+
DownloadSource.DEFAULT: "google/gemma-2-9b-it",
|
523 |
+
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it",
|
524 |
+
},
|
525 |
+
"Gemma-2-27B-Chat": {
|
526 |
+
DownloadSource.DEFAULT: "google/gemma-2-27b-it",
|
527 |
+
DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it",
|
528 |
+
},
|
529 |
+
},
|
530 |
+
template="gemma",
|
531 |
+
)
|
532 |
+
|
533 |
+
|
534 |
+
register_model_group(
|
535 |
+
models={
|
536 |
+
"GLM-4-9B": {
|
537 |
+
DownloadSource.DEFAULT: "THUDM/glm-4-9b",
|
538 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b",
|
539 |
+
},
|
540 |
+
"GLM-4-9B-Chat": {
|
541 |
+
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat",
|
542 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat",
|
543 |
+
},
|
544 |
+
"GLM-4-9B-1M-Chat": {
|
545 |
+
DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m",
|
546 |
+
DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m",
|
547 |
+
},
|
548 |
+
},
|
549 |
+
template="glm4",
|
550 |
+
)
|
551 |
+
|
552 |
+
|
553 |
+
register_model_group(
|
554 |
+
models={
|
555 |
+
"InternLM-7B": {
|
556 |
+
DownloadSource.DEFAULT: "internlm/internlm-7b",
|
557 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b",
|
558 |
+
},
|
559 |
+
"InternLM-20B": {
|
560 |
+
DownloadSource.DEFAULT: "internlm/internlm-20b",
|
561 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b",
|
562 |
+
},
|
563 |
+
"InternLM-7B-Chat": {
|
564 |
+
DownloadSource.DEFAULT: "internlm/internlm-chat-7b",
|
565 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b",
|
566 |
+
},
|
567 |
+
"InternLM-20B-Chat": {
|
568 |
+
DownloadSource.DEFAULT: "internlm/internlm-chat-20b",
|
569 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b",
|
570 |
+
},
|
571 |
+
},
|
572 |
+
template="intern",
|
573 |
+
)
|
574 |
+
|
575 |
+
|
576 |
+
register_model_group(
|
577 |
+
models={
|
578 |
+
"InternLM2-7B": {
|
579 |
+
DownloadSource.DEFAULT: "internlm/internlm2-7b",
|
580 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b",
|
581 |
+
},
|
582 |
+
"InternLM2-20B": {
|
583 |
+
DownloadSource.DEFAULT: "internlm/internlm2-20b",
|
584 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b",
|
585 |
+
},
|
586 |
+
"InternLM2-7B-Chat": {
|
587 |
+
DownloadSource.DEFAULT: "internlm/internlm2-chat-7b",
|
588 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b",
|
589 |
+
},
|
590 |
+
"InternLM2-20B-Chat": {
|
591 |
+
DownloadSource.DEFAULT: "internlm/internlm2-chat-20b",
|
592 |
+
DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b",
|
593 |
+
},
|
594 |
+
},
|
595 |
+
template="intern2",
|
596 |
+
)
|
597 |
+
|
598 |
+
|
599 |
+
register_model_group(
|
600 |
+
models={
|
601 |
+
"Jamba-v0.1": {
|
602 |
+
DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1",
|
603 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1",
|
604 |
+
}
|
605 |
+
},
|
606 |
+
)
|
607 |
+
|
608 |
+
|
609 |
+
register_model_group(
|
610 |
+
models={
|
611 |
+
"LingoWhale-8B": {
|
612 |
+
DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B",
|
613 |
+
DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B",
|
614 |
+
}
|
615 |
+
},
|
616 |
+
)
|
617 |
+
|
618 |
+
|
619 |
+
register_model_group(
|
620 |
+
models={
|
621 |
+
"LLaMA-7B": {
|
622 |
+
DownloadSource.DEFAULT: "huggyllama/llama-7b",
|
623 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-7b",
|
624 |
+
},
|
625 |
+
"LLaMA-13B": {
|
626 |
+
DownloadSource.DEFAULT: "huggyllama/llama-13b",
|
627 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-13b",
|
628 |
+
},
|
629 |
+
"LLaMA-30B": {
|
630 |
+
DownloadSource.DEFAULT: "huggyllama/llama-30b",
|
631 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-30b",
|
632 |
+
},
|
633 |
+
"LLaMA-65B": {
|
634 |
+
DownloadSource.DEFAULT: "huggyllama/llama-65b",
|
635 |
+
DownloadSource.MODELSCOPE: "skyline2006/llama-65b",
|
636 |
+
},
|
637 |
+
}
|
638 |
+
)
|
639 |
+
|
640 |
+
|
641 |
+
register_model_group(
|
642 |
+
models={
|
643 |
+
"LLaMA2-7B": {
|
644 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf",
|
645 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms",
|
646 |
+
},
|
647 |
+
"LLaMA2-13B": {
|
648 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf",
|
649 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms",
|
650 |
+
},
|
651 |
+
"LLaMA2-70B": {
|
652 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf",
|
653 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms",
|
654 |
+
},
|
655 |
+
"LLaMA2-7B-Chat": {
|
656 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf",
|
657 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms",
|
658 |
+
},
|
659 |
+
"LLaMA2-13B-Chat": {
|
660 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf",
|
661 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms",
|
662 |
+
},
|
663 |
+
"LLaMA2-70B-Chat": {
|
664 |
+
DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf",
|
665 |
+
DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms",
|
666 |
+
},
|
667 |
+
},
|
668 |
+
template="llama2",
|
669 |
+
)
|
670 |
+
|
671 |
+
|
672 |
+
register_model_group(
|
673 |
+
models={
|
674 |
+
"LLaMA3-8B": {
|
675 |
+
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B",
|
676 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B",
|
677 |
+
},
|
678 |
+
"LLaMA3-70B": {
|
679 |
+
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B",
|
680 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B",
|
681 |
+
},
|
682 |
+
"LLaMA3-8B-Chat": {
|
683 |
+
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct",
|
684 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct",
|
685 |
+
},
|
686 |
+
"LLaMA3-70B-Chat": {
|
687 |
+
DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct",
|
688 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct",
|
689 |
+
},
|
690 |
+
"LLaMA3-8B-Chinese-Chat": {
|
691 |
+
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat",
|
692 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat",
|
693 |
+
},
|
694 |
+
"LLaMA3-70B-Chinese-Chat": {
|
695 |
+
DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat",
|
696 |
+
},
|
697 |
+
},
|
698 |
+
template="llama3",
|
699 |
+
)
|
700 |
+
|
701 |
+
|
702 |
+
register_model_group(
|
703 |
+
models={
|
704 |
+
"LLaVA1.5-7B-Chat": {
|
705 |
+
DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf",
|
706 |
+
},
|
707 |
+
"LLaVA1.5-13B-Chat": {
|
708 |
+
DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf",
|
709 |
+
},
|
710 |
+
},
|
711 |
+
template="vicuna",
|
712 |
+
vision=True,
|
713 |
+
)
|
714 |
+
|
715 |
+
|
716 |
+
register_model_group(
|
717 |
+
models={
|
718 |
+
"MiniCPM-2B-SFT-Chat": {
|
719 |
+
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16",
|
720 |
+
DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16",
|
721 |
+
},
|
722 |
+
"MiniCPM-2B-DPO-Chat": {
|
723 |
+
DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16",
|
724 |
+
DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16",
|
725 |
+
},
|
726 |
+
},
|
727 |
+
template="cpm",
|
728 |
+
)
|
729 |
+
|
730 |
+
|
731 |
+
register_model_group(
|
732 |
+
models={
|
733 |
+
"Mistral-7B-v0.1": {
|
734 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1",
|
735 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1",
|
736 |
+
},
|
737 |
+
"Mistral-7B-v0.1-Chat": {
|
738 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1",
|
739 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1",
|
740 |
+
},
|
741 |
+
"Mistral-7B-v0.2": {
|
742 |
+
DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf",
|
743 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf",
|
744 |
+
},
|
745 |
+
"Mistral-7B-v0.2-Chat": {
|
746 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2",
|
747 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2",
|
748 |
+
},
|
749 |
+
"Mistral-7B-v0.3": {
|
750 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3",
|
751 |
+
},
|
752 |
+
"Mistral-7B-v0.3-Chat": {
|
753 |
+
DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3",
|
754 |
+
},
|
755 |
+
},
|
756 |
+
template="mistral",
|
757 |
+
)
|
758 |
+
|
759 |
+
|
760 |
+
register_model_group(
|
761 |
+
models={
|
762 |
+
"Mixtral-8x7B-v0.1": {
|
763 |
+
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1",
|
764 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1",
|
765 |
+
},
|
766 |
+
"Mixtral-8x7B-v0.1-Chat": {
|
767 |
+
DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1",
|
768 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1",
|
769 |
+
},
|
770 |
+
"Mixtral-8x22B-v0.1": {
|
771 |
+
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1",
|
772 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1",
|
773 |
+
},
|
774 |
+
"Mixtral-8x22B-v0.1-Chat": {
|
775 |
+
DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1",
|
776 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1",
|
777 |
+
},
|
778 |
+
},
|
779 |
+
template="mistral",
|
780 |
+
)
|
781 |
+
|
782 |
+
|
783 |
+
register_model_group(
|
784 |
+
models={
|
785 |
+
"OLMo-1B": {
|
786 |
+
DownloadSource.DEFAULT: "allenai/OLMo-1B-hf",
|
787 |
+
},
|
788 |
+
"OLMo-7B": {
|
789 |
+
DownloadSource.DEFAULT: "allenai/OLMo-7B-hf",
|
790 |
+
},
|
791 |
+
"OLMo-7B-Chat": {
|
792 |
+
DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf",
|
793 |
+
},
|
794 |
+
"OLMo-1.7-7B": {
|
795 |
+
DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf",
|
796 |
+
},
|
797 |
+
},
|
798 |
+
)
|
799 |
+
|
800 |
+
|
801 |
+
register_model_group(
|
802 |
+
models={
|
803 |
+
"OpenChat3.5-7B-Chat": {
|
804 |
+
DownloadSource.DEFAULT: "openchat/openchat-3.5-0106",
|
805 |
+
DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106",
|
806 |
+
}
|
807 |
+
},
|
808 |
+
template="openchat",
|
809 |
+
)
|
810 |
+
|
811 |
+
|
812 |
+
register_model_group(
|
813 |
+
models={
|
814 |
+
"OpenChat3.6-8B-Chat": {
|
815 |
+
DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522",
|
816 |
+
}
|
817 |
+
},
|
818 |
+
template="openchat-3.6",
|
819 |
+
)
|
820 |
+
|
821 |
+
|
822 |
+
register_model_group(
|
823 |
+
models={
|
824 |
+
"Orion-14B-Base": {
|
825 |
+
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base",
|
826 |
+
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base",
|
827 |
+
},
|
828 |
+
"Orion-14B-Chat": {
|
829 |
+
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat",
|
830 |
+
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat",
|
831 |
+
},
|
832 |
+
"Orion-14B-Long-Chat": {
|
833 |
+
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat",
|
834 |
+
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat",
|
835 |
+
},
|
836 |
+
"Orion-14B-RAG-Chat": {
|
837 |
+
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG",
|
838 |
+
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG",
|
839 |
+
},
|
840 |
+
"Orion-14B-Plugin-Chat": {
|
841 |
+
DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin",
|
842 |
+
DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin",
|
843 |
+
},
|
844 |
+
},
|
845 |
+
template="orion",
|
846 |
+
)
|
847 |
+
|
848 |
+
|
849 |
+
register_model_group(
|
850 |
+
models={
|
851 |
+
"PaliGemma-3B-pt-224": {
|
852 |
+
DownloadSource.DEFAULT: "google/paligemma-3b-pt-224",
|
853 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224",
|
854 |
+
},
|
855 |
+
"PaliGemma-3B-pt-448": {
|
856 |
+
DownloadSource.DEFAULT: "google/paligemma-3b-pt-448",
|
857 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448",
|
858 |
+
},
|
859 |
+
"PaliGemma-3B-pt-896": {
|
860 |
+
DownloadSource.DEFAULT: "google/paligemma-3b-pt-896",
|
861 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896",
|
862 |
+
},
|
863 |
+
"PaliGemma-3B-mix-224": {
|
864 |
+
DownloadSource.DEFAULT: "google/paligemma-3b-mix-224",
|
865 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224",
|
866 |
+
},
|
867 |
+
"PaliGemma-3B-mix-448": {
|
868 |
+
DownloadSource.DEFAULT: "google/paligemma-3b-mix-448",
|
869 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448",
|
870 |
+
},
|
871 |
+
},
|
872 |
+
vision=True,
|
873 |
+
)
|
874 |
+
|
875 |
+
|
876 |
+
register_model_group(
|
877 |
+
models={
|
878 |
+
"Phi-1.5-1.3B": {
|
879 |
+
DownloadSource.DEFAULT: "microsoft/phi-1_5",
|
880 |
+
DownloadSource.MODELSCOPE: "allspace/PHI_1-5",
|
881 |
+
},
|
882 |
+
"Phi-2-2.7B": {
|
883 |
+
DownloadSource.DEFAULT: "microsoft/phi-2",
|
884 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2",
|
885 |
+
},
|
886 |
+
}
|
887 |
+
)
|
888 |
+
|
889 |
+
|
890 |
+
register_model_group(
|
891 |
+
models={
|
892 |
+
"Phi3-4B-4k-Chat": {
|
893 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct",
|
894 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct",
|
895 |
+
},
|
896 |
+
"Phi3-4B-128k-Chat": {
|
897 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct",
|
898 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct",
|
899 |
+
},
|
900 |
+
"Phi3-7B-8k-Chat": {
|
901 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct",
|
902 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct",
|
903 |
+
},
|
904 |
+
"Phi3-7B-128k-Chat": {
|
905 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct",
|
906 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct",
|
907 |
+
},
|
908 |
+
"Phi3-14B-8k-Chat": {
|
909 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct",
|
910 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct",
|
911 |
+
},
|
912 |
+
"Phi3-14B-128k-Chat": {
|
913 |
+
DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct",
|
914 |
+
DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct",
|
915 |
+
},
|
916 |
+
},
|
917 |
+
template="phi",
|
918 |
+
)
|
919 |
+
|
920 |
+
|
921 |
+
register_model_group(
|
922 |
+
models={
|
923 |
+
"Qwen-1.8B": {
|
924 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B",
|
925 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B",
|
926 |
+
},
|
927 |
+
"Qwen-7B": {
|
928 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B",
|
929 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B",
|
930 |
+
},
|
931 |
+
"Qwen-14B": {
|
932 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B",
|
933 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B",
|
934 |
+
},
|
935 |
+
"Qwen-72B": {
|
936 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B",
|
937 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B",
|
938 |
+
},
|
939 |
+
"Qwen-1.8B-Chat": {
|
940 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat",
|
941 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat",
|
942 |
+
},
|
943 |
+
"Qwen-7B-Chat": {
|
944 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat",
|
945 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat",
|
946 |
+
},
|
947 |
+
"Qwen-14B-Chat": {
|
948 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat",
|
949 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat",
|
950 |
+
},
|
951 |
+
"Qwen-72B-Chat": {
|
952 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat",
|
953 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat",
|
954 |
+
},
|
955 |
+
"Qwen-1.8B-int8-Chat": {
|
956 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8",
|
957 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8",
|
958 |
+
},
|
959 |
+
"Qwen-1.8B-int4-Chat": {
|
960 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4",
|
961 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4",
|
962 |
+
},
|
963 |
+
"Qwen-7B-int8-Chat": {
|
964 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8",
|
965 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8",
|
966 |
+
},
|
967 |
+
"Qwen-7B-int4-Chat": {
|
968 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4",
|
969 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4",
|
970 |
+
},
|
971 |
+
"Qwen-14B-int8-Chat": {
|
972 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8",
|
973 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8",
|
974 |
+
},
|
975 |
+
"Qwen-14B-int4-Chat": {
|
976 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4",
|
977 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4",
|
978 |
+
},
|
979 |
+
"Qwen-72B-int8-Chat": {
|
980 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8",
|
981 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8",
|
982 |
+
},
|
983 |
+
"Qwen-72B-int4-Chat": {
|
984 |
+
DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4",
|
985 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4",
|
986 |
+
},
|
987 |
+
},
|
988 |
+
template="qwen",
|
989 |
+
)
|
990 |
+
|
991 |
+
|
992 |
+
register_model_group(
|
993 |
+
models={
|
994 |
+
"Qwen1.5-0.5B": {
|
995 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B",
|
996 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B",
|
997 |
+
},
|
998 |
+
"Qwen1.5-1.8B": {
|
999 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B",
|
1000 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B",
|
1001 |
+
},
|
1002 |
+
"Qwen1.5-4B": {
|
1003 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B",
|
1004 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B",
|
1005 |
+
},
|
1006 |
+
"Qwen1.5-7B": {
|
1007 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B",
|
1008 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B",
|
1009 |
+
},
|
1010 |
+
"Qwen1.5-14B": {
|
1011 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B",
|
1012 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B",
|
1013 |
+
},
|
1014 |
+
"Qwen1.5-32B": {
|
1015 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B",
|
1016 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B",
|
1017 |
+
},
|
1018 |
+
"Qwen1.5-72B": {
|
1019 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B",
|
1020 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B",
|
1021 |
+
},
|
1022 |
+
"Qwen1.5-110B": {
|
1023 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B",
|
1024 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B",
|
1025 |
+
},
|
1026 |
+
"Qwen1.5-MoE-A2.7B": {
|
1027 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B",
|
1028 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B",
|
1029 |
+
},
|
1030 |
+
"Qwen1.5-Code-7B": {
|
1031 |
+
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B",
|
1032 |
+
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B",
|
1033 |
+
},
|
1034 |
+
"Qwen1.5-0.5B-Chat": {
|
1035 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat",
|
1036 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat",
|
1037 |
+
},
|
1038 |
+
"Qwen1.5-1.8B-Chat": {
|
1039 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat",
|
1040 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat",
|
1041 |
+
},
|
1042 |
+
"Qwen1.5-4B-Chat": {
|
1043 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat",
|
1044 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat",
|
1045 |
+
},
|
1046 |
+
"Qwen1.5-7B-Chat": {
|
1047 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat",
|
1048 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat",
|
1049 |
+
},
|
1050 |
+
"Qwen1.5-14B-Chat": {
|
1051 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat",
|
1052 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat",
|
1053 |
+
},
|
1054 |
+
"Qwen1.5-32B-Chat": {
|
1055 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat",
|
1056 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat",
|
1057 |
+
},
|
1058 |
+
"Qwen1.5-72B-Chat": {
|
1059 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat",
|
1060 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat",
|
1061 |
+
},
|
1062 |
+
"Qwen1.5-110B-Chat": {
|
1063 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat",
|
1064 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat",
|
1065 |
+
},
|
1066 |
+
"Qwen1.5-MoE-A2.7B-Chat": {
|
1067 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat",
|
1068 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat",
|
1069 |
+
},
|
1070 |
+
"Qwen1.5-Code-7B-Chat": {
|
1071 |
+
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat",
|
1072 |
+
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat",
|
1073 |
+
},
|
1074 |
+
"Qwen1.5-0.5B-int8-Chat": {
|
1075 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
1076 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8",
|
1077 |
+
},
|
1078 |
+
"Qwen1.5-0.5B-int4-Chat": {
|
1079 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ",
|
1080 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ",
|
1081 |
+
},
|
1082 |
+
"Qwen1.5-1.8B-int8-Chat": {
|
1083 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
1084 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8",
|
1085 |
+
},
|
1086 |
+
"Qwen1.5-1.8B-int4-Chat": {
|
1087 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ",
|
1088 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ",
|
1089 |
+
},
|
1090 |
+
"Qwen1.5-4B-int8-Chat": {
|
1091 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
1092 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8",
|
1093 |
+
},
|
1094 |
+
"Qwen1.5-4B-int4-Chat": {
|
1095 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ",
|
1096 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ",
|
1097 |
+
},
|
1098 |
+
"Qwen1.5-7B-int8-Chat": {
|
1099 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
1100 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8",
|
1101 |
+
},
|
1102 |
+
"Qwen1.5-7B-int4-Chat": {
|
1103 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ",
|
1104 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ",
|
1105 |
+
},
|
1106 |
+
"Qwen1.5-14B-int8-Chat": {
|
1107 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
1108 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8",
|
1109 |
+
},
|
1110 |
+
"Qwen1.5-14B-int4-Chat": {
|
1111 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ",
|
1112 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ",
|
1113 |
+
},
|
1114 |
+
"Qwen1.5-32B-int4-Chat": {
|
1115 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ",
|
1116 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ",
|
1117 |
+
},
|
1118 |
+
"Qwen1.5-72B-int8-Chat": {
|
1119 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
1120 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8",
|
1121 |
+
},
|
1122 |
+
"Qwen1.5-72B-int4-Chat": {
|
1123 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ",
|
1124 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ",
|
1125 |
+
},
|
1126 |
+
"Qwen1.5-110B-int4-Chat": {
|
1127 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ",
|
1128 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ",
|
1129 |
+
},
|
1130 |
+
"Qwen1.5-MoE-A2.7B-int4-Chat": {
|
1131 |
+
DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
1132 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4",
|
1133 |
+
},
|
1134 |
+
"Qwen1.5-Code-7B-int4-Chat": {
|
1135 |
+
DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ",
|
1136 |
+
DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ",
|
1137 |
+
},
|
1138 |
+
},
|
1139 |
+
template="qwen",
|
1140 |
+
)
|
1141 |
+
|
1142 |
+
|
1143 |
+
register_model_group(
|
1144 |
+
models={
|
1145 |
+
"Qwen2-0.5B": {
|
1146 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B",
|
1147 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B",
|
1148 |
+
},
|
1149 |
+
"Qwen2-1.5B": {
|
1150 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B",
|
1151 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B",
|
1152 |
+
},
|
1153 |
+
"Qwen2-7B": {
|
1154 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-7B",
|
1155 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B",
|
1156 |
+
},
|
1157 |
+
"Qwen2-72B": {
|
1158 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-72B",
|
1159 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B",
|
1160 |
+
},
|
1161 |
+
"Qwen2-MoE-57B": {
|
1162 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B",
|
1163 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B",
|
1164 |
+
},
|
1165 |
+
"Qwen2-0.5B-Chat": {
|
1166 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct",
|
1167 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct",
|
1168 |
+
},
|
1169 |
+
"Qwen2-1.5B-Chat": {
|
1170 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct",
|
1171 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct",
|
1172 |
+
},
|
1173 |
+
"Qwen2-7B-Chat": {
|
1174 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct",
|
1175 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct",
|
1176 |
+
},
|
1177 |
+
"Qwen2-72B-Chat": {
|
1178 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct",
|
1179 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct",
|
1180 |
+
},
|
1181 |
+
"Qwen2-MoE-57B-Chat": {
|
1182 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct",
|
1183 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct",
|
1184 |
+
},
|
1185 |
+
"Qwen2-0.5B-int8-Chat": {
|
1186 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
|
1187 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8",
|
1188 |
+
},
|
1189 |
+
"Qwen2-0.5B-int4-Chat": {
|
1190 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ",
|
1191 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ",
|
1192 |
+
},
|
1193 |
+
"Qwen2-1.5B-int8-Chat": {
|
1194 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
|
1195 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8",
|
1196 |
+
},
|
1197 |
+
"Qwen2-1.5B-int4-Chat": {
|
1198 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ",
|
1199 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ",
|
1200 |
+
},
|
1201 |
+
"Qwen2-7B-int8-Chat": {
|
1202 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8",
|
1203 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8",
|
1204 |
+
},
|
1205 |
+
"Qwen2-7B-int4-Chat": {
|
1206 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ",
|
1207 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ",
|
1208 |
+
},
|
1209 |
+
"Qwen2-72B-int8-Chat": {
|
1210 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8",
|
1211 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8",
|
1212 |
+
},
|
1213 |
+
"Qwen2-72B-int4-Chat": {
|
1214 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ",
|
1215 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ",
|
1216 |
+
},
|
1217 |
+
"Qwen2-MoE-57B-int4-Chat": {
|
1218 |
+
DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
|
1219 |
+
DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4",
|
1220 |
+
},
|
1221 |
+
},
|
1222 |
+
template="qwen",
|
1223 |
+
)
|
1224 |
+
|
1225 |
+
|
1226 |
+
register_model_group(
|
1227 |
+
models={
|
1228 |
+
"SOLAR-10.7B": {
|
1229 |
+
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0",
|
1230 |
+
},
|
1231 |
+
"SOLAR-10.7B-Chat": {
|
1232 |
+
DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0",
|
1233 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0",
|
1234 |
+
},
|
1235 |
+
},
|
1236 |
+
template="solar",
|
1237 |
+
)
|
1238 |
+
|
1239 |
+
|
1240 |
+
register_model_group(
|
1241 |
+
models={
|
1242 |
+
"Skywork-13B-Base": {
|
1243 |
+
DownloadSource.DEFAULT: "Skywork/Skywork-13B-base",
|
1244 |
+
DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base",
|
1245 |
+
}
|
1246 |
+
}
|
1247 |
+
)
|
1248 |
+
|
1249 |
+
|
1250 |
+
register_model_group(
|
1251 |
+
models={
|
1252 |
+
"StarCoder2-3B": {
|
1253 |
+
DownloadSource.DEFAULT: "bigcode/starcoder2-3b",
|
1254 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b",
|
1255 |
+
},
|
1256 |
+
"StarCoder2-7B": {
|
1257 |
+
DownloadSource.DEFAULT: "bigcode/starcoder2-7b",
|
1258 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b",
|
1259 |
+
},
|
1260 |
+
"StarCoder2-15B": {
|
1261 |
+
DownloadSource.DEFAULT: "bigcode/starcoder2-15b",
|
1262 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b",
|
1263 |
+
},
|
1264 |
+
}
|
1265 |
+
)
|
1266 |
+
|
1267 |
+
|
1268 |
+
register_model_group(
|
1269 |
+
models={
|
1270 |
+
"TeleChat-1B-Chat": {
|
1271 |
+
DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B",
|
1272 |
+
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B",
|
1273 |
+
},
|
1274 |
+
"TeleChat-7B-Chat": {
|
1275 |
+
DownloadSource.DEFAULT: "Tele-AI/telechat-7B",
|
1276 |
+
DownloadSource.MODELSCOPE: "TeleAI/telechat-7B",
|
1277 |
+
},
|
1278 |
+
"TeleChat-12B-Chat": {
|
1279 |
+
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B",
|
1280 |
+
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B",
|
1281 |
+
},
|
1282 |
+
"TeleChat-12B-v2-Chat": {
|
1283 |
+
DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2",
|
1284 |
+
DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2",
|
1285 |
+
},
|
1286 |
+
},
|
1287 |
+
template="telechat",
|
1288 |
+
)
|
1289 |
+
|
1290 |
+
|
1291 |
+
register_model_group(
|
1292 |
+
models={
|
1293 |
+
"Vicuna1.5-7B-Chat": {
|
1294 |
+
DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5",
|
1295 |
+
DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5",
|
1296 |
+
},
|
1297 |
+
"Vicuna1.5-13B-Chat": {
|
1298 |
+
DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5",
|
1299 |
+
DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5",
|
1300 |
+
},
|
1301 |
+
},
|
1302 |
+
template="vicuna",
|
1303 |
+
)
|
1304 |
+
|
1305 |
+
|
1306 |
+
register_model_group(
|
1307 |
+
models={
|
1308 |
+
"XuanYuan-6B": {
|
1309 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B",
|
1310 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B",
|
1311 |
+
},
|
1312 |
+
"XuanYuan-70B": {
|
1313 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B",
|
1314 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B",
|
1315 |
+
},
|
1316 |
+
"XuanYuan-2-70B": {
|
1317 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B",
|
1318 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B",
|
1319 |
+
},
|
1320 |
+
"XuanYuan-6B-Chat": {
|
1321 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat",
|
1322 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat",
|
1323 |
+
},
|
1324 |
+
"XuanYuan-70B-Chat": {
|
1325 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
1326 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat",
|
1327 |
+
},
|
1328 |
+
"XuanYuan-2-70B-Chat": {
|
1329 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
1330 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat",
|
1331 |
+
},
|
1332 |
+
"XuanYuan-6B-int8-Chat": {
|
1333 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
1334 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit",
|
1335 |
+
},
|
1336 |
+
"XuanYuan-6B-int4-Chat": {
|
1337 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
1338 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit",
|
1339 |
+
},
|
1340 |
+
"XuanYuan-70B-int8-Chat": {
|
1341 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
1342 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit",
|
1343 |
+
},
|
1344 |
+
"XuanYuan-70B-int4-Chat": {
|
1345 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
1346 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit",
|
1347 |
+
},
|
1348 |
+
"XuanYuan-2-70B-int8-Chat": {
|
1349 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
1350 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit",
|
1351 |
+
},
|
1352 |
+
"XuanYuan-2-70B-int4-Chat": {
|
1353 |
+
DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
1354 |
+
DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit",
|
1355 |
+
},
|
1356 |
+
},
|
1357 |
+
template="xuanyuan",
|
1358 |
+
)
|
1359 |
+
|
1360 |
+
|
1361 |
+
register_model_group(
|
1362 |
+
models={
|
1363 |
+
"XVERSE-7B": {
|
1364 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-7B",
|
1365 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B",
|
1366 |
+
},
|
1367 |
+
"XVERSE-13B": {
|
1368 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-13B",
|
1369 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B",
|
1370 |
+
},
|
1371 |
+
"XVERSE-65B": {
|
1372 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-65B",
|
1373 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B",
|
1374 |
+
},
|
1375 |
+
"XVERSE-65B-2": {
|
1376 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-65B-2",
|
1377 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2",
|
1378 |
+
},
|
1379 |
+
"XVERSE-7B-Chat": {
|
1380 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat",
|
1381 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat",
|
1382 |
+
},
|
1383 |
+
"XVERSE-13B-Chat": {
|
1384 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat",
|
1385 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat",
|
1386 |
+
},
|
1387 |
+
"XVERSE-65B-Chat": {
|
1388 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat",
|
1389 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat",
|
1390 |
+
},
|
1391 |
+
"XVERSE-MoE-A4.2B": {
|
1392 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B",
|
1393 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B",
|
1394 |
+
},
|
1395 |
+
"XVERSE-7B-int8-Chat": {
|
1396 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
1397 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8",
|
1398 |
+
},
|
1399 |
+
"XVERSE-7B-int4-Chat": {
|
1400 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
1401 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4",
|
1402 |
+
},
|
1403 |
+
"XVERSE-13B-int8-Chat": {
|
1404 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
1405 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8",
|
1406 |
+
},
|
1407 |
+
"XVERSE-13B-int4-Chat": {
|
1408 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
1409 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4",
|
1410 |
+
},
|
1411 |
+
"XVERSE-65B-int4-Chat": {
|
1412 |
+
DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
1413 |
+
DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4",
|
1414 |
+
},
|
1415 |
+
},
|
1416 |
+
template="xverse",
|
1417 |
+
)
|
1418 |
+
|
1419 |
+
|
1420 |
+
register_model_group(
|
1421 |
+
models={
|
1422 |
+
"Yayi-7B": {
|
1423 |
+
DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2",
|
1424 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2",
|
1425 |
+
},
|
1426 |
+
"Yayi-13B": {
|
1427 |
+
DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2",
|
1428 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2",
|
1429 |
+
},
|
1430 |
+
},
|
1431 |
+
template="yayi",
|
1432 |
+
)
|
1433 |
+
|
1434 |
+
|
1435 |
+
register_model_group(
|
1436 |
+
models={
|
1437 |
+
"Yi-6B": {
|
1438 |
+
DownloadSource.DEFAULT: "01-ai/Yi-6B",
|
1439 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-6B",
|
1440 |
+
},
|
1441 |
+
"Yi-9B": {
|
1442 |
+
DownloadSource.DEFAULT: "01-ai/Yi-9B",
|
1443 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-9B",
|
1444 |
+
},
|
1445 |
+
"Yi-34B": {
|
1446 |
+
DownloadSource.DEFAULT: "01-ai/Yi-34B",
|
1447 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-34B",
|
1448 |
+
},
|
1449 |
+
"Yi-6B-Chat": {
|
1450 |
+
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat",
|
1451 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat",
|
1452 |
+
},
|
1453 |
+
"Yi-34B-Chat": {
|
1454 |
+
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat",
|
1455 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat",
|
1456 |
+
},
|
1457 |
+
"Yi-6B-int8-Chat": {
|
1458 |
+
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits",
|
1459 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits",
|
1460 |
+
},
|
1461 |
+
"Yi-6B-int4-Chat": {
|
1462 |
+
DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits",
|
1463 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits",
|
1464 |
+
},
|
1465 |
+
"Yi-34B-int8-Chat": {
|
1466 |
+
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits",
|
1467 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits",
|
1468 |
+
},
|
1469 |
+
"Yi-34B-int4-Chat": {
|
1470 |
+
DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits",
|
1471 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits",
|
1472 |
+
},
|
1473 |
+
"Yi-1.5-6B": {
|
1474 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B",
|
1475 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B",
|
1476 |
+
},
|
1477 |
+
"Yi-1.5-9B": {
|
1478 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B",
|
1479 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B",
|
1480 |
+
},
|
1481 |
+
"Yi-1.5-34B": {
|
1482 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B",
|
1483 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B",
|
1484 |
+
},
|
1485 |
+
"Yi-1.5-6B-Chat": {
|
1486 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat",
|
1487 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat",
|
1488 |
+
},
|
1489 |
+
"Yi-1.5-9B-Chat": {
|
1490 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat",
|
1491 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat",
|
1492 |
+
},
|
1493 |
+
"Yi-1.5-34B-Chat": {
|
1494 |
+
DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat",
|
1495 |
+
DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat",
|
1496 |
+
},
|
1497 |
+
},
|
1498 |
+
template="yi",
|
1499 |
+
)
|
1500 |
+
|
1501 |
+
|
1502 |
+
register_model_group(
|
1503 |
+
models={
|
1504 |
+
"YiVL-6B-Chat": {
|
1505 |
+
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf",
|
1506 |
+
},
|
1507 |
+
"YiVL-34B-Chat": {
|
1508 |
+
DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf",
|
1509 |
+
},
|
1510 |
+
},
|
1511 |
+
template="yi_vl",
|
1512 |
+
vision=True,
|
1513 |
+
)
|
1514 |
+
|
1515 |
+
|
1516 |
+
register_model_group(
|
1517 |
+
models={
|
1518 |
+
"Yuan2-2B-Chat": {
|
1519 |
+
DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf",
|
1520 |
+
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf",
|
1521 |
+
},
|
1522 |
+
"Yuan2-51B-Chat": {
|
1523 |
+
DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf",
|
1524 |
+
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf",
|
1525 |
+
},
|
1526 |
+
"Yuan2-102B-Chat": {
|
1527 |
+
DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf",
|
1528 |
+
DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf",
|
1529 |
+
},
|
1530 |
+
},
|
1531 |
+
template="yuan",
|
1532 |
+
)
|
1533 |
+
|
1534 |
+
|
1535 |
+
register_model_group(
|
1536 |
+
models={
|
1537 |
+
"Zephyr-7B-Alpha-Chat": {
|
1538 |
+
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha",
|
1539 |
+
DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha",
|
1540 |
+
},
|
1541 |
+
"Zephyr-7B-Beta-Chat": {
|
1542 |
+
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta",
|
1543 |
+
DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta",
|
1544 |
+
},
|
1545 |
+
"Zephyr-141B-ORPO-Chat": {
|
1546 |
+
DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1",
|
1547 |
+
},
|
1548 |
+
},
|
1549 |
+
template="zephyr",
|
1550 |
+
)
|
src/extras/env.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's transformers library.
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import platform
|
19 |
+
|
20 |
+
import accelerate
|
21 |
+
import datasets
|
22 |
+
import peft
|
23 |
+
import torch
|
24 |
+
import transformers
|
25 |
+
import trl
|
26 |
+
from transformers.utils import is_torch_cuda_available, is_torch_npu_available
|
27 |
+
|
28 |
+
|
29 |
+
VERSION = "0.8.3.dev0"
|
30 |
+
|
31 |
+
|
32 |
+
def print_env() -> None:
|
33 |
+
info = {
|
34 |
+
"`llamafactory` version": VERSION,
|
35 |
+
"Platform": platform.platform(),
|
36 |
+
"Python version": platform.python_version(),
|
37 |
+
"PyTorch version": torch.__version__,
|
38 |
+
"Transformers version": transformers.__version__,
|
39 |
+
"Datasets version": datasets.__version__,
|
40 |
+
"Accelerate version": accelerate.__version__,
|
41 |
+
"PEFT version": peft.__version__,
|
42 |
+
"TRL version": trl.__version__,
|
43 |
+
}
|
44 |
+
|
45 |
+
if is_torch_cuda_available():
|
46 |
+
info["PyTorch version"] += " (GPU)"
|
47 |
+
info["GPU type"] = torch.cuda.get_device_name()
|
48 |
+
|
49 |
+
if is_torch_npu_available():
|
50 |
+
info["PyTorch version"] += " (NPU)"
|
51 |
+
info["NPU type"] = torch.npu.get_device_name()
|
52 |
+
info["CANN version"] = torch.version.cann
|
53 |
+
|
54 |
+
try:
|
55 |
+
import deepspeed # type: ignore
|
56 |
+
|
57 |
+
info["DeepSpeed version"] = deepspeed.__version__
|
58 |
+
except Exception:
|
59 |
+
pass
|
60 |
+
|
61 |
+
try:
|
62 |
+
import bitsandbytes
|
63 |
+
|
64 |
+
info["Bitsandbytes version"] = bitsandbytes.__version__
|
65 |
+
except Exception:
|
66 |
+
pass
|
67 |
+
|
68 |
+
try:
|
69 |
+
import vllm
|
70 |
+
|
71 |
+
info["vLLM version"] = vllm.__version__
|
72 |
+
except Exception:
|
73 |
+
pass
|
74 |
+
|
75 |
+
print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n")
|
src/extras/logging.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import logging
|
16 |
+
import os
|
17 |
+
import sys
|
18 |
+
from concurrent.futures import ThreadPoolExecutor
|
19 |
+
|
20 |
+
from .constants import RUNNING_LOG
|
21 |
+
|
22 |
+
|
23 |
+
class LoggerHandler(logging.Handler):
|
24 |
+
r"""
|
25 |
+
Logger handler used in Web UI.
|
26 |
+
"""
|
27 |
+
|
28 |
+
def __init__(self, output_dir: str) -> None:
|
29 |
+
super().__init__()
|
30 |
+
formatter = logging.Formatter(
|
31 |
+
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
32 |
+
)
|
33 |
+
self.setLevel(logging.INFO)
|
34 |
+
self.setFormatter(formatter)
|
35 |
+
|
36 |
+
os.makedirs(output_dir, exist_ok=True)
|
37 |
+
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
38 |
+
if os.path.exists(self.running_log):
|
39 |
+
os.remove(self.running_log)
|
40 |
+
|
41 |
+
self.thread_pool = ThreadPoolExecutor(max_workers=1)
|
42 |
+
|
43 |
+
def _write_log(self, log_entry: str) -> None:
|
44 |
+
with open(self.running_log, "a", encoding="utf-8") as f:
|
45 |
+
f.write(log_entry + "\n\n")
|
46 |
+
|
47 |
+
def emit(self, record) -> None:
|
48 |
+
if record.name == "httpx":
|
49 |
+
return
|
50 |
+
|
51 |
+
log_entry = self.format(record)
|
52 |
+
self.thread_pool.submit(self._write_log, log_entry)
|
53 |
+
|
54 |
+
def close(self) -> None:
|
55 |
+
self.thread_pool.shutdown(wait=True)
|
56 |
+
return super().close()
|
57 |
+
|
58 |
+
|
59 |
+
def get_logger(name: str) -> logging.Logger:
|
60 |
+
r"""
|
61 |
+
Gets a standard logger with a stream hander to stdout.
|
62 |
+
"""
|
63 |
+
formatter = logging.Formatter(
|
64 |
+
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
65 |
+
)
|
66 |
+
handler = logging.StreamHandler(sys.stdout)
|
67 |
+
handler.setFormatter(formatter)
|
68 |
+
|
69 |
+
logger = logging.getLogger(name)
|
70 |
+
logger.setLevel(logging.INFO)
|
71 |
+
logger.addHandler(handler)
|
72 |
+
|
73 |
+
return logger
|
74 |
+
|
75 |
+
|
76 |
+
def reset_logging() -> None:
|
77 |
+
r"""
|
78 |
+
Removes basic config of root logger. (unused in script)
|
79 |
+
"""
|
80 |
+
root = logging.getLogger()
|
81 |
+
list(map(root.removeHandler, root.handlers))
|
82 |
+
list(map(root.removeFilter, root.filters))
|
src/extras/misc.py
ADDED
@@ -0,0 +1,215 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's PEFT library.
|
4 |
+
# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import gc
|
19 |
+
import os
|
20 |
+
from typing import TYPE_CHECKING, Tuple
|
21 |
+
|
22 |
+
import torch
|
23 |
+
import transformers.dynamic_module_utils
|
24 |
+
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
25 |
+
from transformers.dynamic_module_utils import get_relative_imports
|
26 |
+
from transformers.utils import (
|
27 |
+
is_torch_bf16_gpu_available,
|
28 |
+
is_torch_cuda_available,
|
29 |
+
is_torch_mps_available,
|
30 |
+
is_torch_npu_available,
|
31 |
+
is_torch_xpu_available,
|
32 |
+
)
|
33 |
+
from transformers.utils.versions import require_version
|
34 |
+
|
35 |
+
from .logging import get_logger
|
36 |
+
|
37 |
+
|
38 |
+
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
39 |
+
try:
|
40 |
+
_is_bf16_available = is_torch_bf16_gpu_available()
|
41 |
+
except Exception:
|
42 |
+
_is_bf16_available = False
|
43 |
+
|
44 |
+
|
45 |
+
if TYPE_CHECKING:
|
46 |
+
from ..hparams import ModelArguments
|
47 |
+
|
48 |
+
|
49 |
+
logger = get_logger(__name__)
|
50 |
+
|
51 |
+
|
52 |
+
class AverageMeter:
|
53 |
+
r"""
|
54 |
+
Computes and stores the average and current value.
|
55 |
+
"""
|
56 |
+
|
57 |
+
def __init__(self):
|
58 |
+
self.reset()
|
59 |
+
|
60 |
+
def reset(self):
|
61 |
+
self.val = 0
|
62 |
+
self.avg = 0
|
63 |
+
self.sum = 0
|
64 |
+
self.count = 0
|
65 |
+
|
66 |
+
def update(self, val, n=1):
|
67 |
+
self.val = val
|
68 |
+
self.sum += val * n
|
69 |
+
self.count += n
|
70 |
+
self.avg = self.sum / self.count
|
71 |
+
|
72 |
+
|
73 |
+
def check_dependencies() -> None:
|
74 |
+
r"""
|
75 |
+
Checks the version of the required packages.
|
76 |
+
"""
|
77 |
+
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
78 |
+
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
79 |
+
else:
|
80 |
+
require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2")
|
81 |
+
require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0")
|
82 |
+
require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1")
|
83 |
+
require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1")
|
84 |
+
require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6")
|
85 |
+
|
86 |
+
|
87 |
+
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
88 |
+
r"""
|
89 |
+
Returns the number of trainable parameters and number of all parameters in the model.
|
90 |
+
"""
|
91 |
+
trainable_params, all_param = 0, 0
|
92 |
+
for param in model.parameters():
|
93 |
+
num_params = param.numel()
|
94 |
+
# if using DS Zero 3 and the weights are initialized empty
|
95 |
+
if num_params == 0 and hasattr(param, "ds_numel"):
|
96 |
+
num_params = param.ds_numel
|
97 |
+
|
98 |
+
# Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize
|
99 |
+
if param.__class__.__name__ == "Params4bit":
|
100 |
+
if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"):
|
101 |
+
num_bytes = param.quant_storage.itemsize
|
102 |
+
elif hasattr(param, "element_size"): # for older pytorch version
|
103 |
+
num_bytes = param.element_size()
|
104 |
+
else:
|
105 |
+
num_bytes = 1
|
106 |
+
|
107 |
+
num_params = num_params * 2 * num_bytes
|
108 |
+
|
109 |
+
all_param += num_params
|
110 |
+
if param.requires_grad:
|
111 |
+
trainable_params += num_params
|
112 |
+
|
113 |
+
return trainable_params, all_param
|
114 |
+
|
115 |
+
|
116 |
+
def get_current_device() -> "torch.device":
|
117 |
+
r"""
|
118 |
+
Gets the current available device.
|
119 |
+
"""
|
120 |
+
if is_torch_xpu_available():
|
121 |
+
device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
122 |
+
elif is_torch_npu_available():
|
123 |
+
device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
124 |
+
elif is_torch_mps_available():
|
125 |
+
device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
126 |
+
elif is_torch_cuda_available():
|
127 |
+
device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0"))
|
128 |
+
else:
|
129 |
+
device = "cpu"
|
130 |
+
|
131 |
+
return torch.device(device)
|
132 |
+
|
133 |
+
|
134 |
+
def get_device_count() -> int:
|
135 |
+
r"""
|
136 |
+
Gets the number of available GPU or NPU devices.
|
137 |
+
"""
|
138 |
+
if is_torch_npu_available():
|
139 |
+
return torch.npu.device_count()
|
140 |
+
elif is_torch_cuda_available():
|
141 |
+
return torch.cuda.device_count()
|
142 |
+
else:
|
143 |
+
return 0
|
144 |
+
|
145 |
+
|
146 |
+
def get_logits_processor() -> "LogitsProcessorList":
|
147 |
+
r"""
|
148 |
+
Gets logits processor that removes NaN and Inf logits.
|
149 |
+
"""
|
150 |
+
logits_processor = LogitsProcessorList()
|
151 |
+
logits_processor.append(InfNanRemoveLogitsProcessor())
|
152 |
+
return logits_processor
|
153 |
+
|
154 |
+
|
155 |
+
def has_tokenized_data(path: "os.PathLike") -> bool:
|
156 |
+
r"""
|
157 |
+
Checks if the path has a tokenized dataset.
|
158 |
+
"""
|
159 |
+
return os.path.isdir(path) and len(os.listdir(path)) > 0
|
160 |
+
|
161 |
+
|
162 |
+
def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype":
|
163 |
+
r"""
|
164 |
+
Infers the optimal dtype according to the model_dtype and device compatibility.
|
165 |
+
"""
|
166 |
+
if _is_bf16_available and model_dtype == torch.bfloat16:
|
167 |
+
return torch.bfloat16
|
168 |
+
elif _is_fp16_available:
|
169 |
+
return torch.float16
|
170 |
+
else:
|
171 |
+
return torch.float32
|
172 |
+
|
173 |
+
|
174 |
+
def is_gpu_or_npu_available() -> bool:
|
175 |
+
r"""
|
176 |
+
Checks if the GPU or NPU is available.
|
177 |
+
"""
|
178 |
+
return is_torch_npu_available() or is_torch_cuda_available()
|
179 |
+
|
180 |
+
|
181 |
+
def skip_check_imports() -> None:
|
182 |
+
if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]:
|
183 |
+
transformers.dynamic_module_utils.check_imports = get_relative_imports
|
184 |
+
|
185 |
+
|
186 |
+
def torch_gc() -> None:
|
187 |
+
r"""
|
188 |
+
Collects GPU or NPU memory.
|
189 |
+
"""
|
190 |
+
gc.collect()
|
191 |
+
if is_torch_xpu_available():
|
192 |
+
torch.xpu.empty_cache()
|
193 |
+
elif is_torch_npu_available():
|
194 |
+
torch.npu.empty_cache()
|
195 |
+
elif is_torch_mps_available():
|
196 |
+
torch.mps.empty_cache()
|
197 |
+
elif is_torch_cuda_available():
|
198 |
+
torch.cuda.empty_cache()
|
199 |
+
|
200 |
+
|
201 |
+
def try_download_model_from_ms(model_args: "ModelArguments") -> str:
|
202 |
+
if not use_modelscope() or os.path.exists(model_args.model_name_or_path):
|
203 |
+
return model_args.model_name_or_path
|
204 |
+
|
205 |
+
try:
|
206 |
+
from modelscope import snapshot_download
|
207 |
+
|
208 |
+
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
209 |
+
return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir)
|
210 |
+
except ImportError:
|
211 |
+
raise ImportError("Please install modelscope via `pip install modelscope -U`")
|
212 |
+
|
213 |
+
|
214 |
+
def use_modelscope() -> bool:
|
215 |
+
return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"]
|
src/extras/packages.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's transformers library.
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
import importlib.metadata
|
19 |
+
import importlib.util
|
20 |
+
from functools import lru_cache
|
21 |
+
from typing import TYPE_CHECKING
|
22 |
+
|
23 |
+
from packaging import version
|
24 |
+
|
25 |
+
|
26 |
+
if TYPE_CHECKING:
|
27 |
+
from packaging.version import Version
|
28 |
+
|
29 |
+
|
30 |
+
def _is_package_available(name: str) -> bool:
|
31 |
+
return importlib.util.find_spec(name) is not None
|
32 |
+
|
33 |
+
|
34 |
+
def _get_package_version(name: str) -> "Version":
|
35 |
+
try:
|
36 |
+
return version.parse(importlib.metadata.version(name))
|
37 |
+
except Exception:
|
38 |
+
return version.parse("0.0.0")
|
39 |
+
|
40 |
+
|
41 |
+
def is_fastapi_available():
|
42 |
+
return _is_package_available("fastapi")
|
43 |
+
|
44 |
+
|
45 |
+
def is_galore_available():
|
46 |
+
return _is_package_available("galore_torch")
|
47 |
+
|
48 |
+
|
49 |
+
def is_gradio_available():
|
50 |
+
return _is_package_available("gradio")
|
51 |
+
|
52 |
+
|
53 |
+
def is_matplotlib_available():
|
54 |
+
return _is_package_available("matplotlib")
|
55 |
+
|
56 |
+
|
57 |
+
def is_pillow_available():
|
58 |
+
return _is_package_available("PIL")
|
59 |
+
|
60 |
+
|
61 |
+
def is_requests_available():
|
62 |
+
return _is_package_available("requests")
|
63 |
+
|
64 |
+
|
65 |
+
def is_rouge_available():
|
66 |
+
return _is_package_available("rouge_chinese")
|
67 |
+
|
68 |
+
|
69 |
+
def is_starlette_available():
|
70 |
+
return _is_package_available("sse_starlette")
|
71 |
+
|
72 |
+
|
73 |
+
def is_uvicorn_available():
|
74 |
+
return _is_package_available("uvicorn")
|
75 |
+
|
76 |
+
|
77 |
+
def is_vllm_available():
|
78 |
+
return _is_package_available("vllm")
|
79 |
+
|
80 |
+
|
81 |
+
@lru_cache
|
82 |
+
def is_vllm_version_greater_than_0_5():
|
83 |
+
return _get_package_version("vllm") >= version.parse("0.5.0")
|
src/extras/ploting.py
ADDED
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import json
|
16 |
+
import math
|
17 |
+
import os
|
18 |
+
from typing import Any, Dict, List
|
19 |
+
|
20 |
+
from transformers.trainer import TRAINER_STATE_NAME
|
21 |
+
|
22 |
+
from .logging import get_logger
|
23 |
+
from .packages import is_matplotlib_available
|
24 |
+
|
25 |
+
|
26 |
+
if is_matplotlib_available():
|
27 |
+
import matplotlib.figure
|
28 |
+
import matplotlib.pyplot as plt
|
29 |
+
|
30 |
+
|
31 |
+
logger = get_logger(__name__)
|
32 |
+
|
33 |
+
|
34 |
+
def smooth(scalars: List[float]) -> List[float]:
|
35 |
+
r"""
|
36 |
+
EMA implementation according to TensorBoard.
|
37 |
+
"""
|
38 |
+
if len(scalars) == 0:
|
39 |
+
return []
|
40 |
+
|
41 |
+
last = scalars[0]
|
42 |
+
smoothed = []
|
43 |
+
weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function
|
44 |
+
for next_val in scalars:
|
45 |
+
smoothed_val = last * weight + (1 - weight) * next_val
|
46 |
+
smoothed.append(smoothed_val)
|
47 |
+
last = smoothed_val
|
48 |
+
return smoothed
|
49 |
+
|
50 |
+
|
51 |
+
def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure":
|
52 |
+
r"""
|
53 |
+
Plots loss curves in LlamaBoard.
|
54 |
+
"""
|
55 |
+
plt.close("all")
|
56 |
+
plt.switch_backend("agg")
|
57 |
+
fig = plt.figure()
|
58 |
+
ax = fig.add_subplot(111)
|
59 |
+
steps, losses = [], []
|
60 |
+
for log in trainer_log:
|
61 |
+
if log.get("loss", None):
|
62 |
+
steps.append(log["current_steps"])
|
63 |
+
losses.append(log["loss"])
|
64 |
+
|
65 |
+
ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original")
|
66 |
+
ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed")
|
67 |
+
ax.legend()
|
68 |
+
ax.set_xlabel("step")
|
69 |
+
ax.set_ylabel("loss")
|
70 |
+
return fig
|
71 |
+
|
72 |
+
|
73 |
+
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
74 |
+
r"""
|
75 |
+
Plots loss curves and saves the image.
|
76 |
+
"""
|
77 |
+
plt.switch_backend("agg")
|
78 |
+
with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f:
|
79 |
+
data = json.load(f)
|
80 |
+
|
81 |
+
for key in keys:
|
82 |
+
steps, metrics = [], []
|
83 |
+
for i in range(len(data["log_history"])):
|
84 |
+
if key in data["log_history"][i]:
|
85 |
+
steps.append(data["log_history"][i]["step"])
|
86 |
+
metrics.append(data["log_history"][i][key])
|
87 |
+
|
88 |
+
if len(metrics) == 0:
|
89 |
+
logger.warning(f"No metric {key} to plot.")
|
90 |
+
continue
|
91 |
+
|
92 |
+
plt.figure()
|
93 |
+
plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original")
|
94 |
+
plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed")
|
95 |
+
plt.title("training {} of {}".format(key, save_dictionary))
|
96 |
+
plt.xlabel("step")
|
97 |
+
plt.ylabel(key)
|
98 |
+
plt.legend()
|
99 |
+
figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_")))
|
100 |
+
plt.savefig(figure_path, format="png", dpi=100)
|
101 |
+
print("Figure saved at:", figure_path)
|
src/hparams/__init__.py
ADDED
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from .data_args import DataArguments
|
16 |
+
from .evaluation_args import EvaluationArguments
|
17 |
+
from .finetuning_args import FinetuningArguments
|
18 |
+
from .generating_args import GeneratingArguments
|
19 |
+
from .model_args import ModelArguments
|
20 |
+
from .parser import get_infer_args, get_train_args
|
21 |
+
|
22 |
+
|
23 |
+
__all__ = [
|
24 |
+
"DataArguments",
|
25 |
+
"EvaluationArguments",
|
26 |
+
"FinetuningArguments",
|
27 |
+
"GeneratingArguments",
|
28 |
+
"ModelArguments",
|
29 |
+
"get_infer_args",
|
30 |
+
"get_train_args",
|
31 |
+
]
|
src/hparams/data_args.py
ADDED
@@ -0,0 +1,110 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 HuggingFace Inc., the LlamaFactory team, and the Llamole team.
|
2 |
+
#
|
3 |
+
# This code is inspired by the HuggingFace's transformers library.
|
4 |
+
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
5 |
+
#
|
6 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
7 |
+
# you may not use this file except in compliance with the License.
|
8 |
+
# You may obtain a copy of the License at
|
9 |
+
#
|
10 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
11 |
+
#
|
12 |
+
# Unless required by applicable law or agreed to in writing, software
|
13 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
14 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
15 |
+
# See the License for the specific language governing permissions and
|
16 |
+
# limitations under the License.
|
17 |
+
|
18 |
+
from dataclasses import dataclass, field
|
19 |
+
from typing import Literal, Optional
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class DataArguments:
|
24 |
+
r"""
|
25 |
+
Arguments pertaining to what data we are going to input our model for training and evaluation.
|
26 |
+
"""
|
27 |
+
|
28 |
+
template: Optional[str] = field(
|
29 |
+
default=None,
|
30 |
+
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
31 |
+
)
|
32 |
+
dataset: Optional[str] = field(
|
33 |
+
default=None,
|
34 |
+
metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."},
|
35 |
+
)
|
36 |
+
dataset_dir: str = field(
|
37 |
+
default="data",
|
38 |
+
metadata={"help": "Path to the folder containing the datasets."},
|
39 |
+
)
|
40 |
+
split: str = field(
|
41 |
+
default="train",
|
42 |
+
metadata={"help": "Which dataset split to use for training and evaluation."},
|
43 |
+
)
|
44 |
+
cutoff_len: int = field(
|
45 |
+
default=1024,
|
46 |
+
metadata={"help": "The cutoff length of the tokenized inputs in the dataset."},
|
47 |
+
)
|
48 |
+
train_on_prompt: bool = field(
|
49 |
+
default=False,
|
50 |
+
metadata={"help": "Whether to disable the mask on the prompt or not."},
|
51 |
+
)
|
52 |
+
# number of learnable queries
|
53 |
+
learned_query_size: int = field(
|
54 |
+
default=8,
|
55 |
+
metadata={"help": "Number of queries to use in training."},
|
56 |
+
)
|
57 |
+
|
58 |
+
# used for evaluation with specific starting idx
|
59 |
+
start_idx: int = field(
|
60 |
+
default=0,
|
61 |
+
metadata={"help": "Start index of the evaluation data for drug"},
|
62 |
+
)
|
63 |
+
buffer_size: int = field(
|
64 |
+
default=16384,
|
65 |
+
metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."},
|
66 |
+
)
|
67 |
+
mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field(
|
68 |
+
default="concat",
|
69 |
+
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
70 |
+
)
|
71 |
+
overwrite_cache: bool = field(
|
72 |
+
default=False,
|
73 |
+
metadata={"help": "Overwrite the cached training and evaluation sets."},
|
74 |
+
)
|
75 |
+
preprocessing_num_workers: Optional[int] = field(
|
76 |
+
default=None,
|
77 |
+
metadata={"help": "The number of processes to use for the pre-processing."},
|
78 |
+
)
|
79 |
+
max_samples: Optional[int] = field(
|
80 |
+
default=None,
|
81 |
+
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
82 |
+
)
|
83 |
+
eval_num_beams: Optional[int] = field(
|
84 |
+
default=None,
|
85 |
+
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
86 |
+
)
|
87 |
+
ignore_pad_token_for_loss: bool = field(
|
88 |
+
default=True,
|
89 |
+
metadata={
|
90 |
+
"help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation."
|
91 |
+
},
|
92 |
+
)
|
93 |
+
val_size: float = field(
|
94 |
+
default=0.0,
|
95 |
+
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."},
|
96 |
+
)
|
97 |
+
packing: Optional[bool] = field(
|
98 |
+
default=None,
|
99 |
+
metadata={
|
100 |
+
"help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training."
|
101 |
+
},
|
102 |
+
)
|
103 |
+
tool_format: Optional[str] = field(
|
104 |
+
default=None,
|
105 |
+
metadata={"help": "Tool format to use for constructing function calling examples."},
|
106 |
+
)
|
107 |
+
tokenized_path: Optional[str] = field(
|
108 |
+
default=None,
|
109 |
+
metadata={"help": "Path to save or load the tokenized datasets."},
|
110 |
+
)
|
src/hparams/evaluation_args.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
import os
|
16 |
+
from dataclasses import dataclass, field
|
17 |
+
from typing import Literal, Optional
|
18 |
+
|
19 |
+
from datasets import DownloadMode
|
20 |
+
|
21 |
+
|
22 |
+
@dataclass
|
23 |
+
class EvaluationArguments:
|
24 |
+
r"""
|
25 |
+
Arguments pertaining to specify the evaluation parameters.
|
26 |
+
"""
|
27 |
+
|
28 |
+
task: str = field(
|
29 |
+
metadata={"help": "Name of the evaluation task."},
|
30 |
+
)
|
31 |
+
task_dir: str = field(
|
32 |
+
default="evaluation",
|
33 |
+
metadata={"help": "Path to the folder containing the evaluation datasets."},
|
34 |
+
)
|
35 |
+
batch_size: int = field(
|
36 |
+
default=4,
|
37 |
+
metadata={"help": "The batch size per GPU for evaluation."},
|
38 |
+
)
|
39 |
+
seed: int = field(
|
40 |
+
default=42,
|
41 |
+
metadata={"help": "Random seed to be used with data loaders."},
|
42 |
+
)
|
43 |
+
lang: Literal["en", "zh"] = field(
|
44 |
+
default="en",
|
45 |
+
metadata={"help": "Language used at evaluation."},
|
46 |
+
)
|
47 |
+
n_shot: int = field(
|
48 |
+
default=5,
|
49 |
+
metadata={"help": "Number of examplars for few-shot learning."},
|
50 |
+
)
|
51 |
+
save_dir: Optional[str] = field(
|
52 |
+
default=None,
|
53 |
+
metadata={"help": "Path to save the evaluation results."},
|
54 |
+
)
|
55 |
+
download_mode: DownloadMode = field(
|
56 |
+
default=DownloadMode.REUSE_DATASET_IF_EXISTS,
|
57 |
+
metadata={"help": "Download mode used for the evaluation datasets."},
|
58 |
+
)
|
59 |
+
|
60 |
+
def __post_init__(self):
|
61 |
+
if self.save_dir is not None and os.path.exists(self.save_dir):
|
62 |
+
raise ValueError("`save_dir` already exists, use another one.")
|
src/hparams/finetuning_args.py
ADDED
@@ -0,0 +1,402 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 the LlamaFactory team.
|
2 |
+
#
|
3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4 |
+
# you may not use this file except in compliance with the License.
|
5 |
+
# You may obtain a copy of the License at
|
6 |
+
#
|
7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8 |
+
#
|
9 |
+
# Unless required by applicable law or agreed to in writing, software
|
10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12 |
+
# See the License for the specific language governing permissions and
|
13 |
+
# limitations under the License.
|
14 |
+
|
15 |
+
from dataclasses import dataclass, field
|
16 |
+
from typing import List, Literal, Optional
|
17 |
+
|
18 |
+
|
19 |
+
@dataclass
|
20 |
+
class FreezeArguments:
|
21 |
+
r"""
|
22 |
+
Arguments pertaining to the freeze (partial-parameter) training.
|
23 |
+
"""
|
24 |
+
|
25 |
+
freeze_trainable_layers: int = field(
|
26 |
+
default=2,
|
27 |
+
metadata={
|
28 |
+
"help": (
|
29 |
+
"The number of trainable layers for freeze (partial-parameter) fine-tuning. "
|
30 |
+
"Positive numbers mean the last n layers are set as trainable, "
|
31 |
+
"negative numbers mean the first n layers are set as trainable."
|
32 |
+
)
|
33 |
+
},
|
34 |
+
)
|
35 |
+
freeze_trainable_modules: str = field(
|
36 |
+
default="all",
|
37 |
+
metadata={
|
38 |
+
"help": (
|
39 |
+
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
|
40 |
+
"Use commas to separate multiple modules. "
|
41 |
+
"Use `all` to specify all the available modules."
|
42 |
+
)
|
43 |
+
},
|
44 |
+
)
|
45 |
+
freeze_extra_modules: Optional[str] = field(
|
46 |
+
default=None,
|
47 |
+
metadata={
|
48 |
+
"help": (
|
49 |
+
"Name(s) of modules apart from hidden layers to be set as trainable "
|
50 |
+
"for freeze (partial-parameter) fine-tuning. "
|
51 |
+
"Use commas to separate multiple modules."
|
52 |
+
)
|
53 |
+
},
|
54 |
+
)
|
55 |
+
|
56 |
+
|
57 |
+
@dataclass
|
58 |
+
class LoraArguments:
|
59 |
+
r"""
|
60 |
+
Arguments pertaining to the LoRA training.
|
61 |
+
"""
|
62 |
+
|
63 |
+
additional_target: Optional[str] = field(
|
64 |
+
default=None,
|
65 |
+
metadata={
|
66 |
+
"help": (
|
67 |
+
"Name(s) of modules apart from LoRA layers to be set as trainable "
|
68 |
+
"and saved in the final checkpoint. "
|
69 |
+
"Use commas to separate multiple modules."
|
70 |
+
)
|
71 |
+
},
|
72 |
+
)
|
73 |
+
lora_alpha: Optional[int] = field(
|
74 |
+
default=None,
|
75 |
+
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
76 |
+
)
|
77 |
+
lora_dropout: float = field(
|
78 |
+
default=0.0,
|
79 |
+
metadata={"help": "Dropout rate for the LoRA fine-tuning."},
|
80 |
+
)
|
81 |
+
lora_rank: int = field(
|
82 |
+
default=8,
|
83 |
+
metadata={"help": "The intrinsic dimension for LoRA fine-tuning."},
|
84 |
+
)
|
85 |
+
lora_target: str = field(
|
86 |
+
default="all",
|
87 |
+
metadata={
|
88 |
+
"help": (
|
89 |
+
"Name(s) of target modules to apply LoRA. "
|
90 |
+
"Use commas to separate multiple modules. "
|
91 |
+
"Use `all` to specify all the linear modules."
|
92 |
+
)
|
93 |
+
},
|
94 |
+
)
|
95 |
+
loraplus_lr_ratio: Optional[float] = field(
|
96 |
+
default=None,
|
97 |
+
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
98 |
+
)
|
99 |
+
loraplus_lr_embedding: float = field(
|
100 |
+
default=1e-6,
|
101 |
+
metadata={"help": "LoRA plus learning rate for lora embedding layers."},
|
102 |
+
)
|
103 |
+
use_rslora: bool = field(
|
104 |
+
default=False,
|
105 |
+
metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."},
|
106 |
+
)
|
107 |
+
use_dora: bool = field(
|
108 |
+
default=False,
|
109 |
+
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
|
110 |
+
)
|
111 |
+
pissa_init: bool = field(
|
112 |
+
default=False,
|
113 |
+
metadata={"help": "Whether or not to initialize a PiSSA adapter."},
|
114 |
+
)
|
115 |
+
pissa_iter: int = field(
|
116 |
+
default=16,
|
117 |
+
metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
|
118 |
+
)
|
119 |
+
pissa_convert: bool = field(
|
120 |
+
default=False,
|
121 |
+
metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
|
122 |
+
)
|
123 |
+
create_new_adapter: bool = field(
|
124 |
+
default=False,
|
125 |
+
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
126 |
+
)
|
127 |
+
|
128 |
+
|
129 |
+
@dataclass
|
130 |
+
class RLHFArguments:
|
131 |
+
r"""
|
132 |
+
Arguments pertaining to the PPO, DPO and KTO training.
|
133 |
+
"""
|
134 |
+
|
135 |
+
pref_beta: float = field(
|
136 |
+
default=0.1,
|
137 |
+
metadata={"help": "The beta parameter in the preference loss."},
|
138 |
+
)
|
139 |
+
pref_ftx: float = field(
|
140 |
+
default=0.0,
|
141 |
+
metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."},
|
142 |
+
)
|
143 |
+
pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field(
|
144 |
+
default="sigmoid",
|
145 |
+
metadata={"help": "The type of DPO loss to use."},
|
146 |
+
)
|
147 |
+
dpo_label_smoothing: float = field(
|
148 |
+
default=0.0,
|
149 |
+
metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."},
|
150 |
+
)
|
151 |
+
kto_chosen_weight: float = field(
|
152 |
+
default=1.0,
|
153 |
+
metadata={"help": "The weight factor of the desirable losses in KTO training."},
|
154 |
+
)
|
155 |
+
kto_rejected_weight: float = field(
|
156 |
+
default=1.0,
|
157 |
+
metadata={"help": "The weight factor of the undesirable losses in KTO training."},
|
158 |
+
)
|
159 |
+
simpo_gamma: float = field(
|
160 |
+
default=0.5,
|
161 |
+
metadata={"help": "The target reward margin term in SimPO loss."},
|
162 |
+
)
|
163 |
+
ppo_buffer_size: int = field(
|
164 |
+
default=1,
|
165 |
+
metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."},
|
166 |
+
)
|
167 |
+
ppo_epochs: int = field(
|
168 |
+
default=4,
|
169 |
+
metadata={"help": "The number of epochs to perform in a PPO optimization step."},
|
170 |
+
)
|
171 |
+
ppo_score_norm: bool = field(
|
172 |
+
default=False,
|
173 |
+
metadata={"help": "Use score normalization in PPO training."},
|
174 |
+
)
|
175 |
+
ppo_target: float = field(
|
176 |
+
default=6.0,
|
177 |
+
metadata={"help": "Target KL value for adaptive KL control in PPO training."},
|
178 |
+
)
|
179 |
+
ppo_whiten_rewards: bool = field(
|
180 |
+
default=False,
|
181 |
+
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
182 |
+
)
|
183 |
+
ref_model: Optional[str] = field(
|
184 |
+
default=None,
|
185 |
+
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
186 |
+
)
|
187 |
+
ref_model_adapters: Optional[str] = field(
|
188 |
+
default=None,
|
189 |
+
metadata={"help": "Path to the adapters of the reference model."},
|
190 |
+
)
|
191 |
+
ref_model_quantization_bit: Optional[int] = field(
|
192 |
+
default=None,
|
193 |
+
metadata={"help": "The number of bits to quantize the reference model."},
|
194 |
+
)
|
195 |
+
reward_model: Optional[str] = field(
|
196 |
+
default=None,
|
197 |
+
metadata={"help": "Path to the reward model used for the PPO training."},
|
198 |
+
)
|
199 |
+
reward_model_adapters: Optional[str] = field(
|
200 |
+
default=None,
|
201 |
+
metadata={"help": "Path to the adapters of the reward model."},
|
202 |
+
)
|
203 |
+
reward_model_quantization_bit: Optional[int] = field(
|
204 |
+
default=None,
|
205 |
+
metadata={"help": "The number of bits to quantize the reward model."},
|
206 |
+
)
|
207 |
+
reward_model_type: Literal["lora", "full", "api"] = field(
|
208 |
+
default="lora",
|
209 |
+
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
210 |
+
)
|
211 |
+
|
212 |
+
|
213 |
+
@dataclass
|
214 |
+
class GaloreArguments:
|
215 |
+
r"""
|
216 |
+
Arguments pertaining to the GaLore algorithm.
|
217 |
+
"""
|
218 |
+
|
219 |
+
use_galore: bool = field(
|
220 |
+
default=False,
|
221 |
+
metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."},
|
222 |
+
)
|
223 |
+
galore_target: str = field(
|
224 |
+
default="all",
|
225 |
+
metadata={
|
226 |
+
"help": (
|
227 |
+
"Name(s) of modules to apply GaLore. Use commas to separate multiple modules. "
|
228 |
+
"Use `all` to specify all the linear modules."
|
229 |
+
)
|
230 |
+
},
|
231 |
+
)
|
232 |
+
galore_rank: int = field(
|
233 |
+
default=16,
|
234 |
+
metadata={"help": "The rank of GaLore gradients."},
|
235 |
+
)
|
236 |
+
galore_update_interval: int = field(
|
237 |
+
default=200,
|
238 |
+
metadata={"help": "Number of steps to update the GaLore projection."},
|
239 |
+
)
|
240 |
+
galore_scale: float = field(
|
241 |
+
default=0.25,
|
242 |
+
metadata={"help": "GaLore scaling coefficient."},
|
243 |
+
)
|
244 |
+
galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field(
|
245 |
+
default="std",
|
246 |
+
metadata={"help": "Type of GaLore projection."},
|
247 |
+
)
|
248 |
+
galore_layerwise: bool = field(
|
249 |
+
default=False,
|
250 |
+
metadata={"help": "Whether or not to enable layer-wise update to further save memory."},
|
251 |
+
)
|
252 |
+
|
253 |
+
|
254 |
+
@dataclass
|
255 |
+
class BAdamArgument:
|
256 |
+
r"""
|
257 |
+
Arguments pertaining to the BAdam optimizer.
|
258 |
+
"""
|
259 |
+
|
260 |
+
use_badam: bool = field(
|
261 |
+
default=False,
|
262 |
+
metadata={"help": "Whether or not to use the BAdam optimizer."},
|
263 |
+
)
|
264 |
+
badam_mode: Literal["layer", "ratio"] = field(
|
265 |
+
default="layer",
|
266 |
+
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
267 |
+
)
|
268 |
+
badam_start_block: Optional[int] = field(
|
269 |
+
default=None,
|
270 |
+
metadata={"help": "The starting block index for layer-wise BAdam."},
|
271 |
+
)
|
272 |
+
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
273 |
+
default="ascending",
|
274 |
+
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
275 |
+
)
|
276 |
+
badam_switch_interval: Optional[int] = field(
|
277 |
+
default=50,
|
278 |
+
metadata={
|
279 |
+
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
280 |
+
},
|
281 |
+
)
|
282 |
+
badam_update_ratio: float = field(
|
283 |
+
default=0.05,
|
284 |
+
metadata={"help": "The ratio of the update for ratio-wise BAdam."},
|
285 |
+
)
|
286 |
+
badam_mask_mode: Literal["adjacent", "scatter"] = field(
|
287 |
+
default="adjacent",
|
288 |
+
metadata={
|
289 |
+
"help": (
|
290 |
+
"The mode of the mask for BAdam optimizer. "
|
291 |
+
"`adjacent` means that the trainable parameters are adjacent to each other, "
|
292 |
+
"`scatter` means that trainable parameters are randomly choosed from the weight."
|
293 |
+
)
|
294 |
+
},
|
295 |
+
)
|
296 |
+
badam_verbose: int = field(
|
297 |
+
default=0,
|
298 |
+
metadata={
|
299 |
+
"help": (
|
300 |
+
"The verbosity level of BAdam optimizer. "
|
301 |
+
"0 for no print, 1 for print the block prefix, 2 for print trainable parameters."
|
302 |
+
)
|
303 |
+
},
|
304 |
+
)
|
305 |
+
|
306 |
+
|
307 |
+
@dataclass
|
308 |
+
class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument):
|
309 |
+
r"""
|
310 |
+
Arguments pertaining to which techniques we are going to fine-tuning with.
|
311 |
+
"""
|
312 |
+
|
313 |
+
loss_weight_lm: int = field(
|
314 |
+
default=1,
|
315 |
+
metadata={"help": "Loss weight for the language next token prediction."},
|
316 |
+
)
|
317 |
+
loss_weight_design: int = field(
|
318 |
+
default=1,
|
319 |
+
metadata={"help": "Loss weight for molecular design."},
|
320 |
+
)
|
321 |
+
loss_weight_retro: int = field(
|
322 |
+
default=1,
|
323 |
+
metadata={"help": "Loss weight for retrosynthesis."},
|
324 |
+
)
|
325 |
+
|
326 |
+
pure_bf16: bool = field(
|
327 |
+
default=False,
|
328 |
+
metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."},
|
329 |
+
)
|
330 |
+
stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field(
|
331 |
+
default="sft",
|
332 |
+
metadata={"help": "Which stage will be performed in training."},
|
333 |
+
)
|
334 |
+
finetuning_type: Literal["lora", "freeze", "full"] = field(
|
335 |
+
default="lora",
|
336 |
+
metadata={"help": "Which fine-tuning method to use."},
|
337 |
+
)
|
338 |
+
use_llama_pro: bool = field(
|
339 |
+
default=False,
|
340 |
+
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
341 |
+
)
|
342 |
+
freeze_vision_tower: bool = field(
|
343 |
+
default=True,
|
344 |
+
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
|
345 |
+
)
|
346 |
+
train_mm_proj_only: bool = field(
|
347 |
+
default=False,
|
348 |
+
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
|
349 |
+
)
|
350 |
+
plot_loss: bool = field(
|
351 |
+
default=False,
|
352 |
+
metadata={"help": "Whether or not to save the training loss curves."},
|
353 |
+
)
|
354 |
+
|
355 |
+
def __post_init__(self):
|
356 |
+
def split_arg(arg):
|
357 |
+
if isinstance(arg, str):
|
358 |
+
return [item.strip() for item in arg.split(",")]
|
359 |
+
return arg
|
360 |
+
|
361 |
+
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
|
362 |
+
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
|
363 |
+
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
364 |
+
self.lora_target: List[str] = split_arg(self.lora_target)
|
365 |
+
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
366 |
+
self.galore_target: List[str] = split_arg(self.galore_target)
|
367 |
+
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
368 |
+
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
369 |
+
|
370 |
+
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
371 |
+
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
372 |
+
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
373 |
+
|
374 |
+
if self.stage == "ppo" and self.reward_model is None:
|
375 |
+
raise ValueError("`reward_model` is necessary for PPO training.")
|
376 |
+
|
377 |
+
if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora":
|
378 |
+
raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.")
|
379 |
+
|
380 |
+
if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6:
|
381 |
+
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
382 |
+
|
383 |
+
if self.use_llama_pro and self.finetuning_type == "full":
|
384 |
+
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
|
385 |
+
|
386 |
+
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
|
387 |
+
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
|
388 |
+
|
389 |
+
if self.use_galore and self.use_badam:
|
390 |
+
raise ValueError("Cannot use GaLore with BAdam together.")
|
391 |
+
|
392 |
+
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
393 |
+
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
394 |
+
|
395 |
+
if self.pissa_init and self.finetuning_type != "lora":
|
396 |
+
raise ValueError("`pissa_init` is only valid for LoRA training.")
|
397 |
+
|
398 |
+
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
399 |
+
raise ValueError("Cannot use PiSSA for current training stage.")
|
400 |
+
|
401 |
+
if self.train_mm_proj_only and self.finetuning_type != "full":
|
402 |
+
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|