diff --git a/.gitattributes b/.gitattributes index a6344aac8c09253b3b630fb776ae94478aa0275b..95780644263ecc62dd91867afd13c6f1fcc2d920 100644 --- a/.gitattributes +++ b/.gitattributes @@ -1,35 +1,4 @@ -*.7z filter=lfs diff=lfs merge=lfs -text -*.arrow filter=lfs diff=lfs merge=lfs -text -*.bin filter=lfs diff=lfs merge=lfs -text -*.bz2 filter=lfs diff=lfs merge=lfs -text -*.ckpt filter=lfs diff=lfs merge=lfs -text -*.ftz filter=lfs diff=lfs merge=lfs -text -*.gz filter=lfs diff=lfs merge=lfs -text -*.h5 filter=lfs diff=lfs merge=lfs -text -*.joblib filter=lfs diff=lfs merge=lfs -text -*.lfs.* filter=lfs diff=lfs merge=lfs -text -*.mlmodel filter=lfs diff=lfs merge=lfs -text -*.model filter=lfs diff=lfs merge=lfs -text -*.msgpack filter=lfs diff=lfs merge=lfs -text -*.npy filter=lfs diff=lfs merge=lfs -text -*.npz filter=lfs diff=lfs merge=lfs -text -*.onnx filter=lfs diff=lfs merge=lfs -text -*.ot filter=lfs diff=lfs merge=lfs -text -*.parquet filter=lfs diff=lfs merge=lfs -text -*.pb filter=lfs diff=lfs merge=lfs -text -*.pickle filter=lfs diff=lfs merge=lfs -text -*.pkl filter=lfs diff=lfs merge=lfs -text -*.pt filter=lfs diff=lfs merge=lfs -text -*.pth filter=lfs diff=lfs merge=lfs -text -*.rar filter=lfs diff=lfs merge=lfs -text -*.safetensors filter=lfs diff=lfs merge=lfs -text -saved_model/**/* filter=lfs diff=lfs merge=lfs -text -*.tar.* filter=lfs diff=lfs merge=lfs -text -*.tar filter=lfs diff=lfs merge=lfs -text -*.tflite filter=lfs diff=lfs merge=lfs -text -*.tgz filter=lfs diff=lfs merge=lfs -text -*.wasm filter=lfs diff=lfs merge=lfs -text -*.xz filter=lfs diff=lfs merge=lfs -text -*.zip filter=lfs diff=lfs merge=lfs -text -*.zst filter=lfs diff=lfs merge=lfs -text -*tfevents* filter=lfs diff=lfs merge=lfs -text +# Auto detect text files and perform LF normalization +* text=auto +assets/llamole.png filter=lfs diff=lfs merge=lfs -text +assets/ui_example.png filter=lfs diff=lfs merge=lfs -text diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..e65bbe3e61c5d2c38a82a44df6ea59e2ab926331 --- /dev/null +++ b/.gitignore @@ -0,0 +1,172 @@ +# Byte-compiled / optimized / DLL files +__pycache__/ +*.py[cod] +*$py.class + +# C extensions +*.so + +# Distribution / packaging +.Python +build/ +develop-eggs/ +dist/ +downloads/ +eggs/ +.eggs/ +lib/ +lib64/ +parts/ +sdist/ +var/ +wheels/ +share/python-wheels/ +*.egg-info/ +.installed.cfg +*.egg +MANIFEST + +# PyInstaller +# Usually these files are written by a python script from a template +# before PyInstaller builds the exe, so as to inject date/other infos into it. +*.manifest +*.spec + +# Installer logs +pip-log.txt +pip-delete-this-directory.txt + +# Unit test / coverage reports +htmlcov/ +.tox/ +.nox/ +.coverage +.coverage.* +.cache +nosetests.xml +coverage.xml +*.cover +*.py,cover +.hypothesis/ +.pytest_cache/ +cover/ + +# Translations +*.mo +*.pot + +# Django stuff: +*.log +local_settings.py +db.sqlite3 +db.sqlite3-journal + +# Flask stuff: +instance/ +.webassets-cache + +# Scrapy stuff: +.scrapy + +# Sphinx documentation +docs/_build/ + +# PyBuilder +.pybuilder/ +target/ + +# Jupyter Notebook +.ipynb_checkpoints + +# IPython +profile_default/ +ipython_config.py + +# pyenv +# For a library or package, you might want to ignore these files since the code is +# intended to run in multiple environments; otherwise, check them in: +# .python-version + +# pipenv +# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control. +# However, in case of collaboration, if having platform-specific dependencies or dependencies +# having no cross-platform support, pipenv may install dependencies that don't work, or not +# install all needed dependencies. +#Pipfile.lock + +# poetry +# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control. +# This is especially recommended for binary packages to ensure reproducibility, and is more +# commonly ignored for libraries. +# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control +#poetry.lock + +# pdm +# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control. +#pdm.lock +# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it +# in version control. +# https://pdm.fming.dev/#use-with-ide +.pdm.toml + +# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm +__pypackages__/ + +# Celery stuff +celerybeat-schedule +celerybeat.pid + +# SageMath parsed files +*.sage.py + +# Environments +.env +.venv +env/ +venv/ +ENV/ +env.bak/ +venv.bak/ + +# Spyder project settings +.spyderproject +.spyproject + +# Rope project settings +.ropeproject + +# mkdocs documentation +/site + +# mypy +.mypy_cache/ +.dmypy.json +dmypy.json + +# Pyre type checker +.pyre/ + +# pytype static type analyzer +.pytype/ + +# Cython debug symbols +cython_debug/ + +# PyCharm +# JetBrains specific template is maintained in a separate JetBrains.gitignore that can +# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore +# and can be added to the global gitignore or merged into this file. For a more nuclear +# option (not recommended) you can uncomment the following to ignore the entire idea folder. +.idea/ + +# custom .gitignore +cache/ +output/ +wandb/ +# Ignore everything in data directory +run_script* + +saves/* +!saves/README.md + +.gradio \ No newline at end of file diff --git a/LICENSE b/LICENSE new file mode 100644 index 0000000000000000000000000000000000000000..b09cd7856d58590578ee1a4f3ad45d1310a97f87 --- /dev/null +++ b/LICENSE @@ -0,0 +1,201 @@ +Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/README.md b/README.md index f71dbcd28ab6d925eb7db48e7e9739d92d410223..d0757408fe88552cf1840fb904e5a7fdaa0ed4a2 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,190 @@ --- title: Llamole -emoji: 🐢 -colorFrom: gray -colorTo: purple +app_file: launch.py sdk: gradio sdk_version: 5.15.0 -app_file: app.py -pinned: false --- + -Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference +
+ +# 🔬 🧪 Llamole: Multimodal Large Language Models for Inverse Molecular Design with Retrosynthetic Planning + +[![](https://img.shields.io/badge/Paper-pink?style=plastic&logo=arXiv)](https://arxiv.org/abs/2410.04223) +[![](https://img.shields.io/badge/GitHub-blue?style=plastic&logo=github)](https://github.com/liugangcode/Llamole) +[![](https://img.shields.io/badge/Huggingface-orange?style=plastic&logo=huggingface)](https://huggingface.co/collections/liuganghuggingface/llamole-collection-67073a2e76b47d5fea909434) + +
+ +## 🌟 Introduction + +**Llamole** is a multimodal Large Language Model (LLM) that integrates a base LLM with the Graph Diffusion Transformer and Graph Neural Networks for multi-conditional molecular generation and multi-step reaction inference within texts. + +📄 **[Paper](https://arxiv.org/abs/2410.04223)**: *Multimodal Large Language Models for Inverse Molecular Design with Retrosynthetic Planning* + +
+ +
+ 🔍 Abstract +

+ While large language models (LLMs) have integrated images, adapting them to graphs remains challenging, limiting their applications in materials and drug design. This difficulty stems from the need for coherent autoregressive generation across texts and graphs. To address this, we introduce Llamole, the first multimodal LLM capable of interleaved text and graph generation, enabling molecular inverse design with retrosynthetic planning. Llamole integrates a base LLM with the Graph Diffusion Transformer and Graph Neural Networks for multi-conditional molecular generation and reaction inference within texts, while the LLM, with enhanced molecular understanding, flexibly controls activation among the different graph modules. Additionally, Llamole integrates A* search with LLM-based cost functions for efficient retrosynthetic planning. We create benchmarking datasets and conduct extensive experiments to evaluate Llamole against in-context learning and supervised fine-tuning. Llamole significantly outperforms 14 adapted LLMs across 12 metrics for controllable molecular design and retrosynthetic planning. +

+
+ +Llamole Design + +
+ +--- + +## 🛠️ Environment Setup + +Initialize the environment by following these steps: + +```bash +conda create --name llamole python=3.11 -y +conda activate llamole +./install_environment.sh +``` + +Alternatively, you can install all required dependencies using the `requirements.sh` script. + +--- + +## 🚀 Model Usage + +### 🔧 Requirements + +- **Hardware**: A single V100 or A6000 GPU for inference. +- **Configuration Files**: + - `config/train/{model}_lora.yaml` + - `config/generate/{model}_{task}.yaml` + +### 📥 Automatic Model Download + +On the first run, the necessary models will be automatically downloaded, including: + +1. **Base LLMs** (Please ensure you have access to the model): + - [Llama-3.1-8b-Instruct](https://huggingface.co/meta-llama/Llama-3.1-8B) + - [Qwen2-7B-Instruct](https://huggingface.co/Qwen/Qwen2-7B-Instruct) + - [Mistral-7B-Instruct-v0.3](https://huggingface.co/mistralai/Mistral-7B-Instruct-v0.3) + +2. **Pretrained Graph Models**: + - **Graph Decoder**: [Graph Diffusion Transformer](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphDiT) + - **Graph Encoder**: [GIN-based Encoder](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphEncoder) + - **Graph Predictor**: [GIN-based Predictor](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GNNPredictor) + +3. **Adapters and Connectors** for integrating the base LLM with pretrained graph models. + +### 🗂️ Manual Model Download + +If you prefer to download the models manually, refer to [![](https://img.shields.io/badge/Huggingface-orange?style=plastic&logo=huggingface)](https://huggingface.co/collections/liuganghuggingface/llamole-collection-67073a2e76b47d5fea909434) and place them in the following directories: + +- `saves/graph_decoder` +- `saves/graph_encoder` +- `saves/graph_predictor` +- `saves/{model_name}-Adapter` + +--- + +### 🕸️ Gradio-based Web UI + +Launch the web interface using Gradio: + +```bash +python launch.py +``` + +The default base LLM is **Qwen2-7B-Instruct**. If you wish to change this, please modify the `args_dict` variable accordingly. Upon launch, the web UI will appear as shown below: + +
+ Web UI Example +
+ + +--- + +### 💻 Command Line Usage + +For command-line evaluation, specify the path to the configuration file: + +```bash +python main.py eval config/generate/qwen_material.yaml +``` + +You can modify the configuration files to suit your custom datasets. + +**Note**: Examples of training and evaluation datasets are available in the `data` folder. For more details, refer to `data/dataset_info.json`. To test generation on all MolQA questions, first download the dataset by running: + +```bash +python main.py download_data +``` + +Then, update the configuration files to point to the downloaded dataset based on the names from `data/dataset_info.json`. + +--- + +## 📚 Supervised Fine-Tuning + +The codebase supports multimodal graph-text supervised fine-tuning. Follow these steps: + +1. **Download MolQA Training Data**: + + ```bash + python main.py download_data + ``` + Then you may need to modify the configuration files in the `config` folder to point to the downloaded training data. Skipping this step and directly using the command from step 2 will result in training only on the example training set. + +2. **Run Fine-Tuning**: + + ```bash + python main.py train config/train/mistral_lora.yaml + ``` + + During the first run, pretrained graph models will be downloaded in the `saves` folder. Modify the configuration files as needed for your setup. An 80G A100 GPU is recommended for supervised fine-tuning. + +--- +## 📖 Citation + +If you find this repository useful, please cite our paper: + +``` +@misc{liu2024llamole, + title={Multimodal Large Language Models for Inverse Molecular Design with Retrosynthetic Planning}, + author={Gang Liu and Michael Sun and Wojciech Matusik and Meng Jiang and Jie Chen}, + year={2024}, + eprint={2410.04223}, + archivePrefix={arXiv}, + primaryClass={cs.LG}, + url={https://arxiv.org/abs/2410.04223}, +} +``` + +``` +@article{liu2024graphdit, + title={Graph Diffusion Transformers for Multi-Conditional Molecular Generation}, + author={Liu, Gang and Xu, Jiaxin and Luo, Tengfei and Jiang, Meng}, + journal={Thirty-Eighth Annual Conference on Neural Information Processing Systems}, + year={2024} +} +``` + +--- + +## 📄 Acknowledgments + +This codebase is built upon **[Llama-Factory](https://github.com/hiyouga/LLaMA-Factory)**. We extend our gratitude for their open-source contributions. + +--- + +## 📂 Additional Resources + +🔗 **Huggingface Models**: Llamole is developed with three variants (adapters) and three pretrained graph modules (encoder, decoder, predictor): +- **Base LLM Variant 1**: [Llama-3.1-8b-Instruct](https://huggingface.co/liuganghuggingface/Llamole-Llama-3.1-8B-Instruct-Adapter) +- **Base LLM Variant 2**: [Qwen2-7B-Instruct](https://huggingface.co/liuganghuggingface/Llamole-Qwen2-7B-Instruct-Adapter) +- **Base LLM Variant 3**: [Mistral-7B-Instruct-v0.3](https://huggingface.co/liuganghuggingface/Llamole-Mistral-7B-Instruct-v0.3-Adapter) +- **Pretrained Graph Decoder** for multi-conditional molecular generation: [Graph Diffusion Transformer](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphDiT) +- **Pretrained Graph Predictor** for one-step reaction prediction: [GNN Predictor](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GNNPredictor) +- **Pretrained Graph Encoder** for enhanced molecule understanding: [Graph Encoder](https://huggingface.co/liuganghuggingface/Llamole-Pretrained-GraphEncoder) + +--- diff --git a/assets/llamole.png b/assets/llamole.png new file mode 100644 index 0000000000000000000000000000000000000000..219559fac13f78692b09fc18d490de6481a6f7eb --- /dev/null +++ b/assets/llamole.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:8f98ad05d897399cd4db1c618f78509e19b0b832f3aa2d938121bbc5a6e77166 +size 771154 diff --git a/assets/ui_example.png b/assets/ui_example.png new file mode 100644 index 0000000000000000000000000000000000000000..f45c02874b33de5f64a5daf3188e181a0ab407db --- /dev/null +++ b/assets/ui_example.png @@ -0,0 +1,3 @@ +version https://git-lfs.github.com/spec/v1 +oid sha256:36a5f766fa0f9da927fea34d046939d88fc07ddbdaa3f35244f5bf6ea873b246 +size 150578 diff --git a/config/generate/llama_drug.yaml b/config/generate/llama_drug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bae79f29709d14eed525cd369ae70d3e6d27d87b --- /dev/null +++ b/config/generate/llama_drug.yaml @@ -0,0 +1,30 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct +new_special_tokens: ,,,,,,,, +graph_decoder_path: saves/graph_decoder +graph_encoder_path: saves/graph_encoder +graph_predictor_path: saves/graph_predictor +adapter_name_or_path: saves/Llama-3.1-8B-Instruct-Adapter +graph_lm_connector_path: saves/Llama-3.1-8B-Instruct-Adapter/connector +### generation +max_length: 512 +max_new_tokens: 128 +temperature: 0.6 +top_p: 0.9 +### method +finetuning_type: lora +do_train: false +flash_attn: disabled +learned_query_size: 8 +### dataset +dataset: molqa_drug_examples +template: llama3 +cutoff_len: 128 +overwrite_cache: true +preprocessing_num_workers: 16 +output_dir: null +bf16: true +pure_bf16: true +### eval +per_device_eval_batch_size: 6 +report_to: 'none' \ No newline at end of file diff --git a/config/generate/llama_material.yaml b/config/generate/llama_material.yaml new file mode 100644 index 0000000000000000000000000000000000000000..e06e2c848f08c5dd77ee58431d22e26ad941f7c8 --- /dev/null +++ b/config/generate/llama_material.yaml @@ -0,0 +1,30 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct +new_special_tokens: ,,,,,,,, +graph_decoder_path: saves/graph_decoder +graph_encoder_path: saves/graph_encoder +graph_predictor_path: saves/graph_predictor +adapter_name_or_path: saves/Llama-3.1-8B-Instruct-Adapter +graph_lm_connector_path: saves/Llama-3.1-8B-Instruct-Adapter/connector +### generation +max_length: 512 +max_new_tokens: 128 +temperature: 0.6 +top_p: 0.9 +### method +finetuning_type: lora +do_train: false +flash_attn: disabled +learned_query_size: 8 +### dataset +dataset: molqa_material_examples +template: llama3 +cutoff_len: 128 +overwrite_cache: true +preprocessing_num_workers: 16 +output_dir: null +bf16: true +pure_bf16: true +### eval +per_device_eval_batch_size: 6 +report_to: 'none' \ No newline at end of file diff --git a/config/generate/mistral_drug.yaml b/config/generate/mistral_drug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..bc6ad005380b534e5e75bdedeed3b540882a64cd --- /dev/null +++ b/config/generate/mistral_drug.yaml @@ -0,0 +1,30 @@ +### model +model_name_or_path: mistralai/Mistral-7B-Instruct-v0.3 +new_special_tokens: ,,,,,,,, +graph_decoder_path: saves/graph_decoder +graph_encoder_path: saves/graph_encoder +graph_predictor_path: saves/graph_predictor +adapter_name_or_path: saves/Mistral-7B-Instruct-v0.3-Adapter +graph_lm_connector_path: saves/Mistral-7B-Instruct-v0.3-Adapter/connector +### generation +max_length: 512 +max_new_tokens: 128 +temperature: 0.6 +top_p: 0.9 +### method +finetuning_type: lora +do_train: false +flash_attn: disabled +learned_query_size: 8 +### dataset +dataset: molqa_drug_examples +template: mistral +cutoff_len: 128 +overwrite_cache: true +preprocessing_num_workers: 16 +output_dir: null +bf16: true +pure_bf16: true +### eval +per_device_eval_batch_size: 6 +report_to: 'none' \ No newline at end of file diff --git a/config/generate/mistral_material.yaml b/config/generate/mistral_material.yaml new file mode 100644 index 0000000000000000000000000000000000000000..5f578dbf11926055588cc2fe2c449ee9b9d7e1d6 --- /dev/null +++ b/config/generate/mistral_material.yaml @@ -0,0 +1,30 @@ +### model +model_name_or_path: mistralai/Mistral-7B-Instruct-v0.3 +new_special_tokens: ,,,,,,,, +graph_decoder_path: saves/graph_decoder +graph_encoder_path: saves/graph_encoder +graph_predictor_path: saves/graph_predictor +adapter_name_or_path: saves/Mistral-7B-Instruct-v0.3-Adapter +graph_lm_connector_path: saves/Mistral-7B-Instruct-v0.3-Adapter/connector +### generation +max_length: 512 +max_new_tokens: 128 +temperature: 0.6 +top_p: 0.9 +### method +finetuning_type: lora +do_train: false +flash_attn: disabled +learned_query_size: 8 +### dataset +dataset: molqa_material_examples +template: mistral +cutoff_len: 128 +overwrite_cache: true +preprocessing_num_workers: 16 +output_dir: null +bf16: true +pure_bf16: true +### eval +per_device_eval_batch_size: 6 +report_to: 'none' \ No newline at end of file diff --git a/config/generate/qwen_drug.yaml b/config/generate/qwen_drug.yaml new file mode 100644 index 0000000000000000000000000000000000000000..b9babecd750cb21f5167c56e47e9f9290c3f6f05 --- /dev/null +++ b/config/generate/qwen_drug.yaml @@ -0,0 +1,30 @@ +### model +model_name_or_path: Qwen/Qwen2-7B-Instruct +new_special_tokens: ,,,,,,,, +graph_decoder_path: saves/graph_decoder +graph_encoder_path: saves/graph_encoder +graph_predictor_path: saves/graph_predictor +adapter_name_or_path: saves/Qwen2-7B-Instruct-Adapter +graph_lm_connector_path: saves/Qwen2-7B-Instruct-Adapter/connector +### generation +max_length: 512 +max_new_tokens: 128 +temperature: 0.6 +top_p: 0.9 +### method +finetuning_type: lora +do_train: false +flash_attn: disabled +learned_query_size: 8 +### dataset +dataset: molqa_drug_examples +template: qwen +cutoff_len: 128 +overwrite_cache: true +preprocessing_num_workers: 16 +output_dir: null +bf16: true +pure_bf16: true +### eval +per_device_eval_batch_size: 6 +report_to: 'none' \ No newline at end of file diff --git a/config/generate/qwen_material.yaml b/config/generate/qwen_material.yaml new file mode 100644 index 0000000000000000000000000000000000000000..71ea710bd40c5768d2be0871fb59107e881dd37c --- /dev/null +++ b/config/generate/qwen_material.yaml @@ -0,0 +1,30 @@ +### model +model_name_or_path: Qwen/Qwen2-7B-Instruct +new_special_tokens: ,,,,,,,, +graph_decoder_path: saves/graph_decoder +graph_encoder_path: saves/graph_encoder +graph_predictor_path: saves/graph_predictor +adapter_name_or_path: saves/Qwen2-7B-Instruct-Adapter +graph_lm_connector_path: saves/Qwen2-7B-Instruct-Adapter/connector +### generation +max_length: 512 +max_new_tokens: 128 +temperature: 0.6 +top_p: 0.9 +### method +finetuning_type: lora +do_train: false +flash_attn: disabled +learned_query_size: 8 +### dataset +dataset: molqa_material_examples +template: qwen +cutoff_len: 128 +overwrite_cache: true +preprocessing_num_workers: 16 +output_dir: null +bf16: true +pure_bf16: true +### eval +per_device_eval_batch_size: 6 +report_to: 'none' \ No newline at end of file diff --git a/config/train/llama_lora.yaml b/config/train/llama_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..17543117e02510c1190a3e3d797f0ccd447e1919 --- /dev/null +++ b/config/train/llama_lora.yaml @@ -0,0 +1,49 @@ +### model +model_name_or_path: meta-llama/Meta-Llama-3.1-8B-Instruct +new_special_tokens: ,,,,,,,, +graph_decoder_path: saves/graph_decoder +graph_encoder_path: saves/graph_encoder +graph_predictor_path: saves/graph_predictor +graph_lm_connector_path: null #specify when resume +### method +stage: mmsft +do_train: true +finetuning_type: lora +lora_target: all +flash_attn: disabled +learned_query_size: 8 +### dataset +dataset: molqa_train_examples +template: llama3 +cutoff_len: 2048 +overwrite_cache: true +preprocessing_num_workers: 16 +### output +output_dir: saves/Llama-3.1-8B-Instruct-Adapter +logging_steps: 10 +save_steps: 2000 +plot_loss: true +overwrite_output_dir: true +### to resume +# overwrite_output_dir: false +### train +per_device_train_batch_size: 10 +gradient_accumulation_steps: 2 +learning_rate: 1.0e-4 +num_train_epochs: 4.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +ddp_timeout: 180000000 +bf16: true +pure_bf16: true +### train loss +loss_weight_retro: 1 +loss_weight_design: 1 +loss_weight_lm: 1 +### eval +val_size: 0.1 +per_device_eval_batch_size: 6 +eval_strategy: steps +eval_steps: 2000 +### specify if connected to wandb +report_to: 'none' \ No newline at end of file diff --git a/config/train/mistral_lora.yaml b/config/train/mistral_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..a2446b9b2ad1c81d493a3d1d18cf087c2154a1c2 --- /dev/null +++ b/config/train/mistral_lora.yaml @@ -0,0 +1,48 @@ +### model +model_name_or_path: mistralai/Mistral-7B-Instruct-v0.3 +new_special_tokens: ,,,,,,,, +graph_decoder_path: saves/graph_decoder +graph_encoder_path: saves/graph_encoder +graph_predictor_path: saves/graph_predictor +graph_lm_connector_path: null #specify when resume +### method +stage: mmsft +do_train: true +finetuning_type: lora +lora_target: all +flash_attn: disabled +learned_query_size: 8 +### dataset +dataset: molqa_train_examples +template: mistral +cutoff_len: 2048 +overwrite_cache: true +preprocessing_num_workers: 16 +### output +output_dir: saves/Mistral-7B-Instruct-v0.3-Adapter +logging_steps: 10 +save_steps: 2000 +plot_loss: true +overwrite_output_dir: true +### to resume +# overwrite_output_dir: false +### train +per_device_train_batch_size: 6 +gradient_accumulation_steps: 2 +learning_rate: 1.0e-4 +num_train_epochs: 4.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +ddp_timeout: 180000000 +bf16: true +pure_bf16: true +loss_weight_retro: 1 +loss_weight_design: 1 +loss_weight_lm: 1 +### eval +val_size: 0.1 +per_device_eval_batch_size: 6 +eval_strategy: steps +eval_steps: 2000 +### specify if connected to wandb +report_to: 'none' \ No newline at end of file diff --git a/config/train/qwen_lora.yaml b/config/train/qwen_lora.yaml new file mode 100644 index 0000000000000000000000000000000000000000..4110f3b8c27b69f7931097bad04f290c2581e7f8 --- /dev/null +++ b/config/train/qwen_lora.yaml @@ -0,0 +1,48 @@ +### model +model_name_or_path: Qwen/Qwen2-7B-Instruct +new_special_tokens: ,,,,,,,, +graph_decoder_path: saves/graph_decoder +graph_encoder_path: saves/graph_encoder +graph_predictor_path: saves/graph_predictor +graph_lm_connector_path: null #specify when resume +### method +stage: mmsft +do_train: true +finetuning_type: lora +lora_target: all +flash_attn: disabled +learned_query_size: 8 +### dataset +dataset: molqa_train_examples +template: qwen +cutoff_len: 2048 +overwrite_cache: true +preprocessing_num_workers: 16 +### output +output_dir: saves/Qwen2-7B-Instruct-Adapter +logging_steps: 10 +save_steps: 2000 +plot_loss: true +overwrite_output_dir: true +### to resume +# overwrite_output_dir: false +### train +per_device_train_batch_size: 6 +gradient_accumulation_steps: 2 +learning_rate: 1.0e-4 +num_train_epochs: 4.0 +lr_scheduler_type: cosine +warmup_ratio: 0.1 +ddp_timeout: 180000000 +bf16: true +pure_bf16: true +loss_weight_retro: 1 +loss_weight_design: 1 +loss_weight_lm: 1 +### eval +val_size: 0.1 +per_device_eval_batch_size: 6 +eval_strategy: steps +eval_steps: 2000 +### specify if connected to wandb +report_to: 'none' \ No newline at end of file diff --git a/data/dataset_info.json b/data/dataset_info.json new file mode 100644 index 0000000000000000000000000000000000000000..65736361cfb62706c4e026bacba8a2e9ab135805 --- /dev/null +++ b/data/dataset_info.json @@ -0,0 +1,20 @@ +{ + "molqa_train_examples": { + "file_name": "molqa_train_examples.json" + }, + "molqa_material_examples": { + "file_name": "molqa_material_examples.json" + }, + "molqa_drug_examples": { + "file_name": "molqa_drug_examples.json" + }, + "molqa_train": { + "file_name": "molqa_train.json" + }, + "molqa_material": { + "file_name": "molqa_material.json" + }, + "molqa_drug": { + "file_name": "molqa_drug.json" + } +} \ No newline at end of file diff --git a/data/molqa_drug_examples.json b/data/molqa_drug_examples.json new file mode 100644 index 0000000000000000000000000000000000000000..5a5e37dfa0077ef4a349aa7ee70fd4e0d19fa07e --- /dev/null +++ b/data/molqa_drug_examples.json @@ -0,0 +1,75 @@ +[ + { + "instruction": "Can you design and synthesize a molecule that inhibits both Human Immunodeficiency Virus and Beta-Secretase 1, with a molecular weight of around 335 and two aromatic rings connected to aliphatic chains?", + "input": "", + "output": "To satisfy the requirements: The molecule has a moderately complex structure, consisting of aromatic rings connected to aliphatic chains, and featuring carboxyl and hydroxyl functional groups. Therefore, the designed molecule is: CC(=O)OCCC(C)Oc1ccc(Cl)cc1Oc1ccccc1. This is step 1 in the retrosynthesis process. To synthesize CC(=O)OCCC(C)Oc1ccc(Cl)cc1Oc1ccccc1, follow these procedures: A solution of (R)-acetic acid 3-(toluene-4-sulfonyloxy)-butyl ester (5.89 g, 21 mmol) and 4-chloro-2-phenoxy-phenol (5.0 g, 23 mmol) in DMF (50 mL) is treated with cesium carbonate (7.4 g, 23 mmol). The solution is heated to 60\u00b0 C. and stirred overnight. The reaction is cooled and quenched with 1N HCl. The solution is partitioned in EtOAc and water. The organic is separated, washed with brine, and dried over sodium sulfate. The organic is filtered, and the solvent is removed to afford acetic acid 3-(4-chloro-2-phenoxy-phenoxy)-butyl ester, which is then diluted in methanol (100 mL) and treated with potassium carbonate (5.68 g, 40 mmol). The reaction is stirred for 2 hours at rt. The reaction is then partitioned in EtOAc and water. The organic layer is separated, washed with brine, and dried over sodium sulfate. The organic is filtered and the solvent is removed to afford the crude product. The crude is purified by silica gel column chromatography using 1/1 hexanes/EtOAc to elute the pure product. The solvent is removed to afford 4.35 g (72%) of the desired product. 1H NMR (400 MHz, CDCl3); MS (ES+) m/z mass calcd for C16H17ClO3 292, found 293 (M+1, 100%). The applied reaction is: CC(=O)OCCC(C)Oc1ccc(Cl)cc1Oc1ccccc1>>CC(=O)OCCC(C)OS(=O)(=O)c1ccc(C)cc1.O=C(O)O.Oc1ccc(Cl)cc1Oc1ccccc1.[Cs+].[Cs+], which requires the reactants: CC(=O)OCCC(C)OS(=O)(=O)c1ccc(C)cc1 (available), O=C(O)O (available), Oc1ccc(Cl)cc1Oc1ccccc1 (not available, needs to be synthesized), [Cs+] (available), [Cs+] (available). This is step 2 in the retrosynthesis process. To synthesize Oc1ccc(Cl)cc1Oc1ccccc1, follow these procedures: A \u221240\u00b0 C. solution of 4-chloro-2-phenoxy-1-methoxy-benzene (7.16 g, 30.5 mmol) in dry CH2Cl2 (70 mL) is treated dropwise with borontribromide (22.9 g, 91.5 mmol) and then warmed to 0\u00b0 C. and stirred for 3 h under N2. The reaction is diluted with Et2O and quenched with water. The organic layer is dried (Na2SO4), and the solvent is removed in vacuo to afford 7.11 g (100%) of the title compound. Rf=0.30 (4/1 hexanes/acetone). 1H NMR (400 MHz, CDCl3); MS (ES\u2212) m/z mass calcd for C12H9O2Cl 220, found 219 (M\u22121, 100%). The applied reaction is: Oc1ccc(Cl)cc1Oc1ccccc1>>COc1ccc(Cl)cc1Oc1ccccc1, which requires the reactants: COc1ccc(Cl)cc1Oc1ccccc1 (not available, needs to be synthesized). This is step 3 in the retrosynthesis process. To synthesize COc1ccc(Cl)cc1Oc1ccccc1, follow these procedures: A mixture of 2-bromo-4-chloro-1-methoxy-benzene (8.0 g, 36.1 mmol), phenol (6.80 g, 72.2 mmol), cesium carbonate (23.54 g, 72.2 mmol), copper (I) chloride (1.79 g, 18.1 mmol) and 2,2,6,6-tetramethyl-3,5-heptanedione (1.66 g, 9.00 mmol) in dry 1-methyl-2-pyrrolidinone (80 mL) is heated to 120\u00b0 C. for 20 hours under N2. The reaction is cooled, filtered and the filtrate quenched with 1 N HCl (50 mL). The filtrate is diluted with Et2O and extracted with water. The organic layer is dried (Na2SO4), and the solvent is removed in vacuo to afford crude product that is absorbed on silica gel and purified by flash chromatography using 9/1 hexanes/EtOAc to afford 7.42 g (88%) of the title compound. Rf=0.37 (4/1 hexanes/EtOAc). The applied reaction is: COc1ccc(Cl)cc1Oc1ccccc1>>CC(C)(C)C(=O)CC(=O)C(C)(C)C.COc1ccc(Cl)cc1Br.O=C(O)O.[Cs+].[Cs+].[O-]c1ccccc1, which requires the reactants: CC(C)(C)C(=O)CC(=O)C(C)(C)C (available), COc1ccc(Cl)cc1Br (available), O=C(O)O (available), [Cs+] (available), [Cs+] (available), [O-]c1ccccc1 (available).", + "property": { + "HIV": 1.0, + "BACE": 1.0, + "SC": 3.1, + "SA": 2.46 + }, + "retro": [ + null, + null, + null + ] + }, + { + "instruction": "What is a possible molecular structure that can penetrate the blood-brain barrier and inhibit Human Immunodeficiency Virus, given a molecular weight of 209.249 and specific ring and bond characteristics?", + "input": "", + "output": "To satisfy the requirements: The molecule has an aromatic ring and a common functional group of carboxyl, which suggests a certain level of structural complexity, likely influencing its physical and chemical properties, and indicating a moderate synthetic complexity score. Therefore, the designed molecule is: CC(C)(C)OC(=O)Nc1ccc(N)cn1. This is step 1 in the retrosynthesis process. To synthesize CC(C)(C)OC(=O)Nc1ccc(N)cn1, follow these procedures: To a suspension of (tert-butoxy)-N-(5-nitro(2-pyridyl))carboxamide (0.27 g, 1.13 mmol) in methanol (2 mL), ethyl acetate (4 mL) and TEA (0.16 mL) was added 10% Pd/C (60 mg, 0.056 mmol) under argon. The reaction mixture was hydrogenated under 1 atm H2 for 20 hr, filtered through Celite and concentrated in vacuo to give N-(5-amino(2-pyridyl))(tert-butoxy)carboxamide (0.226 g, 97%). 1H-NMR (DMSO-d6): \u03b4 1.40 (s, 9H), 4.92 (br s, 2H), 6.89-6.91 (dd, 1H), 7.35-7.37 (d, 1H), 7.58 (d, 1H), 9.06 (s, 1H). The applied reaction is: CC(C)(C)OC(=O)Nc1ccc(N)cn1>>CC(C)(C)OC(=O)Nc1ccc([N+](=O)[O-])cn1 with the template [#7;a:4]:[c:3]:[c:2]-[NH2;D1;+0:1]>>O=[N+;H0;D3:1](-[O-])-[c:2]:[c:3]:[#7;a:4], which requires the reactants: CC(C)(C)OC(=O)Nc1ccc([N+](=O)[O-])cn1 (available).", + "property": { + "BBBP": 1.0, + "HIV": 1.0, + "SC": 2.03, + "SA": 2.08 + }, + "retro": [ + 2965 + ] + }, + { + "instruction": "What molecular structure can be designed to exhibit Blood-Brain Barrier Penetration and Human Immunodeficiency Virus Inhibition while featuring a specific combination of molecular weight, ring count, and functional groups?", + "input": "", + "output": "To satisfy the requirements: The molecule features a complex structure with two aromatic rings and a high number of rotatable bonds, indicating potential difficulties in synthesis. Its chemical properties are influenced by the presence of functional groups such as carboxyl and amino groups. The predicted properties suggest high substrate specificity for CYP2C9 and moderate substrate specificity for CYP3A4, while also indicating a low risk of carcinogenicity and high clearance in microsomes. The molecule's hydration free energy and skin reaction predictions are also noteworthy. Overall, the molecule's unique combination of structural features and predicted chemical properties makes it an interesting candidate for further investigation. Therefore, the designed molecule is: CC(=O)c1sc(NS(=O)(=O)c2cc(Br)c(Cl)s2)nc1C. This is step 1 in the retrosynthesis process. To synthesize CC(=O)c1sc(NS(=O)(=O)c2cc(Br)c(Cl)s2)nc1C, follow these procedures: The title compound was prepared from 5-acety1-2-amino-4-methylthiazole (42 mg) and 4-bromo-5-chlorothiophene-2-sulfonyl chloride (80 mg) as described in the synthetic METHOD B to give a white solid (11.7 mg) with purity >90%: MS (pos) m/z 415.3, 417.3. The applied reaction is: CC(=O)c1sc(NS(=O)(=O)c2cc(Br)c(Cl)s2)nc1C>>CC(=O)c1sc(N)nc1C.O=S(=O)(Cl)c1cc(Br)c(Cl)s1 with the template [#16;a:5]:[c:4]-[S;H0;D4;+0:1](=[O;D1;H0:2])(=[O;D1;H0:3])-[NH;D2;+0:9]-[c:7](:[#16;a:6]):[#7;a:8]>>Cl-[S;H0;D4;+0:1](=[O;D1;H0:2])(=[O;D1;H0:3])-[c:4]:[#16;a:5].[#16;a:6]:[c:7](:[#7;a:8])-[NH2;D1;+0:9], which requires the reactants: CC(=O)c1sc(N)nc1C (available), O=S(=O)(Cl)c1cc(Br)c(Cl)s1 (available).", + "property": { + "BBBP": 1.0, + "HIV": 1.0, + "SC": 2.83, + "SA": 2.74 + }, + "retro": [ + 145329 + ] + }, + { + "instruction": "Can you design a molecule that penetrates the blood-brain barrier, inhibits Human Immunodeficiency Virus, and doesn't inhibit Beta-Secretase 1, while having a specific set of structural properties?", + "input": "", + "output": "To satisfy the requirements: The molecule exhibits a moderate level of structural complexity, featuring three aromatic rings and an aliphatic chain, with a molecular weight of 216.27. Its synthetic complexity score is 2.88, indicating it may be moderately challenging to synthesize. This structure is predicted to have a relatively low likelihood of mutagenicity, with an AMES score of 0.608. It also shows moderate to low likelihood of blood-brain barrier penetration, with a BBB_Martins score of 0.576. The molecule's oral bioavailability is predicted to be moderate, with a Bioavailability_Ma score of 0.797. Additionally, it is predicted to have a low to moderate interaction with various cytochrome P450 enzymes. Overall, the molecule's drug utility is influenced by its structural features, bioactivity, and pharmacokinetic properties, which suggest it may have potential as a therapeutic agent. Therefore, the designed molecule is: c1csc(Nc2nccn3ccnc23)c1. This is step 1 in the retrosynthesis process. To synthesize c1csc(Nc2nccn3ccnc23)c1, follow these procedures: To a solution of (3-bromo-imidazo[1,2-a]pyrazin-8-yl)-methyl-amine (50 mg, 0.2 mmol) in tetrahydrofuran (2 ml) under nitrogen was added 2-thiopheneboronic acid (41 mg, 0.3 mmol), K2CO3 (1.1 ml of a 1 M solution in water) and 16 mg (0.1 mmol) of PdCl2(Dppf)CH2Cl2 (16 mg, 0.1 eq). The mixture was heated at 70\u00b0 C. in a sealed tube overnight. The product was precipitated by adding methanol. The filtrate was evaporated and purified by preparative thin layer chromatography on silica gel to give additional title compound for a total of 45 mg (100% yield). The applied reaction is: c1csc(Nc2nccn3ccnc23)c1>>CNc1nccn2c(Br)cnc12.O=C(O)O.OB(O)c1cccs1.[K+].[K+] with the template [#7;a:4]:[c:3]:[cH;D2;+0:1]:[#7;a:2].[#7:5]-[c;H0;D3;+0:6]1:[cH;D2;+0:9]:[c:8]:[cH;D2;+0:7]:[s;H0;D2;+0:10]:1>>Br-[c;H0;D3;+0:1](:[#7;a:2]):[c:3]:[#7;a:4].[#7:5]-[CH3;D1;+0:6].O-B(-O)-[c;H0;D3;+0:7]1:[c:8]:[cH;D2;+0:9]:c:[s;H0;D2;+0:10]:1, which requires the reactants: CNc1nccn2c(Br)cnc12 (available), O=C(O)O (available), OB(O)c1cccs1 (available), [K+] (available), [K+] (available).", + "property": { + "BBBP": 1.0, + "HIV": 1.0, + "BACE": 0.0, + "SC": 2.88, + "SA": 2.75 + }, + "retro": [ + 155241 + ] + }, + { + "instruction": "What molecule can be designed and synthesized to penetrate the blood-brain barrier and inhibit Human Immunodeficiency Virus, given the constraints of a molecular weight around 209.245 and a single aromatic ring?", + "input": "", + "output": "To satisfy the requirements: The molecule featuring an aromatic ring and an aliphatic chain, with functional groups including hydroxyl and carboxyl, exhibits moderate synthetic complexity and moderate lipophilicity, which may influence its physical and chemical properties. Therefore, the designed molecule is: CC(C)(C)OC(=O)Nc1ccccc1O. This is step 1 in the retrosynthesis process. To synthesize CC(C)(C)OC(=O)Nc1ccccc1O, follow these procedures: To a solution of 2-aminophenol (2.84 g) in dichloromethane (120 ml) was added di-tert-butyl dicarbonate (6.55 g). The mixture was stirred at room temperature for 18 hours. The mixture was partitioned between water and dichloromethane. The organic extracts were dried (MgSO4) and evaporated to give a solid which was purified by column chromatography on silica eluting with a mixture of ethyl acetate and hexane (20:80) to give 2-tert-butyloxycarbonylaminophenol (1.80 g);. The applied reaction is: CC(C)(C)OC(=O)Nc1ccccc1O>>CC(C)(C)OC(=O)OC(=O)OC(C)(C)C.Nc1ccccc1[O-] with the template [C;D1;H3:1]-[C:2](-[C;D1;H3:3])(-[C;D1;H3:4])-[#8:5]-[C;H0;D3;+0:6](=[O;D1;H0:7])-[NH;D2;+0:8]-[c:9]>>[C;D1;H3:1]-[C:2](-[C;D1;H3:3])(-[C;D1;H3:4])-[#8:5]-[C;H0;D3;+0:6](=[O;D1;H0:7])-O-[C;H0;D3;+0:6](=[O;D1;H0:7])-[#8:5]-[C:2](-[C;D1;H3:1])(-[C;D1;H3:3])-[C;D1;H3:4].[NH2;D1;+0:8]-[c:9], which requires the reactants: CC(C)(C)OC(=O)OC(=O)OC(C)(C)C (available), Nc1ccccc1[O-] (available).", + "property": { + "BBBP": 1.0, + "HIV": 1.0, + "SC": 1.7, + "SA": 1.73 + }, + "retro": [ + 38065 + ] + } +] \ No newline at end of file diff --git a/data/molqa_material_examples.json b/data/molqa_material_examples.json new file mode 100644 index 0000000000000000000000000000000000000000..9f3cfcb8802b1be01035924bcdd61bcd4640b303 --- /dev/null +++ b/data/molqa_material_examples.json @@ -0,0 +1,85 @@ +[ + { + "instruction": "What is the optimal molecular design and synthesis route for a polymer with high CO2 gas permeability and low permeability to N2 and O2, featuring an aromatic ring and specific functional groups?", + "input": "", + "output": "To satisfy the requirements: The polymer, featuring an aromatic ring and functional groups such as carboxyl and hydroxyl, exhibits a moderate synthetic complexity score and high gas permeability, indicating potential for use in gas separation membranes with good efficiency. Therefore, the designed molecule is: *c1cc(C(=O)OCC)c(*)[nH]1. This is step 1 in the retrosynthesis process. To synthesize CCOC(=O)c1cc[nH]c1, follow these procedures: Ethanol (3.2 mL, 54 mmol) and 4-dimethylamino pyridine (DMAP) (54 mg, 0.45 mmol) were added to a mixture of pyrrole-3-carboxylic acid (500 mg, 4.5 mmol) and dicyclohexylcarbodiimide (1.11 g, 5.4 mmol) (Aldrich, Milwaukee, Wis.) in tetrahydrofuran (THF) (15 mL). After heating at 60\u00b0 C. for 10 hours, the reaction was cooled. The precipitate was filtered off, washed with ethyl acetate, the combined filtrate was concentrated and purified on a silica gel column to give 500 mg (81%) of 1H-pyrrole-3-carboxylic acid ethyl ester as a colorless oil. The applied reaction is: CCOC(=O)c1cc[nH]c1>>C(=NC1CCCCC1)=NC1CCCCC1.CCO.O=C(O)c1cc[nH]c1 with the template [#7;a:5]:[c:4]:[c:3]-[C;H0;D3;+0:1](=[O;D1;H0:2])-[O;H0;D2;+0:7]-[C:6]>>O-[C;H0;D3;+0:1](=[O;D1;H0:2])-[c:3]:[c:4]:[#7;a:5].[C:6]-[OH;D1;+0:7], which requires the reactants: C(=NC1CCCCC1)=NC1CCCCC1 (available), CCO (available), O=C(O)c1cc[nH]c1 (available).", + "property": { + "CO2": 0.94, + "N2": 0.0, + "O2": 0.0, + "FFV": 0.381, + "SC": 2.28, + "SA": 4.21 + }, + "retro": [ + 36155 + ] + }, + { + "instruction": "Can a polymer with a molecular weight of 184.22, one aromatic ring, and specific gas permeability, thermal conductivity, and electrical properties be designed and synthesized for gas separation applications?", + "input": "", + "output": "To satisfy the requirements: The polymer has an aromatic ring incorporated within its structure, which can contribute to its overall rigidity and mechanical strength. The presence of the sulfonamide group can also enhance the polymer's chemical stability and resistance to degradation. The predicted thermal properties, such as melting point and thermal conductivity, suggest that the polymer may exhibit suitable thermal behavior for various applications. Additionally, the predicted dielectric constant and electrical conductivity indicate that the polymer may possess suitable electrical properties. The gas permeability coefficient and gas diffusion coefficient suggest that the polymer may be suitable for gas separation applications. Overall, the polymer's material utility seems promising for creating polymeric materials with desirable properties. Therefore, the designed molecule is: *Nc1cc(*)ccc1S(=O)(=O)NC. This is step 1 in the retrosynthesis process. To synthesize CNS(=O)(=O)c1ccccc1N, follow these procedures: N-Methyl-2-nitrobenzenesulfonamide (1.0 g, 4.6 mmol), 10% palladium on carbon (100 mg) and absolute ethanol (30 mL) were placed in a Parr bottle, the bottle placed on a Parr apparatus under a hydrogen atmosphere (45 psi), and shaken for 1 hour, The Parr bottle was evacuated and flushed with nitrogen, then the mixture was filtered through a disposable syringe filter. The filtrate was concentrated in vacuo to yield 0.81 g of 2-amino-N-methylbenzenesulfonamide. The applied reaction is: CNS(=O)(=O)c1ccccc1N>>CNS(=O)(=O)c1ccccc1[N+](=O)[O-] with the template [NH2;D1;+0:1]-[c:2]>>O=[N+;H0;D3:1](-[O-])-[c:2], which requires the reactants: CNS(=O)(=O)c1ccccc1[N+](=O)[O-] (available).", + "property": { + "CO2": 1.746, + "O2": 1.247, + "FFV": 0.37, + "TC": 0.192, + "SC": 2.8, + "SA": 4.17 + }, + "retro": [ + 3 + ] + }, + { + "instruction": "Can a molecule be designed and synthesized to exhibit high CO2 permeability, moderate fractional free volume, and specific thermal conductivity, while incorporating aromatic and carboxyl functional groups?", + "input": "", + "output": "To satisfy the requirements: The polymer, with its aromatic ring and carboxyl functional group, exhibits a moderate synthetic complexity score, suggesting a relatively accessible molecular structure for synthesis. Its molecular structure and functional groups are expected to impart specific physical and chemical properties, influencing its material utility for polymeric materials. Therefore, the designed molecule is: *CC(*)OC(=O)c1cc(C)cc(Cl)c1. This is step 1 in the retrosynthesis process. To synthesize CCOC(=O)c1cc(C)cc(Cl)c1, follow these procedures: The title compound was synthesized from 3-chloro-5-methyl-benzoic acid under the same conditions as for Compound b1. The applied reaction is: CCOC(=O)c1cc(C)cc(Cl)c1>>CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1.Cc1cc(Cl)cc(C(=O)O)c1, which requires the reactants: CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1 (not available, needs to be synthesized), Cc1cc(Cl)cc(C(=O)O)c1 (available). This is step 2 in the retrosynthesis process. To synthesize CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1, follow these procedures: Potassium carbonate (1.5 g, 11.2 mmol) and ethyl iodide (1.2 g, 7.4 mmol) were added to a solution of 4-bromo-3-trifluoromethyl-benzoic acid (1.0 g, 3.7 mmol) in DMF (5 ml), and the mixture was stirred at room temperature for 24 hours. The reaction mixture was diluted with ethyl acetate, and the organic layer was washed with water and saturated-saline, and dried over anhydrous sodium sulfate. The drying agent was removed by filtration. After concentration under reduced pressure, the resulting residue was purified by silica gel column chromatography (ethyl acetate/hexane) to yield the title compound (1.03 g, 94%) as a brown oily substance. The applied reaction is: CCOC(=O)c1ccc(Br)c(C(F)(F)F)c1>>CCI.O=C(O)c1ccc(Br)c(C(F)(F)F)c1.O=C([O-])[O-].[K+].[K+], which requires the reactants: CCI (available), O=C(O)c1ccc(Br)c(C(F)(F)F)c1 (available), O=C([O-])[O-] (available), [K+] (available), [K+] (available).", + "property": { + "CO2": 0.94, + "N2": 0.0, + "O2": 0.0, + "FFV": 0.375, + "TC": 0.223, + "SC": 2.58, + "SA": 4.01 + }, + "retro": [ + null, + null + ] + }, + { + "instruction": "What molecular structure can be designed and synthesized to achieve CO2 permeability of 1.743 Barrer, N2 impermeability, and specific thermal conductivity and fractional free volume properties, while incorporating an aromatic ring and limited rotatable bonds?", + "input": "", + "output": "To satisfy the requirements: The polymer is a complex molecular structure incorporating an aromatic ring and a bromine functional group, which may provide enhanced thermal and mechanical properties, potentially useful for gas separation applications. Therefore, the designed molecule is: *c1ccc(*)c(CBr)c1. This is step 1 in the retrosynthesis process. To synthesize BrCc1ccccc1, follow these procedures: 4.65 g (10 mM) of compound 1a were treated, while stirring, with 40 ml of 2 N HBr in glacial acetic acid for 45 min. at 20\u00b0 in the absence of moisture. The amino acid derivative dissolved with CO2 evolution. The reaction solution was added dropwise with vigorous stirring to 250 ml of absolute ether which resulted in the precipitation of 2HBr.H-Arg-pNA. The ethereal phase was sucked off, whereupon the solid phase was washed 4 times with portions of 100 ml of abs. ether in order to substantially remove benzyl bromide which had formed as a by-product as well as excess HBr and AcOH. The residue was dissolved in 50 ml of MeOH, the pH was adjusted to 4.5 by the addition of Et3N, and the solution was concentrated to dryness in vacuo at 30\u00b0. The resulting product was dissolved in 75 ml of MeOH and passed through a column of \"Sephadex\" LH-20 (cross-linked dextran gel) equilibrated with MeOH. From a fraction of the eluate there were obtained 4.18 g (91.6% of the theory) of amorphous compound 1b which was homogeneous in the SS as shown by TLC. Elementary analysis and calculation from the empirical formula C12H20N6O3Br2 gave the following values: C=31.15% (31.60%), H=4.35% (4.42%), N=18.84% (18.43%) and Br=34.81% (35.03%). The applied reaction is: BrCc1ccccc1>>CC(C)(C#N)/N=N/C(C)(C)C#N.Cc1ccc(S(=O)(=O)Cl)cc1.O=C1CCC(=O)N1Br, which requires the reactants: CC(C)(C#N)/N=N/C(C)(C)C#N (available), Cc1ccc(S(=O)(=O)Cl)cc1 (available), O=C1CCC(=O)N1Br (available).", + "property": { + "CO2": 1.743, + "N2": 0.0, + "O2": 4.692, + "FFV": 0.411, + "TC": 0.256, + "SC": 2.25, + "SA": 5.14 + }, + "retro": [ + null + ] + }, + { + "instruction": "Can you design a molecule with moderate synthetic complexity, featuring a carboxyl functional group and an aromatic ring, that exhibits CO2 gas permeability of 9.176 Barrer, O2 gas permeability of 1.494 Barrer, fractional free volume of 0.374, and thermal conductivity of 0.228 W m^{-1} K^{-1}?", + "input": "", + "output": "To satisfy the requirements: The polymer exhibits a moderate synthetic complexity score, indicating that its synthesis is relatively feasible. The presence of an aromatic ring and a carboxyl functional group in its scaffold contributes to its moderate complexity. Therefore, the designed molecule is: *OC(=O)c1ccc(*)s1. This is step 1 in the retrosynthesis process. To synthesize O=C([O-])c1cccs1, follow these procedures: To a 15-mL polypropylene centrifuge tube was added 3.86 mL of 50 mM potassium phosphate buffer (pH 7.0), 1.0 mL of a suspension of 22.1 mg dry cell weight E. coli SW132 wet cells (prepared as described in Example 10) in 50 mM potassium phosphate buffer (pH 7.0), and 0.1691 g of 2-thiophenecarbonitrile. The final concentration of 2-thiophencarbonitrile was 0.307 M. The reaction mixture was mixed on a rotating platform at 27\u00b0 C. After 30 min, 7.50 mL of 95:5 acetonitrile/water containing 0.30 M N,N-dimethylbenzamide (HPLC external standard) was added to the reaction, the resulting mixture centrifuged, and a 0.100 mL of the supernatant mixed with 0.900 mL of acetonitrile and analyzed by HPLC. The conversion of 2-thiophenecarbonitrile was 99.5%, and the yields of 2-thiophenecarboxamide and 2-thiophenecarboxylic acid were 98% and 0%, respectively. The applied reaction is: O=C([O-])c1cccs1>>CN(C)C(=O)c1ccccc1.N#Cc1cccs1.O=P([O-])([O-])O.[2H]C([2H])([2H])C#N.[K+].[K+].[K+].[OH-], which requires the reactants: CN(C)C(=O)c1ccccc1 (available), N#Cc1cccs1 (available), O=P([O-])([O-])O (available), [2H]C([2H])([2H])C#N (available), [K+] (available), [K+] (available), [K+] (available), [OH-] (available).", + "property": { + "CO2": 9.176, + "O2": 1.494, + "FFV": 0.374, + "TC": 0.228, + "SC": 2.23, + "SA": 4.78 + }, + "retro": [ + null + ] + } +] \ No newline at end of file diff --git a/data/molqa_train_examples.json b/data/molqa_train_examples.json new file mode 100644 index 0000000000000000000000000000000000000000..ff85e69a1791a0553a7ea45bbeaaeeaab6d5717a --- /dev/null +++ b/data/molqa_train_examples.json @@ -0,0 +1,62 @@ +[ + { + "instruction": "What is a feasible molecular design that meets the criteria of a molecular weight of 243.306, two rings including one aromatic and one aliphatic, two rotatable bonds, and two hydrogen bond acceptors, with a balanced hydrophobic and hydrophilic profile?", + "input": "", + "output": "To satisfy the requirements: The molecule has a scaffold consisting of a central aromatic ring linked to an amino-functionalized aliphatic chain, featuring a carboxyl group. Its structural complexity is moderate, with a balance of hydrophobic and hydrophilic regions, which could influence its physical and chemical properties. Therefore, the designed molecule is: C#CCC1CCN(C(=O)Oc2ccccc2)CC1. This is step 1 in the retrosynthesis process. To synthesize C#CCC1CCN(C(=O)Oc2ccccc2)CC1, follow these procedures: Phenyl chloroformate (6.2 g, 40.2 mmol) was added to a solution of tert-butyl 4-(prop-2-ynyl)piperidine-1-carboxylate (1.65 g, 13.4 mmol) according to general procedure 1. Yield=0.600 g, 34%. m/z MH+=244.08. HPLC rt=10.3 min. The applied reaction is: C#CCC1CCN(C(=O)Oc2ccccc2)CC1>>C#CCC1CCN(C(=O)OC(C)(C)C)CC1.O=C(Cl)Oc1ccccc1 with the template [#8:5]-[C;H0;D3;+0:4](=[O;D1;H0:6])-[N;H0;D3;+0:1](-[C:2])-[C:3]>>C-C(-C)(-C)-O-C(=O)-[N;H0;D3;+0:1](-[C:2])-[C:3].Cl-[C;H0;D3;+0:4](-[#8:5])=[O;D1;H0:6], which requires the reactants: C#CCC1CCN(C(=O)OC(C)(C)C)CC1 (available), O=C(Cl)Oc1ccccc1 (available).", + "property": { + "SC": 2.42, + "SA": 2.22 + }, + "retro": [ + 180575 + ] + }, + { + "instruction": "Can you design and synthesize a molecule with a molecular weight of around 206, containing 2 rings including 1 aromatic and 1 aliphatic, and 3 rotatable bonds, with no hydrogen bond donors and 2 acceptors?", + "input": "", + "output": "To satisfy the requirements: The molecule has a relatively low synthetic complexity score and a moderate synthetic accessibility score, indicating that it is relatively easy to synthesize. Its molecular structure is composed of an aromatic ring and an aliphatic chain, which may influence its physical and chemical properties. Therefore, the designed molecule is: Cc1ccc(C)n1CCN1CCCC1=O. This is step 1 in the retrosynthesis process. To synthesize Cc1ccc(C)n1CCN1CCCC1=O, follow these procedures: 5.7 g (0.05 mmol) of acetonylacetone and 6.4 g (0.05 mmol) of 1-(2-aminoethyl)-2-pyrrolidinone are heated under reflux in 70 ml of methanol for 2 h. The reaction mixture is then concentrated and, after the concentration, the crude reaction product is then precipitated by addition of petroleum ether. The crude product is recrystallized from diethyl ether. Yield: 3.2 g (31% of theory), Melting point: 66\u00b0-68\u00b0 C. Elemental analysis: C12H18N2O (206.29) calculated: C 69.9 H 8.8 N 13.6 0 7.8 found: C 69.4 H 8.7 N 13.7 0 8.11. The applied reaction is: Cc1ccc(C)n1CCN1CCCC1=O>>CC(=O)CCC(C)=O.NCCN1CCCC1=O with the template [C:7]-[n;H0;D3;+0:8]1:[c;H0;D3;+0:1](-[C;D1;H3:2]):[cH;D2;+0:3]:[cH;D2;+0:4]:[c;H0;D3;+0:5]:1-[C;D1;H3:6]>>O=[C;H0;D3;+0:1](-[C;D1;H3:2])-[CH2;D2;+0:3]-[CH2;D2;+0:4]-[C;H0;D3;+0:5](=O)-[C;D1;H3:6].[C:7]-[NH2;D1;+0:8], which requires the reactants: CC(=O)CCC(C)=O (available), NCCN1CCCC1=O (available).", + "property": { + "SC": 2.69, + "SA": 2.34 + }, + "retro": [ + 18636 + ] + }, + { + "instruction": "Can you design and synthesize a molecule with a molecular weight around 377, incorporating four rings including two aromatic and two aliphatic rings, and six rotatable bonds, with no hydrogen bond donors and four acceptors?", + "input": "", + "output": "To satisfy the requirements: The molecule has a moderately complex structure built around a core scaffold that incorporates aromatic rings and aliphatic chains, with the presence of functional groups and a moderate number of rotatable bonds, contributing to its moderate synthetic complexity score. Therefore, the designed molecule is: Cc1ccccc1N1CCN(CCCCN2C(=O)c3ccccc3C2=O)CC1. This is step 1 in the retrosynthesis process. To synthesize Cc1ccccc1N1CCN(CCCCN2C(=O)c3ccccc3C2=O)CC1, follow these procedures: The title compound was synthesized applying the general procedure III step A using 1-(2-methylphenyl)piperazine (1.99 mmol, 350 mg), N-(4-bromobutyl)phthalimide (1.99 mmol, 560 mg), and K2CO3 (4.97 mmol, 686 mg) in 6 mL of acetonitrile. White solid 750 mg (99%). 1H NMR (400 MHz, Chloroform-d) \u03b4 7.92-7.79 (m, 2H), 7.72 (dt, J=5.5, 3.1 Hz, 2H), 7.16 (t, J=7.6 Hz, 2H), 7.08-6.88 (m, 2H), 3.74 (t, J=7.1 Hz, 2H), 2.93 (t, J=4.8 Hz, 4H), 2.60 (s, 4H), 2.52-2.39 (m, 2H), 2.30 (s, 3H), 1.75 (p, J=7.3 Hz, 2H), 1.59 (tt, J=9.7, 5.9 Hz, 2H). The applied reaction is: Cc1ccccc1N1CCN(CCCCN2C(=O)c3ccccc3C2=O)CC1>>Cc1ccccc1[NH+]1CC[NH2+]CC1.O=C(O)O.O=C1c2ccccc2C(=O)N1CCCCBr.[K+].[K+], which requires the reactants: Cc1ccccc1[NH+]1CC[NH2+]CC1 (available), O=C(O)O (available), O=C1c2ccccc2C(=O)N1CCCCBr (available), [K+] (available), [K+] (available).", + "property": { + "SC": 2.93, + "SA": 1.97 + }, + "retro": [ + null + ] + }, + { + "instruction": "Can you design a molecule with a molecular weight around 242, one aromatic ring, three rotatable bonds, and three hydrogen bond acceptors, and describe its synthesis process?", + "input": "", + "output": "To satisfy the requirements: The molecule has an aromatic ring and an aliphatic chain, with the presence of carboxyl, amino, and halogen groups, which contribute to its moderate synthetic complexity score, indicating a relatively challenging synthesis process. Therefore, the designed molecule is: COC(=O)c1ccccc1N(C)C(=O)CCl. This is step 1 in the retrosynthesis process. To synthesize COC(=O)c1ccccc1N(C)C(=O)CCl, follow these procedures: 25 Grams of methyl o-methylaminobenzoate was dissolved in 125 ml of acetone, to this solution was added a solution prepared by dissolving 20.92 g of potassium carbonate in 60 ml of water. Under an ice-cooled condition, 17.95 g of chloroacetyl chloride was added dropwise thereto, then the reaction mixture was stired at a room temperature for 1 hour. Acetone was removed by distillation, the residue obtained was extracted with chloroform, then the chloroform layer was washed with water, dried and the solvent was removed by distillation. The residue thus obtained was purified by a silica gel column chromatography to obtain 16.80 g of N-methyl-o-methoxycarbonyl-\u03b1-chloroacetanilide in the form of colorless oily substance. The applied reaction is: COC(=O)c1ccccc1N(C)C(=O)CCl>>CNc1ccccc1C(=O)OC.O=C(Cl)CCl.O=C(O)O.[K+].[K+] with the template [C:2]-[C;H0;D3;+0:1](=[O;D1;H0:3])-[N;H0;D3;+0:5](-[C;D1;H3:4])-[c:6]>>Cl-[C;H0;D3;+0:1](-[C:2])=[O;D1;H0:3].[C;D1;H3:4]-[NH;D2;+0:5]-[c:6], which requires the reactants: CNc1ccccc1C(=O)OC (available), O=C(Cl)CCl (available), O=C(O)O (available), [K+] (available), [K+] (available).", + "property": { + "SC": 2.13, + "SA": 2.07 + }, + "retro": [ + 39003 + ] + }, + { + "instruction": "Can you design a molecule with a molecular weight of 317.186, three aromatic rings, and two rotatable bonds, and outline a synthesis pathway for it?", + "input": "", + "output": "To satisfy the requirements: The molecule features a complex aromatic scaffold with multiple rings, exhibiting moderate synthetic complexity and potential bioavailability, while also showing possible interactions with nuclear receptors and stress response pathways. Therefore, the designed molecule is: COc1ccccc1-c1cn2cc(Br)cc(C)c2n1. This is step 1 in the retrosynthesis process. To synthesize COc1ccccc1-c1cn2cc(Br)cc(C)c2n1, follow these procedures: A mixture of 2-bromo-2\u2032-methoxyacetophenone (0.500 g, 2.183 mmol) and 2-amino-5-bromo-3-methylpyridine (0.408 g, 2.183 mmol) in ethanol (8 mL) was heated to 80\u00b0 C. for 16 hours in a pressure vessel. A yellow thick precipitate formed. Water was added and the solid was extracted with ethyl acetate (some DCM and MeOH were added to help in the dissolution). The organic extracts were then washed with brine, dried over Na2SO4, filtered and concentrated and the residue was purified by silica gel chromatography (Biotage SNAP 100 g, 10 to 20% ethyl acetate in hexanes) to give 220 (0.375 g, 1.182 mmol, 54.2% yield) as a white solid. LRMS (ESI): calc. 316.02 found 317.1 (MH)+. The applied reaction is: COc1ccccc1-c1cn2cc(Br)cc(C)c2n1>>COc1ccccc1C(=O)CBr.Cc1cc(Br)cnc1N.[OH-] with the template [c:4]:[c;H0;D3;+0:3](:[c:5])-[c;H0;D3;+0:2]1:[cH;D2;+0:1]:[n;H0;D3;+0:8](:[c:9]):[c:7]:[n;H0;D2;+0:6]:1>>Br-[CH2;D2;+0:1]-[C;H0;D3;+0:2](=O)-[c;H0;D3;+0:3](:[c:4]):[c:5].[NH2;D1;+0:6]-[c:7]:[n;H0;D2;+0:8]:[c:9], which requires the reactants: COc1ccccc1C(=O)CBr (available), Cc1cc(Br)cnc1N (available), [OH-] (available).", + "property": { + "SC": 2.94, + "SA": 2.19 + }, + "retro": [ + 10919 + ] + } +] \ No newline at end of file diff --git a/data/property_ranges.json b/data/property_ranges.json new file mode 100644 index 0000000000000000000000000000000000000000..cad714371b949ecdb69cd0574d61ff202aa56cc5 --- /dev/null +++ b/data/property_ranges.json @@ -0,0 +1,42 @@ +{ + "BBBP": { + "min": 0.0, + "max": 1.0 + }, + "HIV": { + "min": 0.0, + "max": 1.0 + }, + "BACE": { + "min": 0.0, + "max": 1.0 + }, + "CO2": { + "min": 0.94, + "max": 1019.265 + }, + "N2": { + "min": 0.0, + "max": 73.417 + }, + "O2": { + "min": 0.0, + "max": 122.94 + }, + "FFV": { + "min": 0.324, + "max": 0.434 + }, + "TC": { + "min": 0.117, + "max": 0.38 + }, + "SC": { + "min": 1.0, + "max": 5.0 + }, + "SA": { + "min": 1.0, + "max": 8.48 + } +} \ No newline at end of file diff --git a/install_environment.sh b/install_environment.sh new file mode 100644 index 0000000000000000000000000000000000000000..878ec6131eea1bca3b9ee7a4afcb394a89e2e54b --- /dev/null +++ b/install_environment.sh @@ -0,0 +1,85 @@ +#!/bin/bash + +# Set non-interactive frontend +export DEBIAN_FRONTEND=noninteractive + +# Activate the llama_factory environment +# Uncomment the following line if you need to create the environment +# conda create --name llamole python=3.11 -y +# conda activate llamole + +# Function to get the current conda environment name +get_current_conda_env() { + echo "current: $CONDA_DEFAULT_ENV" +} + +# Get the current conda environment +current_env=$(basename "$CONDA_PREFIX") + +# Check if the current environment is 'llamole' +if [ "$current_env" != "llamole" ]; then + echo "current: $CONDA_DEFAULT_ENV" + echo "Current conda environment is neither 'llama_factory' nor 'llamole'." + echo "Please activate one of these environments before running this script." + echo "You can activate an environment using one of these commands:" + echo "conda activate llama_factory" + echo "conda activate llamole" + exit 1 +fi + +echo "Running in conda environment: $current_env" + +# "pandas>=2.0.0" \ +# Install packages using pip +pip install --no-cache-dir \ + pyarrow \ + "pandas>=1.5.3" \ + "rdkit==2023.9.6" \ + pyyaml \ + ipykernel \ + packaging \ + gdown \ + "fcd_torch==1.0.7" \ + "omegaconf==2.3.0" \ + "imageio==2.26.0" \ + wandb \ + pandarallel \ + scipy \ + einops \ + sentencepiece \ + tiktoken \ + protobuf \ + uvicorn \ + pydantic \ + fastapi \ + sse-starlette \ + "matplotlib>=3.7.0" \ + fire \ + "numpy<2.0.0" \ + gradio + +pip install --no-cache-dir hydra-core --upgrade + +# Install PyTorch +pip install --no-cache-dir torch + +# Install PyTorch Geometric and related packages +pip install --no-cache-dir torch_geometric + +# for retro reaction +pip install rdchiral +pip install nltk + +# Install transformers and related packages +pip install --no-cache-dir \ + "transformers>=4.41.3" \ + "datasets>=2.16.0" \ + "accelerate>=0.30.1" \ + "peft>=0.11.1" \ + "trl>=0.8.6" \ + "gradio>=4.0.0" + +# Install mini-moses from GitHub +pip install --no-cache-dir git+https://github.com/igor-krawczuk/mini-moses + +echo "Installation complete!" \ No newline at end of file diff --git a/launch.py b/launch.py new file mode 100644 index 0000000000000000000000000000000000000000..4e31fcd739f403ca3ce24861e176ac545d76e987 --- /dev/null +++ b/launch.py @@ -0,0 +1,238 @@ +# Copyright 2024 Llamole Team +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import yaml +import numpy as np +import gradio as gr +import random +from rdkit import Chem +from rdkit.Chem import Draw +from rdkit.Chem import AllChem + +from src.webui.workflow import load_model_and_tokenizer, process_input, generate +from src.webui.elements import create_input_components + +# Load candidates +with open('data/molqa_material_examples.json', 'r') as f: + material_examples = json.load(f) + +with open('data/molqa_drug_examples.json', 'r') as f: + drug_examples = json.load(f) + +# Add type to each example +for example in material_examples: + example['type'] = 'Material' + +for example in drug_examples: + example['type'] = 'Drug' + +# Function to process property values +def process_property(value): + return 1e-8 if value == 0 else value + +# Add type to each example and process property values +for example in material_examples: + example['type'] = 'Material' + for prop in ['CO2', 'N2', 'O2', 'FFV']: + if prop in example['property']: + example['property'][prop] = process_property(example['property'][prop]) + +# Combine examples +all_examples = material_examples + drug_examples + +# Get default values from the first material example +default_values = drug_examples[0] + +# Load property ranges and arguments +with open('data/property_ranges.json', 'r') as f: + property_ranges = json.load(f) + +# with open('config/generate/qwen_material.yaml', 'r') as file: +with open('config/generate/llama_material.yaml', 'r') as file: + args_dict = yaml.safe_load(file) + +# Load model and tokenizer outside the function +model, tokenizer, generating_args = load_model_and_tokenizer(args_dict) + +def format_example(example): + formatted = [example['instruction']] + + # Determine if it's a drug or material example based on properties + is_drug = any(prop in example.get('property', {}) for prop in ["HIV", "BBBP", "BACE"]) + formatted.append("Drug" if is_drug else "Material") + + # Handle drug properties + for prop in ["HIV", "BBBP", "BACE"]: + value = example.get('property', {}).get(prop, float('nan')) + formatted.append(value if not np.isnan(value) else "NAN") + + # Handle material properties + for prop in ["CO2", "N2", "O2", "FFV", "TC"]: + value = example.get('property', {}).get(prop, float('nan')) + formatted.append(value if not np.isnan(value) else 0) # 0 represents NAN for material properties + + # Handle synthetic properties + for prop in ["SC", "SA"]: + value = example.get('property', {}).get(prop, float('nan')) + formatted.append(value if not np.isnan(value) else float('nan')) + + return formatted + +# Prepare examples +formatted_examples = [format_example(example) for example in all_examples] + +def random_example(examples): + example = random.choice(examples) + property_type = example['type'] + + outputs = [example['instruction'], property_type] + + for prop in ["HIV", "BBBP", "BACE"]: + outputs.append(example['property'].get(prop, "NAN")) + + for prop in ["CO2", "N2", "O2", "FFV", "TC"]: + outputs.append(example['property'].get(prop, 0)) + + for prop in ["SC", "SA"]: + outputs.append(example['property'].get(prop, float('nan'))) + + return outputs + +def generate_and_visualize(instruction, property_type, HIV, BBBP, BACE, CO2, N2, O2, FFV, TC, SC, SA): + properties = { + "HIV": float('nan') if HIV == "NAN" else HIV, + "BBBP": float('nan') if BBBP == "NAN" else BBBP, + "BACE": float('nan') if BACE == "NAN" else BACE, + "CO2": float('nan') if CO2 == 0 else CO2, + "N2": float('nan') if N2 == 0 else N2, + "O2": float('nan') if O2 == 0 else O2, + "FFV": float('nan') if FFV == 0 else FFV, + "TC": float('nan') if TC == 0 else TC, + "SC": SC, + "SA": SA + } + + # Filter out NaN values + properties = {k: v for k, v in properties.items() if not np.isnan(v)} + + print('instruction', instruction) + print('properties', properties) + results = run_molqa(instruction, **properties) + + llm_response = results.get('llm_response', 'No response generated') + llm_smiles = results.get('llm_smiles') + llm_reactions = results['llm_reactions'] + + molecule_img = visualize_molecule(llm_smiles) if llm_smiles else None + + reaction_steps = [] + reaction_imgs = [] + if llm_reactions: + for i, reaction_dict in enumerate(llm_reactions): + reaction = reaction_dict.get('reaction') + if reaction: + reaction_steps.append(f"Step {i+1}: {reaction}") + reaction_imgs.append(visualize_reaction(reaction)) + + return ( + llm_response, + llm_smiles if llm_smiles else "No SMILES generated", + molecule_img, + gr.JSON(value=reaction_steps, visible=bool(reaction_steps)), + gr.Gallery(value=reaction_imgs, visible=bool(reaction_imgs)) + ) + +def run_molqa(instruction: str, **properties) -> dict: + # Filter out properties with NaN values + filtered_properties = {k: v for k, v in properties.items() if not np.isnan(v)} + + input_data = { + "instruction": instruction, + "input": "", + "property": filtered_properties + } + + dataloader, gen_kwargs = process_input(input_data, model, tokenizer, generating_args) + generated_results = generate(model, dataloader, gen_kwargs) + + return generated_results + +def visualize_molecule(smiles: str) -> np.ndarray: + mol = Chem.MolFromSmiles(smiles) + if mol is not None: + img = Draw.MolToImage(mol) + return np.array(img) + return np.zeros((300, 300, 3), dtype=np.uint8) + +def visualize_reaction(reaction: str) -> np.ndarray: + rxn = AllChem.ReactionFromSmarts(reaction, useSmiles=True) + if rxn is not None: + img = Draw.ReactionToImage(rxn) + return np.array(img) + return np.zeros((300, 300, 3), dtype=np.uint8) + +# Define property names and their full descriptions +property_names = { + "HIV": "HIV virus replication inhibition", + "BBBP": "Blood-brain barrier permeability", + "BACE": "Human β-secretase 1 inhibition", + "CO2": "CO2 Perm", + "N2": "N2 Perm", + "O2": "O2 Perm", + "FFV": "Fractional free volume", + "TC": "Thermal conductivity", + "SC": "Heuristic Synthetic Scores (SCScore)", + "SA": "Synthetic Synthetic Scores (SAScore)" +} + +# Define outputs +outputs = [ + gr.Textbox(label="Overall LLM Response"), + gr.Textbox(label="Generated SMILES"), + gr.Image(label="Generated Molecule"), + gr.JSON(label="Reaction Steps"), + gr.Gallery(label="Reaction Visualizations") +] + +with gr.Blocks() as iface: + gr.Markdown("# Llamole Demo Interface") + gr.Markdown("Enter an instruction and property values to generate a molecule design.") + + interface, instruction, property_type, drug_properties, material_properties, synthetic_properties = create_input_components(default_values, property_names, property_ranges) + + random_btn = gr.Button("Random Example") + generate_btn = gr.Button("Generate") + + for output in outputs: + output.render() + + # Update the inputs for the generate button + all_inputs = [instruction, property_type] + all_inputs.extend(drug_properties.values()) + all_inputs.extend(material_properties.values()) + all_inputs.extend(synthetic_properties.values()) + + generate_btn.click(generate_and_visualize, inputs=all_inputs, outputs=outputs) + random_btn.click( + random_example, + inputs=gr.State(all_examples), + outputs=all_inputs + ) + +if __name__ == "__main__": + iface.launch(share=True) \ No newline at end of file diff --git a/main.py b/main.py new file mode 100644 index 0000000000000000000000000000000000000000..1ef96f0db4d33180cbc6df272f2478ebbd05835b --- /dev/null +++ b/main.py @@ -0,0 +1,66 @@ +# Copyright 2024 Llamole Team +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import sys +import os +from src.train.tuner import run_train, merge_adapter +from src.eval.workflow import run_eval +from tqdm import tqdm + +from huggingface_hub import hf_hub_download + +def download_data(): + repo_id = "liuganghuggingface/Llamole-MolQA" + files_to_download = [ + "molqa_drug.json", + "molqa_material.json", + "molqa_train.json" + ] + local_dir = "data" + + # Create the data directory if it doesn't exist + os.makedirs(local_dir, exist_ok=True) + + print(f"Downloading files from {repo_id} to {local_dir}/") + for file in tqdm(files_to_download, desc="Downloading files"): + try: + hf_hub_download( + repo_id=repo_id, + filename=file, + repo_type="dataset", + local_dir=local_dir, + local_dir_use_symlinks=False + ) + print(f"Successfully downloaded: {file}") + except Exception as e: + print(f"Error downloading {file}: {e}") + + print("Download complete!") + +if __name__ == "__main__": + command = sys.argv.pop(1) if len(sys.argv) != 1 else 'train' + if command == 'train': + run_train() + elif command == 'export': + merge_adapter() + elif command == 'eval': + run_eval() + elif command == 'download_data': + download_data() + else: + print(f"Invalid command: {command}. Please use 'train', 'export', 'eval', or 'download_data'.") + sys.exit(1) \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..a2fe4f69c2186fa1f41ecd3cc1c645fe9ef217d6 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,30 @@ +accelerate==0.33.0 +badam==1.2.2 +bitsandbytes==0.44.1 +datasets==2.21.0 +deepspeed==0.15.2 +galore_torch==1.0 +gradio==5.0.1 +huggingface_hub==0.24.5 +jieba==0.42.1 +matplotlib==3.9.2 +MoD==0.3.0 +modelscope==1.18.1 +nltk==3.9.1 +numpy==2.1.2 +optuna==3.6.1 +packaging==24.1 +pandas==1.5.3 +peft==0.12.0 +PyYAML==6.0.2 +PyYAML==6.0.2 +rdchiral==1.1.0 +rdkit==2023.9.6 +rouge_chinese==1.0.3 +safetensors==0.4.5 +torch==2.4.0 +torch_geometric==2.6.1 +transformers==4.44.0 +trl==0.9.6 +typing_extensions==4.12.2 +vllm==0.6.2 diff --git a/src/__init__.py b/src/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data/__init__.py b/src/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..fa8b6b9c00d80fa71c76337d15dc5f6cd8aa5175 --- /dev/null +++ b/src/data/__init__.py @@ -0,0 +1,29 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding +from .data_utils import Role, split_dataset +from .loader import get_dataset +from .template import TEMPLATES, Template, get_template_and_fix_tokenizer +from .collator import DataCollatorForSeqGraph + +__all__ = [ + "Role", + "split_dataset", + "get_dataset", + "TEMPLATES", + "Template", + "get_template_and_fix_tokenizer", + 'DataCollatorForSeqGraph' +] diff --git a/src/data/aligner.py b/src/data/aligner.py new file mode 100644 index 0000000000000000000000000000000000000000..8d33041d5f5b48196dcf4d25e87cce3d3c179fa1 --- /dev/null +++ b/src/data/aligner.py @@ -0,0 +1,233 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import re +from functools import partial +from typing import TYPE_CHECKING, Any, Dict, List, Union, Tuple + +from datasets import Features + +from ..extras.logging import get_logger +from .data_utils import Role + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import Seq2SeqTrainingArguments + + from ..hparams import DataArguments + from .parser import DatasetAttr + + +logger = get_logger(__name__) + +def extract_all_smiles(text): + pattern = r'(.*?)' + return re.findall(pattern, text) + +def replace_all_smiles(text): + pattern = r'.*?' + return re.sub(pattern, '', text) + +def replace_smiles_with_callback(text): + def replace_mol(match): + design_end = match.group(1) + smiles = match.group(2) + # return f'{design_end}{smiles}' + return f'{design_end}{smiles}' + + pattern = r'()(.*?)' + text = re.sub(pattern, replace_mol, text) + + # Replace remaining molecules that are not immediately after + remaining_pattern = r'.*?' + text = re.sub(remaining_pattern, '', text) + + return text + +def dict_to_list(data_dict, mol_properties): + return [data_dict.get(prop, None) for prop in mol_properties] + +def insert_bodies(text, num_insertions, retro_labels): + design_pattern = r'(.*?)' + retro_pattern = r'(This is step \d+ in the retrosynthesis process\..*?.*?)(.*?)(?=This is step \d+|$)' + + def replace_design(match): + return f'' + ''.join([''] * num_insertions) + f'' + + def replace_retro(match, label): + step_content = match.group(1) + remaining_text = match.group(2) + retro_match = re.search(r'(.*?)', step_content) + if retro_match and label is not None: + modified_content = f'' + ''.join([''] * num_insertions) + f'' + return re.sub(r'.*?', modified_content, step_content) + return step_content + remaining_text + + text = re.sub(design_pattern, replace_design, text) + + steps = re.finditer(retro_pattern, text) + modified_text = "" + last_end = 0 + + for i, step in enumerate(steps): + label = retro_labels[i] if i < len(retro_labels) else None + modified_text += text[last_end:step.start()] + replace_retro(step, label) + last_end = step.end() + + modified_text += text[last_end:] + return modified_text + +def extract_retro_products(text): + pattern = r'(.*?)>>' + matches = re.findall(pattern, text) + return [match.strip() for match in matches] + +def convert_molqa( + examples: Dict[str, List[Any]], dataset_attr: "DatasetAttr", data_args: "DataArguments" +) -> Dict[str, List[Any]]: + r""" + Converts alpaca format dataset to the standard format. + """ + outputs = {"prompt": [], "response": [], "system": [], "molecules": [], "property": [], "retro_labels": [], "retro_products": []} + + mol_properties = ['BBBP', 'HIV', 'BACE', 'CO2', 'N2', 'O2', 'FFV', 'TC', 'SC', 'SA'] + for i in range(len(examples[dataset_attr.prompt])): + prompt = [] + if dataset_attr.history and isinstance(examples[dataset_attr.history][i], list): + for old_prompt, old_response in examples[dataset_attr.history][i]: + prompt.append({"role": Role.USER.value, "content": old_prompt}) + prompt.append({"role": Role.ASSISTANT.value, "content": old_response}) + + content = [] + if dataset_attr.prompt and examples[dataset_attr.prompt][i]: + content.append(examples[dataset_attr.prompt][i]) + + if dataset_attr.query and examples[dataset_attr.query][i]: + content.append(examples[dataset_attr.query][i]) + + prompt.append({"role": Role.USER.value, "content": "\n".join(content)}) # "prompt\nquery" + + if dataset_attr.response and isinstance(examples[dataset_attr.response][i], str): # normal example + current_response = examples[dataset_attr.response][i] + smiles_list = extract_all_smiles(current_response) + modified_response = replace_smiles_with_callback(current_response) + retro_labels = examples[dataset_attr.retro][i] if dataset_attr.retro else [] + retro_products = extract_retro_products(current_response) + modified_response = insert_bodies(modified_response, data_args.learned_query_size, retro_labels) + # modified_response = insert_bodies(modified_response, dataset_attr.learned_query_size, retro_labels) + response = [{"role": Role.ASSISTANT.value, "content": modified_response}] + else: # unsupervised + response = [] + + outputs["prompt"].append(prompt) + outputs["response"].append(response) + outputs["system"].append(examples[dataset_attr.system][i] if dataset_attr.system else "") + outputs["molecules"].append(smiles_list) + outputs["property"].append(dict_to_list(examples[dataset_attr.property][i], mol_properties)) + outputs["retro_labels"].append(retro_labels) + outputs["retro_products"].append(retro_products) + + return outputs + +def map_smiles_to_id(example, smiles_to_id): + example['molecules'] = [smiles_to_id[smiles] for smiles in example['molecules']] + return example + +def align_dataset( + dataset: Union["Dataset", "IterableDataset"], + dataset_attr: "DatasetAttr", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", +) -> Tuple[Union["Dataset", "IterableDataset"], Dict[int, str]]: + r""" + Aligns the dataset and maps unique SMILES strings to molecule IDs. + + This function performs the following operations: + 1. Converts the dataset to the required format (molqa). + 2. Extracts all unique SMILES strings from the dataset. + 3. Maps each unique SMILES string to a unique integer ID (0, 1, 2, ...). + 4. Update 'molecules' field to each example, containing the mapped IDs. + + The aligned dataset contains the following fields: + prompt: [{"role": "user", "content": "..."}] * (2T - 1) + response: [{"role": "assistant", "content": "..."}] * N (N > 1 for ranking dataset) + system: "..." + molecules: [List of SMILES string] + property: [List of float values] + retro_labels: [List of int values] + retro_products: [List of SMILES string] + + Args: + dataset (Union["Dataset", "IterableDataset"]): The input dataset. + dataset_attr (DatasetAttr): Attributes of the dataset. + data_args (DataArguments): Arguments for data processing. + training_args (Seq2SeqTrainingArguments): Arguments for training. + + Returns: + Tuple[Union["Dataset", "IterableDataset"], Dict[int, str]]: + - The aligned and converted dataset with molecule IDs. + - A dictionary mapping molecule IDs to their SMILES strings. + """ + assert dataset_attr.formatting == "molqa" + + features = Features.from_dict( + { + "prompt": [ + {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} + ], + "response": [ + {"role": {"dtype": "string", "_type": "Value"}, "content": {"dtype": "string", "_type": "Value"}} + ], + "system": {"dtype": "string", "_type": "Value"}, + "molecules": [{'dtype': "string", "_type": "Value"}], + "property": [{"dtype": "float", "_type": "Value"}], + "retro_labels": [{"dtype": "int32", "_type": "Value"}], + "retro_products": [{'dtype': "string", "_type": "Value"}], + } + ) + + convert_func = partial(convert_molqa, dataset_attr=dataset_attr, data_args=data_args) + aligned = dataset.map( + convert_func, + batched=True, + remove_columns=['instruction', 'input', 'output', 'property', 'retro'], + features=features, + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + desc="Converting molqa format of dataset" + ) + + # Extract all unique SMILES strings and map them to molecule IDs + all_smiles = set() + for item in aligned: + all_smiles.update(item['molecules']) + all_smiles.update(item['retro_products']) + + smiles_to_id = {smiles: idx for idx, smiles in enumerate(sorted(all_smiles))} + id_to_smiles = {idx: smiles for smiles, idx in smiles_to_id.items()} + + def map_smiles_to_id(example, smiles_to_id): + example['molecules'] = [smiles_to_id[smiles] for smiles in example['molecules']] + example['retro_products'] = [smiles_to_id[smiles] for smiles in example['retro_products']] + return example + + smiles_convert_func = partial(map_smiles_to_id, smiles_to_id=smiles_to_id) + + aligned = aligned.map( + smiles_convert_func, + desc="Mapping SMILES to molecule IDs", + ) + + return aligned, id_to_smiles \ No newline at end of file diff --git a/src/data/collator.py b/src/data/collator.py new file mode 100644 index 0000000000000000000000000000000000000000..989ac073a842146529da1c9bb69146788b9dec31 --- /dev/null +++ b/src/data/collator.py @@ -0,0 +1,165 @@ +import torch +import numpy as np +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Union + +from torch_geometric.data import Batch as PyGBatch +from transformers.tokenization_utils_base import PreTrainedTokenizerBase +from transformers.utils import PaddingStrategy + +def pad_without_fast_tokenizer_warning(tokenizer, *pad_args, **pad_kwargs): + """ + Pads without triggering the warning about how using the pad function is sub-optimal when using a fast tokenizer. + """ + # To avoid errors when using Feature extractors + if not hasattr(tokenizer, "deprecation_warnings"): + return tokenizer.pad(*pad_args, **pad_kwargs) + + # Save the state of the warning, then disable it + warning_state = tokenizer.deprecation_warnings.get("Asking-to-pad-a-fast-tokenizer", False) + tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = True + + try: + padded = tokenizer.pad(*pad_args, **pad_kwargs) + finally: + # Restore the state of the warning. + tokenizer.deprecation_warnings["Asking-to-pad-a-fast-tokenizer"] = warning_state + + return padded + +@dataclass +class DataCollatorForSeqGraph: + """ + Data collator that will dynamically pad the inputs received, as well as the labels. + """ + tokenizer: PreTrainedTokenizerBase + mol_id_to_pyg: Dict[str, Any] + model: Optional[Any] = None + padding: Union[bool, str, PaddingStrategy] = True + max_length: Optional[int] = None + pad_to_multiple_of: Optional[int] = None + label_pad_token_id: int = -100 + return_tensors: str = "pt" + + def __call__(self, features, return_tensors=None): + if return_tensors is None: + return_tensors = self.return_tensors + + label_name = "label" if "label" in features[0].keys() else "labels" + labels = [feature[label_name] for feature in features] if label_name in features[0].keys() else None + if labels is not None and all(label is None for label in labels): + labels = None + + # Store molecule_ids, retro_labels, and retro_product_ids separately and remove from non_labels_features + molecule_ids_list = [] + retro_labels_list = [] + retro_products_list = [] + non_labels_features = [] + for feature in features: + new_feature = {k: v for k, v in feature.items() if k != label_name} + if 'molecule_ids' in new_feature: + molecule_ids_list.append(new_feature['molecule_ids']) + del new_feature['molecule_ids'] + else: + molecule_ids_list.append(None) + if 'retro_labels' in new_feature: + retro_labels_list.append(new_feature['retro_labels']) + del new_feature['retro_labels'] + else: + retro_labels_list.append(None) + if 'retro_product_ids' in new_feature: + retro_products_list.append(new_feature['retro_product_ids']) + del new_feature['retro_product_ids'] + else: + retro_products_list.append(None) + non_labels_features.append(new_feature) + + # Convert molecule IDs to PyG Data objects + molecule_graphs_list = [] + design_graphs_list = [] + for seq_idx, molecule_ids in enumerate(molecule_ids_list): + if molecule_ids is not None and len(molecule_ids) > 0: + for pos, mol_id in enumerate(molecule_ids): + if pos == 0: + design_graphs_list.append(self.mol_id_to_pyg[mol_id]) + if mol_id != self.label_pad_token_id and mol_id in self.mol_id_to_pyg: + molecule_graphs_list.append(self.mol_id_to_pyg[mol_id]) + + # Convert retro_product_ids to PyG Data objects + retro_product_graphs_list = [] + for seq_idx, retro_product_ids in enumerate(retro_products_list): + if retro_product_ids is not None and len(retro_product_ids) > 0: + for pos, mol_id in enumerate(retro_product_ids): + if mol_id != self.label_pad_token_id and mol_id in self.mol_id_to_pyg: + retro_product_graphs_list.append(self.mol_id_to_pyg[mol_id]) + + # Batch the PyG Data objects + if molecule_graphs_list: + batched_graphs = PyGBatch.from_data_list(molecule_graphs_list) + else: + batched_graphs = None + + if design_graphs_list: + batched_design_graphs = PyGBatch.from_data_list(design_graphs_list) + else: + batched_design_graphs = None + + if retro_product_graphs_list: + batched_retro_products = PyGBatch.from_data_list(retro_product_graphs_list) + else: + batched_retro_products = None + + # Pad retro_labels + if retro_labels_list and any(retro_labels is not None for retro_labels in retro_labels_list): + max_retro_length = max(len(retro_labels) for retro_labels in retro_labels_list if retro_labels is not None) + padded_retro_labels = [ + retro_labels + [self.label_pad_token_id] * (max_retro_length - len(retro_labels)) if retro_labels is not None else [self.label_pad_token_id] * max_retro_length + for retro_labels in retro_labels_list + ] + else: + padded_retro_labels = None + + # Pad other features + batch = pad_without_fast_tokenizer_warning( + self.tokenizer, + non_labels_features, + padding=self.padding, + max_length=self.max_length, + pad_to_multiple_of=self.pad_to_multiple_of, + return_tensors=return_tensors, + ) + + batch["molecule_graphs"] = batched_graphs + batch["design_graphs"] = batched_design_graphs + batch["retro_product_graphs"] = batched_retro_products + batch["retro_labels"] = torch.tensor(padded_retro_labels, dtype=torch.int64) + + # Pad labels + if labels is not None: + max_label_length = max(len(l) for l in labels) + if self.pad_to_multiple_of is not None: + max_label_length = ( + (max_label_length + self.pad_to_multiple_of - 1) + // self.pad_to_multiple_of + * self.pad_to_multiple_of + ) + + padding_side = self.tokenizer.padding_side + padded_labels = [ + label + [self.label_pad_token_id] * (max_label_length - len(label)) + if padding_side == "right" + else [self.label_pad_token_id] * (max_label_length - len(label)) + label + for label in labels + ] + batch["labels"] = torch.tensor(padded_labels, dtype=torch.int64) + + # Prepare decoder_input_ids + if ( + labels is not None + and self.model is not None + and hasattr(self.model, "prepare_decoder_input_ids_from_labels") + ): + decoder_input_ids = self.model.prepare_decoder_input_ids_from_labels(labels=batch["labels"]) + batch["decoder_input_ids"] = decoder_input_ids + + return batch \ No newline at end of file diff --git a/src/data/data_utils.py b/src/data/data_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..561aadb7720af4b085ff59a364b9744daa16b954 --- /dev/null +++ b/src/data/data_utils.py @@ -0,0 +1,82 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from enum import Enum, unique +from typing import TYPE_CHECKING, Dict, List, Sequence, Set, Union + +from datasets import concatenate_datasets, interleave_datasets + +from ..extras.logging import get_logger + + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import Seq2SeqTrainingArguments + + from ..hparams import DataArguments + + +logger = get_logger(__name__) + + +SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] + + +@unique +class Role(str, Enum): + USER = "user" + ASSISTANT = "assistant" + SYSTEM = "system" + FUNCTION = "function" + OBSERVATION = "observation" + + +def merge_dataset( + all_datasets: List[Union["Dataset", "IterableDataset"]], + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", +) -> Union["Dataset", "IterableDataset"]: + if len(all_datasets) == 1: + return all_datasets[0] + elif data_args.mix_strategy == "concat": + if data_args.streaming: + logger.warning("The samples between different datasets will not be mixed in streaming mode.") + return concatenate_datasets(all_datasets) + elif data_args.mix_strategy.startswith("interleave"): + if not data_args.streaming: + logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") + return interleave_datasets( + datasets=all_datasets, + probabilities=data_args.interleave_probs, + seed=training_args.seed, + stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", + ) + else: + raise ValueError("Unknown mixing strategy.") + + +def split_dataset( + dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", training_args: "Seq2SeqTrainingArguments" +) -> Dict[str, "Dataset"]: + if training_args.do_train: + if data_args.val_size > 1e-6: # Split the dataset + val_size = int(data_args.val_size) if data_args.val_size > 1 else data_args.val_size + dataset = dataset.train_test_split(test_size=val_size, seed=training_args.seed) + return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]} + else: + if data_args.streaming: + dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed) + return {"train_dataset": dataset} + else: # do_eval or do_predict + return {"eval_dataset": dataset} diff --git a/src/data/formatter.py b/src/data/formatter.py new file mode 100644 index 0000000000000000000000000000000000000000..2086900afefa97481f8e30838b69ec4827bbf3ad --- /dev/null +++ b/src/data/formatter.py @@ -0,0 +1,139 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass, field +from typing import List, Literal, Optional, Tuple, Union + +from .data_utils import SLOTS +from .tool_utils import DefaultToolUtils, GLM4ToolUtils + + +@dataclass +class Formatter(ABC): + slots: SLOTS = field(default_factory=list) + tool_format: Optional[Literal["default", "glm4"]] = None + + @abstractmethod + def apply(self, **kwargs) -> SLOTS: ... + + def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: + raise NotImplementedError + + +@dataclass +class EmptyFormatter(Formatter): + def __post_init__(self): + has_placeholder = False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot): + has_placeholder = True + + if has_placeholder: + raise ValueError("Empty formatter should not contain any placeholder.") + + def apply(self, **kwargs) -> SLOTS: + return self.slots + + +@dataclass +class StringFormatter(Formatter): + def __post_init__(self): + has_placeholder = False + for slot in filter(lambda s: isinstance(s, str), self.slots): + if re.search(r"\{\{[a-zA-Z_][a-zA-Z0-9_]*\}\}", slot): + has_placeholder = True + + if not has_placeholder: + raise ValueError("A placeholder is required in the string formatter.") + + def apply(self, **kwargs) -> SLOTS: + elements = [] + for slot in self.slots: + if isinstance(slot, str): + for name, value in kwargs.items(): + if not isinstance(value, str): + raise RuntimeError("Expected a string, got {}".format(value)) + + slot = slot.replace("{{" + name + "}}", value, 1) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) + else: + raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) + + return elements + +@dataclass +class FunctionFormatter(Formatter): + def __post_init__(self): + if self.tool_format == "default": + self.slots = DefaultToolUtils.get_function_slots() + self.slots + elif self.tool_format == "glm4": + self.slots = GLM4ToolUtils.get_function_slots() + self.slots + else: + raise NotImplementedError("Tool format {} was not found.".format(self.tool_format)) + + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") + functions: List[Tuple[str, str]] = [] + try: + tool_calls = json.loads(content) + if not isinstance(tool_calls, list): # parallel function call + tool_calls = [tool_calls] + + for tool_call in tool_calls: + functions.append((tool_call["name"], json.dumps(tool_call["arguments"], ensure_ascii=False))) + + except json.JSONDecodeError: + functions = [] + + elements = [] + for name, arguments in functions: + for slot in self.slots: + if isinstance(slot, str): + slot = slot.replace("{{name}}", name).replace("{{arguments}}", arguments) + elements.append(slot) + elif isinstance(slot, (dict, set)): + elements.append(slot) + else: + raise RuntimeError("Input must be string, set[str] or dict[str, str], got {}".format(type(slot))) + + return elements + + +@dataclass +class ToolFormatter(Formatter): + def __post_init__(self): + if self.tool_format == "default": + self._tool_formatter = DefaultToolUtils.tool_formatter + self._tool_extractor = DefaultToolUtils.tool_extractor + elif self.tool_format == "glm4": + self._tool_formatter = GLM4ToolUtils.tool_formatter + self._tool_extractor = GLM4ToolUtils.tool_extractor + else: + raise NotImplementedError("Tool format {} was not found.".format(self.tool_format)) + + def apply(self, **kwargs) -> SLOTS: + content = kwargs.pop("content") + try: + tools = json.loads(content) + return [self._tool_formatter(tools) if len(tools) != 0 else ""] + except json.JSONDecodeError: + return [""] + + def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: + return self._tool_extractor(content) diff --git a/src/data/loader.py b/src/data/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..0647d46e70e155e86eff42ac0243a2276a4ed83e --- /dev/null +++ b/src/data/loader.py @@ -0,0 +1,149 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +import os +import sys +from typing import TYPE_CHECKING, Literal, Optional, Union +from functools import partial + +import numpy as np +from datasets import load_dataset, load_from_disk + +# from ..extras.constants import FILEEXT2TYPE +from ..extras.logging import get_logger +from ..extras.misc import has_tokenized_data +from .aligner import align_dataset +from .data_utils import merge_dataset +from .parser import get_dataset_attr +# from .preprocess import get_preprocess_and_print_func +from .template import get_template_and_fix_tokenizer + +from .processors.mmsupervised import ( + preprocess_mmsupervised_dataset, + print_supervised_dataset_example, + encode_graph_pyg +) + +if TYPE_CHECKING: + from datasets import Dataset, IterableDataset + from transformers import PreTrainedTokenizer, ProcessorMixin, Seq2SeqTrainingArguments + + from ..hparams import DataArguments, ModelArguments + from .parser import DatasetAttr + + +logger = get_logger(__name__) + + +def load_single_dataset( + dataset_attr: "DatasetAttr", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", +) -> Union["Dataset", "IterableDataset"]: + logger.info("Loading dataset {}...".format(dataset_attr)) + + data_files = [] + assert dataset_attr.load_from == "file" + + data_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + data_files.append(data_path) + data_path = data_path.split(".")[-1] + + if "trust_remote_code" in inspect.signature(load_dataset).parameters: # for datasets==2.16.0 + kwargs = {"trust_remote_code": True} + else: + kwargs = {} + + dataset = load_dataset( + path=data_path, + name=None, + data_dir=None, + data_files=data_files, + split=data_args.split, + cache_dir=model_args.cache_dir, + token=model_args.hf_hub_token, + streaming=False, + **kwargs, + ) + + converted_dataset, mol_id_to_smiles = align_dataset(dataset, dataset_attr, data_args, training_args) + return converted_dataset, mol_id_to_smiles + +def get_dataset( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + tokenizer: "PreTrainedTokenizer", +) -> Union["Dataset", "IterableDataset"]: + + template = get_template_and_fix_tokenizer(tokenizer, data_args.template, data_args.tool_format) + if data_args.train_on_prompt and template.efficient_eos: + raise ValueError("Current template does not support `train_on_prompt`.") + print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) + + # Load tokenized dataset + if data_args.tokenized_path is not None: + if has_tokenized_data(data_args.tokenized_path): + mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path) + logger.warning("Loading dataset from disk will ignore other data arguments.") + dataset = load_from_disk(data_args.tokenized_path) + logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) + # print_function(next(iter(dataset))) + data_iter = iter(dataset) + print_function(next(data_iter)) + return mol_id_to_pyg, dataset + + # Load tokenized dataset + with training_args.main_process_first(desc="load dataset"): + # current only support one dataset + dataset_attr = get_dataset_attr(data_args) + dataset, mol_id_to_smiles = load_single_dataset(dataset_attr, model_args, data_args, training_args) + + with training_args.main_process_first(desc="pre-process dataset"): + preprocess_func = partial( + preprocess_mmsupervised_dataset, + template=template, + tokenizer=tokenizer, + data_args=data_args, + ) + + column_names = list(next(iter(dataset)).keys()) + kwargs = {} + kwargs = dict( + num_proc=data_args.preprocessing_num_workers, + load_from_cache_file=(not data_args.overwrite_cache) or (training_args.local_process_index != 0), + desc="Running tokenizer on dataset", + ) + + dataset = dataset.map(preprocess_func, batched=True, remove_columns=column_names, **kwargs) + + if data_args.tokenized_path is not None: + if training_args.should_save: + dataset.save_to_disk(data_args.tokenized_path) + mol_id_to_pyg = encode_graph_pyg(data_path=data_args.tokenized_path, mol_id_to_smiles=mol_id_to_smiles) + logger.info("Tokenized dataset saved at {}.".format(data_args.tokenized_path)) + logger.info("Please restart the training with `tokenized_path: {}`.".format(data_args.tokenized_path)) + sys.exit(0) + else: + mol_id_to_pyg = encode_graph_pyg(mol_id_to_smiles=mol_id_to_smiles) + + if training_args.should_log: + try: + print_function(next(iter(dataset))) + except StopIteration: + raise RuntimeError("Cannot find valid samples.") + + return mol_id_to_pyg, dataset diff --git a/src/data/parser.py b/src/data/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..8364e965e693506fe49fc824dfcffb21c16abf69 --- /dev/null +++ b/src/data/parser.py @@ -0,0 +1,113 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Dict, List, Literal, Optional + +from ..extras.constants import DATA_CONFIG +from ..extras.misc import use_modelscope + + +if TYPE_CHECKING: + from ..hparams import DataArguments + + +@dataclass +class DatasetAttr: + r""" + Dataset attributes. + """ + + # basic configs + load_from: Literal["hf_hub", "ms_hub", "script", "file"] + dataset_name: str + formatting: Literal["alpaca", "sharegpt", "molqa"] = "molqa" + ranking: bool = False + # extra configs + subset: Optional[str] = None + folder: Optional[str] = None + num_samples: Optional[int] = None + # common columns + system: Optional[str] = None + tools: Optional[str] = None + images: Optional[str] = None + # rlhf columns + chosen: Optional[str] = None + rejected: Optional[str] = None + kto_tag: Optional[str] = None + # alpaca columns + prompt: Optional[str] = "instruction" + query: Optional[str] = "input" + response: Optional[str] = "output" + history: Optional[str] = None + # sharegpt columns + messages: Optional[str] = "conversations" + # sharegpt tags + role_tag: Optional[str] = "from" + content_tag: Optional[str] = "value" + user_tag: Optional[str] = "human" + assistant_tag: Optional[str] = "gpt" + observation_tag: Optional[str] = "observation" + function_tag: Optional[str] = "function_call" + system_tag: Optional[str] = "system" + # molqa columns + property: Optional[str] = 'property' + retro: Optional[str] = 'retro' + # learned_query_size: Optional[int] = None + + def __repr__(self) -> str: + return self.dataset_name + + def set_attr(self, key: str, obj: Dict[str, Any], default: Optional[Any] = None) -> None: + setattr(self, key, obj.get(key, default)) + +def get_dataset_attr(data_args: "DataArguments") -> List["DatasetAttr"]: + if data_args.dataset is not None: + dataset_name = data_args.dataset.strip() + else: + raise ValueError("Please specify the dataset name.") + + try: + with open(os.path.join(data_args.dataset_dir, DATA_CONFIG), "r") as f: + dataset_info = json.load(f) + except Exception as err: + raise ValueError( + "Cannot open {} due to {}.".format(os.path.join(data_args.dataset_dir, DATA_CONFIG), str(err)) + ) + dataset_info = None + + if dataset_name not in dataset_info: + raise ValueError("Undefined dataset {} in {}.".format(dataset_name, DATA_CONFIG)) + + dataset_attr = DatasetAttr("file", dataset_name=dataset_info[dataset_name]["file_name"]) + + print('dataset_info', dataset_info) + + dataset_attr.set_attr("formatting", dataset_info[dataset_name], default="molqa") + dataset_attr.set_attr("ranking", dataset_info[dataset_name], default=False) + dataset_attr.set_attr("subset", dataset_info[dataset_name]) + dataset_attr.set_attr("folder", dataset_info[dataset_name]) + dataset_attr.set_attr("num_samples", dataset_info[dataset_name]) + + if "columns" in dataset_info[dataset_name]: + column_names = ["system", "tools", "images", "chosen", "rejected", "kto_tag"] + assert dataset_attr.formatting == "molqa" + column_names.extend(["prompt", "query", "response", "history", "property", "retro"]) + + for column_name in column_names: + dataset_attr.set_attr(column_name, dataset_info[dataset_name]["columns"]) + + return dataset_attr \ No newline at end of file diff --git a/src/data/processors/__init__.py b/src/data/processors/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/data/processors/mmsupervised.py b/src/data/processors/mmsupervised.py new file mode 100644 index 0000000000000000000000000000000000000000..7597e902c6ad7d7fc41a530362a2ed30d84930a8 --- /dev/null +++ b/src/data/processors/mmsupervised.py @@ -0,0 +1,335 @@ +# Copyright 2024 Llamole Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import defaultdict +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple + +from ...extras.constants import IGNORE_INDEX, BOND_INDEX, NO_LABEL_INDEX +from ...extras.logging import get_logger + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer, ProcessorMixin + + from ...hparams import DataArguments + from ..template import Template + +import os +from rdkit import Chem +import torch +from torch_geometric.data import Data, Batch +import pickle + +logger = get_logger(__name__) + +import os +import torch +from typing import Dict +from torch_geometric.data import Data +from rdkit import Chem +import pickle + + +def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]: + if target_len * 2 < cutoff_len: # truncate source + max_target_len = cutoff_len + elif source_len * 2 < cutoff_len: # truncate target + max_target_len = cutoff_len - source_len + else: # truncate both + max_target_len = int(cutoff_len * (target_len / (source_len + target_len))) + + new_target_len = min(max_target_len, target_len) + new_source_len = max(cutoff_len - new_target_len, 0) + return new_source_len, new_target_len + +def encode_graph_pyg( + data_path: Optional[str] = None, mol_id_to_smiles: Optional[Dict[str, str]] = None +) -> Dict[str, Data]: + """ + Converts molecule data to a dictionary of PyTorch Geometric Data objects, with caching functionality. + Uses a sparse representation for efficiency. + + Args: + data_path (Optional[str]): Path to the Hugging Face dataset folder. + mol_id_to_smiles (Optional[Dict[str, str]]): Dictionary where keys are molecule IDs + and values are SMILES strings. + + Returns: + Dict[str, Data]: Dictionary where keys are molecule IDs and values are + PyTorch Geometric Data objects. + + Raises: + ValueError: If both data_path and mol_id_to_smiles are None, or if data_path is provided but loading fails. + """ + print(f"Current execution directory: {os.getcwd()}") + + if data_path is None and mol_id_to_smiles is None: + raise ValueError("Either data_path or mol_id_to_smiles must be provided.") + + if data_path is not None: + cache_file = os.path.join(data_path, "pyg_molecule.pickle") + + # Try to load cached data + if os.path.exists(cache_file): + try: + with open(cache_file, "rb") as f: + return pickle.load(f) + except Exception as e: + print(f"Failed to load cached data: {e}") + + mol_id_to_pyg = {} + + for mol_id, smiles in mol_id_to_smiles.items(): + mol = Chem.MolFromSmiles(smiles) + if mol is None: + raise ValueError(f"Invalid SMILES string for molecule {mol_id}: {smiles}") + + type_idx = [] + heavy_atom_indices = [] + for atom in mol.GetAtoms(): + if atom.GetAtomicNum() != 1: # Exclude hydrogen atoms + type_idx.append( + 119 - 2 if atom.GetSymbol() == "*" else atom.GetAtomicNum() - 2 + ) + heavy_atom_indices.append(atom.GetIdx()) + + x = torch.LongTensor(type_idx) + + edge_index = [] + edge_attr = [] + for bond in mol.GetBonds(): + start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + if start in heavy_atom_indices and end in heavy_atom_indices: + start_new, end_new = heavy_atom_indices.index( + start + ), heavy_atom_indices.index(end) + edge_index.extend([[start_new, end_new], [end_new, start_new]]) + bond_type = BOND_INDEX[bond.GetBondType()] + edge_attr.extend([bond_type, bond_type]) + + edge_index = torch.tensor(edge_index, dtype=torch.long).t().contiguous() + edge_attr = torch.tensor(edge_attr, dtype=torch.long) + + # Create PyG Data object + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr) + + mol_id_to_pyg[mol_id] = data + + # Save cached data if data_path is provided + if data_path is not None: + with open(cache_file, "wb") as f: + pickle.dump(mol_id_to_pyg, f) + + print(f"Saved PyG data to {cache_file}") + + return mol_id_to_pyg + +def encode_supervised_example( + prompt: Sequence[Dict[str, str]], + response: Sequence[Dict[str, str]], + system: Optional[str], + molecule_ids: List[int], + retro_product_ids: List[int], + retro_labels: List[int], + template: "Template", + tokenizer: "PreTrainedTokenizer", + data_args: "DataArguments", +) -> Tuple[List[int], List[int], List[int], List[int], List[int]]: + + messages = prompt + response + input_ids, labels = [], [] + final_molecule_ids = [] + final_product_ids = [] + final_retro_labels = [] + + encoded_pairs = template.encode_multiturn(tokenizer, messages, system) + special_tokens = [ + "", + "", + "", + "", + "", + "", + "", + ] + special_token_ids = template._convert_elements_to_ids(tokenizer, special_tokens) + special_token_dict = dict(zip(special_tokens, special_token_ids)) + + total_length = 1 if template.efficient_eos else 0 + for turn_idx, (source_ids, target_ids) in enumerate(encoded_pairs): + if total_length >= data_args.cutoff_len: + break + + source_len, target_len = infer_seqlen( + len(source_ids), len(target_ids), data_args.cutoff_len - total_length + ) + source_ids = source_ids[:source_len] + + # Ensure balanced retro tags when truncating + retro_start_indices = [ + i + for i, id in enumerate(target_ids) + if id == special_token_dict[""] + ] + retro_end_indices = [ + i + for i, id in enumerate(target_ids) + if id == special_token_dict[""] + ] + + if retro_start_indices and retro_end_indices: + # Find the last matching pair that fits within target_len + last_pair_index = -1 + for start, end in zip(retro_start_indices, retro_end_indices): + if end < target_len: + last_pair_index = end + else: + break + + if last_pair_index >= 0: + target_len = last_pair_index + 1 + else: + # If no complete pair fits, truncate before the first start tag + target_len = ( + min(target_len, retro_start_indices[0]) + if retro_start_indices + else target_len + ) + + target_ids = target_ids[:target_len] + + # Calculate the number of molecules in this turn + molecules_in_turn = target_ids.count(special_token_dict[""]) + retro_start_in_turn = target_ids.count(special_token_dict[""]) + retro_end_in_turn = target_ids.count(special_token_dict[""]) + + assert retro_start_in_turn == retro_end_in_turn + + retro_product_ids_in_turn = retro_product_ids[:retro_end_in_turn] + retro_labels_in_turn = retro_labels[:retro_end_in_turn] + + # Add corresponding retro_labels and retro_product_ids + final_molecule_ids.extend(molecule_ids[:molecules_in_turn]) + final_product_ids.extend(retro_product_ids_in_turn) + final_retro_labels.extend(retro_labels_in_turn) + + total_length += source_len + target_len + + if data_args.train_on_prompt: + source_mask = source_ids + elif turn_idx != 0 and template.efficient_eos: + source_mask = [tokenizer.eos_token_id] + [IGNORE_INDEX] * ( + len(source_ids) - 1 + ) + else: + source_mask = [IGNORE_INDEX] * len(source_ids) + + source_mask = [ + IGNORE_INDEX if id in special_token_dict.values() else id + for id in source_mask + ] + target_ids_mask = [ + id if id in [special_token_dict[""], special_token_dict[""]] + else (IGNORE_INDEX if id in special_token_dict.values() else id) + for id in target_ids + ] + + input_ids += source_ids + target_ids + labels += source_mask + target_ids_mask + + if template.efficient_eos: + input_ids += [tokenizer.eos_token_id] + labels += [tokenizer.eos_token_id] + + return input_ids, labels, final_molecule_ids, final_product_ids, final_retro_labels + + +def preprocess_mmsupervised_dataset( + examples: Dict[str, List[Any]], + template: "Template", + tokenizer: "PreTrainedTokenizer", + data_args: "DataArguments", +) -> Dict[str, List[List[int]]]: + model_inputs = { + "input_ids": [], + "attention_mask": [], + "labels": [], + "molecule_ids": [], + "molecule_properties": [], + "retro_labels": [], + "retro_product_ids": [], + } + + for i in range(len(examples["prompt"])): + if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) != 1: + logger.warning( + "Dropped invalid example: {}".format( + examples["prompt"][i] + examples["response"][i] + ) + ) + continue + + retro_product_ids = examples["retro_products"][i] + retro_labels = [ + NO_LABEL_INDEX if label is None else label + for label in examples["retro_labels"][i] + ] + properties = [ + NO_LABEL_INDEX if prop is None else prop for prop in examples["property"][i] + ] + + input_ids, labels, molecule_ids, retro_product_ids, retro_labels = ( + encode_supervised_example( + prompt=examples["prompt"][i], + response=examples["response"][i], + system=examples["system"][i], + molecule_ids=examples["molecules"][i], + retro_product_ids=retro_product_ids, + retro_labels=retro_labels, + template=template, + tokenizer=tokenizer, + data_args=data_args, + ) + ) + # molecule_ids = examples["molecules"][i] + + model_inputs["input_ids"].append(input_ids) + model_inputs["attention_mask"].append([1] * len(input_ids)) + model_inputs["labels"].append(labels) + model_inputs["molecule_ids"].append(molecule_ids) + model_inputs["molecule_properties"].append(properties) + model_inputs["retro_labels"].append(retro_labels) + model_inputs["retro_product_ids"].append(retro_product_ids) + + return model_inputs + +def print_supervised_dataset_example( + example: Dict[str, List[int]], tokenizer: "PreTrainedTokenizer" +) -> None: + valid_labels = list(filter(lambda x: x != IGNORE_INDEX, example["labels"])) + print("Print_supervised_dataset_example") + + print("input_ids:\n{}".format(example["input_ids"])) + print( + "inputs:\n{}".format( + tokenizer.decode(example["input_ids"], skip_special_tokens=False) + ) + ) + print("label_ids:\n{}".format(example["labels"])) + print( + "labels:\n{}".format(tokenizer.decode(valid_labels, skip_special_tokens=False)) + ) + print("molecule_ids:\n{}".format(example["molecule_ids"])) + print("molecule_properties:\n{}".format(example["molecule_properties"])) + print("retro_labels:\n{}".format(example["retro_labels"])) + print("retro_product_ids:\n{}".format(example["retro_product_ids"])) diff --git a/src/data/template.py b/src/data/template.py new file mode 100644 index 0000000000000000000000000000000000000000..aefd5195c57380f43194ed027818c77df437be26 --- /dev/null +++ b/src/data/template.py @@ -0,0 +1,889 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union + +from ..extras.logging import get_logger +from .data_utils import Role +from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter + + +if TYPE_CHECKING: + from transformers import PreTrainedTokenizer + + from .formatter import SLOTS, Formatter + + +logger = get_logger(__name__) + + +@dataclass +class Template: + format_user: "Formatter" + format_assistant: "Formatter" + format_system: "Formatter" + format_function: "Formatter" + format_observation: "Formatter" + format_tools: "Formatter" + format_separator: "Formatter" + format_prefix: "Formatter" + default_system: str + stop_words: List[str] + image_token: str + efficient_eos: bool + replace_eos: bool + + def encode_oneturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> Tuple[List[int], List[int]]: + r""" + Returns a single pair of token ids representing prompt and response respectively. + """ + encoded_messages = self._encode(tokenizer, messages, system, tools) + prompt_ids = [] + for encoded_ids in encoded_messages[:-1]: + prompt_ids += encoded_ids + + answer_ids = encoded_messages[-1] + return prompt_ids, answer_ids + + def encode_multiturn( + self, + tokenizer: "PreTrainedTokenizer", + messages: Sequence[Dict[str, str]], + system: Optional[str] = None, + tools: Optional[str] = None, + ) -> List[Tuple[List[int], List[int]]]: + r""" + Returns multiple pairs of token ids representing prompts and responses respectively. + """ + encoded_messages = self._encode(tokenizer, messages, system, tools) + return [(encoded_messages[i], encoded_messages[i + 1]) for i in range(0, len(encoded_messages), 2)] + + def extract_tool(self, content: str) -> Union[str, List[Tuple[str, str]]]: + r""" + Extracts tool message. + """ + return self.format_tools.extract(content) + + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: Sequence[Dict[str, str]], + system: Optional[str], + tools: Optional[str], + ) -> List[List[int]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: prefix + system + query resp + Turn t: sep + query resp + """ + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + elements += self.format_system.apply(content=(system + tool_text)) + + if i > 0 and i % 2 == 0: + elements += self.format_separator.apply() + + if message["role"] == Role.USER.value: + elements += self.format_user.apply(content=message["content"], idx=str(i // 2)) + elif message["role"] == Role.ASSISTANT.value: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION.value: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION.value: + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return encoded_messages + + def _convert_elements_to_ids(self, tokenizer: "PreTrainedTokenizer", elements: "SLOTS") -> List[int]: + r""" + Converts elements to token ids. + """ + token_ids = [] + for elem in elements: + if isinstance(elem, str): + if len(elem) != 0: + token_ids += tokenizer.encode(elem, add_special_tokens=False) + elif isinstance(elem, dict): + token_ids += [tokenizer.convert_tokens_to_ids(elem.get("token"))] + elif isinstance(elem, set): + if "bos_token" in elem and tokenizer.bos_token_id is not None: + token_ids += [tokenizer.bos_token_id] + elif "eos_token" in elem and tokenizer.eos_token_id is not None: + token_ids += [tokenizer.eos_token_id] + else: + raise ValueError("Input must be string, set[str] or dict[str, str], got {}".format(type(elem))) + + return token_ids + + +@dataclass +class Llama2Template(Template): + def _encode( + self, + tokenizer: "PreTrainedTokenizer", + messages: Sequence[Dict[str, str]], + system: str, + tools: str, + ) -> List[List[int]]: + r""" + Encodes formatted inputs to pairs of token ids. + Turn 0: prefix + system + query resp + Turn t: sep + query resp + """ + system = system or self.default_system + encoded_messages = [] + for i, message in enumerate(messages): + elements = [] + + system_text = "" + if i == 0: + elements += self.format_prefix.apply() + if system or tools: + tool_text = self.format_tools.apply(content=tools)[0] if tools else "" + system_text = self.format_system.apply(content=(system + tool_text))[0] + + if i > 0 and i % 2 == 0: + elements += self.format_separator.apply() + + if message["role"] == Role.USER.value: + elements += self.format_user.apply(content=system_text + message["content"]) + elif message["role"] == Role.ASSISTANT.value: + elements += self.format_assistant.apply(content=message["content"]) + elif message["role"] == Role.OBSERVATION.value: + elements += self.format_observation.apply(content=message["content"]) + elif message["role"] == Role.FUNCTION.value: + elements += self.format_function.apply(content=message["content"]) + else: + raise NotImplementedError("Unexpected role: {}".format(message["role"])) + + encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements)) + + return encoded_messages + + +TEMPLATES: Dict[str, Template] = {} + + +def _register_template( + name: str, + format_user: Optional["Formatter"] = None, + format_assistant: Optional["Formatter"] = None, + format_system: Optional["Formatter"] = None, + format_function: Optional["Formatter"] = None, + format_observation: Optional["Formatter"] = None, + format_tools: Optional["Formatter"] = None, + format_separator: Optional["Formatter"] = None, + format_prefix: Optional["Formatter"] = None, + default_system: str = "", + stop_words: Sequence[str] = [], + image_token: str = "", + efficient_eos: bool = False, + replace_eos: bool = False, +) -> None: + r""" + Registers a chat template. + + To add the following chat template: + ``` + [HUMAN]: + user prompt here + [AI]: + model response here + + [HUMAN]: + user prompt here + [AI]: + model response here + ``` + + The corresponding code should be: + ``` + _register_template( + name="custom", + format_user=StringFormatter(slots=["[HUMAN]:\n{{content}}\n[AI]:\n"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + efficient_eos=True, + ) + ``` + """ + eos_slots = [] if efficient_eos else [{"eos_token"}] + template_class = Llama2Template if name.startswith("llama2") else Template + default_user_formatter = StringFormatter(slots=["{{content}}"]) + default_assistant_formatter = StringFormatter(slots=["{{content}}"] + eos_slots) + default_function_formatter = FunctionFormatter(slots=eos_slots, tool_format="default") + default_tool_formatter = ToolFormatter(tool_format="default") + default_separator_formatter = EmptyFormatter() + default_prefix_formatter = EmptyFormatter() + TEMPLATES[name] = template_class( + format_user=format_user or default_user_formatter, + format_assistant=format_assistant or default_assistant_formatter, + format_system=format_system or default_user_formatter, + format_function=format_function or default_function_formatter, + format_observation=format_observation or format_user or default_user_formatter, + format_tools=format_tools or default_tool_formatter, + format_separator=format_separator or default_separator_formatter, + format_prefix=format_prefix or default_prefix_formatter, + default_system=default_system, + stop_words=stop_words, + image_token=image_token, + efficient_eos=efficient_eos, + replace_eos=replace_eos, + ) + + +def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) -> None: + is_added = tokenizer.eos_token_id is None + num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) + + if is_added: + logger.info("Add eos token: {}".format(tokenizer.eos_token)) + else: + logger.info("Replace eos token: {}".format(tokenizer.eos_token)) + + if num_added_tokens > 0: + logger.warning("New tokens have been added, make sure `resize_vocab` is True.") + + +def _jinja_escape(content: str) -> str: + return content.replace("'", r"\'") + + +def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", placeholder: str = "content") -> str: + slot_items = [] + for slot in slots: + if isinstance(slot, str): + slot_pieces = slot.split("{{content}}") + if slot_pieces[0]: + slot_items.append("'" + _jinja_escape(slot_pieces[0]) + "'") + if len(slot_pieces) > 1: + slot_items.append(placeholder) + if slot_pieces[1]: + slot_items.append("'" + _jinja_escape(slot_pieces[1]) + "'") + elif isinstance(slot, set): # do not use {{ eos_token }} since it may be replaced + if "bos_token" in slot and tokenizer.bos_token_id is not None: + slot_items.append("'" + tokenizer.bos_token + "'") + elif "eos_token" in slot and tokenizer.eos_token_id is not None: + slot_items.append("'" + tokenizer.eos_token + "'") + elif isinstance(slot, dict): + raise ValueError("Dict is not supported.") + + return " + ".join(slot_items) + + +def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: + jinja_template = "" + + prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer) + if prefix: + jinja_template += "{{ " + prefix + " }}" + + if template.default_system: + jinja_template += "{% set system_message = '" + _jinja_escape(template.default_system) + "' %}" + + jinja_template += ( + "{% if messages[0]['role'] == 'system' %}{% set system_message = messages[0]['content'] %}{% endif %}" + ) + + system_message = _convert_slots_to_jinja(template.format_system.apply(), tokenizer, placeholder="system_message") + if not isinstance(template, Llama2Template): + jinja_template += "{% if system_message is defined %}{{ " + system_message + " }}{% endif %}" + + jinja_template += "{% for message in messages %}" + jinja_template += "{% set content = message['content'] %}" + if isinstance(template, Llama2Template): + jinja_template += "{% if loop.index0 == 0 and system_message is defined %}" + jinja_template += "{% set content = " + system_message + " + message['content'] %}" + jinja_template += "{% endif %}" + + jinja_template += "{% if message['role'] == 'user' %}" + user_message = _convert_slots_to_jinja(template.format_user.apply(), tokenizer) + jinja_template += "{{ " + user_message + " }}" + + jinja_template += "{% elif message['role'] == 'assistant' %}" + assistant_message = _convert_slots_to_jinja( + template.format_assistant.apply() + template.format_separator.apply(), tokenizer + ) + jinja_template += "{{ " + assistant_message + " }}" + jinja_template += "{% endif %}" + jinja_template += "{% endfor %}" + return jinja_template + + +def get_template_and_fix_tokenizer( + tokenizer: "PreTrainedTokenizer", + name: Optional[str] = None, + tool_format: Optional[str] = None, +) -> Template: + if name is None: + template = TEMPLATES["empty"] # placeholder + else: + template = TEMPLATES.get(name, None) + if template is None: + raise ValueError("Template {} does not exist.".format(name)) + + if tool_format is not None: + logger.info("Using tool format: {}.".format(tool_format)) + eos_slots = [] if template.efficient_eos else [{"eos_token"}] + template.format_tools = ToolFormatter(tool_format=tool_format) + template.format_function = FunctionFormatter(slots=eos_slots, tool_format=tool_format) + + stop_words = template.stop_words + if template.replace_eos: + if not stop_words: + raise ValueError("Stop words are required to replace the EOS token.") + + _add_or_replace_eos_token(tokenizer, eos_token=stop_words[0]) + stop_words = stop_words[1:] + + if tokenizer.eos_token_id is None: + _add_or_replace_eos_token(tokenizer, eos_token="<|endoftext|>") + + if tokenizer.pad_token_id is None: + tokenizer.pad_token = tokenizer.eos_token + logger.info("Add pad token: {}".format(tokenizer.pad_token)) + + if stop_words: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False + ) + logger.info("Add {} to stop words.".format(",".join(stop_words))) + if num_added_tokens > 0: + logger.warning("New tokens have been added, make sure `resize_vocab` is True.") + + try: + tokenizer.chat_template = _get_jinja_template(template, tokenizer) + except ValueError: + logger.info("Cannot add this chat template to tokenizer.") + + return template + + +_register_template( + name="alpaca", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n\n### Response:\n"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + default_system=( + "Below is an instruction that describes a task. " + "Write a response that appropriately completes the request.\n\n" + ), +) + + +_register_template( + name="aquila", + format_user=StringFormatter(slots=["Human: {{content}}###Assistant:"]), + format_separator=EmptyFormatter(slots=["###"]), + default_system=( + "A chat between a curious human and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the human's questions." + ), + stop_words=[""], + efficient_eos=True, +) + + +_register_template( + name="atom", + format_user=StringFormatter( + slots=[{"bos_token"}, "Human: {{content}}\n", {"eos_token"}, {"bos_token"}, "Assistant:"] + ), + format_assistant=StringFormatter(slots=["{{content}}\n", {"eos_token"}]), +) + + +_register_template( + name="baichuan", + format_user=StringFormatter(slots=[{"token": ""}, "{{content}}", {"token": ""}]), + efficient_eos=True, +) + + +_register_template( + name="baichuan2", + format_user=StringFormatter(slots=["{{content}}"]), + efficient_eos=True, +) + + +_register_template( + name="belle", + format_user=StringFormatter(slots=["Human: {{content}}\n\nBelle: "]), + format_separator=EmptyFormatter(slots=["\n\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="bluelm", + format_user=StringFormatter(slots=[{"token": "[|Human|]:"}, "{{content}}", {"token": "[|AI|]:"}]), +) + + +_register_template( + name="breeze", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST] "]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + efficient_eos=True, +) + + +_register_template( + name="chatglm2", + format_user=StringFormatter(slots=["[Round {{idx}}]\n\n问:{{content}}\n\n答:"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), + efficient_eos=True, +) + + +_register_template( + name="chatglm3", + format_user=StringFormatter(slots=[{"token": "<|user|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}]), + format_assistant=StringFormatter(slots=["\n", "{{content}}"]), + format_system=StringFormatter(slots=[{"token": "<|system|>"}, "\n", "{{content}}"]), + format_function=FunctionFormatter(slots=[], tool_format="glm4"), + format_observation=StringFormatter( + slots=[{"token": "<|observation|>"}, "\n", "{{content}}", {"token": "<|assistant|>"}] + ), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +_register_template( + name="chatml", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, +) + + +_register_template( + name="chatml_de", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system="Du bist ein freundlicher und hilfsbereiter KI-Assistent.", + stop_words=["<|im_end|>", "<|im_start|>"], + replace_eos=True, +) + + +_register_template( + name="codegeex2", + format_prefix=EmptyFormatter(slots=[{"token": "[gMASK]"}, {"token": "sop"}]), +) + + +_register_template( + name="cohere", + format_user=StringFormatter( + slots=[ + ( + "<|START_OF_TURN_TOKEN|><|USER_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>" + "<|START_OF_TURN_TOKEN|><|CHATBOT_TOKEN|>" + ) + ] + ), + format_system=StringFormatter(slots=["<|START_OF_TURN_TOKEN|><|SYSTEM_TOKEN|>{{content}}<|END_OF_TURN_TOKEN|>"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="cpm", + format_user=StringFormatter(slots=["<用户>{{content}}"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="dbrx", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system=( + "You are DBRX, created by Databricks. You were last updated in December 2023. " + "You answer questions based on information available up to that point.\n" + "YOU PROVIDE SHORT RESPONSES TO SHORT QUESTIONS OR STATEMENTS, but provide thorough " + "responses to more complex and open-ended questions.\nYou assist with various tasks, " + "from writing to coding (using markdown for code blocks — remember to use ``` with " + "code, JSON, and tables).\n(You do not have real-time data access or code execution " + "capabilities. You avoid stereotyping and provide balanced perspectives on " + "controversial topics. You do not provide song lyrics, poems, or news articles and " + "do not divulge details of your training data.)\nThis is your system prompt, " + "guiding your responses. Do not reference it, just respond to the user. If you find " + "yourself talking about this message, stop. You should be responding appropriately " + "and usually that means not mentioning this.\nYOU DO NOT MENTION ANY OF THIS INFORMATION " + "ABOUT YOURSELF UNLESS THE INFORMATION IS DIRECTLY PERTINENT TO THE USER'S QUERY." + ), + stop_words=["<|im_end|>"], + replace_eos=True, +) + + +_register_template( + name="deepseek", + format_user=StringFormatter(slots=["User: {{content}}\n\nAssistant:"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="deepseekcoder", + format_user=StringFormatter(slots=["### Instruction:\n{{content}}\n### Response:"]), + format_assistant=StringFormatter(slots=["\n{{content}}\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + default_system=( + "You are an AI programming assistant, utilizing the Deepseek Coder model, " + "developed by Deepseek Company, and you only answer questions related to computer science. " + "For politically sensitive questions, security and privacy issues, " + "and other non-computer science questions, you will refuse to answer\n" + ), +) + + +_register_template( + name="default", + format_user=StringFormatter(slots=["Human: {{content}}\nAssistant:"]), + format_system=StringFormatter(slots=["{{content}}\n"]), + format_separator=EmptyFormatter(slots=["\n"]), +) + + +_register_template( + name="empty", + efficient_eos=True, +) + + +_register_template( + name="falcon", + format_user=StringFormatter(slots=["User: {{content}}\nFalcon:"]), + format_separator=EmptyFormatter(slots=["\n"]), + efficient_eos=True, +) + + +_register_template( + name="fewshot", + format_separator=EmptyFormatter(slots=["\n\n"]), + efficient_eos=True, +) + + +_register_template( + name="gemma", + format_user=StringFormatter(slots=["user\n{{content}}\nmodel\n"]), + format_observation=StringFormatter( + slots=["tool\n{{content}}\nmodel\n"] + ), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + efficient_eos=True, +) + + +_register_template( + name="glm4", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}"]), + format_function=FunctionFormatter(slots=[], tool_format="glm4"), + format_observation=StringFormatter(slots=["<|observation|>\n{{content}}<|assistant|>"]), + format_tools=ToolFormatter(tool_format="glm4"), + format_prefix=EmptyFormatter(slots=["[gMASK]"]), + stop_words=["<|user|>", "<|observation|>"], + efficient_eos=True, +) + + +_register_template( + name="intern", + format_user=StringFormatter(slots=["<|User|>:{{content}}\n<|Bot|>:"]), + format_system=StringFormatter(slots=["<|System|>:{{content}}\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=[""], + efficient_eos=True, # internlm tokenizer cannot set eos_token_id +) + + +_register_template( + name="intern2", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["<|im_end|>\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|im_end|>"], + efficient_eos=True, # internlm2 tokenizer cannot set eos_token_id +) + + +_register_template( + name="llama2", + format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), + format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), +) + + +_register_template( + name="llama2_zh", + format_user=StringFormatter(slots=[{"bos_token"}, "[INST] {{content}} [/INST]"]), + format_system=StringFormatter(slots=["<>\n{{content}}\n<>\n\n"]), + default_system="You are a helpful assistant. 你是一个乐于助人的助手。", +) + + +_register_template( + name="llama3", + format_user=StringFormatter( + slots=[ + ( + "<|start_header_id|>user<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_system=StringFormatter(slots=["<|start_header_id|>system<|end_header_id|>\n\n{{content}}<|eot_id|>"]), + format_observation=StringFormatter( + slots=[ + ( + "<|start_header_id|>tool<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>assistant<|end_header_id|>\n\n" + ) + ] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|eot_id|>"], + replace_eos=True, +) + + +_register_template( + name="mistral", + format_user=StringFormatter(slots=["[INST] {{content}} [/INST]"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="olmo", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|assistant|>\n"]), + format_prefix=EmptyFormatter(slots=[{"eos_token"}]), +) + + +_register_template( + name="openchat", + format_user=StringFormatter(slots=["GPT4 Correct User: {{content}}", {"eos_token"}, "GPT4 Correct Assistant:"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="openchat-3.6", + format_user=StringFormatter( + slots=[ + ( + "<|start_header_id|>GPT4 Correct User<|end_header_id|>\n\n{{content}}<|eot_id|>" + "<|start_header_id|>GPT4 Correct Assistant<|end_header_id|>\n\n" + ) + ] + ), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|eot_id|>"], + replace_eos=True, +) + + +_register_template( + name="orion", + format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: ", {"eos_token"}]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), +) + + +_register_template( + name="phi", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>\n"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + format_prefix=EmptyFormatter(slots=[{"bos_token"}]), + stop_words=["<|end|>"], + replace_eos=True, +) + + +_register_template( + name="qwen", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_observation=StringFormatter(slots=["<|im_start|>tool\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system="You are a helpful assistant.", + stop_words=["<|im_end|>"], + replace_eos=True, +) + + +_register_template( + name="solar", + format_user=StringFormatter(slots=["### User:\n{{content}}\n\n### Assistant:\n"]), + format_system=StringFormatter(slots=["### System:\n{{content}}\n\n"]), + efficient_eos=True, +) + + +_register_template( + name="starchat", + format_user=StringFormatter(slots=["<|user|>\n{{content}}<|end|>\n<|assistant|>"]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}<|end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=["<|end|>"], + replace_eos=True, +) + + +_register_template( + name="telechat", + format_user=StringFormatter(slots=["<_user>{{content}}<_bot>"]), + format_system=StringFormatter(slots=["<_system>{{content}}<_end>"]), + stop_words=["<_end>"], + replace_eos=True, +) + + +_register_template( + name="vicuna", + format_user=StringFormatter(slots=["USER: {{content}} ASSISTANT:"]), + default_system=( + "A chat between a curious user and an artificial intelligence assistant. " + "The assistant gives helpful, detailed, and polite answers to the user's questions." + ), +) + + +_register_template( + name="xuanyuan", + format_user=StringFormatter(slots=["Human: {{content}} Assistant:"]), + default_system=( + "以下是用户和人工智能助手之间的对话。用户以Human开头,人工智能助手以Assistant开头," + "会对人类提出的问题给出有帮助、高质量、详细和礼貌的回答,并且总是拒绝参与与不道德、" + "不安全、有争议、政治敏感等相关的话题、问题和指示。\n" + ), +) + + +_register_template( + name="xverse", + format_user=StringFormatter(slots=["Human: {{content}}\n\nAssistant: "]), +) + + +_register_template( + name="yayi", + format_user=StringFormatter(slots=[{"token": "<|Human|>"}, ":\n{{content}}\n\n", {"token": "<|YaYi|>"}, ":"]), + format_system=StringFormatter(slots=[{"token": "<|System|>"}, ":\n{{content}}\n\n"]), + format_separator=EmptyFormatter(slots=["\n\n"]), + default_system=( + "You are a helpful, respectful and honest assistant named YaYi " + "developed by Beijing Wenge Technology Co.,Ltd. " + "Always answer as helpfully as possible, while being safe. " + "Your answers should not include any harmful, unethical, " + "racist, sexist, toxic, dangerous, or illegal content. " + "Please ensure that your responses are socially unbiased and positive in nature.\n\n" + "If a question does not make any sense, or is not factually coherent, " + "explain why instead of answering something not correct. " + "If you don't know the answer to a question, please don't share false information." + ), + stop_words=["<|End|>"], +) + + +_register_template( + name="yi", + format_user=StringFormatter(slots=["<|im_start|>user\n{{content}}<|im_end|>\n<|im_start|>assistant\n"]), + format_system=StringFormatter(slots=["<|im_start|>system\n{{content}}<|im_end|>\n"]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=["<|im_end|>"], + replace_eos=True, +) + + +_register_template( + name="yi_vl", + format_user=StringFormatter(slots=["### Human: {{content}}\n### Assistant:"]), + format_separator=EmptyFormatter(slots=["\n"]), + default_system=( + "This is a chat between an inquisitive human and an AI assistant. " + "Assume the role of the AI assistant. Read all the images carefully, " + "and respond to the human's questions with informative, helpful, detailed and polite answers. " + "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。" + "仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。\n\n" + ), + stop_words=["###"], + efficient_eos=True, +) + + +_register_template( + name="yuan", + format_user=StringFormatter(slots=["{{content}}", {"token": ""}]), + format_separator=EmptyFormatter(slots=["\n"]), + stop_words=[""], + replace_eos=True, +) + + +_register_template( + name="zephyr", + format_user=StringFormatter(slots=["<|user|>\n{{content}}", {"eos_token"}, "<|assistant|>"]), + format_assistant=StringFormatter(slots=["\n{{content}}", {"eos_token"}]), + format_system=StringFormatter(slots=["<|system|>\n{{content}}", {"eos_token"}]), + default_system="You are Zephyr, a helpful assistant.", +) + + +_register_template( + name="ziya", + format_user=StringFormatter(slots=[":{{content}}\n:"]), + format_separator=EmptyFormatter(slots=["\n"]), +) diff --git a/src/data/tool_utils.py b/src/data/tool_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ac5565d5e6d7e2a906580c12274741d55892ae35 --- /dev/null +++ b/src/data/tool_utils.py @@ -0,0 +1,140 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import re +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import Any, Dict, List, Tuple, Union + +from .data_utils import SLOTS + + +DEFAULT_TOOL_PROMPT = ( + "You have access to the following tools:\n{tool_text}" + "Use the following format if using a tool:\n" + "```\n" + "Action: tool name (one of [{tool_names}]).\n" + "Action Input: the input to the tool, in a JSON format representing the kwargs " + """(e.g. ```{{"input": "hello world", "num_beams": 5}}```).\n""" + "```\n" +) + + +GLM4_TOOL_PROMPT = ( + "你是一个名为 ChatGLM 的人工智能助手。你是基于智谱AI训练的语言模型 GLM-4 模型开发的," + "你的任务是针对用户的问题和要求提供适当的答复和支持。# 可用工具{tool_text}" +) + + +@dataclass +class ToolUtils(ABC): + @staticmethod + @abstractmethod + def get_function_slots() -> SLOTS: ... + + @staticmethod + @abstractmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: ... + + @staticmethod + @abstractmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ... + + +class DefaultToolUtils(ToolUtils): + @staticmethod + def get_function_slots() -> SLOTS: + return ["Action: {{name}}\nAction Input: {{arguments}}\n"] + + @staticmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + tool_names = [] + for tool in tools: + param_text = "" + for name, param in tool["parameters"]["properties"].items(): + required, enum, items = "", "", "" + if name in tool["parameters"].get("required", []): + required = ", required" + + if param.get("enum", None): + enum = ", should be one of [{}]".format(", ".join(param["enum"])) + + if param.get("items", None): + items = ", where each item should be {}".format(param["items"].get("type", "")) + + param_text += " - {name} ({type}{required}): {desc}{enum}{items}\n".format( + name=name, + type=param.get("type", ""), + required=required, + desc=param.get("description", ""), + enum=enum, + items=items, + ) + + tool_text += "> Tool Name: {name}\nTool Description: {desc}\nTool Args:\n{args}\n".format( + name=tool["name"], desc=tool.get("description", ""), args=param_text + ) + tool_names.append(tool["name"]) + + return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) + + @staticmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: + regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) + action_match: List[Tuple[str, str]] = re.findall(regex, content) + if not action_match: + return content + + results = [] + for match in action_match: + tool_name = match[0].strip() + tool_input = match[1].strip().strip('"').strip("```") + try: + arguments = json.loads(tool_input) + results.append((tool_name, json.dumps(arguments, ensure_ascii=False))) + except json.JSONDecodeError: + return content + + return results + + +class GLM4ToolUtils(ToolUtils): + @staticmethod + def get_function_slots() -> SLOTS: + return ["{{name}}\n{{arguments}}"] + + @staticmethod + def tool_formatter(tools: List[Dict[str, Any]]) -> str: + tool_text = "" + for tool in tools: + tool_text += "\n\n## {name}\n\n{body}\n在调用上述函数时,请使用 Json 格式表示调用的参数。".format( + name=tool["name"], body=json.dumps(tool, indent=4, ensure_ascii=False) + ) + + return GLM4_TOOL_PROMPT.format(tool_text=tool_text) + + @staticmethod + def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: + if "\n" not in content: + return content + + tool_name, tool_input = content.split("\n", maxsplit=1) + try: + arguments = json.loads(tool_input) + except json.JSONDecodeError: + return content + + return [(tool_name, json.dumps(arguments, ensure_ascii=False))] diff --git a/src/eval/__init__.py b/src/eval/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/eval/dataset.py b/src/eval/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e84169e9cc997c913e7008cbfb90b6c63bd3a4 --- /dev/null +++ b/src/eval/dataset.py @@ -0,0 +1,78 @@ +# Copyright 2024 Llamole Team +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.utils.data import Dataset + +from ..extras.constants import BOND_INDEX + +def dict_to_list(data_dict, mol_properties): + return [data_dict.get(prop, float("nan")) for prop in mol_properties] + +class MolQADataset(Dataset): + def __init__(self, data, tokenizer, max_len): + self.data = data + self.tokenizer = tokenizer + self.max_len = max_len + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + mol_properties = [ + "BBBP", + "HIV", + "BACE", + "CO2", + "N2", + "O2", + "FFV", + "TC", + "SC", + "SA", + ] + item = self.data[idx] + instruction = item["instruction"] + input_text = item["input"] + property_data = dict_to_list(item["property"], mol_properties) + property_data = torch.tensor(property_data) + + # Combine instruction and input + combined_input = f"{instruction}\n{input_text}" + + # Create messages for chat template + messages = [ + {"role": "user", "content": combined_input} + ] + + # Apply chat template + chat_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + # Tokenize the chat text + encoding = self.tokenizer( + chat_text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.max_len, + ) + + return { + "input_ids": encoding.input_ids.squeeze(), + "attention_mask": encoding.attention_mask.squeeze(), + "property": property_data, + } \ No newline at end of file diff --git a/src/eval/workflow.py b/src/eval/workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..beb78cf99a6c3fc0a642930d7e96bb9bbee6bc1a --- /dev/null +++ b/src/eval/workflow.py @@ -0,0 +1,219 @@ +# Copyright 2024 Llamole Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional, Dict, Any + +from ..data import get_dataset, DataCollatorForSeqGraph, get_template_and_fix_tokenizer +from ..extras.constants import IGNORE_INDEX, NO_LABEL_INDEX +from ..extras.misc import get_logits_processor +from ..extras.ploting import plot_loss +from ..model import load_tokenizer +from ..hparams import get_infer_args, get_train_args +from ..model import GraphLLMForCausalMLM +from .dataset import MolQADataset + +import re +import os +import json +import math +import torch +from torch.utils.data import DataLoader + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments + + from ..hparams import ( + DataArguments, + FinetuningArguments, + GeneratingArguments, + ModelArguments, + ) + +def remove_extra_spaces(text): + # Replace multiple spaces with a single space + cleaned_text = re.sub(r'\s+', ' ', text) + # Strip leading and trailing spaces + return cleaned_text.strip() + +def run_eval(args: Optional[Dict[str, Any]] = None) -> None: + print(args) + raise ValueError('stop') + model_args, data_args, training_args, finetuning_args, generating_args = ( + get_train_args(args) + ) + + if data_args.dataset in ["molqa", "molqa_drug", "molqa_material"]: + run_molqa( + model_args, data_args, training_args, finetuning_args, generating_args + ) + else: + raise ValueError("Unknown dataset: {}.".format(data_args.dataset)) + + +def run_molqa( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", +): + tokenizer = load_tokenizer(model_args, generate_mode=True)["tokenizer"] + + data_info_path = os.path.join(data_args.dataset_dir, "dataset_info.json") + with open(data_info_path, "r") as f: + dataset_info = json.load(f) + + tokenizer.pad_token = tokenizer.eos_token + dataset_name = data_args.dataset.strip() + try: + filename = dataset_info[dataset_name]["file_name"] + except KeyError: + raise ValueError(f"Dataset {dataset_name} not found in dataset_info.json") + data_path = os.path.join(data_args.dataset_dir, f"{filename}") + with open(data_path, "r") as f: + original_data = json.load(f) + + # Create dataset and dataloader + dataset = MolQADataset(original_data, tokenizer, data_args.cutoff_len) + dataloader = DataLoader( + dataset, batch_size=training_args.per_device_eval_batch_size, shuffle=False + ) + + gen_kwargs = generating_args.to_dict() + gen_kwargs["eos_token_id"] = [ + tokenizer.eos_token_id + ] + tokenizer.additional_special_tokens_ids + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + + model = GraphLLMForCausalMLM.from_pretrained( + tokenizer, model_args, data_args, training_args, finetuning_args, load_adapter=True + ) + + all_results = [] + property_names = ["BBBP", "HIV", "BACE", "CO2", "N2", "O2", "FFV", "TC", "SC", "SA"] + + # Phase 1: Molecular Design + global_idx = 0 + all_smiles = [] + for batch_idx, batch in enumerate(dataloader): + input_ids = batch["input_ids"].to(model.device) + attention_mask = batch["attention_mask"].to(model.device) + property_data = batch["property"].to(model.device) + model.eval() + with torch.no_grad(): + all_info_dict = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + molecule_properties=property_data, + do_molecular_design=True, + do_retrosynthesis=False, + rollback=True, + **gen_kwargs, + ) + + batch_results = [] + for i in range(len(all_info_dict["smiles_list"])): + original_data_idx = global_idx + i + original_item = original_data[original_data_idx] + + llm_response = "".join(item for item in all_info_dict["text_lists"][i]) + result = { + "qa_idx": original_data_idx, + "instruction": original_item["instruction"], + "input": original_item["input"], + "llm_response": llm_response, + "response_design": remove_extra_spaces(llm_response), + "llm_smiles": all_info_dict["smiles_list"][i], + "property": {}, + } + + # Add non-NaN property values + for j, prop_name in enumerate(property_names): + prop_value = property_data[i][j].item() + if not math.isnan(prop_value): + result["property"][prop_name] = prop_value + + batch_results.append(result) + + all_results.extend(batch_results) + all_smiles.extend([result['llm_smiles'] for result in batch_results]) + global_idx += len(batch_results) + + # Phase 2: Retrosynthesis + retro_batch_start = 0 + for batch_idx, batch in enumerate(dataloader): + + input_ids = batch["input_ids"].to(model.device) + attention_mask = batch["attention_mask"].to(model.device) + batch_size = input_ids.shape[0] + batch_smiles = all_smiles[retro_batch_start : retro_batch_start + batch_size] + + model.eval() + with torch.no_grad(): + all_info_dict = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + do_molecular_design=False, + do_retrosynthesis=True, + input_smiles_list=batch_smiles, + expansion_topk=50, + iterations=100, + max_planning_time=30, + **gen_kwargs, + ) + + batch_results = [] + for i in range(batch_size): + result = all_results[retro_batch_start + i] + retro_plan = all_info_dict["retro_plan_dict"][result["llm_smiles"]] + result["llm_reactions"] = [] + if retro_plan["success"]: + for reaction, template, cost in zip( + retro_plan["reaction_list"], + retro_plan["templates"], + retro_plan["cost"], + ): + result["llm_reactions"].append( + {"reaction": reaction, "template": template, "cost": cost} + ) + + # new_text = "".join(item for item in all_info_dict["text_lists"][i]) + if None in all_info_dict["text_lists"][i]: + print(f"List contains None: {all_info_dict['text_lists'][i]}") + new_text = "".join(item for item in all_info_dict["text_lists"][i] if item is not None) + else: + new_text = "".join(item for item in all_info_dict["text_lists"][i]) + + result["llm_response"] += new_text + result["llm_response"] = remove_extra_spaces(result["llm_response"]) + result["response_retro"] = remove_extra_spaces(new_text) + batch_results.append(result) + + retro_batch_start += batch_size + + print('all_results', all_results) + print("\nSummary of results:") + print_len = min(5, len(all_results)) + for result in all_results[:print_len]: + print(f"\nData point {result['qa_idx']}:") + print(f" Instruction: {result['instruction']}") + print(f" Input: {result['input']}") + print(f" LLM Response: {result['llm_response']}") + print(f" LLM SMILES: {result['llm_smiles']}") + print(f" Number of reactions: {len(result['llm_reactions'])}") + for prop_name, prop_value in result["property"].items(): + print(f" {prop_name}: {prop_value}") + + print("\nAll data processed successfully.") \ No newline at end of file diff --git a/src/extras/__init__.py b/src/extras/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/extras/constants.py b/src/extras/constants.py new file mode 100644 index 0000000000000000000000000000000000000000..6b20b24e4e865b997bdce6ae3b1d7329ad048895 --- /dev/null +++ b/src/extras/constants.py @@ -0,0 +1,1550 @@ +# Copyright 2024 the LlamaFactory team and the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from collections import OrderedDict, defaultdict +from enum import Enum +from typing import Dict, Optional + +from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME +from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME +from transformers.utils import SAFE_WEIGHTS_INDEX_NAME, SAFE_WEIGHTS_NAME, WEIGHTS_INDEX_NAME, WEIGHTS_NAME + +from rdkit.Chem.rdchem import BondType as BT +BOND_INDEX = {BT.SINGLE: 1, BT.DOUBLE: 2, BT.TRIPLE: 3, BT.AROMATIC: 4} +NO_LABEL_INDEX = -200 + +CHECKPOINT_NAMES = { + SAFE_ADAPTER_WEIGHTS_NAME, + ADAPTER_WEIGHTS_NAME, + SAFE_WEIGHTS_INDEX_NAME, + SAFE_WEIGHTS_NAME, + WEIGHTS_INDEX_NAME, + WEIGHTS_NAME, +} + +CHOICES = ["A", "B", "C", "D"] + +DATA_CONFIG = "dataset_info.json" + +DEFAULT_TEMPLATE = defaultdict(str) + +FILEEXT2TYPE = { + "arrow": "arrow", + "csv": "csv", + "json": "json", + "jsonl": "json", + "parquet": "parquet", + "txt": "text", +} + +IGNORE_INDEX = -100 + +LAYERNORM_NAMES = {"norm", "ln"} + +LLAMABOARD_CONFIG = "llamaboard_config.yaml" + +METHODS = ["full", "freeze", "lora"] + +MOD_SUPPORTED_MODELS = {"bloom", "falcon", "gemma", "llama", "mistral", "mixtral", "phi", "starcoder2"} + +PEFT_METHODS = {"lora"} + +RUNNING_LOG = "running_log.txt" + +SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] + +SUPPORTED_MODELS = OrderedDict() + +TRAINER_LOG = "trainer_log.jsonl" + +TRAINING_ARGS = "training_args.yaml" + +TRAINING_STAGES = { + "Supervised Fine-Tuning": "sft", + "Reward Modeling": "rm", + "PPO": "ppo", + "DPO": "dpo", + "KTO": "kto", + "Pre-Training": "pt", +} + +STAGES_USE_PAIR_DATA = {"rm", "dpo"} + +SUPPORTED_CLASS_FOR_S2ATTN = {"llama"} + +V_HEAD_WEIGHTS_NAME = "value_head.bin" + +V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" + +VISION_MODELS = set() + + +class DownloadSource(str, Enum): + DEFAULT = "hf" + MODELSCOPE = "ms" + + +def register_model_group( + models: Dict[str, Dict[DownloadSource, str]], + template: Optional[str] = None, + vision: bool = False, +) -> None: + prefix = None + for name, path in models.items(): + if prefix is None: + prefix = name.split("-")[0] + else: + assert prefix == name.split("-")[0], "prefix should be identical." + SUPPORTED_MODELS[name] = path + if template is not None: + DEFAULT_TEMPLATE[prefix] = template + if vision: + VISION_MODELS.add(prefix) + + +register_model_group( + models={ + "Aya-23-8B-Chat": { + DownloadSource.DEFAULT: "CohereForAI/aya-23-8B", + }, + "Aya-23-35B-Chat": { + DownloadSource.DEFAULT: "CohereForAI/aya-23-35B", + }, + }, + template="cohere", +) + + +register_model_group( + models={ + "Baichuan-7B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-7B", + DownloadSource.MODELSCOPE: "baichuan-inc/baichuan-7B", + }, + "Baichuan-13B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Base", + }, + "Baichuan-13B-Chat": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan-13B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan-13B-Chat", + }, + }, + template="baichuan", +) + + +register_model_group( + models={ + "Baichuan2-7B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Base", + }, + "Baichuan2-13B-Base": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Base", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Base", + }, + "Baichuan2-7B-Chat": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-7B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-7B-Chat", + }, + "Baichuan2-13B-Chat": { + DownloadSource.DEFAULT: "baichuan-inc/Baichuan2-13B-Chat", + DownloadSource.MODELSCOPE: "baichuan-inc/Baichuan2-13B-Chat", + }, + }, + template="baichuan2", +) + + +register_model_group( + models={ + "BLOOM-560M": { + DownloadSource.DEFAULT: "bigscience/bloom-560m", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-560m", + }, + "BLOOM-3B": { + DownloadSource.DEFAULT: "bigscience/bloom-3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-3b", + }, + "BLOOM-7B1": { + DownloadSource.DEFAULT: "bigscience/bloom-7b1", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloom-7b1", + }, + }, +) + + +register_model_group( + models={ + "BLOOMZ-560M": { + DownloadSource.DEFAULT: "bigscience/bloomz-560m", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-560m", + }, + "BLOOMZ-3B": { + DownloadSource.DEFAULT: "bigscience/bloomz-3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-3b", + }, + "BLOOMZ-7B1-mt": { + DownloadSource.DEFAULT: "bigscience/bloomz-7b1-mt", + DownloadSource.MODELSCOPE: "AI-ModelScope/bloomz-7b1-mt", + }, + }, +) + + +register_model_group( + models={ + "BlueLM-7B-Base": { + DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Base", + DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Base", + }, + "BlueLM-7B-Chat": { + DownloadSource.DEFAULT: "vivo-ai/BlueLM-7B-Chat", + DownloadSource.MODELSCOPE: "vivo-ai/BlueLM-7B-Chat", + }, + }, + template="bluelm", +) + + +register_model_group( + models={ + "Breeze-7B": { + DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Base-v1_0", + }, + "Breeze-7B-Chat": { + DownloadSource.DEFAULT: "MediaTek-Research/Breeze-7B-Instruct-v1_0", + }, + }, + template="breeze", +) + + +register_model_group( + models={ + "ChatGLM2-6B-Chat": { + DownloadSource.DEFAULT: "THUDM/chatglm2-6b", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm2-6b", + } + }, + template="chatglm2", +) + + +register_model_group( + models={ + "ChatGLM3-6B-Base": { + DownloadSource.DEFAULT: "THUDM/chatglm3-6b-base", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b-base", + }, + "ChatGLM3-6B-Chat": { + DownloadSource.DEFAULT: "THUDM/chatglm3-6b", + DownloadSource.MODELSCOPE: "ZhipuAI/chatglm3-6b", + }, + }, + template="chatglm3", +) + + +register_model_group( + models={ + "ChineseLLaMA2-1.3B": { + DownloadSource.DEFAULT: "hfl/chinese-llama-2-1.3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-1.3b", + }, + "ChineseLLaMA2-7B": { + DownloadSource.DEFAULT: "hfl/chinese-llama-2-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-7b", + }, + "ChineseLLaMA2-13B": { + DownloadSource.DEFAULT: "hfl/chinese-llama-2-13b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-llama-2-13b", + }, + "ChineseLLaMA2-1.3B-Chat": { + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-1.3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-1.3b", + }, + "ChineseLLaMA2-7B-Chat": { + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-7b", + }, + "ChineseLLaMA2-13B-Chat": { + DownloadSource.DEFAULT: "hfl/chinese-alpaca-2-13b", + DownloadSource.MODELSCOPE: "AI-ModelScope/chinese-alpaca-2-13b", + }, + }, + template="llama2_zh", +) + + +register_model_group( + models={ + "CodeGemma-7B": { + DownloadSource.DEFAULT: "google/codegemma-7b", + }, + "CodeGemma-7B-Chat": { + DownloadSource.DEFAULT: "google/codegemma-7b-it", + DownloadSource.MODELSCOPE: "AI-ModelScope/codegemma-7b-it", + }, + "CodeGemma-1.1-2B": { + DownloadSource.DEFAULT: "google/codegemma-1.1-2b", + }, + "CodeGemma-1.1-7B-Chat": { + DownloadSource.DEFAULT: "google/codegemma-1.1-7b-it", + }, + }, + template="gemma", +) + + +register_model_group( + models={ + "Codestral-22B-v0.1-Chat": { + DownloadSource.DEFAULT: "mistralai/Codestral-22B-v0.1", + }, + }, + template="mistral", +) + + +register_model_group( + models={ + "CommandR-35B-Chat": { + DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01", + DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-v01", + }, + "CommandR-Plus-104B-Chat": { + DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus", + DownloadSource.MODELSCOPE: "AI-ModelScope/c4ai-command-r-plus", + }, + "CommandR-35B-4bit-Chat": { + DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-v01-4bit", + DownloadSource.MODELSCOPE: "mirror013/c4ai-command-r-v01-4bit", + }, + "CommandR-Plus-104B-4bit-Chat": { + DownloadSource.DEFAULT: "CohereForAI/c4ai-command-r-plus-4bit", + }, + }, + template="cohere", +) + + +register_model_group( + models={ + "DBRX-132B-Base": { + DownloadSource.DEFAULT: "databricks/dbrx-base", + DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-base", + }, + "DBRX-132B-Chat": { + DownloadSource.DEFAULT: "databricks/dbrx-instruct", + DownloadSource.MODELSCOPE: "AI-ModelScope/dbrx-instruct", + }, + }, + template="dbrx", +) + + +register_model_group( + models={ + "DeepSeek-LLM-7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-base", + }, + "DeepSeek-LLM-67B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-base", + }, + "DeepSeek-LLM-7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-7b-chat", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-7b-chat", + }, + "DeepSeek-LLM-67B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-llm-67b-chat", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-llm-67b-chat", + }, + "DeepSeek-Math-7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-base", + }, + "DeepSeek-Math-7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-math-7b-instruct", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-math-7b-instruct", + }, + "DeepSeek-MoE-16B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-base", + }, + "DeepSeek-MoE-16B-v2-Base": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite", + DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite", + }, + "DeepSeek-MoE-236B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2", + DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2", + }, + "DeepSeek-MoE-16B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-moe-16b-chat", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-moe-16b-chat", + }, + "DeepSeek-MoE-16B-v2-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Lite-Chat", + DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Lite-Chat", + }, + "DeepSeek-MoE-236B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-V2-Chat", + DownloadSource.MODELSCOPE: "deepseek-ai/DeepSeek-V2-Chat", + }, + "DeepSeek-MoE-Coder-16B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Base", + }, + "DeepSeek-MoE-Coder-236B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Base", + }, + "DeepSeek-MoE-Coder-16B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Lite-Instruct", + }, + "DeepSeek-MoE-Coder-236B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/DeepSeek-Coder-V2-Instruct", + }, + }, + template="deepseek", +) + + +register_model_group( + models={ + "DeepSeekCoder-6.7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-base", + }, + "DeepSeekCoder-7B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-base-v1.5", + }, + "DeepSeekCoder-33B-Base": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-base", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-base", + }, + "DeepSeekCoder-6.7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-6.7b-instruct", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-6.7b-instruct", + }, + "DeepSeekCoder-7B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-7b-instruct-v1.5", + }, + "DeepSeekCoder-33B-Chat": { + DownloadSource.DEFAULT: "deepseek-ai/deepseek-coder-33b-instruct", + DownloadSource.MODELSCOPE: "deepseek-ai/deepseek-coder-33b-instruct", + }, + }, + template="deepseekcoder", +) + + +register_model_group( + models={ + "Falcon-7B": { + DownloadSource.DEFAULT: "tiiuae/falcon-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b", + }, + "Falcon-11B": { + DownloadSource.DEFAULT: "tiiuae/falcon-11B", + }, + "Falcon-40B": { + DownloadSource.DEFAULT: "tiiuae/falcon-40b", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b", + }, + "Falcon-180B": { + DownloadSource.DEFAULT: "tiiuae/falcon-180b", + DownloadSource.MODELSCOPE: "modelscope/falcon-180B", + }, + "Falcon-7B-Chat": { + DownloadSource.DEFAULT: "tiiuae/falcon-7b-instruct", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-7b-instruct", + }, + "Falcon-40B-Chat": { + DownloadSource.DEFAULT: "tiiuae/falcon-40b-instruct", + DownloadSource.MODELSCOPE: "AI-ModelScope/falcon-40b-instruct", + }, + "Falcon-180B-Chat": { + DownloadSource.DEFAULT: "tiiuae/falcon-180b-chat", + DownloadSource.MODELSCOPE: "modelscope/falcon-180B-chat", + }, + }, + template="falcon", +) + + +register_model_group( + models={ + "Gemma-2B": { + DownloadSource.DEFAULT: "google/gemma-2b", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b", + }, + "Gemma-7B": { + DownloadSource.DEFAULT: "google/gemma-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-2b-it", + }, + "Gemma-2B-Chat": { + DownloadSource.DEFAULT: "google/gemma-2b-it", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b", + }, + "Gemma-7B-Chat": { + DownloadSource.DEFAULT: "google/gemma-7b-it", + DownloadSource.MODELSCOPE: "AI-ModelScope/gemma-7b-it", + }, + "Gemma-1.1-2B-Chat": { + DownloadSource.DEFAULT: "google/gemma-1.1-2b-it", + }, + "Gemma-1.1-7B-Chat": { + DownloadSource.DEFAULT: "google/gemma-1.1-7b-it", + }, + "Gemma-2-9B": { + DownloadSource.DEFAULT: "google/gemma-2-9b", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b", + }, + "Gemma-2-27B": { + DownloadSource.DEFAULT: "google/gemma-2-27b", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b", + }, + "Gemma-2-9B-Chat": { + DownloadSource.DEFAULT: "google/gemma-2-9b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-9b-it", + }, + "Gemma-2-27B-Chat": { + DownloadSource.DEFAULT: "google/gemma-2-27b-it", + DownloadSource.MODELSCOPE: "LLM-Research/gemma-2-27b-it", + }, + }, + template="gemma", +) + + +register_model_group( + models={ + "GLM-4-9B": { + DownloadSource.DEFAULT: "THUDM/glm-4-9b", + DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b", + }, + "GLM-4-9B-Chat": { + DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat", + DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat", + }, + "GLM-4-9B-1M-Chat": { + DownloadSource.DEFAULT: "THUDM/glm-4-9b-chat-1m", + DownloadSource.MODELSCOPE: "ZhipuAI/glm-4-9b-chat-1m", + }, + }, + template="glm4", +) + + +register_model_group( + models={ + "InternLM-7B": { + DownloadSource.DEFAULT: "internlm/internlm-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-7b", + }, + "InternLM-20B": { + DownloadSource.DEFAULT: "internlm/internlm-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-20b", + }, + "InternLM-7B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm-chat-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-7b", + }, + "InternLM-20B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm-chat-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm-chat-20b", + }, + }, + template="intern", +) + + +register_model_group( + models={ + "InternLM2-7B": { + DownloadSource.DEFAULT: "internlm/internlm2-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-7b", + }, + "InternLM2-20B": { + DownloadSource.DEFAULT: "internlm/internlm2-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-20b", + }, + "InternLM2-7B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm2-chat-7b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-7b", + }, + "InternLM2-20B-Chat": { + DownloadSource.DEFAULT: "internlm/internlm2-chat-20b", + DownloadSource.MODELSCOPE: "Shanghai_AI_Laboratory/internlm2-chat-20b", + }, + }, + template="intern2", +) + + +register_model_group( + models={ + "Jamba-v0.1": { + DownloadSource.DEFAULT: "ai21labs/Jamba-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Jamba-v0.1", + } + }, +) + + +register_model_group( + models={ + "LingoWhale-8B": { + DownloadSource.DEFAULT: "deeplang-ai/LingoWhale-8B", + DownloadSource.MODELSCOPE: "DeepLang/LingoWhale-8B", + } + }, +) + + +register_model_group( + models={ + "LLaMA-7B": { + DownloadSource.DEFAULT: "huggyllama/llama-7b", + DownloadSource.MODELSCOPE: "skyline2006/llama-7b", + }, + "LLaMA-13B": { + DownloadSource.DEFAULT: "huggyllama/llama-13b", + DownloadSource.MODELSCOPE: "skyline2006/llama-13b", + }, + "LLaMA-30B": { + DownloadSource.DEFAULT: "huggyllama/llama-30b", + DownloadSource.MODELSCOPE: "skyline2006/llama-30b", + }, + "LLaMA-65B": { + DownloadSource.DEFAULT: "huggyllama/llama-65b", + DownloadSource.MODELSCOPE: "skyline2006/llama-65b", + }, + } +) + + +register_model_group( + models={ + "LLaMA2-7B": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-ms", + }, + "LLaMA2-13B": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-ms", + }, + "LLaMA2-70B": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-ms", + }, + "LLaMA2-7B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-7b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-7b-chat-ms", + }, + "LLaMA2-13B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-13b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-13b-chat-ms", + }, + "LLaMA2-70B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Llama-2-70b-chat-hf", + DownloadSource.MODELSCOPE: "modelscope/Llama-2-70b-chat-ms", + }, + }, + template="llama2", +) + + +register_model_group( + models={ + "LLaMA3-8B": { + DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B", + DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B", + }, + "LLaMA3-70B": { + DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B", + DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B", + }, + "LLaMA3-8B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-8B-Instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-8B-Instruct", + }, + "LLaMA3-70B-Chat": { + DownloadSource.DEFAULT: "meta-llama/Meta-Llama-3-70B-Instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Meta-Llama-3-70B-Instruct", + }, + "LLaMA3-8B-Chinese-Chat": { + DownloadSource.DEFAULT: "shenzhi-wang/Llama3-8B-Chinese-Chat", + DownloadSource.MODELSCOPE: "LLM-Research/Llama3-8B-Chinese-Chat", + }, + "LLaMA3-70B-Chinese-Chat": { + DownloadSource.DEFAULT: "shenzhi-wang/Llama3-70B-Chinese-Chat", + }, + }, + template="llama3", +) + + +register_model_group( + models={ + "LLaVA1.5-7B-Chat": { + DownloadSource.DEFAULT: "llava-hf/llava-1.5-7b-hf", + }, + "LLaVA1.5-13B-Chat": { + DownloadSource.DEFAULT: "llava-hf/llava-1.5-13b-hf", + }, + }, + template="vicuna", + vision=True, +) + + +register_model_group( + models={ + "MiniCPM-2B-SFT-Chat": { + DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-sft-bf16", + DownloadSource.MODELSCOPE: "OpenBMB/miniCPM-bf16", + }, + "MiniCPM-2B-DPO-Chat": { + DownloadSource.DEFAULT: "openbmb/MiniCPM-2B-dpo-bf16", + DownloadSource.MODELSCOPE: "OpenBMB/MiniCPM-2B-dpo-bf16", + }, + }, + template="cpm", +) + + +register_model_group( + models={ + "Mistral-7B-v0.1": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.1", + }, + "Mistral-7B-v0.1-Chat": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.1", + }, + "Mistral-7B-v0.2": { + DownloadSource.DEFAULT: "alpindale/Mistral-7B-v0.2-hf", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-v0.2-hf", + }, + "Mistral-7B-v0.2-Chat": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.2", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mistral-7B-Instruct-v0.2", + }, + "Mistral-7B-v0.3": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-v0.3", + }, + "Mistral-7B-v0.3-Chat": { + DownloadSource.DEFAULT: "mistralai/Mistral-7B-Instruct-v0.3", + }, + }, + template="mistral", +) + + +register_model_group( + models={ + "Mixtral-8x7B-v0.1": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-v0.1", + }, + "Mixtral-8x7B-v0.1-Chat": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x7B-Instruct-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x7B-Instruct-v0.1", + }, + "Mixtral-8x22B-v0.1": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-v0.1", + }, + "Mixtral-8x22B-v0.1-Chat": { + DownloadSource.DEFAULT: "mistralai/Mixtral-8x22B-Instruct-v0.1", + DownloadSource.MODELSCOPE: "AI-ModelScope/Mixtral-8x22B-Instruct-v0.1", + }, + }, + template="mistral", +) + + +register_model_group( + models={ + "OLMo-1B": { + DownloadSource.DEFAULT: "allenai/OLMo-1B-hf", + }, + "OLMo-7B": { + DownloadSource.DEFAULT: "allenai/OLMo-7B-hf", + }, + "OLMo-7B-Chat": { + DownloadSource.DEFAULT: "ssec-uw/OLMo-7B-Instruct-hf", + }, + "OLMo-1.7-7B": { + DownloadSource.DEFAULT: "allenai/OLMo-1.7-7B-hf", + }, + }, +) + + +register_model_group( + models={ + "OpenChat3.5-7B-Chat": { + DownloadSource.DEFAULT: "openchat/openchat-3.5-0106", + DownloadSource.MODELSCOPE: "xcwzxcwz/openchat-3.5-0106", + } + }, + template="openchat", +) + + +register_model_group( + models={ + "OpenChat3.6-8B-Chat": { + DownloadSource.DEFAULT: "openchat/openchat-3.6-8b-20240522", + } + }, + template="openchat-3.6", +) + + +register_model_group( + models={ + "Orion-14B-Base": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Base", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Base", + }, + "Orion-14B-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat", + }, + "Orion-14B-Long-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-LongChat", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-LongChat", + }, + "Orion-14B-RAG-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-RAG", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-RAG", + }, + "Orion-14B-Plugin-Chat": { + DownloadSource.DEFAULT: "OrionStarAI/Orion-14B-Chat-Plugin", + DownloadSource.MODELSCOPE: "OrionStarAI/Orion-14B-Chat-Plugin", + }, + }, + template="orion", +) + + +register_model_group( + models={ + "PaliGemma-3B-pt-224": { + DownloadSource.DEFAULT: "google/paligemma-3b-pt-224", + DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-224", + }, + "PaliGemma-3B-pt-448": { + DownloadSource.DEFAULT: "google/paligemma-3b-pt-448", + DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-448", + }, + "PaliGemma-3B-pt-896": { + DownloadSource.DEFAULT: "google/paligemma-3b-pt-896", + DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-pt-896", + }, + "PaliGemma-3B-mix-224": { + DownloadSource.DEFAULT: "google/paligemma-3b-mix-224", + DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-224", + }, + "PaliGemma-3B-mix-448": { + DownloadSource.DEFAULT: "google/paligemma-3b-mix-448", + DownloadSource.MODELSCOPE: "AI-ModelScope/paligemma-3b-mix-448", + }, + }, + vision=True, +) + + +register_model_group( + models={ + "Phi-1.5-1.3B": { + DownloadSource.DEFAULT: "microsoft/phi-1_5", + DownloadSource.MODELSCOPE: "allspace/PHI_1-5", + }, + "Phi-2-2.7B": { + DownloadSource.DEFAULT: "microsoft/phi-2", + DownloadSource.MODELSCOPE: "AI-ModelScope/phi-2", + }, + } +) + + +register_model_group( + models={ + "Phi3-4B-4k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-mini-4k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-4k-instruct", + }, + "Phi3-4B-128k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-mini-128k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-mini-128k-instruct", + }, + "Phi3-7B-8k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-small-8k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-8k-instruct", + }, + "Phi3-7B-128k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-small-128k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-small-128k-instruct", + }, + "Phi3-14B-8k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-medium-4k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-4k-instruct", + }, + "Phi3-14B-128k-Chat": { + DownloadSource.DEFAULT: "microsoft/Phi-3-medium-128k-instruct", + DownloadSource.MODELSCOPE: "LLM-Research/Phi-3-medium-128k-instruct", + }, + }, + template="phi", +) + + +register_model_group( + models={ + "Qwen-1.8B": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B", + }, + "Qwen-7B": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B", + }, + "Qwen-14B": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B", + }, + "Qwen-72B": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B", + }, + "Qwen-1.8B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat", + }, + "Qwen-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat", + }, + "Qwen-14B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat", + }, + "Qwen-72B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat", + }, + "Qwen-1.8B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int8", + }, + "Qwen-1.8B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-1_8B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-1_8B-Chat-Int4", + }, + "Qwen-7B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int8", + }, + "Qwen-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-7B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-7B-Chat-Int4", + }, + "Qwen-14B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int8", + }, + "Qwen-14B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-14B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-14B-Chat-Int4", + }, + "Qwen-72B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int8", + }, + "Qwen-72B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen-72B-Chat-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen-72B-Chat-Int4", + }, + }, + template="qwen", +) + + +register_model_group( + models={ + "Qwen1.5-0.5B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B", + }, + "Qwen1.5-1.8B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B", + }, + "Qwen1.5-4B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B", + }, + "Qwen1.5-7B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B", + }, + "Qwen1.5-14B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B", + }, + "Qwen1.5-32B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B", + }, + "Qwen1.5-72B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B", + }, + "Qwen1.5-110B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B", + }, + "Qwen1.5-MoE-A2.7B": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B", + }, + "Qwen1.5-Code-7B": { + DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B", + DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B", + }, + "Qwen1.5-0.5B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat", + }, + "Qwen1.5-1.8B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat", + }, + "Qwen1.5-4B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat", + }, + "Qwen1.5-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat", + }, + "Qwen1.5-14B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat", + }, + "Qwen1.5-32B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat", + }, + "Qwen1.5-72B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat", + }, + "Qwen1.5-110B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat", + }, + "Qwen1.5-MoE-A2.7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat", + }, + "Qwen1.5-Code-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat", + DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat", + }, + "Qwen1.5-0.5B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-GPTQ-Int8", + }, + "Qwen1.5-0.5B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-0.5B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-0.5B-Chat-AWQ", + }, + "Qwen1.5-1.8B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-GPTQ-Int8", + }, + "Qwen1.5-1.8B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-1.8B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-1.8B-Chat-AWQ", + }, + "Qwen1.5-4B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-GPTQ-Int8", + }, + "Qwen1.5-4B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-4B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-4B-Chat-AWQ", + }, + "Qwen1.5-7B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-GPTQ-Int8", + }, + "Qwen1.5-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-7B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-7B-Chat-AWQ", + }, + "Qwen1.5-14B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-GPTQ-Int8", + }, + "Qwen1.5-14B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-14B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-14B-Chat-AWQ", + }, + "Qwen1.5-32B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-32B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-32B-Chat-AWQ", + }, + "Qwen1.5-72B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-GPTQ-Int8", + }, + "Qwen1.5-72B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-72B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-72B-Chat-AWQ", + }, + "Qwen1.5-110B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-110B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-110B-Chat-AWQ", + }, + "Qwen1.5-MoE-A2.7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen1.5-MoE-A2.7B-Chat-GPTQ-Int4", + }, + "Qwen1.5-Code-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/CodeQwen1.5-7B-Chat-AWQ", + DownloadSource.MODELSCOPE: "qwen/CodeQwen1.5-7B-Chat-AWQ", + }, + }, + template="qwen", +) + + +register_model_group( + models={ + "Qwen2-0.5B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B", + }, + "Qwen2-1.5B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B", + }, + "Qwen2-7B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B", + }, + "Qwen2-72B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B", + }, + "Qwen2-MoE-57B": { + DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B", + DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B", + }, + "Qwen2-0.5B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct", + }, + "Qwen2-1.5B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct", + }, + "Qwen2-7B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct", + }, + "Qwen2-72B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct", + }, + "Qwen2-MoE-57B-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct", + DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct", + }, + "Qwen2-0.5B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-GPTQ-Int8", + }, + "Qwen2-0.5B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-0.5B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-0.5B-Instruct-AWQ", + }, + "Qwen2-1.5B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-GPTQ-Int8", + }, + "Qwen2-1.5B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-1.5B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-1.5B-Instruct-AWQ", + }, + "Qwen2-7B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-GPTQ-Int8", + }, + "Qwen2-7B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-7B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-7B-Instruct-AWQ", + }, + "Qwen2-72B-int8-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-GPTQ-Int8", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-GPTQ-Int8", + }, + "Qwen2-72B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-72B-Instruct-AWQ", + DownloadSource.MODELSCOPE: "qwen/Qwen2-72B-Instruct-AWQ", + }, + "Qwen2-MoE-57B-int4-Chat": { + DownloadSource.DEFAULT: "Qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4", + DownloadSource.MODELSCOPE: "qwen/Qwen2-57B-A14B-Instruct-GPTQ-Int4", + }, + }, + template="qwen", +) + + +register_model_group( + models={ + "SOLAR-10.7B": { + DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-v1.0", + }, + "SOLAR-10.7B-Chat": { + DownloadSource.DEFAULT: "upstage/SOLAR-10.7B-Instruct-v1.0", + DownloadSource.MODELSCOPE: "AI-ModelScope/SOLAR-10.7B-Instruct-v1.0", + }, + }, + template="solar", +) + + +register_model_group( + models={ + "Skywork-13B-Base": { + DownloadSource.DEFAULT: "Skywork/Skywork-13B-base", + DownloadSource.MODELSCOPE: "skywork/Skywork-13B-base", + } + } +) + + +register_model_group( + models={ + "StarCoder2-3B": { + DownloadSource.DEFAULT: "bigcode/starcoder2-3b", + DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-3b", + }, + "StarCoder2-7B": { + DownloadSource.DEFAULT: "bigcode/starcoder2-7b", + DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-7b", + }, + "StarCoder2-15B": { + DownloadSource.DEFAULT: "bigcode/starcoder2-15b", + DownloadSource.MODELSCOPE: "AI-ModelScope/starcoder2-15b", + }, + } +) + + +register_model_group( + models={ + "TeleChat-1B-Chat": { + DownloadSource.DEFAULT: "Tele-AI/TeleChat-1B", + DownloadSource.MODELSCOPE: "TeleAI/TeleChat-1B", + }, + "TeleChat-7B-Chat": { + DownloadSource.DEFAULT: "Tele-AI/telechat-7B", + DownloadSource.MODELSCOPE: "TeleAI/telechat-7B", + }, + "TeleChat-12B-Chat": { + DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B", + DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B", + }, + "TeleChat-12B-v2-Chat": { + DownloadSource.DEFAULT: "Tele-AI/TeleChat-12B-v2", + DownloadSource.MODELSCOPE: "TeleAI/TeleChat-12B-v2", + }, + }, + template="telechat", +) + + +register_model_group( + models={ + "Vicuna1.5-7B-Chat": { + DownloadSource.DEFAULT: "lmsys/vicuna-7b-v1.5", + DownloadSource.MODELSCOPE: "Xorbits/vicuna-7b-v1.5", + }, + "Vicuna1.5-13B-Chat": { + DownloadSource.DEFAULT: "lmsys/vicuna-13b-v1.5", + DownloadSource.MODELSCOPE: "Xorbits/vicuna-13b-v1.5", + }, + }, + template="vicuna", +) + + +register_model_group( + models={ + "XuanYuan-6B": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B", + }, + "XuanYuan-70B": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B", + }, + "XuanYuan-2-70B": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B", + }, + "XuanYuan-6B-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat", + }, + "XuanYuan-70B-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat", + }, + "XuanYuan-2-70B-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat", + }, + "XuanYuan-6B-int8-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-8bit", + }, + "XuanYuan-6B-int4-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-6B-Chat-4bit", + }, + "XuanYuan-70B-int8-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-8bit", + }, + "XuanYuan-70B-int4-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan-70B-Chat-4bit", + }, + "XuanYuan-2-70B-int8-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-8bit", + }, + "XuanYuan-2-70B-int4-Chat": { + DownloadSource.DEFAULT: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit", + DownloadSource.MODELSCOPE: "Duxiaoman-DI/XuanYuan2-70B-Chat-4bit", + }, + }, + template="xuanyuan", +) + + +register_model_group( + models={ + "XVERSE-7B": { + DownloadSource.DEFAULT: "xverse/XVERSE-7B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B", + }, + "XVERSE-13B": { + DownloadSource.DEFAULT: "xverse/XVERSE-13B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B", + }, + "XVERSE-65B": { + DownloadSource.DEFAULT: "xverse/XVERSE-65B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B", + }, + "XVERSE-65B-2": { + DownloadSource.DEFAULT: "xverse/XVERSE-65B-2", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-2", + }, + "XVERSE-7B-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat", + }, + "XVERSE-13B-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat", + }, + "XVERSE-65B-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat", + }, + "XVERSE-MoE-A4.2B": { + DownloadSource.DEFAULT: "xverse/XVERSE-MoE-A4.2B", + DownloadSource.MODELSCOPE: "xverse/XVERSE-MoE-A4.2B", + }, + "XVERSE-7B-int8-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int8", + }, + "XVERSE-7B-int4-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-7B-Chat-GPTQ-Int4", + DownloadSource.MODELSCOPE: "xverse/XVERSE-7B-Chat-GPTQ-Int4", + }, + "XVERSE-13B-int8-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int8", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int8", + }, + "XVERSE-13B-int4-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-13B-Chat-GPTQ-Int4", + DownloadSource.MODELSCOPE: "xverse/XVERSE-13B-Chat-GPTQ-Int4", + }, + "XVERSE-65B-int4-Chat": { + DownloadSource.DEFAULT: "xverse/XVERSE-65B-Chat-GPTQ-Int4", + DownloadSource.MODELSCOPE: "xverse/XVERSE-65B-Chat-GPTQ-Int4", + }, + }, + template="xverse", +) + + +register_model_group( + models={ + "Yayi-7B": { + DownloadSource.DEFAULT: "wenge-research/yayi-7b-llama2", + DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-7b-llama2", + }, + "Yayi-13B": { + DownloadSource.DEFAULT: "wenge-research/yayi-13b-llama2", + DownloadSource.MODELSCOPE: "AI-ModelScope/yayi-13b-llama2", + }, + }, + template="yayi", +) + + +register_model_group( + models={ + "Yi-6B": { + DownloadSource.DEFAULT: "01-ai/Yi-6B", + DownloadSource.MODELSCOPE: "01ai/Yi-6B", + }, + "Yi-9B": { + DownloadSource.DEFAULT: "01-ai/Yi-9B", + DownloadSource.MODELSCOPE: "01ai/Yi-9B", + }, + "Yi-34B": { + DownloadSource.DEFAULT: "01-ai/Yi-34B", + DownloadSource.MODELSCOPE: "01ai/Yi-34B", + }, + "Yi-6B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat", + }, + "Yi-34B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat", + }, + "Yi-6B-int8-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-8bits", + DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-8bits", + }, + "Yi-6B-int4-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-6B-Chat-4bits", + DownloadSource.MODELSCOPE: "01ai/Yi-6B-Chat-4bits", + }, + "Yi-34B-int8-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-8bits", + DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-8bits", + }, + "Yi-34B-int4-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-34B-Chat-4bits", + DownloadSource.MODELSCOPE: "01ai/Yi-34B-Chat-4bits", + }, + "Yi-1.5-6B": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B", + }, + "Yi-1.5-9B": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B", + }, + "Yi-1.5-34B": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B", + }, + "Yi-1.5-6B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-6B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-6B-Chat", + }, + "Yi-1.5-9B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-9B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-9B-Chat", + }, + "Yi-1.5-34B-Chat": { + DownloadSource.DEFAULT: "01-ai/Yi-1.5-34B-Chat", + DownloadSource.MODELSCOPE: "01ai/Yi-1.5-34B-Chat", + }, + }, + template="yi", +) + + +register_model_group( + models={ + "YiVL-6B-Chat": { + DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-6B-hf", + }, + "YiVL-34B-Chat": { + DownloadSource.DEFAULT: "BUAADreamer/Yi-VL-34B-hf", + }, + }, + template="yi_vl", + vision=True, +) + + +register_model_group( + models={ + "Yuan2-2B-Chat": { + DownloadSource.DEFAULT: "IEITYuan/Yuan2-2B-hf", + DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-2B-hf", + }, + "Yuan2-51B-Chat": { + DownloadSource.DEFAULT: "IEITYuan/Yuan2-51B-hf", + DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-51B-hf", + }, + "Yuan2-102B-Chat": { + DownloadSource.DEFAULT: "IEITYuan/Yuan2-102B-hf", + DownloadSource.MODELSCOPE: "YuanLLM/Yuan2.0-102B-hf", + }, + }, + template="yuan", +) + + +register_model_group( + models={ + "Zephyr-7B-Alpha-Chat": { + DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-alpha", + DownloadSource.MODELSCOPE: "AI-ModelScope/zephyr-7b-alpha", + }, + "Zephyr-7B-Beta-Chat": { + DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-7b-beta", + DownloadSource.MODELSCOPE: "modelscope/zephyr-7b-beta", + }, + "Zephyr-141B-ORPO-Chat": { + DownloadSource.DEFAULT: "HuggingFaceH4/zephyr-orpo-141b-A35b-v0.1", + }, + }, + template="zephyr", +) diff --git a/src/extras/env.py b/src/extras/env.py new file mode 100644 index 0000000000000000000000000000000000000000..14876048742b8a155d6058477558fa11e70dcbbd --- /dev/null +++ b/src/extras/env.py @@ -0,0 +1,75 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/commands/env.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import platform + +import accelerate +import datasets +import peft +import torch +import transformers +import trl +from transformers.utils import is_torch_cuda_available, is_torch_npu_available + + +VERSION = "0.8.3.dev0" + + +def print_env() -> None: + info = { + "`llamafactory` version": VERSION, + "Platform": platform.platform(), + "Python version": platform.python_version(), + "PyTorch version": torch.__version__, + "Transformers version": transformers.__version__, + "Datasets version": datasets.__version__, + "Accelerate version": accelerate.__version__, + "PEFT version": peft.__version__, + "TRL version": trl.__version__, + } + + if is_torch_cuda_available(): + info["PyTorch version"] += " (GPU)" + info["GPU type"] = torch.cuda.get_device_name() + + if is_torch_npu_available(): + info["PyTorch version"] += " (NPU)" + info["NPU type"] = torch.npu.get_device_name() + info["CANN version"] = torch.version.cann + + try: + import deepspeed # type: ignore + + info["DeepSpeed version"] = deepspeed.__version__ + except Exception: + pass + + try: + import bitsandbytes + + info["Bitsandbytes version"] = bitsandbytes.__version__ + except Exception: + pass + + try: + import vllm + + info["vLLM version"] = vllm.__version__ + except Exception: + pass + + print("\n" + "\n".join(["- {}: {}".format(key, value) for key, value in info.items()]) + "\n") diff --git a/src/extras/logging.py b/src/extras/logging.py new file mode 100644 index 0000000000000000000000000000000000000000..676222127fc18a7212fbe9d633ccd05e3e156271 --- /dev/null +++ b/src/extras/logging.py @@ -0,0 +1,82 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +from concurrent.futures import ThreadPoolExecutor + +from .constants import RUNNING_LOG + + +class LoggerHandler(logging.Handler): + r""" + Logger handler used in Web UI. + """ + + def __init__(self, output_dir: str) -> None: + super().__init__() + formatter = logging.Formatter( + fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" + ) + self.setLevel(logging.INFO) + self.setFormatter(formatter) + + os.makedirs(output_dir, exist_ok=True) + self.running_log = os.path.join(output_dir, RUNNING_LOG) + if os.path.exists(self.running_log): + os.remove(self.running_log) + + self.thread_pool = ThreadPoolExecutor(max_workers=1) + + def _write_log(self, log_entry: str) -> None: + with open(self.running_log, "a", encoding="utf-8") as f: + f.write(log_entry + "\n\n") + + def emit(self, record) -> None: + if record.name == "httpx": + return + + log_entry = self.format(record) + self.thread_pool.submit(self._write_log, log_entry) + + def close(self) -> None: + self.thread_pool.shutdown(wait=True) + return super().close() + + +def get_logger(name: str) -> logging.Logger: + r""" + Gets a standard logger with a stream hander to stdout. + """ + formatter = logging.Formatter( + fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" + ) + handler = logging.StreamHandler(sys.stdout) + handler.setFormatter(formatter) + + logger = logging.getLogger(name) + logger.setLevel(logging.INFO) + logger.addHandler(handler) + + return logger + + +def reset_logging() -> None: + r""" + Removes basic config of root logger. (unused in script) + """ + root = logging.getLogger() + list(map(root.removeHandler, root.handlers)) + list(map(root.removeFilter, root.filters)) diff --git a/src/extras/misc.py b/src/extras/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..20c752c56960819a551603ca6f441277f6f62cce --- /dev/null +++ b/src/extras/misc.py @@ -0,0 +1,215 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's PEFT library. +# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/peft_model.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gc +import os +from typing import TYPE_CHECKING, Tuple + +import torch +import transformers.dynamic_module_utils +from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList +from transformers.dynamic_module_utils import get_relative_imports +from transformers.utils import ( + is_torch_bf16_gpu_available, + is_torch_cuda_available, + is_torch_mps_available, + is_torch_npu_available, + is_torch_xpu_available, +) +from transformers.utils.versions import require_version + +from .logging import get_logger + + +_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() +try: + _is_bf16_available = is_torch_bf16_gpu_available() +except Exception: + _is_bf16_available = False + + +if TYPE_CHECKING: + from ..hparams import ModelArguments + + +logger = get_logger(__name__) + + +class AverageMeter: + r""" + Computes and stores the average and current value. + """ + + def __init__(self): + self.reset() + + def reset(self): + self.val = 0 + self.avg = 0 + self.sum = 0 + self.count = 0 + + def update(self, val, n=1): + self.val = val + self.sum += val * n + self.count += n + self.avg = self.sum / self.count + + +def check_dependencies() -> None: + r""" + Checks the version of the required packages. + """ + if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: + logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") + else: + require_version("transformers>=4.41.2", "To fix: pip install transformers>=4.41.2") + require_version("datasets>=2.16.0", "To fix: pip install datasets>=2.16.0") + require_version("accelerate>=0.30.1", "To fix: pip install accelerate>=0.30.1") + require_version("peft>=0.11.1", "To fix: pip install peft>=0.11.1") + require_version("trl>=0.8.6", "To fix: pip install trl>=0.8.6") + + +def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]: + r""" + Returns the number of trainable parameters and number of all parameters in the model. + """ + trainable_params, all_param = 0, 0 + for param in model.parameters(): + num_params = param.numel() + # if using DS Zero 3 and the weights are initialized empty + if num_params == 0 and hasattr(param, "ds_numel"): + num_params = param.ds_numel + + # Due to the design of 4bit linear layers from bitsandbytes, multiply the number of parameters by itemsize + if param.__class__.__name__ == "Params4bit": + if hasattr(param, "quant_storage") and hasattr(param.quant_storage, "itemsize"): + num_bytes = param.quant_storage.itemsize + elif hasattr(param, "element_size"): # for older pytorch version + num_bytes = param.element_size() + else: + num_bytes = 1 + + num_params = num_params * 2 * num_bytes + + all_param += num_params + if param.requires_grad: + trainable_params += num_params + + return trainable_params, all_param + + +def get_current_device() -> "torch.device": + r""" + Gets the current available device. + """ + if is_torch_xpu_available(): + device = "xpu:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif is_torch_npu_available(): + device = "npu:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif is_torch_mps_available(): + device = "mps:{}".format(os.environ.get("LOCAL_RANK", "0")) + elif is_torch_cuda_available(): + device = "cuda:{}".format(os.environ.get("LOCAL_RANK", "0")) + else: + device = "cpu" + + return torch.device(device) + + +def get_device_count() -> int: + r""" + Gets the number of available GPU or NPU devices. + """ + if is_torch_npu_available(): + return torch.npu.device_count() + elif is_torch_cuda_available(): + return torch.cuda.device_count() + else: + return 0 + + +def get_logits_processor() -> "LogitsProcessorList": + r""" + Gets logits processor that removes NaN and Inf logits. + """ + logits_processor = LogitsProcessorList() + logits_processor.append(InfNanRemoveLogitsProcessor()) + return logits_processor + + +def has_tokenized_data(path: "os.PathLike") -> bool: + r""" + Checks if the path has a tokenized dataset. + """ + return os.path.isdir(path) and len(os.listdir(path)) > 0 + + +def infer_optim_dtype(model_dtype: "torch.dtype") -> "torch.dtype": + r""" + Infers the optimal dtype according to the model_dtype and device compatibility. + """ + if _is_bf16_available and model_dtype == torch.bfloat16: + return torch.bfloat16 + elif _is_fp16_available: + return torch.float16 + else: + return torch.float32 + + +def is_gpu_or_npu_available() -> bool: + r""" + Checks if the GPU or NPU is available. + """ + return is_torch_npu_available() or is_torch_cuda_available() + + +def skip_check_imports() -> None: + if os.environ.get("FORCE_CHECK_IMPORTS", "0").lower() not in ["true", "1"]: + transformers.dynamic_module_utils.check_imports = get_relative_imports + + +def torch_gc() -> None: + r""" + Collects GPU or NPU memory. + """ + gc.collect() + if is_torch_xpu_available(): + torch.xpu.empty_cache() + elif is_torch_npu_available(): + torch.npu.empty_cache() + elif is_torch_mps_available(): + torch.mps.empty_cache() + elif is_torch_cuda_available(): + torch.cuda.empty_cache() + + +def try_download_model_from_ms(model_args: "ModelArguments") -> str: + if not use_modelscope() or os.path.exists(model_args.model_name_or_path): + return model_args.model_name_or_path + + try: + from modelscope import snapshot_download + + revision = "master" if model_args.model_revision == "main" else model_args.model_revision + return snapshot_download(model_args.model_name_or_path, revision=revision, cache_dir=model_args.cache_dir) + except ImportError: + raise ImportError("Please install modelscope via `pip install modelscope -U`") + + +def use_modelscope() -> bool: + return os.environ.get("USE_MODELSCOPE_HUB", "0").lower() in ["true", "1"] diff --git a/src/extras/packages.py b/src/extras/packages.py new file mode 100644 index 0000000000000000000000000000000000000000..0a84a293dbc70f6f45e1b18434101a70fb961a57 --- /dev/null +++ b/src/extras/packages.py @@ -0,0 +1,83 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/import_utils.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import importlib.metadata +import importlib.util +from functools import lru_cache +from typing import TYPE_CHECKING + +from packaging import version + + +if TYPE_CHECKING: + from packaging.version import Version + + +def _is_package_available(name: str) -> bool: + return importlib.util.find_spec(name) is not None + + +def _get_package_version(name: str) -> "Version": + try: + return version.parse(importlib.metadata.version(name)) + except Exception: + return version.parse("0.0.0") + + +def is_fastapi_available(): + return _is_package_available("fastapi") + + +def is_galore_available(): + return _is_package_available("galore_torch") + + +def is_gradio_available(): + return _is_package_available("gradio") + + +def is_matplotlib_available(): + return _is_package_available("matplotlib") + + +def is_pillow_available(): + return _is_package_available("PIL") + + +def is_requests_available(): + return _is_package_available("requests") + + +def is_rouge_available(): + return _is_package_available("rouge_chinese") + + +def is_starlette_available(): + return _is_package_available("sse_starlette") + + +def is_uvicorn_available(): + return _is_package_available("uvicorn") + + +def is_vllm_available(): + return _is_package_available("vllm") + + +@lru_cache +def is_vllm_version_greater_than_0_5(): + return _get_package_version("vllm") >= version.parse("0.5.0") diff --git a/src/extras/ploting.py b/src/extras/ploting.py new file mode 100644 index 0000000000000000000000000000000000000000..596d55e7da89dd234519200b645532059691de3b --- /dev/null +++ b/src/extras/ploting.py @@ -0,0 +1,101 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import math +import os +from typing import Any, Dict, List + +from transformers.trainer import TRAINER_STATE_NAME + +from .logging import get_logger +from .packages import is_matplotlib_available + + +if is_matplotlib_available(): + import matplotlib.figure + import matplotlib.pyplot as plt + + +logger = get_logger(__name__) + + +def smooth(scalars: List[float]) -> List[float]: + r""" + EMA implementation according to TensorBoard. + """ + if len(scalars) == 0: + return [] + + last = scalars[0] + smoothed = [] + weight = 1.8 * (1 / (1 + math.exp(-0.05 * len(scalars))) - 0.5) # a sigmoid function + for next_val in scalars: + smoothed_val = last * weight + (1 - weight) * next_val + smoothed.append(smoothed_val) + last = smoothed_val + return smoothed + + +def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figure": + r""" + Plots loss curves in LlamaBoard. + """ + plt.close("all") + plt.switch_backend("agg") + fig = plt.figure() + ax = fig.add_subplot(111) + steps, losses = [], [] + for log in trainer_log: + if log.get("loss", None): + steps.append(log["current_steps"]) + losses.append(log["loss"]) + + ax.plot(steps, losses, color="#1f77b4", alpha=0.4, label="original") + ax.plot(steps, smooth(losses), color="#1f77b4", label="smoothed") + ax.legend() + ax.set_xlabel("step") + ax.set_ylabel("loss") + return fig + + +def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None: + r""" + Plots loss curves and saves the image. + """ + plt.switch_backend("agg") + with open(os.path.join(save_dictionary, TRAINER_STATE_NAME), "r", encoding="utf-8") as f: + data = json.load(f) + + for key in keys: + steps, metrics = [], [] + for i in range(len(data["log_history"])): + if key in data["log_history"][i]: + steps.append(data["log_history"][i]["step"]) + metrics.append(data["log_history"][i][key]) + + if len(metrics) == 0: + logger.warning(f"No metric {key} to plot.") + continue + + plt.figure() + plt.plot(steps, metrics, color="#1f77b4", alpha=0.4, label="original") + plt.plot(steps, smooth(metrics), color="#1f77b4", label="smoothed") + plt.title("training {} of {}".format(key, save_dictionary)) + plt.xlabel("step") + plt.ylabel(key) + plt.legend() + figure_path = os.path.join(save_dictionary, "training_{}.png".format(key.replace("/", "_"))) + plt.savefig(figure_path, format="png", dpi=100) + print("Figure saved at:", figure_path) diff --git a/src/hparams/__init__.py b/src/hparams/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..d692e2fff481a124a085d020edfe37ac2e26d446 --- /dev/null +++ b/src/hparams/__init__.py @@ -0,0 +1,31 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .data_args import DataArguments +from .evaluation_args import EvaluationArguments +from .finetuning_args import FinetuningArguments +from .generating_args import GeneratingArguments +from .model_args import ModelArguments +from .parser import get_infer_args, get_train_args + + +__all__ = [ + "DataArguments", + "EvaluationArguments", + "FinetuningArguments", + "GeneratingArguments", + "ModelArguments", + "get_infer_args", + "get_train_args", +] diff --git a/src/hparams/data_args.py b/src/hparams/data_args.py new file mode 100644 index 0000000000000000000000000000000000000000..1e331e41d94a75b1a511592db241958a5d40ab94 --- /dev/null +++ b/src/hparams/data_args.py @@ -0,0 +1,110 @@ +# Copyright 2024 HuggingFace Inc., the LlamaFactory team, and the Llamole team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import Literal, Optional + + +@dataclass +class DataArguments: + r""" + Arguments pertaining to what data we are going to input our model for training and evaluation. + """ + + template: Optional[str] = field( + default=None, + metadata={"help": "Which template to use for constructing prompts in training and inference."}, + ) + dataset: Optional[str] = field( + default=None, + metadata={"help": "The name of provided dataset(s) to use. Use commas to separate multiple datasets."}, + ) + dataset_dir: str = field( + default="data", + metadata={"help": "Path to the folder containing the datasets."}, + ) + split: str = field( + default="train", + metadata={"help": "Which dataset split to use for training and evaluation."}, + ) + cutoff_len: int = field( + default=1024, + metadata={"help": "The cutoff length of the tokenized inputs in the dataset."}, + ) + train_on_prompt: bool = field( + default=False, + metadata={"help": "Whether to disable the mask on the prompt or not."}, + ) + # number of learnable queries + learned_query_size: int = field( + default=8, + metadata={"help": "Number of queries to use in training."}, + ) + + # used for evaluation with specific starting idx + start_idx: int = field( + default=0, + metadata={"help": "Start index of the evaluation data for drug"}, + ) + buffer_size: int = field( + default=16384, + metadata={"help": "Size of the buffer to randomly sample examples from in dataset streaming."}, + ) + mix_strategy: Literal["concat", "interleave_under", "interleave_over"] = field( + default="concat", + metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."}, + ) + overwrite_cache: bool = field( + default=False, + metadata={"help": "Overwrite the cached training and evaluation sets."}, + ) + preprocessing_num_workers: Optional[int] = field( + default=None, + metadata={"help": "The number of processes to use for the pre-processing."}, + ) + max_samples: Optional[int] = field( + default=None, + metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."}, + ) + eval_num_beams: Optional[int] = field( + default=None, + metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"}, + ) + ignore_pad_token_for_loss: bool = field( + default=True, + metadata={ + "help": "Whether or not to ignore the tokens corresponding to padded labels in the loss computation." + }, + ) + val_size: float = field( + default=0.0, + metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}, + ) + packing: Optional[bool] = field( + default=None, + metadata={ + "help": "Whether or not to pack the sequences in training. Will automatically enable in pre-training." + }, + ) + tool_format: Optional[str] = field( + default=None, + metadata={"help": "Tool format to use for constructing function calling examples."}, + ) + tokenized_path: Optional[str] = field( + default=None, + metadata={"help": "Path to save or load the tokenized datasets."}, + ) \ No newline at end of file diff --git a/src/hparams/evaluation_args.py b/src/hparams/evaluation_args.py new file mode 100644 index 0000000000000000000000000000000000000000..a7f221ca638ca86d14fa002f814d137b6ca7e917 --- /dev/null +++ b/src/hparams/evaluation_args.py @@ -0,0 +1,62 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from dataclasses import dataclass, field +from typing import Literal, Optional + +from datasets import DownloadMode + + +@dataclass +class EvaluationArguments: + r""" + Arguments pertaining to specify the evaluation parameters. + """ + + task: str = field( + metadata={"help": "Name of the evaluation task."}, + ) + task_dir: str = field( + default="evaluation", + metadata={"help": "Path to the folder containing the evaluation datasets."}, + ) + batch_size: int = field( + default=4, + metadata={"help": "The batch size per GPU for evaluation."}, + ) + seed: int = field( + default=42, + metadata={"help": "Random seed to be used with data loaders."}, + ) + lang: Literal["en", "zh"] = field( + default="en", + metadata={"help": "Language used at evaluation."}, + ) + n_shot: int = field( + default=5, + metadata={"help": "Number of examplars for few-shot learning."}, + ) + save_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to save the evaluation results."}, + ) + download_mode: DownloadMode = field( + default=DownloadMode.REUSE_DATASET_IF_EXISTS, + metadata={"help": "Download mode used for the evaluation datasets."}, + ) + + def __post_init__(self): + if self.save_dir is not None and os.path.exists(self.save_dir): + raise ValueError("`save_dir` already exists, use another one.") diff --git a/src/hparams/finetuning_args.py b/src/hparams/finetuning_args.py new file mode 100644 index 0000000000000000000000000000000000000000..d18c47e539612ca3c3404d8260a69cd98ce25b19 --- /dev/null +++ b/src/hparams/finetuning_args.py @@ -0,0 +1,402 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass, field +from typing import List, Literal, Optional + + +@dataclass +class FreezeArguments: + r""" + Arguments pertaining to the freeze (partial-parameter) training. + """ + + freeze_trainable_layers: int = field( + default=2, + metadata={ + "help": ( + "The number of trainable layers for freeze (partial-parameter) fine-tuning. " + "Positive numbers mean the last n layers are set as trainable, " + "negative numbers mean the first n layers are set as trainable." + ) + }, + ) + freeze_trainable_modules: str = field( + default="all", + metadata={ + "help": ( + "Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. " + "Use commas to separate multiple modules. " + "Use `all` to specify all the available modules." + ) + }, + ) + freeze_extra_modules: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Name(s) of modules apart from hidden layers to be set as trainable " + "for freeze (partial-parameter) fine-tuning. " + "Use commas to separate multiple modules." + ) + }, + ) + + +@dataclass +class LoraArguments: + r""" + Arguments pertaining to the LoRA training. + """ + + additional_target: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Name(s) of modules apart from LoRA layers to be set as trainable " + "and saved in the final checkpoint. " + "Use commas to separate multiple modules." + ) + }, + ) + lora_alpha: Optional[int] = field( + default=None, + metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."}, + ) + lora_dropout: float = field( + default=0.0, + metadata={"help": "Dropout rate for the LoRA fine-tuning."}, + ) + lora_rank: int = field( + default=8, + metadata={"help": "The intrinsic dimension for LoRA fine-tuning."}, + ) + lora_target: str = field( + default="all", + metadata={ + "help": ( + "Name(s) of target modules to apply LoRA. " + "Use commas to separate multiple modules. " + "Use `all` to specify all the linear modules." + ) + }, + ) + loraplus_lr_ratio: Optional[float] = field( + default=None, + metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."}, + ) + loraplus_lr_embedding: float = field( + default=1e-6, + metadata={"help": "LoRA plus learning rate for lora embedding layers."}, + ) + use_rslora: bool = field( + default=False, + metadata={"help": "Whether or not to use the rank stabilization scaling factor for LoRA layer."}, + ) + use_dora: bool = field( + default=False, + metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."}, + ) + pissa_init: bool = field( + default=False, + metadata={"help": "Whether or not to initialize a PiSSA adapter."}, + ) + pissa_iter: int = field( + default=16, + metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."}, + ) + pissa_convert: bool = field( + default=False, + metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."}, + ) + create_new_adapter: bool = field( + default=False, + metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."}, + ) + + +@dataclass +class RLHFArguments: + r""" + Arguments pertaining to the PPO, DPO and KTO training. + """ + + pref_beta: float = field( + default=0.1, + metadata={"help": "The beta parameter in the preference loss."}, + ) + pref_ftx: float = field( + default=0.0, + metadata={"help": "The supervised fine-tuning loss coefficient in DPO training."}, + ) + pref_loss: Literal["sigmoid", "hinge", "ipo", "kto_pair", "orpo", "simpo"] = field( + default="sigmoid", + metadata={"help": "The type of DPO loss to use."}, + ) + dpo_label_smoothing: float = field( + default=0.0, + metadata={"help": "The robust DPO label smoothing parameter in cDPO that should be between 0 and 0.5."}, + ) + kto_chosen_weight: float = field( + default=1.0, + metadata={"help": "The weight factor of the desirable losses in KTO training."}, + ) + kto_rejected_weight: float = field( + default=1.0, + metadata={"help": "The weight factor of the undesirable losses in KTO training."}, + ) + simpo_gamma: float = field( + default=0.5, + metadata={"help": "The target reward margin term in SimPO loss."}, + ) + ppo_buffer_size: int = field( + default=1, + metadata={"help": "The number of mini-batches to make experience buffer in a PPO optimization step."}, + ) + ppo_epochs: int = field( + default=4, + metadata={"help": "The number of epochs to perform in a PPO optimization step."}, + ) + ppo_score_norm: bool = field( + default=False, + metadata={"help": "Use score normalization in PPO training."}, + ) + ppo_target: float = field( + default=6.0, + metadata={"help": "Target KL value for adaptive KL control in PPO training."}, + ) + ppo_whiten_rewards: bool = field( + default=False, + metadata={"help": "Whiten the rewards before compute advantages in PPO training."}, + ) + ref_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the reference model used for the PPO or DPO training."}, + ) + ref_model_adapters: Optional[str] = field( + default=None, + metadata={"help": "Path to the adapters of the reference model."}, + ) + ref_model_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the reference model."}, + ) + reward_model: Optional[str] = field( + default=None, + metadata={"help": "Path to the reward model used for the PPO training."}, + ) + reward_model_adapters: Optional[str] = field( + default=None, + metadata={"help": "Path to the adapters of the reward model."}, + ) + reward_model_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the reward model."}, + ) + reward_model_type: Literal["lora", "full", "api"] = field( + default="lora", + metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."}, + ) + + +@dataclass +class GaloreArguments: + r""" + Arguments pertaining to the GaLore algorithm. + """ + + use_galore: bool = field( + default=False, + metadata={"help": "Whether or not to use the gradient low-Rank projection (GaLore)."}, + ) + galore_target: str = field( + default="all", + metadata={ + "help": ( + "Name(s) of modules to apply GaLore. Use commas to separate multiple modules. " + "Use `all` to specify all the linear modules." + ) + }, + ) + galore_rank: int = field( + default=16, + metadata={"help": "The rank of GaLore gradients."}, + ) + galore_update_interval: int = field( + default=200, + metadata={"help": "Number of steps to update the GaLore projection."}, + ) + galore_scale: float = field( + default=0.25, + metadata={"help": "GaLore scaling coefficient."}, + ) + galore_proj_type: Literal["std", "reverse_std", "right", "left", "full"] = field( + default="std", + metadata={"help": "Type of GaLore projection."}, + ) + galore_layerwise: bool = field( + default=False, + metadata={"help": "Whether or not to enable layer-wise update to further save memory."}, + ) + + +@dataclass +class BAdamArgument: + r""" + Arguments pertaining to the BAdam optimizer. + """ + + use_badam: bool = field( + default=False, + metadata={"help": "Whether or not to use the BAdam optimizer."}, + ) + badam_mode: Literal["layer", "ratio"] = field( + default="layer", + metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."}, + ) + badam_start_block: Optional[int] = field( + default=None, + metadata={"help": "The starting block index for layer-wise BAdam."}, + ) + badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field( + default="ascending", + metadata={"help": "the strategy of picking block to update for layer-wise BAdam."}, + ) + badam_switch_interval: Optional[int] = field( + default=50, + metadata={ + "help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update." + }, + ) + badam_update_ratio: float = field( + default=0.05, + metadata={"help": "The ratio of the update for ratio-wise BAdam."}, + ) + badam_mask_mode: Literal["adjacent", "scatter"] = field( + default="adjacent", + metadata={ + "help": ( + "The mode of the mask for BAdam optimizer. " + "`adjacent` means that the trainable parameters are adjacent to each other, " + "`scatter` means that trainable parameters are randomly choosed from the weight." + ) + }, + ) + badam_verbose: int = field( + default=0, + metadata={ + "help": ( + "The verbosity level of BAdam optimizer. " + "0 for no print, 1 for print the block prefix, 2 for print trainable parameters." + ) + }, + ) + + +@dataclass +class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreArguments, BAdamArgument): + r""" + Arguments pertaining to which techniques we are going to fine-tuning with. + """ + + loss_weight_lm: int = field( + default=1, + metadata={"help": "Loss weight for the language next token prediction."}, + ) + loss_weight_design: int = field( + default=1, + metadata={"help": "Loss weight for molecular design."}, + ) + loss_weight_retro: int = field( + default=1, + metadata={"help": "Loss weight for retrosynthesis."}, + ) + + pure_bf16: bool = field( + default=False, + metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, + ) + stage: Literal["pt", "sft", "rm", "ppo", "dpo", "kto"] = field( + default="sft", + metadata={"help": "Which stage will be performed in training."}, + ) + finetuning_type: Literal["lora", "freeze", "full"] = field( + default="lora", + metadata={"help": "Which fine-tuning method to use."}, + ) + use_llama_pro: bool = field( + default=False, + metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."}, + ) + freeze_vision_tower: bool = field( + default=True, + metadata={"help": "Whether ot not to freeze vision tower in MLLM training."}, + ) + train_mm_proj_only: bool = field( + default=False, + metadata={"help": "Whether or not to train the multimodal projector for MLLM only."}, + ) + plot_loss: bool = field( + default=False, + metadata={"help": "Whether or not to save the training loss curves."}, + ) + + def __post_init__(self): + def split_arg(arg): + if isinstance(arg, str): + return [item.strip() for item in arg.split(",")] + return arg + + self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules) + self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules) + self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2 + self.lora_target: List[str] = split_arg(self.lora_target) + self.additional_target: Optional[List[str]] = split_arg(self.additional_target) + self.galore_target: List[str] = split_arg(self.galore_target) + self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only + self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] + + assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." + assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization." + + if self.stage == "ppo" and self.reward_model is None: + raise ValueError("`reward_model` is necessary for PPO training.") + + if self.stage == "ppo" and self.reward_model_type == "lora" and self.finetuning_type != "lora": + raise ValueError("`reward_model_type` cannot be lora for Freeze/Full PPO training.") + + if self.stage == "dpo" and self.pref_loss != "sigmoid" and self.dpo_label_smoothing > 1e-6: + raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.") + + if self.use_llama_pro and self.finetuning_type == "full": + raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.") + + if self.finetuning_type == "lora" and (self.use_galore or self.use_badam): + raise ValueError("Cannot use LoRA with GaLore or BAdam together.") + + if self.use_galore and self.use_badam: + raise ValueError("Cannot use GaLore with BAdam together.") + + if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora": + raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") + + if self.pissa_init and self.finetuning_type != "lora": + raise ValueError("`pissa_init` is only valid for LoRA training.") + + if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model): + raise ValueError("Cannot use PiSSA for current training stage.") + + if self.train_mm_proj_only and self.finetuning_type != "full": + raise ValueError("`train_mm_proj_only` is only valid for full training.") diff --git a/src/hparams/generating_args.py b/src/hparams/generating_args.py new file mode 100644 index 0000000000000000000000000000000000000000..7ebb4eed980e20f44ffe084e26e91a4def91c513 --- /dev/null +++ b/src/hparams/generating_args.py @@ -0,0 +1,74 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict, dataclass, field +from typing import Any, Dict, Optional + + +@dataclass +class GeneratingArguments: + r""" + Arguments pertaining to specify the decoding parameters. + """ + + do_sample: bool = field( + default=True, + metadata={"help": "Whether or not to use sampling, use greedy decoding otherwise."}, + ) + temperature: float = field( + default=0.95, + metadata={"help": "The value used to modulate the next token probabilities."}, + ) + top_p: float = field( + default=0.7, + metadata={ + "help": "The smallest set of most probable tokens with probabilities that add up to top_p or higher are kept." + }, + ) + top_k: int = field( + default=50, + metadata={"help": "The number of highest probability vocabulary tokens to keep for top-k filtering."}, + ) + num_beams: int = field( + default=1, + metadata={"help": "Number of beams for beam search. 1 means no beam search."}, + ) + max_length: int = field( + default=1024, + metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}, + ) + max_new_tokens: int = field( + default=1024, + metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}, + ) + repetition_penalty: float = field( + default=1.0, + metadata={"help": "The parameter for repetition penalty. 1.0 means no penalty."}, + ) + length_penalty: float = field( + default=1.0, + metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, + ) + default_system: Optional[str] = field( + default=None, + metadata={"help": "Default system message to use in chat completion."}, + ) + + def to_dict(self) -> Dict[str, Any]: + args = asdict(self) + if args.get("max_new_tokens", -1) > 0: + args.pop("max_length", None) + else: + args.pop("max_new_tokens", None) + return args diff --git a/src/hparams/model_args.py b/src/hparams/model_args.py new file mode 100644 index 0000000000000000000000000000000000000000..0bd0e246f11451a9f8ab8c1f0669eb994fa23816 --- /dev/null +++ b/src/hparams/model_args.py @@ -0,0 +1,284 @@ +# Copyright 2024 HuggingFace Inc., the LlamaFactory team, and the Llamole team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import asdict, dataclass, field +from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union + +from typing_extensions import Self + + +if TYPE_CHECKING: + import torch + + +@dataclass +class ModelArguments: + r""" + Arguments pertaining to which model/config/tokenizer we are going to fine-tune or infer. + """ + + ## custom arguments + graph_decoder_path: str = field( + metadata={ + "help": "Path to the model weight for graph decoder model " + }, + ) + graph_encoder_path: str = field( + metadata={ + "help": "Path to the model weight for graph encoder model " + }, + ) + graph_predictor_path: str = field( + metadata={ + "help": "Path to the model weight for graph predictor model " + }, + ) + graph_lm_connector_path: str = field( + metadata={ + "help": "Path to the model weight for graph and language model connector " + }, + ) + + model_name_or_path: str = field( + metadata={ + "help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models." + }, + ) + adapter_name_or_path: Optional[str] = field( + default=None, + metadata={ + "help": ( + "Path to the adapter weight or identifier from huggingface.co/models. " + "Use commas to separate multiple adapters." + ) + }, + ) + adapter_folder: Optional[str] = field( + default=None, + metadata={"help": "The folder containing the adapter weights to load."}, + ) + cache_dir: Optional[str] = field( + default=None, + metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."}, + ) + use_fast_tokenizer: bool = field( + default=True, + metadata={"help": "Whether or not to use one of the fast tokenizer (backed by the tokenizers library)."}, + ) + + disable_graph_model_gradient: bool = field( + default=True, + metadata={"help": "Whether or not to disable the training of graph models"}, + ) + + resize_vocab: bool = field( + default=False, + metadata={"help": "Whether or not to resize the tokenizer vocab and the embedding layers."}, + ) + split_special_tokens: bool = field( + default=False, + metadata={"help": "Whether or not the special tokens should be split during the tokenization process."}, + ) + new_special_tokens: Optional[str] = field( + default=None, + metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."}, + ) + model_revision: str = field( + default="main", + metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}, + ) + low_cpu_mem_usage: bool = field( + default=True, + metadata={"help": "Whether or not to use memory-efficient model loading."}, + ) + quantization_method: Literal["bitsandbytes", "hqq", "eetq"] = field( + default="bitsandbytes", + metadata={"help": "Quantization method to use for on-the-fly quantization."}, + ) + quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the model using bitsandbytes."}, + ) + quantization_type: Literal["fp4", "nf4"] = field( + default="nf4", + metadata={"help": "Quantization data type to use in int4 training."}, + ) + double_quantization: bool = field( + default=True, + metadata={"help": "Whether or not to use double quantization in int4 training."}, + ) + quantization_device_map: Optional[Literal["auto"]] = field( + default=None, + metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."}, + ) + rope_scaling: Optional[Literal["linear", "dynamic"]] = field( + default=None, + metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, + ) + flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field( + default="auto", + metadata={"help": "Enable FlashAttention for faster training and inference."}, + ) + shift_attn: bool = field( + default=False, + metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."}, + ) + mixture_of_depths: Optional[Literal["convert", "load"]] = field( + default=None, + metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."}, + ) + use_unsloth: bool = field( + default=False, + metadata={"help": "Whether or not to use unsloth's optimization for the LoRA training."}, + ) + visual_inputs: bool = field( + default=False, + metadata={"help": "Whethor or not to use multimodal LLM that accepts visual inputs."}, + ) + moe_aux_loss_coef: Optional[float] = field( + default=None, + metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."}, + ) + disable_gradient_checkpointing: bool = field( + default=False, + metadata={"help": "Whether or not to disable gradient checkpointing."}, + ) + upcast_layernorm: bool = field( + default=False, + metadata={"help": "Whether or not to upcast the layernorm weights in fp32."}, + ) + upcast_lmhead_output: bool = field( + default=False, + metadata={"help": "Whether or not to upcast the output of lm_head in fp32."}, + ) + train_from_scratch: bool = field( + default=False, + metadata={"help": "Whether or not to randomly initialize the model weights."}, + ) + infer_backend: Literal["huggingface", "vllm"] = field( + default="huggingface", + metadata={"help": "Backend engine used at inference."}, + ) + vllm_maxlen: int = field( + default=2048, + metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."}, + ) + vllm_gpu_util: float = field( + default=0.9, + metadata={"help": "The fraction of GPU memory in (0,1) to be used for the vLLM engine."}, + ) + vllm_enforce_eager: bool = field( + default=False, + metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."}, + ) + vllm_max_lora_rank: int = field( + default=32, + metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."}, + ) + offload_folder: str = field( + default="offload", + metadata={"help": "Path to offload model weights."}, + ) + use_cache: bool = field( + default=True, + metadata={"help": "Whether or not to use KV cache in generation."}, + ) + infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field( + default="auto", + metadata={"help": "Data type for model weights and activations at inference."}, + ) + hf_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with Hugging Face Hub."}, + ) + ms_hub_token: Optional[str] = field( + default=None, + metadata={"help": "Auth token to log in with ModelScope Hub."}, + ) + export_dir: Optional[str] = field( + default=None, + metadata={"help": "Path to the directory to save the exported model."}, + ) + export_size: int = field( + default=1, + metadata={"help": "The file shard size (in GB) of the exported model."}, + ) + export_device: Literal["cpu", "auto"] = field( + default="cpu", + metadata={"help": "The device used in model export, use `auto` to accelerate exporting."}, + ) + export_quantization_bit: Optional[int] = field( + default=None, + metadata={"help": "The number of bits to quantize the exported model."}, + ) + export_quantization_dataset: Optional[str] = field( + default=None, + metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."}, + ) + export_quantization_nsamples: int = field( + default=128, + metadata={"help": "The number of samples used for quantization."}, + ) + export_quantization_maxlen: int = field( + default=1024, + metadata={"help": "The maximum length of the model inputs used for quantization."}, + ) + export_legacy_format: bool = field( + default=False, + metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."}, + ) + export_hub_model_id: Optional[str] = field( + default=None, + metadata={"help": "The name of the repository if push the model to the Hugging Face hub."}, + ) + print_param_status: bool = field( + default=False, + metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, + ) + + def __post_init__(self): + self.compute_dtype: Optional["torch.dtype"] = None + self.device_map: Optional[Union[str, Dict[str, Any]]] = None + self.model_max_length: Optional[int] = None + + if self.split_special_tokens and self.use_fast_tokenizer: + raise ValueError("`split_special_tokens` is only supported for slow tokenizers.") + + if self.visual_inputs and self.use_unsloth: + raise ValueError("Unsloth does not support MLLM yet. Stay tuned.") + + if self.adapter_name_or_path is not None: # support merging multiple lora weights + self.adapter_name_or_path = [path.strip() for path in self.adapter_name_or_path.split(",")] + + if self.new_special_tokens is not None: # support multiple special tokens + self.new_special_tokens = [token.strip() for token in self.new_special_tokens.split(",")] + + if self.export_quantization_bit is not None and self.export_quantization_dataset is None: + raise ValueError("Quantization dataset is necessary for exporting.") + + def to_dict(self) -> Dict[str, Any]: + return asdict(self) + + @classmethod + def copyfrom(cls, old_arg: Self, **kwargs) -> Self: + arg_dict = old_arg.to_dict() + arg_dict.update(**kwargs) + new_arg = cls(**arg_dict) + new_arg.compute_dtype = old_arg.compute_dtype + new_arg.device_map = old_arg.device_map + new_arg.model_max_length = old_arg.model_max_length + return new_arg diff --git a/src/hparams/parser.py b/src/hparams/parser.py new file mode 100644 index 0000000000000000000000000000000000000000..b2167f356f56420572278fa0145deb9d21de4af6 --- /dev/null +++ b/src/hparams/parser.py @@ -0,0 +1,355 @@ +# Copyright 2024 HuggingFace Inc., the LlamaFactory team, and the Llamole team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import logging +import os +import sys +from typing import Any, Dict, Optional, Tuple + +import torch +import transformers +from transformers import HfArgumentParser, Seq2SeqTrainingArguments +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.trainer_utils import get_last_checkpoint +from transformers.training_args import ParallelMode +from transformers.utils import is_torch_bf16_gpu_available +from transformers.utils.versions import require_version + +from ..extras.constants import CHECKPOINT_NAMES +from ..extras.logging import get_logger +from ..extras.misc import check_dependencies, get_current_device +from .data_args import DataArguments +from .evaluation_args import EvaluationArguments +from .finetuning_args import FinetuningArguments +from .generating_args import GeneratingArguments +from .model_args import ModelArguments + + +logger = get_logger(__name__) + + +check_dependencies() + + +_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] +_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments] +_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] +_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments] + + +def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: + if args is not None: + return parser.parse_dict(args) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".yaml"): + return parser.parse_yaml_file(os.path.abspath(sys.argv[1])) + + if len(sys.argv) == 2 and sys.argv[1].endswith(".json"): + return parser.parse_json_file(os.path.abspath(sys.argv[1])) + + (*parsed_args, unknown_args) = parser.parse_args_into_dataclasses(return_remaining_strings=True) + + if unknown_args: + print(parser.format_help()) + print("Got unknown args, potentially deprecated arguments: {}".format(unknown_args)) + raise ValueError("Some specified arguments are not used by the HfArgumentParser: {}".format(unknown_args)) + + return (*parsed_args,) + + +def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None: + transformers.utils.logging.set_verbosity(log_level) + transformers.utils.logging.enable_default_handler() + transformers.utils.logging.enable_explicit_format() + + +def _verify_model_args(model_args: "ModelArguments", finetuning_args: "FinetuningArguments") -> None: + if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Adapter is only valid for the LoRA method.") + + if model_args.quantization_bit is not None: + if finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") + + if finetuning_args.pissa_init: + raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.") + + if model_args.resize_vocab: + raise ValueError("Cannot resize embedding layers of a quantized model.") + + if model_args.adapter_name_or_path is not None and finetuning_args.create_new_adapter: + raise ValueError("Cannot create new adapter upon a quantized model.") + + if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: + raise ValueError("Quantized model only accepts a single adapter. Merge them first.") + + +def _check_extra_dependencies( + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + training_args: Optional["Seq2SeqTrainingArguments"] = None, +) -> None: + if model_args.use_unsloth: + require_version("unsloth", "Please install unsloth: https://github.com/unslothai/unsloth") + + if model_args.mixture_of_depths is not None: + require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6") + + if model_args.infer_backend == "vllm": + require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3") + + if finetuning_args.use_galore: + require_version("galore_torch", "To fix: pip install galore_torch") + + if finetuning_args.use_badam: + require_version("badam>=1.2.1", "To fix: pip install badam>=1.2.1") + + if finetuning_args.plot_loss: + require_version("matplotlib", "To fix: pip install matplotlib") + + if training_args is not None and training_args.predict_with_generate: + require_version("jieba", "To fix: pip install jieba") + require_version("nltk", "To fix: pip install nltk") + require_version("rouge_chinese", "To fix: pip install rouge-chinese") + +def _parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + parser = HfArgumentParser(_TRAIN_ARGS) + return _parse_args(parser, args) + +def _parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: + parser = HfArgumentParser(_INFER_ARGS) + return _parse_args(parser, args) + +def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_args(args) + + # Setup logging + if training_args.should_log: + _set_transformers_logging() + + # Check arguments + if finetuning_args.stage != "pt" and data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if finetuning_args.stage != "sft" and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True except SFT.") + + if finetuning_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate: + raise ValueError("Please enable `predict_with_generate` to save model predictions.") + + if finetuning_args.stage in ["rm", "ppo"] and training_args.load_best_model_at_end: + raise ValueError("RM and PPO stages do not support `load_best_model_at_end`.") + + if finetuning_args.stage == "ppo" and not training_args.do_train: + raise ValueError("PPO training does not support evaluation, use the SFT stage to evaluate models.") + + if finetuning_args.stage == "ppo" and model_args.shift_attn: + raise ValueError("PPO training is incompatible with S^2-Attn.") + + if finetuning_args.stage == "ppo" and finetuning_args.reward_model_type == "lora" and model_args.use_unsloth: + raise ValueError("Unsloth does not support lora reward model.") + + if ( + finetuning_args.stage == "ppo" + and training_args.report_to + and training_args.report_to[0] not in ["wandb", "tensorboard"] + ): + raise ValueError("PPO only accepts wandb or tensorboard logger.") + + if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED: + raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.") + + if training_args.deepspeed and training_args.parallel_mode != ParallelMode.DISTRIBUTED: + raise ValueError("Please use `FORCE_TORCHRUN=1` to launch DeepSpeed training.") + + # if training_args.max_steps == -1 and data_args.streaming: + # raise ValueError("Please specify `max_steps` in streaming mode.") + + if training_args.do_train and training_args.predict_with_generate: + raise ValueError("`predict_with_generate` cannot be set as True while training.") + + if training_args.do_train and model_args.quantization_device_map == "auto": + raise ValueError("Cannot use device map for quantized models in training.") + + if finetuning_args.pissa_init and is_deepspeed_zero3_enabled(): + raise ValueError("PiSSA is incompatible with DeepSpeed ZeRO-3.") + + if finetuning_args.pure_bf16: + if not is_torch_bf16_gpu_available(): + raise ValueError("This device does not support `pure_bf16`.") + + if is_deepspeed_zero3_enabled(): + raise ValueError("`pure_bf16` is incompatible with DeepSpeed ZeRO-3.") + + if ( + finetuning_args.use_galore + and finetuning_args.galore_layerwise + and training_args.parallel_mode == ParallelMode.DISTRIBUTED + ): + raise ValueError("Distributed training does not support layer-wise GaLore.") + + if finetuning_args.use_badam and training_args.parallel_mode == ParallelMode.DISTRIBUTED: + if finetuning_args.badam_mode == "ratio": + raise ValueError("Radio-based BAdam does not yet support distributed training, use layer-wise BAdam.") + elif not is_deepspeed_zero3_enabled(): + raise ValueError("Layer-wise BAdam only supports DeepSpeed ZeRO-3 training.") + + if finetuning_args.use_galore and training_args.deepspeed is not None: + raise ValueError("GaLore is incompatible with DeepSpeed yet.") + + if model_args.infer_backend == "vllm": + raise ValueError("vLLM backend is only available for API, CLI and Web.") + + if model_args.visual_inputs and data_args.packing: + raise ValueError("Cannot use packing in MLLM fine-tuning.") + + if model_args.use_unsloth and is_deepspeed_zero3_enabled(): + raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") + + _verify_model_args(model_args, finetuning_args) + _check_extra_dependencies(model_args, finetuning_args, training_args) + + if ( + training_args.do_train + and finetuning_args.finetuning_type == "lora" + and model_args.quantization_bit is None + and model_args.resize_vocab + and finetuning_args.additional_target is None + ): + logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.") + + if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): + logger.warning("We recommend enable `upcast_layernorm` in quantized training.") + + if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): + logger.warning("We recommend enable mixed precision training.") + + if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16: + logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.") + + if (not training_args.do_train) and model_args.quantization_bit is not None: + logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") + + if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None: + logger.warning("Specify `ref_model` for computing rewards at evaluation.") + + # Post-process training arguments + if ( + training_args.parallel_mode == ParallelMode.DISTRIBUTED + and training_args.ddp_find_unused_parameters is None + and finetuning_args.finetuning_type == "lora" + ): + logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") + training_args.ddp_find_unused_parameters = False + + if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: + can_resume_from_checkpoint = False + if training_args.resume_from_checkpoint is not None: + logger.warning("Cannot resume from checkpoint in current stage.") + training_args.resume_from_checkpoint = None + else: + can_resume_from_checkpoint = True + + if ( + training_args.resume_from_checkpoint is None + and training_args.do_train + and os.path.isdir(training_args.output_dir) + and not training_args.overwrite_output_dir + and can_resume_from_checkpoint + ): + last_checkpoint = get_last_checkpoint(training_args.output_dir) + if last_checkpoint is None and any( + os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES + ): + raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.") + + if last_checkpoint is not None: + training_args.resume_from_checkpoint = last_checkpoint + logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint)) + logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.") + + if ( + finetuning_args.stage in ["rm", "ppo"] + and finetuning_args.finetuning_type == "lora" + and training_args.resume_from_checkpoint is not None + ): + logger.warning( + "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( + training_args.resume_from_checkpoint + ) + ) + + # Post-process model arguments + if training_args.bf16 or finetuning_args.pure_bf16: + model_args.compute_dtype = torch.bfloat16 + elif training_args.fp16: + model_args.compute_dtype = torch.float16 + + model_args.device_map = {"": get_current_device()} + model_args.model_max_length = data_args.cutoff_len + data_args.packing = data_args.packing if data_args.packing is not None else finetuning_args.stage == "pt" + + # Log on each process the small summary + logger.info( + "Process rank: {}, device: {}, n_gpu: {}, distributed training: {}, compute dtype: {}".format( + training_args.local_rank, + training_args.device, + training_args.n_gpu, + training_args.parallel_mode == ParallelMode.DISTRIBUTED, + str(model_args.compute_dtype), + ) + ) + + transformers.set_seed(training_args.seed) + + return model_args, data_args, training_args, finetuning_args, generating_args + + +def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: + model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args) + + _set_transformers_logging() + + if data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if model_args.infer_backend == "vllm": + if finetuning_args.stage != "sft": + raise ValueError("vLLM engine only supports auto-regressive models.") + + if model_args.quantization_bit is not None: + raise ValueError("vLLM engine does not support bnb quantization (GPTQ and AWQ are supported).") + + if model_args.rope_scaling is not None: + raise ValueError("vLLM engine does not support RoPE scaling.") + + if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1: + raise ValueError("vLLM only accepts a single adapter. Merge them first.") + + if finetuning_args.stage == "rm" and model_args.visual_inputs: + raise ValueError("Reward server does not support MLLM yet. Stay tuned.") + + _verify_model_args(model_args, finetuning_args) + _check_extra_dependencies(model_args, finetuning_args) + + if model_args.export_dir is not None and model_args.export_device == "cpu": + model_args.device_map = {"": torch.device("cpu")} + model_args.model_max_length = data_args.cutoff_len + else: + model_args.device_map = "auto" + + return model_args, data_args, finetuning_args, generating_args \ No newline at end of file diff --git a/src/model/__init__.py b/src/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..9ba761eb921dfe912b9f6df154379ce7b3775a93 --- /dev/null +++ b/src/model/__init__.py @@ -0,0 +1,34 @@ +# Copyright 2024 the LlamaFactory team and the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .loader import load_config, load_tokenizer, load_language_model +from .loader import load_graph_decoder, load_graph_encoder, load_graph_predictor +from .model_utils.misc import find_all_linear_modules +from .model_utils.quantization import QuantizationMethod +from .model_utils.valuehead import load_valuehead_params + +from .modeling_llamole import GraphLLMForCausalMLM + +__all__ = [ + "QuantizationMethod", + "load_config", + "load_language_model", + "load_graph_decoder", + "load_graph_encoder", + "load_graph_predictor", + "load_tokenizer", + "find_all_linear_modules", + "load_valuehead_params", + "GraphLLMForCausalMLM", +] diff --git a/src/model/adapter.py b/src/model/adapter.py new file mode 100644 index 0000000000000000000000000000000000000000..e06862a4de7d0ece82005df7d45dd0b70cffed57 --- /dev/null +++ b/src/model/adapter.py @@ -0,0 +1,316 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +from typing import TYPE_CHECKING + +import torch +from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled + +from ..extras.logging import get_logger +from .model_utils.misc import find_all_linear_modules, find_expanded_modules +from .model_utils.quantization import QuantizationMethod +from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ..hparams import FinetuningArguments, ModelArguments + + +logger = get_logger(__name__) + + +def _setup_full_tuning( + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, + cast_trainable_params_to_fp32: bool, +) -> None: + if not is_trainable: + return + + logger.info("Fine-tuning method: Full") + forbidden_modules = set() + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + forbidden_modules.add("vision_tower") + + if model_args.visual_inputs and finetuning_args.train_mm_proj_only: + forbidden_modules.add("language_model") + + for name, param in model.named_parameters(): + if not any(forbidden_module in name for forbidden_module in forbidden_modules): + if cast_trainable_params_to_fp32: + param.data = param.data.to(torch.float32) + else: + param.requires_grad_(False) + + +def _setup_freeze_tuning( + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, + cast_trainable_params_to_fp32: bool, +) -> None: + if not is_trainable: + return + + logger.info("Fine-tuning method: Freeze") + if model_args.visual_inputs: + config = model.config.text_config + else: + config = model.config + + num_layers = ( + getattr(config, "num_hidden_layers", None) + or getattr(config, "num_layers", None) + or getattr(config, "n_layer", None) + ) + if not num_layers: + raise ValueError("Current model does not support freeze tuning.") + + if finetuning_args.use_llama_pro: + if num_layers % finetuning_args.freeze_trainable_layers != 0: + raise ValueError( + "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format( + num_layers, finetuning_args.freeze_trainable_layers + ) + ) + + stride = num_layers // finetuning_args.freeze_trainable_layers + trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) + elif finetuning_args.freeze_trainable_layers > 0: # fine-tuning the last n layers if num_layer_trainable > 0 + trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers) + else: # fine-tuning the first n layers if num_layer_trainable < 0 + trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers)) + + hidden_modules = set() + non_hidden_modules = set() + for name, _ in model.named_parameters(): + if ".0." in name: + hidden_modules.add(name.split(".0.")[-1].split(".")[0]) + elif ".1." in name: # MoD starts from layer 1 + hidden_modules.add(name.split(".1.")[-1].split(".")[0]) + + if re.search(r"\.\d+\.", name) is None: + non_hidden_modules.add(name.split(".")[-2]) + + trainable_layers = [] + for module_name in finetuning_args.freeze_trainable_modules: + if module_name != "all" and module_name not in hidden_modules: + raise ValueError( + "Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules)) + ) + + for idx in trainable_layer_ids: + trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else "")) + + if finetuning_args.freeze_extra_modules: + for module_name in finetuning_args.freeze_extra_modules: + if module_name not in non_hidden_modules: + raise ValueError( + "Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules)) + ) + + trainable_layers.append(module_name) + + forbidden_modules = set() + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + forbidden_modules.add("vision_tower") + + for name, param in model.named_parameters(): + if any(trainable_layer in name for trainable_layer in trainable_layers) and not any( + forbidden_module in name for forbidden_module in forbidden_modules + ): + if cast_trainable_params_to_fp32: + param.data = param.data.to(torch.float32) + else: + param.requires_grad_(False) + + logger.info("Set trainable layers: {}".format(",".join(trainable_layers))) + + +def _setup_lora_tuning( + config: "PretrainedConfig", + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, + cast_trainable_params_to_fp32: bool, +) -> "PeftModel": + if is_trainable: + logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) + + adapter_to_resume = None + + if model_args.adapter_name_or_path is not None: + is_mergeable = True + if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable + assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter." + is_mergeable = False + + if is_deepspeed_zero3_enabled(): + assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." + is_mergeable = False + + if model_args.use_unsloth: + assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter." + is_mergeable = False + + if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): + adapter_to_merge = model_args.adapter_name_or_path[:-1] + adapter_to_resume = model_args.adapter_name_or_path[-1] + else: + adapter_to_merge = model_args.adapter_name_or_path + + init_kwargs = { + "subfolder": model_args.adapter_folder, + "offload_folder": model_args.offload_folder, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.hf_hub_token, + } + + for adapter in adapter_to_merge: + model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs) + model = model.merge_and_unload() + + if len(adapter_to_merge) > 0: + logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) + + if adapter_to_resume is not None: # resume lora training + if model_args.use_unsloth: + model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) + else: + model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) + + logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) + + if is_trainable and adapter_to_resume is None: # create new lora weights while training + if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": + target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower) + else: + target_modules = finetuning_args.lora_target + + if finetuning_args.use_llama_pro: + target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) + + if model_args.visual_inputs and finetuning_args.freeze_vision_tower: + target_modules = "^(?!.*vision_tower).*(?:{}).*".format("|".join(target_modules)) + + if ( + finetuning_args.use_dora + and getattr(model, "quantization_method", None) is not None + and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES + ): + raise ValueError("DoRA is not compatible with PTQ-quantized models.") + + if model_args.resize_vocab and finetuning_args.additional_target is None: + input_embeddings = model.get_input_embeddings() + output_embeddings = model.get_output_embeddings() + module_names = set() + for name, module in model.named_modules(): + if module in [input_embeddings, output_embeddings]: + module_names.add(name.split(".")[-1]) + + finetuning_args.additional_target = module_names + logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) + + peft_kwargs = { + "r": finetuning_args.lora_rank, + "target_modules": target_modules, + "lora_alpha": finetuning_args.lora_alpha, + "lora_dropout": finetuning_args.lora_dropout, + "use_rslora": finetuning_args.use_rslora, + "use_dora": finetuning_args.use_dora, + "modules_to_save": finetuning_args.additional_target, + } + + if model_args.use_unsloth: + model = get_unsloth_peft_model(model, model_args, peft_kwargs) + else: + if finetuning_args.pissa_init: + if finetuning_args.pissa_iter == -1: + logger.info("Using PiSSA initialization.") + peft_kwargs["init_lora_weights"] = "pissa" + else: + logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter)) + peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter) + + lora_config = LoraConfig( + task_type=TaskType.CAUSAL_LM, + inference_mode=False, + **peft_kwargs, + ) + model = get_peft_model(model, lora_config) + + if is_trainable and cast_trainable_params_to_fp32: + for param in filter(lambda p: p.requires_grad, model.parameters()): + param.data = param.data.to(torch.float32) + + return model + + +def init_adapter( + config: "PretrainedConfig", + model: "PreTrainedModel", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool, +) -> "PreTrainedModel": + r""" + Initializes the adapters. + + Support full-parameter, freeze and LoRA training. + + Note that the trainable parameters must be cast to float32. + """ + if is_trainable and getattr(model, "quantization_method", None) is not None: + if finetuning_args.finetuning_type != "lora": + raise ValueError("Quantized models can only be used for the LoRA tuning.") + + if finetuning_args.pissa_init: + raise ValueError("Cannot initialize PiSSA adapter on quantized models.") + + # cast trainable parameters to float32 if: + # 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora) + # 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32) + cast_trainable_params_to_fp32 = False + if not is_trainable: + pass + elif finetuning_args.pure_bf16 or finetuning_args.use_badam: + logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.") + elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): + logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.") + else: + logger.info("Upcasting trainable params to float32.") + cast_trainable_params_to_fp32 = True + + if finetuning_args.finetuning_type == "full": + _setup_full_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) + elif finetuning_args.finetuning_type == "freeze": + _setup_freeze_tuning(model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32) + elif finetuning_args.finetuning_type == "lora": + model = _setup_lora_tuning( + config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 + ) + else: + raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type)) + + return model diff --git a/src/model/graph_decoder/conditions.py b/src/model/graph_decoder/conditions.py new file mode 100644 index 0000000000000000000000000000000000000000..418693d17e8e8482591e5af1409567e3c0bf1eb4 --- /dev/null +++ b/src/model/graph_decoder/conditions.py @@ -0,0 +1,123 @@ +# Copyright 2024 the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import math + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + def __init__(self, hidden_size, frequency_embedding_size=256): + super().__init__() + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, hidden_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + @staticmethod + def timestep_embedding(t, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + :param t: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an (N, D) Tensor of positional embeddings. + """ + # https://github.com/openai/glide-text2im/blob/main/glide_text2im/nn.py + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=t.device) + args = t[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + def forward(self, t): + t = t.view(-1) + t_freq = self.timestep_embedding(t, self.frequency_embedding_size) + t_freq = t_freq.to(dtype=next(self.mlp.parameters()).dtype) + t_emb = self.mlp(t_freq) + return t_emb + +class ConditionEmbedder(nn.Module): + def __init__(self, input_size, hidden_size, dropout_prob): + super().__init__() + self.embedding_drop = nn.Embedding(input_size, hidden_size) + + self.mlps = nn.ModuleList([ + nn.Sequential( + nn.Linear(1, hidden_size, bias=True), + nn.Softmax(dim=1), + nn.Linear(hidden_size, hidden_size, bias=False) + ) for _ in range(input_size) + ]) + + self.hidden_size = hidden_size + self.dropout_prob = dropout_prob + + def forward(self, labels, train, unconditioned): + embeddings = 0 + for dim in range(labels.shape[1]): + label = labels[:, dim] + if unconditioned: + drop_ids = torch.ones_like(label).bool() + else: + drop_ids = torch.isnan(label) + if train: + random_tensor = torch.rand(label.shape).type_as(labels) + probability_mask = random_tensor < self.dropout_prob + drop_ids = drop_ids | probability_mask + + label = label.unsqueeze(1) + embedding = torch.zeros((label.shape[0], self.hidden_size)).type_as(labels) + mlp_out = self.mlps[dim](label[~drop_ids]) + embedding[~drop_ids] = mlp_out.type_as(embedding) + embedding[drop_ids] += self.embedding_drop.weight[dim] + if train: + embedding = embedding + torch.randn_like(embedding) + embeddings += embedding + + return embeddings + +class TextEmbedder(nn.Module): + def __init__(self, input_size, hidden_size, dropout_prob): + super().__init__() + self.embedding_drop = nn.Embedding(1, hidden_size) + self.linear = nn.Linear(input_size, hidden_size) + self.dropout_prob = dropout_prob + self.hidden_size = hidden_size + + def forward(self, text_emb, train, unconditioned): + if unconditioned: + drop_ids = torch.ones(text_emb.shape[0]).bool().to(text_emb.device) + else: + drop_ids = torch.isnan(text_emb.sum(dim=1)) + if train: + random_tensor = torch.rand(text_emb.shape[0]).type_as(text_emb) + probability_mask = random_tensor < self.dropout_prob + drop_ids = drop_ids | probability_mask + + embeddings = torch.zeros((text_emb.shape[0], self.hidden_size)).type_as(text_emb) + linear_out = self.linear(text_emb[~drop_ids]) + embeddings[~drop_ids] = linear_out.type_as(embeddings) + embeddings[drop_ids] += self.embedding_drop.weight[0] + + return embeddings \ No newline at end of file diff --git a/src/model/graph_decoder/diffusion_model.py b/src/model/graph_decoder/diffusion_model.py new file mode 100644 index 0000000000000000000000000000000000000000..35da082a8125c16d3eaf0cca5209ae8c07af28a5 --- /dev/null +++ b/src/model/graph_decoder/diffusion_model.py @@ -0,0 +1,438 @@ +# Copyright 2024 the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import yaml +import json + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from . import diffusion_utils as utils +from .molecule_utils import graph_to_smiles, check_valid +from .transformer import Transformer + +class GraphDiT(nn.Module): + def __init__( + self, + model_config_path, + data_info_path, + model_dtype, + ): + super().__init__() + + dm_cfg, data_info = utils.load_config(model_config_path, data_info_path) + + input_dims = data_info.input_dims + output_dims = data_info.output_dims + nodes_dist = data_info.nodes_dist + active_index = data_info.active_index + + self.model_config = dm_cfg + self.data_info = data_info + self.T = dm_cfg.diffusion_steps + self.guide_scale = dm_cfg.guide_scale + self.Xdim = input_dims["X"] + self.Edim = input_dims["E"] + self.ydim = input_dims["y"] + self.Xdim_output = output_dims["X"] + self.Edim_output = output_dims["E"] + self.ydim_output = output_dims["y"] + self.node_dist = nodes_dist + self.active_index = active_index + self.max_n_nodes = data_info.max_n_nodes + self.train_loss = TrainLossDiscrete(dm_cfg.lambda_train) + self.atom_decoder = data_info.atom_decoder + self.hidden_size = dm_cfg.hidden_size + self.text_input_size = 768 + + self.denoiser = Transformer( + max_n_nodes=self.max_n_nodes, + hidden_size=dm_cfg.hidden_size, + depth=dm_cfg.depth, + num_heads=dm_cfg.num_heads, + mlp_ratio=dm_cfg.mlp_ratio, + drop_condition=dm_cfg.drop_condition, + Xdim=self.Xdim, + Edim=self.Edim, + ydim=self.ydim, + text_dim=self.text_input_size + ) + self.model_dtype = model_dtype + + self.noise_schedule = utils.PredefinedNoiseScheduleDiscrete( + dm_cfg.diffusion_noise_schedule, timesteps=dm_cfg.diffusion_steps + ) + x_marginals = data_info.node_types.to(self.model_dtype) / torch.sum( + data_info.node_types.to(self.model_dtype) + ) + e_marginals = data_info.edge_types.to(self.model_dtype) / torch.sum( + data_info.edge_types.to(self.model_dtype) + ) + x_marginals = x_marginals / x_marginals.sum() + e_marginals = e_marginals / e_marginals.sum() + + xe_conditions = data_info.transition_E.to(self.model_dtype) + xe_conditions = xe_conditions[self.active_index][:, self.active_index] + + xe_conditions = xe_conditions.sum(dim=1) + ex_conditions = xe_conditions.t() + xe_conditions = xe_conditions / xe_conditions.sum(dim=-1, keepdim=True) + ex_conditions = ex_conditions / ex_conditions.sum(dim=-1, keepdim=True) + + self.transition_model = utils.MarginalTransition( + x_marginals=x_marginals, + e_marginals=e_marginals, + xe_conditions=xe_conditions, + ex_conditions=ex_conditions, + y_classes=self.ydim_output, + n_nodes=self.max_n_nodes, + ) + self.limit_dist = utils.PlaceHolder(X=x_marginals, E=e_marginals, y=None) + + def init_model(self, model_dir, verbose=False): + model_file = os.path.join(model_dir, 'model.pt') + if os.path.exists(model_file): + self.denoiser.load_state_dict(torch.load(model_file, map_location='cpu', weights_only=True)) + else: + raise FileNotFoundError(f"Model file not found: {model_file}") + + if verbose: + print('GraphDiT Denoiser Model initialized.') + print('Denoiser model:\n', self.denoiser) + + def save_pretrained(self, output_dir): + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + # Save model + model_path = os.path.join(output_dir, 'model.pt') + torch.save(self.denoiser.state_dict(), model_path) + + # Save model config + config_path = os.path.join(output_dir, 'model_config.yaml') + with open(config_path, 'w') as f: + yaml.dump(vars(self.model_config), f) + + # Save data info + data_info_path = os.path.join(output_dir, 'data.meta.json') + data_info_dict = { + "active_atoms": self.data_info.active_atoms, + "max_node": self.data_info.max_n_nodes, + "n_atoms_per_mol_dist": self.data_info.n_nodes.tolist(), + "bond_type_dist": self.data_info.edge_types.tolist(), + "transition_E": self.data_info.transition_E.tolist(), + "atom_type_dist": self.data_info.node_types.tolist(), + "valencies": self.data_info.valency_distribution.tolist() + } + with open(data_info_path, 'w') as f: + json.dump(data_info_dict, f, indent=2) + + print('GraphDiT Model and configurations saved to:', output_dir) + + def disable_grads(self): + self.denoiser.disable_grads() + + def forward( + self, x, edge_index, edge_attr, graph_batch, properties, text_embedding, no_label_index + ): + properties = torch.where(properties == no_label_index, float("nan"), properties) + data_x = F.one_hot(x, num_classes=118).to(self.model_dtype)[ + :, self.active_index + ] + data_edge_attr = F.one_hot(edge_attr, num_classes=5).to(self.model_dtype) + + dense_data, node_mask = utils.to_dense( + data_x, edge_index, data_edge_attr, graph_batch, self.max_n_nodes + ) + X, E = dense_data.X, dense_data.E + + dense_data = dense_data.mask(node_mask) + noisy_data = self.apply_noise(X, E, properties, node_mask) + pred = self._forward(noisy_data, text_embedding) + loss = self.train_loss( + masked_pred_X=pred.X, + masked_pred_E=pred.E, + true_X=X, + true_E=E, + node_mask=node_mask, + ) + return loss + + def _forward(self, noisy_data, text_embedding, unconditioned=False): + noisy_x, noisy_e, properties = ( + noisy_data["X_t"].to(self.model_dtype), + noisy_data["E_t"].to(self.model_dtype), + noisy_data["y_t"].to(self.model_dtype).clone(), + ) + node_mask, timestep, text_embedding = ( + noisy_data["node_mask"], + noisy_data["t"], + text_embedding.to(self.model_dtype), + ) + + pred = self.denoiser( + noisy_x, + noisy_e, + node_mask, + properties, + text_embedding, + timestep, + unconditioned=unconditioned, + ) + return pred + + def apply_noise(self, X, E, y, node_mask): + """Sample noise and apply it to the data.""" + + # Sample a timestep t. + # When evaluating, the loss for t=0 is computed separately + lowest_t = 0 if self.training else 1 + t_int = torch.randint( + lowest_t, self.T + 1, size=(X.size(0), 1), device=X.device + ).to( + self.model_dtype + ) # (bs, 1) + s_int = t_int - 1 + + t_float = t_int / self.T + s_float = s_int / self.T + + # beta_t and alpha_s_bar are used for denoising/loss computation + beta_t = self.noise_schedule(t_normalized=t_float) # (bs, 1) + alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s_float) # (bs, 1) + alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t_float) # (bs, 1) + + Qtb = self.transition_model.get_Qt_bar( + alpha_t_bar, X.device + ) # (bs, dx_in, dx_out), (bs, de_in, de_out) + + bs, n, d = X.shape + X_all = torch.cat([X, E.reshape(bs, n, -1)], dim=-1) + prob_all = X_all @ Qtb.X + probX = prob_all[:, :, : self.Xdim_output] + probE = prob_all[:, :, self.Xdim_output :].reshape(bs, n, n, -1) + + sampled_t = utils.sample_discrete_features( + probX=probX, probE=probE, node_mask=node_mask + ) + + X_t = F.one_hot(sampled_t.X, num_classes=self.Xdim_output) + E_t = F.one_hot(sampled_t.E, num_classes=self.Edim_output) + assert (X.shape == X_t.shape) and (E.shape == E_t.shape) + + y_t = y + z_t = utils.PlaceHolder(X=X_t, E=E_t, y=y_t).type_as(X_t).mask(node_mask) + + noisy_data = { + "t_int": t_int, + "t": t_float, + "beta_t": beta_t, + "alpha_s_bar": alpha_s_bar, + "alpha_t_bar": alpha_t_bar, + "X_t": z_t.X, + "E_t": z_t.E, + "y_t": z_t.y, + "node_mask": node_mask, + } + return noisy_data + + @torch.no_grad() + def generate( + self, + properties, + text_embedding, + no_label_index, + ): + properties = torch.where(properties == no_label_index, float("nan"), properties) + batch_size = properties.size(0) + device = torch.device("cuda" if torch.cuda.is_available() else "cpu") + n_nodes = self.node_dist.sample_n(batch_size, device) + arange = ( + torch.arange(self.max_n_nodes, device=device) + .unsqueeze(0) + .expand(batch_size, -1) + ) + node_mask = arange < n_nodes.unsqueeze(1) + + z_T = utils.sample_discrete_feature_noise( + limit_dist=self.limit_dist, node_mask=node_mask + ) + X, E = z_T.X, z_T.E + + assert (E == torch.transpose(E, 1, 2)).all() + + # Iteratively sample p(z_s | z_t) for t = 1, ..., T, with s = t - 1. + y = properties + for s_int in reversed(range(0, self.T)): + s_array = s_int * torch.ones((batch_size, 1)).type_as(y) + t_array = s_array + 1 + s_norm = s_array / self.T + t_norm = t_array / self.T + + # Sample z_s + sampled_s, discrete_sampled_s = self.sample_p_zs_given_zt( + s_norm, t_norm, X, E, y, text_embedding, node_mask + ) + X, E, y = sampled_s.X, sampled_s.E, sampled_s.y + + # Sample + sampled_s = sampled_s.mask(node_mask, collapse=True) + X, E, y = sampled_s.X, sampled_s.E, sampled_s.y + + molecule_list = [] + for i in range(batch_size): + n = n_nodes[i] + atom_types = X[i, :n].cpu() + edge_types = E[i, :n, :n].cpu() + molecule_list.append([atom_types, edge_types]) + + smiles_list = graph_to_smiles(molecule_list, self.atom_decoder) + + return smiles_list + + def check_valid(self, smiles): + return check_valid(smiles) + + def sample_p_zs_given_zt( + self, s, t, X_t, E_t, properties, text_embedding, node_mask + ): + """Samples from zs ~ p(zs | zt). Only used during sampling. + if last_step, return the graph prediction as well""" + bs, n, _ = X_t.shape + beta_t = self.noise_schedule(t_normalized=t) # (bs, 1) + alpha_s_bar = self.noise_schedule.get_alpha_bar(t_normalized=s) + alpha_t_bar = self.noise_schedule.get_alpha_bar(t_normalized=t) + + # Neural net predictions + noisy_data = { + "X_t": X_t, + "E_t": E_t, + "y_t": properties, + "t": t, + "node_mask": node_mask, + } + + def get_prob(noisy_data, text_embedding, unconditioned=False): + pred = self._forward(noisy_data, text_embedding, unconditioned=unconditioned) + + # Normalize predictions + pred_X = F.softmax(pred.X, dim=-1) # bs, n, d0 + pred_E = F.softmax(pred.E, dim=-1) # bs, n, n, d0 + + device = text_embedding.device + # Retrieve transitions matrix + Qtb = self.transition_model.get_Qt_bar(alpha_t_bar, device) + Qsb = self.transition_model.get_Qt_bar(alpha_s_bar, device) + Qt = self.transition_model.get_Qt(beta_t, device) + + Xt_all = torch.cat([X_t, E_t.reshape(bs, n, -1)], dim=-1) + predX_all = torch.cat([pred_X, pred_E.reshape(bs, n, -1)], dim=-1) + + unnormalized_probX_all = utils.reverse_diffusion( + predX_0=predX_all, X_t=Xt_all, Qt=Qt.X, Qsb=Qsb.X, Qtb=Qtb.X + ) + + unnormalized_prob_X = unnormalized_probX_all[:, :, : self.Xdim_output] + unnormalized_prob_E = unnormalized_probX_all[ + :, :, self.Xdim_output : + ].reshape(bs, n * n, -1) + + unnormalized_prob_X[torch.sum(unnormalized_prob_X, dim=-1) == 0] = 1e-5 + unnormalized_prob_E[torch.sum(unnormalized_prob_E, dim=-1) == 0] = 1e-5 + + prob_X = unnormalized_prob_X / torch.sum( + unnormalized_prob_X, dim=-1, keepdim=True + ) # bs, n, d_t-1 + prob_E = unnormalized_prob_E / torch.sum( + unnormalized_prob_E, dim=-1, keepdim=True + ) # bs, n, d_t-1 + prob_E = prob_E.reshape(bs, n, n, pred_E.shape[-1]) + + return prob_X, prob_E + + prob_X, prob_E = get_prob(noisy_data, text_embedding) + + ### Guidance + if self.guide_scale is not None and self.guide_scale != 1: + uncon_prob_X, uncon_prob_E = get_prob( + noisy_data, text_embedding, unconditioned=True + ) + prob_X = ( + uncon_prob_X + * (prob_X / uncon_prob_X.clamp_min(1e-5)) ** self.guide_scale + ) + prob_E = ( + uncon_prob_E + * (prob_E / uncon_prob_E.clamp_min(1e-5)) ** self.guide_scale + ) + prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True).clamp_min(1e-5) + prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True).clamp_min(1e-5) + + sampled_s = utils.sample_discrete_features( + prob_X, prob_E, node_mask=node_mask, step=s[0, 0].item() + ) + + X_s = F.one_hot(sampled_s.X, num_classes=self.Xdim_output).to(self.model_dtype) + E_s = F.one_hot(sampled_s.E, num_classes=self.Edim_output).to(self.model_dtype) + + assert (E_s == torch.transpose(E_s, 1, 2)).all() + assert (X_t.shape == X_s.shape) and (E_t.shape == E_s.shape) + + out_one_hot = utils.PlaceHolder(X=X_s, E=E_s, y=properties) + out_discrete = utils.PlaceHolder(X=X_s, E=E_s, y=properties) + + return out_one_hot.mask(node_mask).type_as(properties), out_discrete.mask( + node_mask, collapse=True + ).type_as(properties) + + +class TrainLossDiscrete(nn.Module): + """Train with Cross entropy""" + + def __init__(self, lambda_train): + super().__init__() + self.lambda_train = lambda_train + + def forward(self, masked_pred_X, masked_pred_E, true_X, true_E, node_mask): + + true_X = torch.reshape(true_X, (-1, true_X.size(-1))) # (bs * n, dx) + true_E = torch.reshape(true_E, (-1, true_E.size(-1))) # (bs * n * n, de) + masked_pred_X = torch.reshape( + masked_pred_X, (-1, masked_pred_X.size(-1)) + ) # (bs * n, dx) + masked_pred_E = torch.reshape( + masked_pred_E, (-1, masked_pred_E.size(-1)) + ) # (bs * n * n, de) + + # Remove masked rows + mask_X = (true_X != 0.0).any(dim=-1) + mask_E = (true_E != 0.0).any(dim=-1) + + flat_true_X = true_X[mask_X, :] + flat_pred_X = masked_pred_X[mask_X, :] + + flat_true_E = true_E[mask_E, :] + flat_pred_E = masked_pred_E[mask_E, :] + + target_X = torch.argmax(flat_true_X, dim=-1) + loss_X = F.cross_entropy(flat_pred_X, target_X, reduction="mean") + + target_E = torch.argmax(flat_true_E, dim=-1) + loss_E = F.cross_entropy(flat_pred_E, target_E, reduction="mean") + + total_loss = self.lambda_train[0] * loss_X + self.lambda_train[1] * loss_E + + return total_loss diff --git a/src/model/graph_decoder/diffusion_utils.py b/src/model/graph_decoder/diffusion_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..fc64671bb60fe958a5f66eea0b46fa68279a2a50 --- /dev/null +++ b/src/model/graph_decoder/diffusion_utils.py @@ -0,0 +1,537 @@ +# Copyright 2024 the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import numpy as np +from torch.nn import functional as F +from torch_geometric.utils import to_dense_adj, to_dense_batch, remove_self_loops +import os +import json +import yaml +from types import SimpleNamespace + +def dict_to_namespace(d): + return SimpleNamespace( + **{k: dict_to_namespace(v) if isinstance(v, dict) else v for k, v in d.items()} + ) + +class DataInfos: + def __init__(self, meta_filename="data.meta.json"): + self.all_targets = ["BBBP", "HIV", "BACE", "CO2", "N2", "O2", "FFV", "TC"] + self.task_type = "pretrain" + if os.path.exists(meta_filename): + with open(meta_filename, "r") as f: + meta_dict = json.load(f) + else: + raise FileNotFoundError(f"Meta file {meta_filename} not found.") + + self.active_atoms = meta_dict["active_atoms"] + self.max_n_nodes = meta_dict["max_node"] + self.original_max_n_nodes = meta_dict["max_node"] + self.n_nodes = torch.Tensor(meta_dict["n_atoms_per_mol_dist"]) + self.edge_types = torch.Tensor(meta_dict["bond_type_dist"]) + self.transition_E = torch.Tensor(meta_dict["transition_E"]) + + self.atom_decoder = meta_dict["active_atoms"] + node_types = torch.Tensor(meta_dict["atom_type_dist"]) + active_index = (node_types > 0).nonzero().squeeze() + self.node_types = torch.Tensor(meta_dict["atom_type_dist"])[active_index] + self.nodes_dist = DistributionNodes(self.n_nodes) + self.active_index = active_index + + val_len = 3 * self.original_max_n_nodes - 2 + meta_val = torch.Tensor(meta_dict["valencies"]) + self.valency_distribution = torch.zeros(val_len) + val_len = min(val_len, len(meta_val)) + self.valency_distribution[:val_len] = meta_val[:val_len] + self.input_dims = {"X": 16, "E": 5, "y": 10} + self.output_dims = {"X": 16, "E": 5, "y": 10} + + +def load_config(config_path, data_meta_info_path): + if not os.path.exists(config_path): + raise FileNotFoundError(f"Configuration file not found: {config_path}") + + if not os.path.exists(data_meta_info_path): + raise FileNotFoundError(f"Data meta info file not found: {data_meta_info_path}") + + with open(config_path, "r") as file: + cfg_dict = yaml.safe_load(file) + + cfg = dict_to_namespace(cfg_dict) + + data_info = DataInfos(data_meta_info_path) + return cfg, data_info + + +### graph utils +class PlaceHolder: + def __init__(self, X, E, y): + self.X = X + self.E = E + self.y = y + + def type_as(self, x: torch.Tensor, categorical: bool = False): + """Changes the device and dtype of X, E, y.""" + self.X = self.X.type_as(x) + self.E = self.E.type_as(x) + if categorical: + self.y = self.y.type_as(x) + return self + + def mask(self, node_mask, collapse=False): + x_mask = node_mask.unsqueeze(-1) # bs, n, 1 + e_mask1 = x_mask.unsqueeze(2) # bs, n, 1, 1 + e_mask2 = x_mask.unsqueeze(1) # bs, 1, n, 1 + + if collapse: + self.X = torch.argmax(self.X, dim=-1) + self.E = torch.argmax(self.E, dim=-1) + + self.X[node_mask == 0] = -1 + self.E[(e_mask1 * e_mask2).squeeze(-1) == 0] = -1 + else: + self.X = self.X * x_mask + self.E = self.E * e_mask1 * e_mask2 + assert torch.allclose(self.E, torch.transpose(self.E, 1, 2)) + return self + + +def to_dense(x, edge_index, edge_attr, batch, max_num_nodes=None): + X, node_mask = to_dense_batch(x=x, batch=batch, max_num_nodes=max_num_nodes) + # node_mask = node_mask.float() + edge_index, edge_attr = remove_self_loops(edge_index, edge_attr) + if max_num_nodes is None: + max_num_nodes = X.size(1) + E = to_dense_adj( + edge_index=edge_index, + batch=batch, + edge_attr=edge_attr, + max_num_nodes=max_num_nodes, + ) + E = encode_no_edge(E) + return PlaceHolder(X=X, E=E, y=None), node_mask + + +def encode_no_edge(E): + assert len(E.shape) == 4 + if E.shape[-1] == 0: + return E + no_edge = torch.sum(E, dim=3) == 0 + first_elt = E[:, :, :, 0] + first_elt[no_edge] = 1 + E[:, :, :, 0] = first_elt + diag = ( + torch.eye(E.shape[1], dtype=torch.bool).unsqueeze(0).expand(E.shape[0], -1, -1) + ) + E[diag] = 0 + return E + + +### diffusion utils +class DistributionNodes: + def __init__(self, histogram): + """Compute the distribution of the number of nodes in the dataset, and sample from this distribution. + historgram: dict. The keys are num_nodes, the values are counts + """ + + if type(histogram) == dict: + max_n_nodes = max(histogram.keys()) + prob = torch.zeros(max_n_nodes + 1) + for num_nodes, count in histogram.items(): + prob[num_nodes] = count + else: + prob = histogram + + self.prob = prob / prob.sum() + self.m = torch.distributions.Categorical(prob) + + def sample_n(self, n_samples, device): + idx = self.m.sample((n_samples,)) + return idx.to(device) + + def log_prob(self, batch_n_nodes): + assert len(batch_n_nodes.size()) == 1 + p = self.prob.to(batch_n_nodes.device) + + probas = p[batch_n_nodes] + log_p = torch.log(probas + 1e-30) + return log_p + +class PredefinedNoiseScheduleDiscrete(torch.nn.Module): + def __init__(self, noise_schedule, timesteps): + super(PredefinedNoiseScheduleDiscrete, self).__init__() + self.timesteps = timesteps + + betas = cosine_beta_schedule_discrete(timesteps) + self.register_buffer("betas", torch.from_numpy(betas).float()) + + # 0.9999 + self.alphas = 1 - torch.clamp(self.betas, min=0, max=1) + + log_alpha = torch.log(self.alphas) + log_alpha_bar = torch.cumsum(log_alpha, dim=0) + self.alphas_bar = torch.exp(log_alpha_bar) + + def forward(self, t_normalized=None, t_int=None): + assert int(t_normalized is None) + int(t_int is None) == 1 + if t_int is None: + t_int = torch.round(t_normalized * self.timesteps) + self.betas = self.betas.type_as(t_int) + return self.betas[t_int.long()] + + def get_alpha_bar(self, t_normalized=None, t_int=None): + assert int(t_normalized is None) + int(t_int is None) == 1 + if t_int is None: + t_int = torch.round(t_normalized * self.timesteps) + self.alphas_bar = self.alphas_bar.type_as(t_int) + return self.alphas_bar[t_int.long()] + + +class DiscreteUniformTransition: + def __init__(self, x_classes: int, e_classes: int, y_classes: int): + self.X_classes = x_classes + self.E_classes = e_classes + self.y_classes = y_classes + self.u_x = torch.ones(1, self.X_classes, self.X_classes) + if self.X_classes > 0: + self.u_x = self.u_x / self.X_classes + + self.u_e = torch.ones(1, self.E_classes, self.E_classes) + if self.E_classes > 0: + self.u_e = self.u_e / self.E_classes + + self.u_y = torch.ones(1, self.y_classes, self.y_classes) + if self.y_classes > 0: + self.u_y = self.u_y / self.y_classes + + def get_Qt(self, beta_t, device, X=None, flatten_e=None): + """Returns one-step transition matrices for X and E, from step t - 1 to step t. + Qt = (1 - beta_t) * I + beta_t / K + + beta_t: (bs) noise level between 0 and 1 + returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy). + """ + beta_t = beta_t.unsqueeze(1) + beta_t = beta_t.to(device) + self.u_x = self.u_x.to(device) + self.u_e = self.u_e.to(device) + self.u_y = self.u_y.to(device) + + q_x = beta_t * self.u_x + (1 - beta_t) * torch.eye( + self.X_classes, device=device + ).unsqueeze(0) + q_e = beta_t * self.u_e + (1 - beta_t) * torch.eye( + self.E_classes, device=device + ).unsqueeze(0) + q_y = beta_t * self.u_y + (1 - beta_t) * torch.eye( + self.y_classes, device=device + ).unsqueeze(0) + + return PlaceHolder(X=q_x, E=q_e, y=q_y) + + def get_Qt_bar(self, alpha_bar_t, device, X=None, flatten_e=None): + """Returns t-step transition matrices for X and E, from step 0 to step t. + Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) / K + + alpha_bar_t: (bs) Product of the (1 - beta_t) for each time step from 0 to t. + returns: qx (bs, dx, dx), qe (bs, de, de), qy (bs, dy, dy). + """ + alpha_bar_t = alpha_bar_t.unsqueeze(1) + alpha_bar_t = alpha_bar_t.to(device) + self.u_x = self.u_x.to(device) + self.u_e = self.u_e.to(device) + self.u_y = self.u_y.to(device) + + q_x = ( + alpha_bar_t * torch.eye(self.X_classes, device=device).unsqueeze(0) + + (1 - alpha_bar_t) * self.u_x + ) + q_e = ( + alpha_bar_t * torch.eye(self.E_classes, device=device).unsqueeze(0) + + (1 - alpha_bar_t) * self.u_e + ) + q_y = ( + alpha_bar_t * torch.eye(self.y_classes, device=device).unsqueeze(0) + + (1 - alpha_bar_t) * self.u_y + ) + + return PlaceHolder(X=q_x, E=q_e, y=q_y) + + +class MarginalTransition: + def __init__( + self, x_marginals, e_marginals, xe_conditions, ex_conditions, y_classes, n_nodes + ): + self.X_classes = len(x_marginals) + self.E_classes = len(e_marginals) + self.y_classes = y_classes + self.x_marginals = x_marginals # Dx + self.e_marginals = e_marginals # Dx, De + self.xe_conditions = xe_conditions + # print('e_marginals.dtype', e_marginals.dtype) + # print('x_marginals.dtype', x_marginals.dtype) + # print('xe_conditions.dtype', xe_conditions.dtype) + + self.u_x = ( + x_marginals.unsqueeze(0).expand(self.X_classes, -1).unsqueeze(0) + ) # 1, Dx, Dx + self.u_e = ( + e_marginals.unsqueeze(0).expand(self.E_classes, -1).unsqueeze(0) + ) # 1, De, De + self.u_xe = xe_conditions.unsqueeze(0) # 1, Dx, De + self.u_ex = ex_conditions.unsqueeze(0) # 1, De, Dx + self.u = self.get_union_transition( + self.u_x, self.u_e, self.u_xe, self.u_ex, n_nodes + ) # 1, Dx + n*De, Dx + n*De + + def get_union_transition(self, u_x, u_e, u_xe, u_ex, n_nodes): + u_e = u_e.repeat(1, n_nodes, n_nodes) # (1, n*de, n*de) + u_xe = u_xe.repeat(1, 1, n_nodes) # (1, dx, n*de) + u_ex = u_ex.repeat(1, n_nodes, 1) # (1, n*de, dx) + u0 = torch.cat([u_x, u_xe], dim=2) # (1, dx, dx + n*de) + u1 = torch.cat([u_ex, u_e], dim=2) # (1, n*de, dx + n*de) + u = torch.cat([u0, u1], dim=1) # (1, dx + n*de, dx + n*de) + return u + + def index_edge_margin(self, X, q_e, n_bond=5): + # q_e: (bs, dx, de) --> (bs, n, de) + bs, n, n_atom = X.shape + node_indices = X.argmax(-1) # (bs, n) + ind = node_indices[:, :, None].expand(bs, n, n_bond) + q_e = torch.gather(q_e, 1, ind) + return q_e + + def get_Qt(self, beta_t, device): + """Returns one-step transition matrices for X and E, from step t - 1 to step t. + Qt = (1 - beta_t) * I + beta_t / K + beta_t: (bs) + returns: q (bs, d0, d0) + """ + bs = beta_t.size(0) + d0 = self.u.size(-1) + self.u = self.u.to(device) + u = self.u.expand(bs, d0, d0) + + beta_t = beta_t.to(device) + beta_t = beta_t.view(bs, 1, 1) + q = beta_t * u + (1 - beta_t) * torch.eye(d0, device=device, dtype=self.u.dtype).unsqueeze(0) + + return PlaceHolder(X=q, E=None, y=None) + + def get_Qt_bar(self, alpha_bar_t, device): + """Returns t-step transition matrices for X and E, from step 0 to step t. + Qt = prod(1 - beta_t) * I + (1 - prod(1 - beta_t)) * K + alpha_bar_t: (bs, 1) roduct of the (1 - beta_t) for each time step from 0 to t. + returns: q (bs, d0, d0) + """ + bs = alpha_bar_t.size(0) + d0 = self.u.size(-1) + alpha_bar_t = alpha_bar_t.to(device) + alpha_bar_t = alpha_bar_t.view(bs, 1, 1) + self.u = self.u.to(device) + q = ( + alpha_bar_t * torch.eye(d0, device=device, dtype=self.u.dtype).unsqueeze(0) + + (1 - alpha_bar_t) * self.u + ) + + return PlaceHolder(X=q, E=None, y=None) + + +def sum_except_batch(x): + return x.reshape(x.size(0), -1).sum(dim=-1) + +def reverse_tensor(x): + return x[torch.arange(x.size(0) - 1, -1, -1)] + +def assert_correctly_masked(variable, node_mask): + assert ( + variable * (1 - node_mask.long()) + ).abs().max().item() < 1e-4, "Variables not masked properly." + + +def cosine_beta_schedule_discrete(timesteps, s=0.008): + """Cosine schedule as proposed in https://openreview.net/forum?id=-NEXDKk8gZ.""" + steps = timesteps + 2 + x = np.linspace(0, steps, steps) + + alphas_cumprod = np.cos(0.5 * np.pi * ((x / steps) + s) / (1 + s)) ** 2 + alphas_cumprod = alphas_cumprod / alphas_cumprod[0] + alphas = alphas_cumprod[1:] / alphas_cumprod[:-1] + betas = 1 - alphas + return betas.squeeze() + + +def sample_discrete_features(probX, probE, node_mask, step=None, add_nose=True): + """Sample features from multinomial distribution with given probabilities (probX, probE, proby) + :param probX: bs, n, dx_out node features + :param probE: bs, n, n, de_out edge features + :param proby: bs, dy_out global features. + """ + bs, n, _ = probX.shape + + # Noise X + # The masked rows should define probability distributions as well + probX[~node_mask] = 1 / probX.shape[-1] + + # Flatten the probability tensor to sample with multinomial + probX = probX.reshape(bs * n, -1) # (bs * n, dx_out) + + # Sample X + probX = probX.clamp_min(1e-5) + probX = probX / probX.sum(dim=-1, keepdim=True) + X_t = probX.multinomial(1) # (bs * n, 1) + X_t = X_t.reshape(bs, n) # (bs, n) + + # Noise E + # The masked rows should define probability distributions as well + inverse_edge_mask = ~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2)) + diag_mask = torch.eye(n).unsqueeze(0).expand(bs, -1, -1) + + probE[inverse_edge_mask] = 1 / probE.shape[-1] + probE[diag_mask.bool()] = 1 / probE.shape[-1] + probE = probE.reshape(bs * n * n, -1) # (bs * n * n, de_out) + probE = probE.clamp_min(1e-5) + probE = probE / probE.sum(dim=-1, keepdim=True) + + # Sample E + E_t = probE.multinomial(1).reshape(bs, n, n) # (bs, n, n) + E_t = torch.triu(E_t, diagonal=1) + E_t = E_t + torch.transpose(E_t, 1, 2) + + return PlaceHolder(X=X_t, E=E_t, y=torch.zeros(bs, 0).type_as(X_t)) + + +def mask_distributions(true_X, true_E, pred_X, pred_E, node_mask): + # Add a small value everywhere to avoid nans + pred_X = pred_X.clamp_min(1e-5) + pred_X = pred_X / torch.sum(pred_X, dim=-1, keepdim=True) + + pred_E = pred_E.clamp_min(1e-5) + pred_E = pred_E / torch.sum(pred_E, dim=-1, keepdim=True) + + # Set masked rows to arbitrary distributions, so it doesn't contribute to loss + row_X = torch.ones(true_X.size(-1), dtype=true_X.dtype, device=true_X.device) + row_E = torch.zeros( + true_E.size(-1), dtype=true_E.dtype, device=true_E.device + ).clamp_min(1e-5) + row_E[0] = 1.0 + + diag_mask = ~torch.eye( + node_mask.size(1), device=node_mask.device, dtype=torch.bool + ).unsqueeze(0) + true_X[~node_mask] = row_X + true_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = row_E + pred_X[~node_mask] = row_X.type_as(pred_X) + pred_E[~(node_mask.unsqueeze(1) * node_mask.unsqueeze(2) * diag_mask), :] = ( + row_E.type_as(pred_E) + ) + + return true_X, true_E, pred_X, pred_E + + +def forward_diffusion(X, X_t, Qt, Qsb, Qtb, X_dim): + bs, n, d = X.shape + + Qt_X_T = torch.transpose(Qt.X, -2, -1) # (bs, d, d) + left_term = X_t @ Qt_X_T # (bs, N, d) + right_term = X @ Qsb.X # (bs, N, d) + + numerator = left_term * right_term # (bs, N, d) + denominator = X @ Qtb.X # (bs, N, d) @ (bs, d, d) = (bs, N, d) + denominator = denominator * X_t + + num_X = numerator[:, :, :X_dim] + num_E = numerator[:, :, X_dim:].reshape(bs, n * n, -1) + + deno_X = denominator[:, :, :X_dim] + deno_E = denominator[:, :, X_dim:].reshape(bs, n * n, -1) + + denominator = denominator.unsqueeze(-1) # (bs, N, 1) + + deno_X = deno_X.sum(dim=-1, keepdim=True) + deno_E = deno_E.sum(dim=-1, keepdim=True) + + deno_X[deno_X == 0.0] = 1 + deno_E[deno_E == 0.0] = 1 + prob_X = num_X / deno_X + prob_E = num_E / deno_E + + prob_E = prob_E / prob_E.sum(dim=-1, keepdim=True) + prob_X = prob_X / prob_X.sum(dim=-1, keepdim=True) + return PlaceHolder(X=prob_X, E=prob_E, y=None) + + +def reverse_diffusion(predX_0, X_t, Qt, Qsb, Qtb): + """M: X or E + Compute xt @ Qt.T * x0 @ Qsb / x0 @ Qtb @ xt.T for each possible value of x0 + X_t: bs, n, dt or bs, n, n, dt + Qt: bs, d_t-1, dt + Qsb: bs, d0, d_t-1 + Qtb: bs, d0, dt. + """ + Qt_T = Qt.transpose(-1, -2) # bs, N, dt + assert Qt.dim() == 3 + left_term = X_t @ Qt_T # bs, N, d_t-1 + right_term = predX_0 @ Qsb + numerator = left_term * right_term # bs, N, d_t-1 + + denominator = Qtb @ X_t.transpose(-1, -2) # bs, d0, N + denominator = denominator.transpose(-1, -2) # bs, N, d0 + return numerator / denominator.clamp_min(1e-5) + + +def sample_discrete_feature_noise(limit_dist, node_mask): + """Sample from the limit distribution of the diffusion process""" + bs, n_max = node_mask.shape + x_limit = limit_dist.X[None, None, :].expand(bs, n_max, -1) + U_X = x_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max) + U_X = F.one_hot(U_X.long(), num_classes=x_limit.shape[-1]).type_as(x_limit) + + e_limit = limit_dist.E[None, None, None, :].expand(bs, n_max, n_max, -1) + U_E = e_limit.flatten(end_dim=-2).multinomial(1).reshape(bs, n_max, n_max) + U_E = F.one_hot(U_E.long(), num_classes=e_limit.shape[-1]).type_as(x_limit) + + U_X = U_X.to(node_mask.device) + U_E = U_E.to(node_mask.device) + + # Get upper triangular part of edge noise, without main diagonal + upper_triangular_mask = torch.zeros_like(U_E) + indices = torch.triu_indices(row=U_E.size(1), col=U_E.size(2), offset=1) + upper_triangular_mask[:, indices[0], indices[1], :] = 1 + + U_E = U_E * upper_triangular_mask + U_E = U_E + torch.transpose(U_E, 1, 2) + + assert (U_E == torch.transpose(U_E, 1, 2)).all() + return PlaceHolder(X=U_X, E=U_E, y=None).mask(node_mask) + + +def index_QE(X, q_e, n_bond=5): + bs, n, n_atom = X.shape + node_indices = X.argmax(-1) # (bs, n) + + exp_ind1 = node_indices[:, :, None, None, None].expand( + bs, n, n_atom, n_bond, n_bond + ) + exp_ind2 = node_indices[:, :, None, None, None].expand(bs, n, n, n_bond, n_bond) + + q_e = torch.gather(q_e, 1, exp_ind1) + q_e = torch.gather(q_e, 2, exp_ind2) # (bs, n, n, n_bond, n_bond) + + node_mask = X.sum(-1) != 0 + no_edge = (~node_mask)[:, :, None] & (~node_mask)[:, None, :] + q_e[no_edge] = torch.tensor([1, 0, 0, 0, 0]).type_as(q_e) + + return q_e diff --git a/src/model/graph_decoder/layers.py b/src/model/graph_decoder/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..d80d4ba1a41aacb42413a57109066e37c2415de6 --- /dev/null +++ b/src/model/graph_decoder/layers.py @@ -0,0 +1,128 @@ +# Copyright 2024 the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from torch.jit import Final +import torch.nn.functional as F +from itertools import repeat +import collections.abc + +import torch +import torch.nn as nn + +class Attention(nn.Module): + fast_attn: Final[bool] + + def __init__( + self, + dim, + num_heads=8, + qkv_bias=False, + qk_norm=False, + attn_drop=0, + proj_drop=0, + norm_layer=nn.LayerNorm, + ): + super().__init__() + assert dim % num_heads == 0, "dim should be divisible by num_heads" + self.num_heads = num_heads + self.head_dim = dim // num_heads + + self.scale = self.head_dim**-0.5 + self.fast_attn = hasattr( + torch.nn.functional, "scaled_dot_product_attention" + ) # FIXME + assert self.fast_attn, "scaled_dot_product_attention Not implemented" + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + + self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity() + self.attn_drop = nn.Dropout(attn_drop) + + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x, node_mask): + B, N, D = x.shape + + # B, head, N, head_dim + qkv = ( + self.qkv(x) + .reshape(B, N, 3, self.num_heads, self.head_dim) + .permute(2, 0, 3, 1, 4) + ) + q, k, v = qkv.unbind(0) # B, head, N, head_dim + q, k = self.q_norm(q), self.k_norm(k) + + attn_mask = (node_mask[:, None, :, None] & node_mask[:, None, None, :]).expand( + -1, self.num_heads, N, N + ) + extended_nodes = (attn_mask.sum(dim=-1) == 0) + attn_mask = attn_mask.clone() + attn_mask[extended_nodes] = True + + x = F.scaled_dot_product_attention( + q, + k, + v, + dropout_p=self.attn_drop.p, + attn_mask=attn_mask, + ) + + x = x.transpose(1, 2).reshape(B, N, -1) + x = self.proj(x) + x = self.proj_drop(x) + + return x + + +class MLP(nn.Module): + def __init__( + self, + in_features, + hidden_features=None, + out_features=None, + act_layer=nn.GELU, + bias=True, + drop=0.0, + ): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + bias = to_2tuple(bias) + linear_layer = nn.Linear + + self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0]) + self.act = act_layer() + self.drop1 = nn.Dropout(drop) + self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1]) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + return x + + +# From PyTorch internals +def _ntuple(n): + def parse(x): + if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): + return tuple(x) + return tuple(repeat(x, n)) + + return parse + +to_2tuple = _ntuple(2) diff --git a/src/model/graph_decoder/molecule_utils.py b/src/model/graph_decoder/molecule_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..c0f76cf1cc7cd02b0b50f9d9bdd165792f1422f4 --- /dev/null +++ b/src/model/graph_decoder/molecule_utils.py @@ -0,0 +1,352 @@ +# Copyright 2024 the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from rdkit import Chem, RDLogger + +RDLogger.DisableLog("rdApp.*") + +import re +import random +import logging +from rdkit import Chem +from typing import List, Tuple, Optional +random.seed(0) +import torch + +bond_dict = [ + None, + Chem.rdchem.BondType.SINGLE, + Chem.rdchem.BondType.DOUBLE, + Chem.rdchem.BondType.TRIPLE, + Chem.rdchem.BondType.AROMATIC, +] + +ATOM_VALENCY = {6: 4, 7: 3, 8: 2, 9: 1, 15: 3, 16: 2, 17: 1, 35: 1, 53: 1} + +logger = logging.getLogger(__name__) + +def check_polymer(smiles): + if "*" in smiles: + monomer = smiles.replace("*", "[H]") + if mol2smiles(get_mol(monomer)) is None: + logger.warning(f"Invalid polymerization point") + return False + else: + return True + return True + +def graph_to_smiles(molecule_list: List[Tuple], atom_decoder: list) -> List[Optional[str]]: + + smiles_list = [] + for index, graph in enumerate(molecule_list): + try: + atom_types, edge_types = graph + mol_init = build_molecule_with_partial_charges(atom_types, edge_types, atom_decoder) + + # Try to correct the molecule with connection=True, then False if needed + for connection in (True, False): + mol_conn, _ = correct_mol(mol_init, connection=connection) + if mol_conn is not None: + break + else: + logger.warning(f"Failed to correct molecule {index}") + mol_conn = mol_init # Fallback to initial molecule + + # Convert to SMILES + smiles = mol2smiles(mol_conn) + if not smiles: + logger.warning(f"Failed to convert molecule {index} to SMILES, falling back to RDKit MolToSmiles") + smiles = Chem.MolToSmiles(mol_conn) + + if smiles: + mol = get_mol(smiles) + if mol is not None: + # Get the largest fragment + mol_frags = Chem.rdmolops.GetMolFrags(mol, asMols=True, sanitizeFrags=False) + largest_mol = max(mol_frags, key=lambda m: m.GetNumAtoms()) + + largest_smiles = mol2smiles(largest_mol) + if largest_smiles and len(largest_smiles) > 1: + if check_polymer(largest_smiles): + smiles_list.append(largest_smiles) + else: + smiles_list.append(None) + elif check_polymer(smiles): + smiles_list.append(smiles) + else: + smiles_list.append(None) + else: + logger.warning(f"Failed to convert SMILES back to molecule for index {index}") + smiles_list.append(None) + else: + logger.warning(f"Failed to generate SMILES for molecule {index}, appending None") + smiles_list.append(None) + + except Exception as e: + logger.error(f"Error processing molecule {index}: {str(e)}") + try: + # Fallback to RDKit's MolToSmiles if everything else fails + fallback_smiles = Chem.MolToSmiles(mol_init) + if fallback_smiles: + smiles_list.append(fallback_smiles) + logger.warning(f"Used RDKit MolToSmiles fallback for molecule {index}") + else: + smiles_list.append(None) + logger.warning(f"RDKit MolToSmiles fallback failed for molecule {index}, appending None") + except Exception as e2: + logger.error(f"All attempts failed for molecule {index}: {str(e2)}") + smiles_list.append(None) + + return smiles_list + +def build_molecule_with_partial_charges( + atom_types, edge_types, atom_decoder, verbose=False +): + if verbose: + print("\nbuilding new molecule") + + mol = Chem.RWMol() + for atom in atom_types: + a = Chem.Atom(atom_decoder[atom.item()]) + mol.AddAtom(a) + if verbose: + print("Atom added: ", atom.item(), atom_decoder[atom.item()]) + + edge_types = torch.triu(edge_types) + all_bonds = torch.nonzero(edge_types) + + for i, bond in enumerate(all_bonds): + if bond[0].item() != bond[1].item(): + mol.AddBond( + bond[0].item(), + bond[1].item(), + bond_dict[edge_types[bond[0], bond[1]].item()], + ) + if verbose: + print( + "bond added:", + bond[0].item(), + bond[1].item(), + edge_types[bond[0], bond[1]].item(), + bond_dict[edge_types[bond[0], bond[1]].item()], + ) + # add formal charge to atom: e.g. [O+], [N+], [S+] + # not support [O-], [N-], [S-], [NH+] etc. + flag, atomid_valence = check_valency(mol) + if verbose: + print("flag, valence", flag, atomid_valence) + if flag: + continue + else: + if len(atomid_valence) == 2: + idx = atomid_valence[0] + v = atomid_valence[1] + an = mol.GetAtomWithIdx(idx).GetAtomicNum() + if verbose: + print("atomic num of atom with a large valence", an) + if an in (7, 8, 16) and (v - ATOM_VALENCY[an]) == 1: + mol.GetAtomWithIdx(idx).SetFormalCharge(1) + # print("Formal charge added") + else: + continue + return mol + + +def correct_mol(mol, connection=False): + ##### + no_correct = False + flag, _ = check_valency(mol) + if flag: + no_correct = True + + while True: + if connection: + mol_conn = connect_fragments(mol) + mol = mol_conn + if mol is None: + return None, no_correct + flag, atomid_valence = check_valency(mol) + if flag: + break + else: + try: + assert len(atomid_valence) == 2 + idx = atomid_valence[0] + v = atomid_valence[1] + queue = [] + check_idx = 0 + for b in mol.GetAtomWithIdx(idx).GetBonds(): + type = int(b.GetBondType()) + queue.append( + (b.GetIdx(), type, b.GetBeginAtomIdx(), b.GetEndAtomIdx()) + ) + if type == 12: + check_idx += 1 + queue.sort(key=lambda tup: tup[1], reverse=True) + + if queue[-1][1] == 12: + return None, no_correct + elif len(queue) > 0: + start = queue[check_idx][2] + end = queue[check_idx][3] + t = queue[check_idx][1] - 1 + mol.RemoveBond(start, end) + if t >= 1: + mol.AddBond(start, end, bond_dict[t]) + except Exception as e: + # print(f"An error occurred in correction: {e}") + return None, no_correct + return mol, no_correct + +def check_valid(smiles): + mol = get_mol(smiles) + if mol is None: + return False + smiles = mol2smiles(mol) + if smiles is None: + return False + return True + +def get_mol(smiles_or_mol): + """ + Loads SMILES/molecule into RDKit's object + """ + if isinstance(smiles_or_mol, str): + if len(smiles_or_mol) == 0: + return None + mol = Chem.MolFromSmiles(smiles_or_mol) + if mol is None: + return None + try: + Chem.SanitizeMol(mol) + except ValueError: + return None + return mol + return smiles_or_mol + + +def mol2smiles(mol): + if mol is None: + return None + try: + Chem.SanitizeMol(mol) + except ValueError: + return None + return Chem.MolToSmiles(mol) + + +def check_valency(mol): + try: + # First attempt to sanitize with specific properties + Chem.SanitizeMol(mol, sanitizeOps=Chem.SanitizeFlags.SANITIZE_PROPERTIES) + return True, None + except ValueError as e: + e = str(e) + p = e.find("#") + e_sub = e[p:] + atomid_valence = list(map(int, re.findall(r"\d+", e_sub))) + return False, atomid_valence + except Exception as e: + # print(f"An unexpected error occurred: {e}") + return False, [] + + +##### connect fragements +def select_atom_with_available_valency(frag): + atoms = list(frag.GetAtoms()) + random.shuffle(atoms) + for atom in atoms: + if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0: + return atom + return None + + +def select_atoms_with_available_valency(frag): + return [ + atom + for atom in frag.GetAtoms() + if atom.GetAtomicNum() > 1 and atom.GetImplicitValence() > 0 + ] + + +def try_to_connect_fragments(combined_mol, frag, atom1, atom2): + # Make copies of the molecules to try the connection + trial_combined_mol = Chem.RWMol(combined_mol) + trial_frag = Chem.RWMol(frag) + + # Add the new fragment to the combined molecule with new indices + new_indices = { + atom.GetIdx(): trial_combined_mol.AddAtom(atom) + for atom in trial_frag.GetAtoms() + } + + # Add the bond between the suitable atoms from each fragment + trial_combined_mol.AddBond( + atom1.GetIdx(), new_indices[atom2.GetIdx()], Chem.BondType.SINGLE + ) + + # Adjust the hydrogen count of the connected atoms + for atom_idx in [atom1.GetIdx(), new_indices[atom2.GetIdx()]]: + atom = trial_combined_mol.GetAtomWithIdx(atom_idx) + num_h = atom.GetTotalNumHs() + atom.SetNumExplicitHs(max(0, num_h - 1)) + + # Add bonds for the new fragment + for bond in trial_frag.GetBonds(): + trial_combined_mol.AddBond( + new_indices[bond.GetBeginAtomIdx()], + new_indices[bond.GetEndAtomIdx()], + bond.GetBondType(), + ) + + # Convert to a Mol object and try to sanitize it + new_mol = Chem.Mol(trial_combined_mol) + try: + Chem.SanitizeMol(new_mol) + return new_mol # Return the new valid molecule + except Chem.MolSanitizeException: + return None # If the molecule is not valid, return None + + +def connect_fragments(mol): + # Get the separate fragments + frags = Chem.GetMolFrags(mol, asMols=True, sanitizeFrags=False) + if len(frags) < 2: + return mol + + combined_mol = Chem.RWMol(frags[0]) + + for frag in frags[1:]: + # Select all atoms with available valency from both molecules + atoms1 = select_atoms_with_available_valency(combined_mol) + atoms2 = select_atoms_with_available_valency(frag) + + # Try to connect using all combinations of available valency atoms + for atom1 in atoms1: + for atom2 in atoms2: + new_mol = try_to_connect_fragments(combined_mol, frag, atom1, atom2) + if new_mol is not None: + # If a valid connection is made, update the combined molecule and break + combined_mol = new_mol + break + else: + # Continue if the inner loop didn't break (no valid connection found for atom1) + continue + # Break if the inner loop did break (valid connection found) + break + else: + # If no valid connections could be made with any of the atoms, return None + return None + + return combined_mol diff --git a/src/model/graph_decoder/transformer.py b/src/model/graph_decoder/transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..70c1fe2afa780c849df49d15cf8b714e2529a837 --- /dev/null +++ b/src/model/graph_decoder/transformer.py @@ -0,0 +1,187 @@ +# Copyright 2024 the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +from .layers import Attention, MLP +from .conditions import TimestepEmbedder, ConditionEmbedder, TextEmbedder +from .diffusion_utils import PlaceHolder + +def modulate(x, shift, scale): + return x * (1 + scale.unsqueeze(1)) + shift.unsqueeze(1) + +class Transformer(nn.Module): + def __init__( + self, + max_n_nodes=50, + hidden_size=1024, + depth=28, + num_heads=16, + mlp_ratio=4.0, + drop_condition=0., + Xdim=16, + Edim=5, + ydim=10, + text_dim=768, + ): + super().__init__() + self.num_heads = num_heads + self.ydim = ydim + self.x_embedder = nn.Sequential( + nn.Linear(Xdim + max_n_nodes * Edim, hidden_size, bias=False), + nn.LayerNorm(hidden_size) + ) + + self.t_embedder = TimestepEmbedder(hidden_size) + self.y_embedder = ConditionEmbedder(ydim, hidden_size, drop_condition) + self.txt_embedder = TextEmbedder(text_dim, hidden_size, drop_condition) + + self.blocks = nn.ModuleList( + [ + Block(hidden_size, num_heads, mlp_ratio=mlp_ratio) + for _ in range(depth) + ] + ) + self.output_layer = OutputLayer( + max_n_nodes=max_n_nodes, + hidden_size=hidden_size, + atom_type=Xdim, + bond_type=Edim, + mlp_ratio=mlp_ratio, + num_heads=num_heads, + ) + self.initialize_weights() + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def _constant_init(module, i): + if isinstance(module, nn.Linear): + nn.init.constant_(module.weight, i) + if module.bias is not None: + nn.init.constant_(module.bias, i) + + self.apply(_basic_init) + + for block in self.blocks: + _constant_init(block.adaLN_modulation[0], 0) + _constant_init(self.output_layer.adaLN_modulation[0], 0) + + def disable_grads(self): + """ + Disable gradients for all parameters in the model. + """ + for param in self.parameters(): + param.requires_grad = False + + def forward(self, X_in, E_in, node_mask, y_in, txt, t, unconditioned): + bs, n, _ = X_in.size() + X = torch.cat([X_in, E_in.reshape(bs, n, -1)], dim=-1) + X = self.x_embedder(X) + + c1 = self.t_embedder(t) + c2 = self.y_embedder(y_in, self.training, unconditioned) + c3 = self.txt_embedder(txt, self.training, unconditioned) + c = c1 + c2 + c3 + + for i, block in enumerate(self.blocks): + X = block(X, c, node_mask) + + # X: B * N * dx, E: B * N * N * de + X, E = self.output_layer(X, X_in, E_in, c, t, node_mask) + return PlaceHolder(X=X, E=E, y=None).mask(node_mask) + +class Block(nn.Module): + def __init__(self, hidden_size, num_heads, mlp_ratio=4.0, **block_kwargs): + super().__init__() + self.attn_norm = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=False) + self.mlp_norm = nn.LayerNorm(hidden_size, eps=1e-05, elementwise_affine=False) + + self.attn = Attention( + hidden_size, num_heads=num_heads, qkv_bias=False, qk_norm=True, **block_kwargs + ) + + self.mlp = MLP( + in_features=hidden_size, + hidden_features=int(hidden_size * mlp_ratio), + ) + + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, 6 * hidden_size, bias=True), + nn.Softsign() + ) + + def forward(self, x, c, node_mask): + ( + shift_msa, + scale_msa, + gate_msa, + shift_mlp, + scale_mlp, + gate_mlp, + ) = self.adaLN_modulation(c).chunk(6, dim=1) + + x = x + gate_msa.unsqueeze(1) * modulate(self.attn_norm(self.attn(x, node_mask=node_mask)), shift_msa, scale_msa) + x = x + gate_mlp.unsqueeze(1) * modulate(self.mlp_norm(self.mlp(x)), shift_mlp, scale_mlp) + + return x + +class OutputLayer(nn.Module): + def __init__(self, max_n_nodes, hidden_size, atom_type, bond_type, mlp_ratio, num_heads=None): + super().__init__() + self.atom_type = atom_type + self.bond_type = bond_type + final_size = atom_type + max_n_nodes * bond_type + self.xedecoder = MLP(in_features=hidden_size, + out_features=final_size, drop=0) + + self.norm_final = nn.LayerNorm(final_size, eps=1e-05, elementwise_affine=False) + self.adaLN_modulation = nn.Sequential( + nn.Linear(hidden_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, 2 * final_size, bias=True) + ) + + def forward(self, x, x_in, e_in, c, t, node_mask): + x_all = self.xedecoder(x) + B, N, D = x_all.size() + shift, scale = self.adaLN_modulation(c).chunk(2, dim=1) + x_all = modulate(self.norm_final(x_all), shift, scale) + + atom_out = x_all[:, :, :self.atom_type] + atom_out = x_in + atom_out + + bond_out = x_all[:, :, self.atom_type:].reshape(B, N, N, self.bond_type) + bond_out = e_in + bond_out + + ##### standardize adj_out + edge_mask = (~node_mask)[:, :, None] & (~node_mask)[:, None, :] + diag_mask = ( + torch.eye(N, dtype=torch.bool) + .unsqueeze(0) + .expand(B, -1, -1) + .type_as(edge_mask) + ) + bond_out.masked_fill_(edge_mask[:, :, :, None], 0) + bond_out.masked_fill_(diag_mask[:, :, :, None], 0) + bond_out = 1 / 2 * (bond_out + torch.transpose(bond_out, 1, 2)) + + return atom_out, bond_out \ No newline at end of file diff --git a/src/model/graph_encoder/model.py b/src/model/graph_encoder/model.py new file mode 100644 index 0000000000000000000000000000000000000000..b460316deca3199ee9e320feb370a40a4d39b965 --- /dev/null +++ b/src/model/graph_encoder/model.py @@ -0,0 +1,205 @@ +# Copyright 2024 the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import global_add_pool, global_max_pool +from torch_geometric.nn import MessagePassing +import json + +class GraphCLIP(nn.Module): + def __init__( + self, + graph_num_layer, + graph_hidden_size, + dropout, + model_config, + ): + super().__init__() + self.model_config = model_config + self.hidden_size = graph_hidden_size + self.molecule_encoder = GNNEncoder(num_layer=graph_num_layer, hidden_size=graph_hidden_size, drop_ratio=dropout) + self.molecule_projection = ProjectionHead(embedding_dim=graph_hidden_size, projection_dim=graph_hidden_size, dropout=dropout) + + def forward(self, x, edge_index, edge_attr, batch): + molecule_features = self.molecule_encoder(x, edge_index, edge_attr, batch) + molecule_embeddings = self.molecule_projection(molecule_features) + molecule_embeddings = molecule_embeddings / molecule_embeddings.norm(dim=-1, keepdim=True) + return molecule_embeddings + + def save_pretrained(self, output_dir): + """ + Save the molecule encoder, projection models, and model_config to the output directory. + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + molecule_path = os.path.join(output_dir, 'model.pt') + proj_path = molecule_path.replace('model', 'model_proj') + config_path = os.path.join(output_dir, 'model_config.json') + + torch.save(self.molecule_encoder.state_dict(), molecule_path) + torch.save(self.molecule_projection.state_dict(), proj_path) + + # Save model_config to JSON file + with open(config_path, 'w') as f: + json.dump(self.model_config, f, indent=2) + + def disable_grads(self): + """ + Disable gradients for all parameters in the model. + """ + for param in self.parameters(): + param.requires_grad = False + + def init_model(self, model_path, verbose=True): + molecule_path = os.path.join(model_path, 'model.pt') + proj_path = molecule_path.replace('model', 'model_proj') + if os.path.exists(molecule_path): + self.molecule_encoder.load_state_dict(torch.load(molecule_path, map_location='cpu', weights_only=False)) + else: + raise FileNotFoundError(f"Molecule encoder file not found: {molecule_path}") + + if os.path.exists(proj_path): + self.molecule_projection.load_state_dict(torch.load(proj_path, map_location='cpu', weights_only=False)) + else: + raise FileNotFoundError(f"Molecule projection file not found: {proj_path}") + + if verbose: + print('GraphCLIP Models initialized.') + print('Molecule model:\n', self.molecule_encoder) + print('Molecule projection:\n', self.molecule_projection) + +class GNNEncoder(nn.Module): + def __init__(self, num_layer, hidden_size, drop_ratio): + + super(GNNEncoder, self).__init__() + + self.num_layer = num_layer + self.drop_ratio = drop_ratio + if self.num_layer < 2: + raise ValueError("Number of GNN layers must be greater than 1.") + + self.atom_encoder = nn.Embedding(118, hidden_size) + + ### set the initial virtual node embedding to 0. + self.virtualnode_embedding = nn.Embedding(1, hidden_size) + nn.init.constant_(self.virtualnode_embedding.weight.data, 0) + + ### List of GNNs + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + self.mlp_virtualnode_list = nn.ModuleList() + + for layer in range(num_layer): + self.convs.append(GINConv(hidden_size, drop_ratio)) + self.norms.append(nn.LayerNorm(hidden_size, elementwise_affine=True)) + if layer < num_layer - 1: + self.mlp_virtualnode_list.append(nn.Sequential(nn.Linear(hidden_size, 4*hidden_size), nn.LayerNorm(4*hidden_size), nn.GELU(), nn.Dropout(drop_ratio), \ + nn.Linear(4*hidden_size, hidden_size))) + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + self.apply(_basic_init) + + def forward(self, x, edge_index, edge_attr, batch): + + ### virtual node embeddings for graphs + virtualnode_embedding = self.virtualnode_embedding(torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device)) + + h_list = [self.atom_encoder(x)] + + for layer in range(self.num_layer): + ### add message from virtual nodes to graph nodes + h_list[layer] = h_list[layer] + virtualnode_embedding[batch] + + ### Message passing among graph nodes + h = self.convs[layer](h_list[layer], edge_index, edge_attr) + h = self.norms[layer](h) + + if layer < self.num_layer - 1: + h = F.gelu(h) + h = F.dropout(h, self.drop_ratio, training = self.training) + + h = h + h_list[layer] + h_list.append(h) + + if layer < self.num_layer - 1: + ### add message from graph nodes to virtual nodes + virtual_pool = global_max_pool(h_list[layer], batch) + virtualnode_embedding = virtualnode_embedding + F.dropout(self.mlp_virtualnode_list[layer](virtual_pool), self.drop_ratio, training = self.training) + + h_node = h_list[-1] + h_graph = global_add_pool(h_node, batch) + + return h_graph + +class GINConv(MessagePassing): + def __init__(self, hidden_size, drop_ratio): + ''' + hidden_size (int) + ''' + super(GINConv, self).__init__(aggr = "add") + + self.mlp = nn.Sequential(nn.Linear(hidden_size, 4*hidden_size), nn.LayerNorm(4*hidden_size), nn.GELU(), nn.Dropout(drop_ratio), nn.Linear(4*hidden_size, hidden_size)) + self.eps = torch.nn.Parameter(torch.Tensor([0])) + self.bond_encoder = nn.Embedding(5, hidden_size) + + def forward(self, x, edge_index, edge_attr): + edge_embedding = self.bond_encoder(edge_attr) + out = self.mlp((1 + self.eps) *x + self.propagate(edge_index, x=x, edge_attr=edge_embedding)) + return out + + def message(self, x_j, edge_attr): + return F.gelu(x_j + edge_attr) + + def update(self, aggr_out): + return aggr_out + +class ProjectionHead(nn.Module): + def __init__( + self, + embedding_dim, + projection_dim, + dropout, + act_layer=nn.GELU, + hidden_features=None, + bias=True + ): + super().__init__() + projection_dim = projection_dim or embedding_dim + hidden_features = hidden_features or embedding_dim + linear_layer = nn.Linear + + self.fc1 = linear_layer(embedding_dim, hidden_features, bias=bias) + self.norm1 = nn.LayerNorm(hidden_features) + self.act = act_layer() + self.drop1 = nn.Dropout(dropout) + self.fc2 = linear_layer(hidden_features, projection_dim, bias=bias) + + def forward(self, x): + x = self.fc1(x) + x = self.norm1(x) + x = self.act(x) + x = self.drop1(x) + x = self.fc2(x) + return x diff --git a/src/model/graph_predictor/model.py b/src/model/graph_predictor/model.py new file mode 100644 index 0000000000000000000000000000000000000000..06c660d2b89bdd5d2564ce9ad917b7533c8bf621 --- /dev/null +++ b/src/model/graph_predictor/model.py @@ -0,0 +1,423 @@ +# Copyright 2024 Llamole Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch_geometric.nn import global_add_pool, global_mean_pool, global_max_pool +from torch_geometric.nn import MessagePassing + +import os +import json +from collections import defaultdict +from rdchiral.main import rdchiralRunText +import pandas as pd + +import numpy as np +from rdkit import Chem +from rdkit.Chem import AllChem + +def modulate(x, shift, scale): + return x * (1 + scale) + shift + +class GraphPredictor(nn.Module): + def __init__( + self, + num_layer, + hidden_size, + drop_ratio, + out_dim, + model_config, + label_to_template, + available=None, + ): + super().__init__() + self.model_config = model_config + self.text_input_size = model_config.get("text_input_size", 768) + self.available = available + self.text_drop = drop_ratio + + # Process label_to_template + if isinstance(label_to_template, pd.DataFrame): + self.label_to_template = dict( + zip( + label_to_template["rule_label"], + label_to_template["retro_templates"], + ) + ) + else: + self.label_to_template = label_to_template + + self.predictor = GNNRetrosynthsizer( + num_layer, hidden_size, self.text_input_size, drop_ratio, out_dim + ) + self.neural_cost = None + + def save_pretrained(self, output_dir): + """ + Save the predictor model, model_config, label_to_template, and available to the output directory. + """ + if not os.path.exists(output_dir): + os.makedirs(output_dir) + + model_path = os.path.join(output_dir, "model.pt") + config_path = os.path.join(output_dir, "model_config.json") + label_to_template_path = os.path.join(output_dir, "label_to_template.csv.gz") + available_path = os.path.join(output_dir, "available.csv.gz") + + # Save predictor model + torch.save(self.predictor.state_dict(), model_path) + + # Save cost model + if self.neural_cost is not None: + neural_cost_path = os.path.join(output_dir, "cost_model.pt") + torch.save(self.neural_cost.state_dict(), neural_cost_path) + + # Save model_config to JSON file + with open(config_path, "w") as f: + json.dump(self.model_config, f, indent=2) + + # Save label_to_template to gzipped CSV file + label_to_template_df = pd.DataFrame( + list(self.label_to_template.items()), + columns=["rule_label", "retro_templates"], + ) + label_to_template_df.to_csv( + label_to_template_path, index=False, compression="gzip" + ) + + # Save available to gzipped CSV file if it's not None + if self.available is not None: + if isinstance(self.available, list): + available_df = pd.DataFrame(self.available, columns=["smiles"]) + elif isinstance(self.available, pd.DataFrame): + available_df = self.available + else: + raise ValueError( + "available must be either a list of SMILES strings or a pandas DataFrame" + ) + + available_df.to_csv(available_path, index=False, compression="gzip") + + def disable_grads(self): + """ + Disable gradients for all parameters in the model. + """ + for param in self.predictor.parameters(): + param.requires_grad = False + + def init_neural_cost(self, model_path, verbose=False): + model_file = os.path.join(model_path, "cost_model.pt") + if os.path.exists(model_file): + self.neural_cost = CostMLP( + n_layers=1, fp_dim=2048, latent_dim=128, dropout_rate=0.1 + ) + self.neural_cost.load_state_dict(torch.load(model_file, map_location="cpu", weights_only=True)) + else: + raise FileNotFoundError(f"Model file not found: {model_file}") + + for param in self.neural_cost.parameters(): + param.requires_grad = False + + if verbose: + print("Neural Cost Model initialized.") + print("Neural Cost Model:\n", self.neural_cost) + + def init_model(self, model_path, verbose=False): + model_file = os.path.join(model_path, "model.pt") + if os.path.exists(model_file): + self.predictor.load_state_dict(torch.load(model_file, map_location="cpu", weights_only=True)) + else: + raise FileNotFoundError(f"Model file not found: {model_file}") + + if verbose: + print("GraphPredictor Model initialized.") + print("Predictor model:\n", self.predictor) + + def forward(self, x, edge_index, edge_attr, batch, c): + return self.predictor(x, edge_index, edge_attr, batch, c) + + def estimate_cost(self, smiles): + if self.neural_cost is None: + raise ValueError("Cost model is not initialized.") + + fp = self.neural_cost.smiles_to_fp(smiles) + dtype, device = ( + next(self.neural_cost.parameters()).dtype, + next(self.neural_cost.parameters()).device, + ) + fp = torch.tensor(fp, dtype=dtype, device=device).view(1, -1) + return self.neural_cost(fp).squeeze().item() + + + def sample_templates(self, product_graph, c, product_smiles, topk=10): + + x, edge_index, edge_attr = ( + product_graph.x, + product_graph.edge_index, + product_graph.edge_attr, + ) + batch = torch.zeros(x.size(0), dtype=torch.long, device=x.device) + + # Sample from main predictor + logits_main = self.predictor(x, edge_index, edge_attr, batch, c) + logits_drop = self.predictor(x, edge_index, edge_attr, batch, None) + probs_main = logits_main + logits_drop * self.text_drop + probs_main = F.softmax(logits_main, dim=1) + + topk_probs, topk_indices = torch.topk(probs_main, k=topk, dim=1) + + # Convert to numpy for easier handling + topk_probs = topk_probs.float().cpu().numpy() + topk_indices = topk_indices.cpu().numpy() + + # Get the corresponding templates + templates = [] + for idx in topk_indices[0]: + templates.append(self.label_to_template[idx]) + + reactants_d = defaultdict(list) + for prob, template in zip(topk_probs[0], templates): + try: + outcomes = rdchiralRunText(template, product_smiles) + if len(outcomes) == 0: + continue + outcomes = sorted(outcomes) + for reactant in outcomes: + if "." in reactant: + str_list = sorted(reactant.strip().split(".")) + reactants_d[".".join(str_list)].append( + (prob.item() / len(outcomes), template) + ) + else: + reactants_d[reactant].append( + (prob.item() / len(outcomes), template) + ) + except Exception: + pass + + if len(reactants_d) == 0: + return [], [], [] + + def merge(reactant_d): + ret = [] + for reactant, l in reactant_d.items(): + ss, ts = zip(*l) + ret.append((reactant, sum(ss), list(ts)[0])) + reactants, scores, templates = zip( + *sorted(ret, key=lambda item: item[1], reverse=True) + ) + return list(reactants), list(scores), list(templates) + + reactants, scores, templates = merge(reactants_d) + + total = sum(scores) + scores = [s / total for s in scores] + + return reactants, scores, templates + +class GNNRetrosynthsizer(torch.nn.Module): + def __init__(self, num_layer, hidden_size, text_input_size, drop_ratio, out_dim): + super(GNNRetrosynthsizer, self).__init__() + self.num_layer = num_layer + self.drop_ratio = drop_ratio + self.text_input_size = text_input_size + if self.num_layer < 2: + raise ValueError("Number of GNN layers must be greater than 1.") + + self.atom_encoder = nn.Embedding(118, hidden_size) + + ### set the initial virtual node embedding to 0. + self.virtualnode_embedding = nn.Embedding(1, hidden_size) + nn.init.constant_(self.virtualnode_embedding.weight.data, 0) + + ### List of GNNs + self.convs = nn.ModuleList() + self.norms = nn.ModuleList() + self.adapters = nn.ModuleList() + self.mlp_virtualnode_list = nn.ModuleList() + + self.text_dropping = nn.Embedding(1, text_input_size) + for layer in range(num_layer): + self.convs.append(GINConv(hidden_size, drop_ratio)) + self.adapters.append( + nn.Sequential( + nn.SiLU(), + nn.Linear(self.text_input_size, 3 * hidden_size, bias=True), + ) + ) + self.norms.append(nn.LayerNorm(hidden_size, elementwise_affine=False)) + if layer < num_layer - 1: + self.mlp_virtualnode_list.append( + nn.Sequential( + nn.Linear(hidden_size, 4 * hidden_size), + nn.LayerNorm(4 * hidden_size), + nn.GELU(), + nn.Dropout(drop_ratio), + nn.Linear(4 * hidden_size, hidden_size), + ) + ) + + self.decoder = nn.Sequential( + nn.Linear(hidden_size, 4 * hidden_size), + nn.LayerNorm(4 * hidden_size), + nn.GELU(), + nn.Dropout(drop_ratio), + nn.Linear(4 * hidden_size, out_dim), + ) + + def initialize_weights(self): + # Initialize transformer layers: + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + + def _constant_init(module, i): + if isinstance(module, nn.Linear): + nn.init.constant_(module.weight, i) + if module.bias is not None: + nn.init.constant_(module.bias, i) + + self.apply(_basic_init) + + for adapter in self.adapters: + _constant_init(adapter[-1], 0) + + def disable_grads(self): + """ + Disable gradients for all parameters in the model. + """ + for param in self.parameters(): + param.requires_grad = False + + def forward(self, x, edge_index, edge_attr, batch, c): + + ### virtual node embeddings for graphs + virtualnode_embedding = self.virtualnode_embedding( + torch.zeros(batch[-1].item() + 1).to(edge_index.dtype).to(edge_index.device) + ) + + h_list = [self.atom_encoder(x)] + + if c is None: + c = self.text_dropping.weight.expand(batch.max().item() + 1, -1) + + for layer in range(self.num_layer): + ### add message from virtual nodes to graph nodes + h_list[layer] = h_list[layer] + virtualnode_embedding[batch] + + shift, scale, gate = self.adapters[layer](c).chunk(3, dim=1) + # B = batch.max().item() + 1 + node_counts = torch.bincount(batch, minlength=batch.max().item() + 1) + shift = shift.repeat_interleave(node_counts, dim=0) + scale = scale.repeat_interleave(node_counts, dim=0) + gate = gate.repeat_interleave(node_counts, dim=0) + + ### Message passing among graph nodes + h = self.convs[layer](h_list[layer], edge_index, edge_attr) + # h = self.norms[layer](h) + h = modulate(self.norms[layer](h), shift, scale) + + if layer < self.num_layer - 1: + h = F.gelu(h) + h = F.dropout(h, self.drop_ratio, training=self.training) + + h = gate * h + h_list[layer] + h_list.append(h) + + if layer < self.num_layer - 1: + ### add message from graph nodes to virtual nodes + virtual_pool = global_max_pool(h_list[layer], batch) + virtualnode_embedding = virtualnode_embedding + F.dropout( + self.mlp_virtualnode_list[layer](virtual_pool), + self.drop_ratio, + training=self.training, + ) + + h_node = h_list[-1] + h_graph = global_add_pool(h_node, batch) + output = self.decoder(h_graph) + return output + + +class CostMLP(nn.Module): + def __init__(self, n_layers, fp_dim, latent_dim, dropout_rate): + super(CostMLP, self).__init__() + self.n_layers = n_layers + self.fp_dim = fp_dim + self.latent_dim = latent_dim + self.dropout_rate = dropout_rate + + layers = [] + layers.append(nn.Linear(fp_dim, latent_dim)) + layers.append(nn.ReLU()) + layers.append(nn.Dropout(self.dropout_rate)) + for _ in range(self.n_layers - 1): + layers.append(nn.Linear(latent_dim, latent_dim)) + layers.append(nn.ReLU()) + layers.append(nn.Dropout(self.dropout_rate)) + layers.append(nn.Linear(latent_dim, 1)) + self.layers = nn.Sequential(*layers) + + def smiles_to_fp(self, smiles: str, fp_dim: int = 2048) -> np.ndarray: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + raise ValueError(f"Invalid SMILES string: {smiles}") + + fp = AllChem.GetMorganFingerprintAsBitVect(mol, 2, nBits=fp_dim) + onbits = list(fp.GetOnBits()) + arr = np.zeros(fp.GetNumBits(), dtype=bool) + arr[onbits] = 1 + + return arr + + def forward(self, fps): + x = fps + x = self.layers(x) + x = torch.log(1 + torch.exp(x)) + return x + + +class GINConv(MessagePassing): + def __init__(self, hidden_size, drop_ratio): + """ + hidden_size (int) + """ + super(GINConv, self).__init__(aggr="add") + + self.mlp = nn.Sequential( + nn.Linear(hidden_size, 4 * hidden_size), + nn.LayerNorm(4 * hidden_size), + nn.GELU(), + nn.Dropout(drop_ratio), + nn.Linear(4 * hidden_size, hidden_size), + ) + self.eps = torch.nn.Parameter(torch.Tensor([0])) + self.bond_encoder = nn.Embedding(5, hidden_size) + + def forward(self, x, edge_index, edge_attr): + edge_embedding = self.bond_encoder(edge_attr) + out = self.mlp( + (1 + self.eps) * x + + self.propagate(edge_index, x=x, edge_attr=edge_embedding) + ) + return out + + def message(self, x_j, edge_attr): + return F.gelu(x_j + edge_attr) + + def update(self, aggr_out): + return aggr_out diff --git a/src/model/loader.py b/src/model/loader.py new file mode 100644 index 0000000000000000000000000000000000000000..de706911430e5ac2da3b549fd13407da3b026d83 --- /dev/null +++ b/src/model/loader.py @@ -0,0 +1,364 @@ +# Copyright 2024 the LlamaFactory team and the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, Optional, TypedDict +from pathlib import Path +import json +import pandas as pd +import os + +import torch +from transformers import ( + AutoConfig, + AutoModelForCausalLM, + AutoModelForVision2Seq, + AutoProcessor, + AutoTokenizer, +) +from trl import AutoModelForCausalLMWithValueHead +from huggingface_hub import hf_hub_download + +from ..extras.logging import get_logger +from ..extras.misc import ( + count_parameters, + skip_check_imports, + try_download_model_from_ms, +) +from .adapter import init_adapter +from .model_utils.misc import register_autoclass +from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model + +from .model_utils.valuehead import load_valuehead_params +from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model + +from .graph_decoder.diffusion_model import GraphDiT +from .graph_encoder.model import GraphCLIP +from .graph_predictor.model import GraphPredictor + +if TYPE_CHECKING: + from transformers import ( + PretrainedConfig, + PreTrainedModel, + PreTrainedTokenizer, + ProcessorMixin, + ) + + from ..hparams import FinetuningArguments, ModelArguments + + +logger = get_logger(__name__) + + +def download_from_hf(repo_id, filename, local_dir): + os.makedirs(local_dir, exist_ok=True) + return hf_hub_download(repo_id=repo_id, filename=filename, local_dir=local_dir) + + +class TokenizerModule(TypedDict): + tokenizer: "PreTrainedTokenizer" + processor: Optional["ProcessorMixin"] + + +def _get_init_kwargs(model_args: "ModelArguments") -> Dict[str, Any]: + r""" + Gets arguments to load config/tokenizer/model. + + Note: including inplace operation of model_args. + """ + skip_check_imports() + model_args.model_name_or_path = try_download_model_from_ms(model_args) + return { + "trust_remote_code": True, + "cache_dir": model_args.cache_dir, + "revision": model_args.model_revision, + "token": model_args.hf_hub_token, + } + +def load_tokenizer(model_args: "ModelArguments", generate_mode=False) -> "TokenizerModule": + r""" + Loads pretrained tokenizer or a pre-saved tokenizer. + + Note: including inplace operation of model_args. + """ + init_kwargs = _get_init_kwargs(model_args) + + padding_size = 'left' if generate_mode else 'right' + try: + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + split_special_tokens=model_args.split_special_tokens, + padding_side=padding_size, + **init_kwargs, + ) + except ValueError: # try the fast one + tokenizer = AutoTokenizer.from_pretrained( + model_args.model_name_or_path, + use_fast=True, + padding_side=padding_size, + **init_kwargs, + ) + + if model_args.new_special_tokens is not None: + num_added_tokens = tokenizer.add_special_tokens( + dict(additional_special_tokens=model_args.new_special_tokens), + replace_additional_special_tokens=False, + ) + logger.info( + "Add {} to special tokens.".format(",".join(model_args.new_special_tokens)) + ) + + if num_added_tokens > 0 and not model_args.resize_vocab: + model_args.resize_vocab = True + logger.warning( + "New tokens have been added, changed `resize_vocab` to True." + ) + + patch_tokenizer(tokenizer) + + if model_args.new_special_tokens is not None: + token_id_dict = {} + for elem in model_args.new_special_tokens: + if isinstance(elem, str) and len(elem) != 0: + elem_token_ids = tokenizer.encode(elem, add_special_tokens=False) + token_id_dict[elem] = elem_token_ids + logger.info(f"Dictionary of added tokens and their IDs: {token_id_dict}") + + return {"tokenizer": tokenizer, "processor": None} + +def load_config(model_args: "ModelArguments") -> "PretrainedConfig": + r""" + Loads model config. + """ + init_kwargs = _get_init_kwargs(model_args) + return AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) + + +def load_language_model( + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool = False, + add_valuehead: bool = False, +) -> "PreTrainedModel": + r""" + Loads pretrained model. + """ + init_kwargs = _get_init_kwargs(model_args) + config = load_config(model_args) + patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) + + model = None + lazy_load = False + + # if model is None and not lazy_load: + init_kwargs["config"] = config + init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path + model = AutoModelForCausalLM.from_pretrained(**init_kwargs) + if not lazy_load: + patch_model(model, tokenizer, model_args, is_trainable, add_valuehead) + register_autoclass(config, model, tokenizer) + + model = init_adapter(config, model, model_args, finetuning_args, is_trainable) + + if add_valuehead: + model = AutoModelForCausalLMWithValueHead.from_pretrained(model) + patch_valuehead_model(model) + + if model_args.adapter_name_or_path is not None: + vhead_path = model_args.adapter_name_or_path[-1] + else: + vhead_path = model_args.model_name_or_path + + vhead_params = load_valuehead_params(vhead_path, model_args) + if vhead_params is not None: + model.load_state_dict(vhead_params, strict=False) + logger.info("Loaded valuehead from checkpoint: {}".format(vhead_path)) + + if not is_trainable: + model.requires_grad_(False) + for param in model.parameters(): + if ( + param.data.dtype == torch.float32 + and model_args.compute_dtype != torch.float32 + ): + param.data = param.data.to(model_args.compute_dtype) + + model.eval() + else: + model.train() + + trainable_params, all_param = count_parameters(model) + if is_trainable: + param_stats = "lm trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + ) + else: + param_stats = "lm all params: {:,}".format(all_param) + + logger.info(param_stats) + + if model_args.print_param_status: + for name, param in model.named_parameters(): + print( + "name: {}, dtype: {}, device: {}, trainable: {}".format( + name, param.dtype, param.device, param.requires_grad + ) + ) + + return model + +def load_graph_decoder(model_args: "ModelArguments", path: str, device: str): + path = Path(path) + config_path = path / "config.yaml" + + if not config_path.exists(): + logger.info(f"Config not found in {path}. Downloading from Hugging Face.") + repo_id = "liuganghuggingface/Llamole-Pretrained-GraphDiT" + config_path = download_from_hf(repo_id, "config.yaml", path) + download_from_hf(repo_id, "data.meta.json", path) + download_from_hf(repo_id, "model.pt", path) + + data_info_path = path / "data.meta.json" + + model = GraphDiT( + model_config_path=config_path, + data_info_path=data_info_path, + model_dtype=model_args.compute_dtype, + ) + model.init_model(path) + if model_args.disable_graph_model_gradient: + model.disable_grads() + model.to(device) + + for param in model.parameters(): + if param.dtype == torch.float32 and model_args.compute_dtype != torch.float32: + param.data = param.data.to(model_args.compute_dtype) + + trainable_params, all_param = count_parameters(model) + param_stats = "Graph DiT trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + ) + logger.info(param_stats) + + if model_args.print_param_status: + for name, param in model.named_parameters(): + logger.info( + f"name: {name}, dtype: {param.dtype}, device: {param.device}, trainable: {param.requires_grad}" + ) + + return model + +def load_graph_predictor(model_args: "ModelArguments", path: str, device: str): + path = Path(path) + config_path = path / "config.json" + + if not config_path.exists(): + logger.info(f"Config not found in {path}. Downloading from Hugging Face.") + repo_id = "liuganghuggingface/Llamole-Pretrained-GNNPredictor" + config_path = download_from_hf(repo_id, "config.json", path) + download_from_hf(repo_id, "model.pt", path) + download_from_hf(repo_id, "cost_model.pt", path) + download_from_hf(repo_id, "label_to_template.csv.gz", path) + download_from_hf(repo_id, "available.csv.gz", path) + + with open(config_path, "r") as f: + config = json.load(f) + + label_to_template_path = path / "label_to_template.csv.gz" + label_to_template_df = pd.read_csv(label_to_template_path, compression='gzip') + label_to_template = dict(zip(label_to_template_df['rule_label'], label_to_template_df['retro_templates'])) + + available_path = path / "available.csv.gz" + available = pd.read_csv(available_path, compression='gzip') + + model = GraphPredictor( + num_layer=config["num_layer"], + hidden_size=config["hidden_size"], + drop_ratio=config["drop_ratio"], + out_dim=config["num_task"], + model_config=config, + label_to_template=label_to_template, + available=available, + ) + + model.init_model(path) + model.init_neural_cost(path) + + if model_args.disable_graph_model_gradient: + model.disable_grads() + + model.to(device) + + for param in model.parameters(): + if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32: + param.data = param.data.to(model_args.compute_dtype) + + trainable_params, all_param = count_parameters(model) + param_stats = "Graph Predictor trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + ) + logger.info(param_stats) + + if model_args.print_param_status: + for name, param in model.named_parameters(): + logger.info( + f"name: {name}, dtype: {param.dtype}, device: {param.device}, trainable: {param.requires_grad}" + ) + + return model + +def load_graph_encoder(model_args: "ModelArguments", path: str, device: str): + path = Path(path) + config_path = path / "config.json" + + if not config_path.exists(): + logger.info(f"Config not found in {path}. Downloading from Hugging Face.") + repo_id = "liuganghuggingface/Llamole-Pretrained-GraphEncoder" + config_path = download_from_hf(repo_id, "config.json", path) + download_from_hf(repo_id, "model.pt", path) + download_from_hf(repo_id, "model_proj.pt", path) + + with open(config_path, "r") as f: + config = json.load(f) + + model = GraphCLIP( + graph_num_layer=config["num_layer"], + graph_hidden_size=config["hidden_size"], + dropout=config["drop_ratio"], + model_config=config, + ) + model.init_model(path, verbose=False) + if model_args.disable_graph_model_gradient: + model.disable_grads() + model.to(device) + + for param in model.parameters(): + if param.data.dtype == torch.float32 and model_args.compute_dtype != torch.float32: + param.data = param.data.to(model_args.compute_dtype) + + trainable_params, all_param = count_parameters(model) + param_stats = "Graph CLIP Encoder trainable params: {:,} || all params: {:,} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + ) + + logger.info(param_stats) + + if model_args.print_param_status: + for name, param in model.named_parameters(): + logger.info( + f"name: {name}, dtype: {param.dtype}, device: {param.device}, trainable: {param.requires_grad}" + ) + + return model \ No newline at end of file diff --git a/src/model/model_utils/__init__.py b/src/model/model_utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/model/model_utils/attention.py b/src/model/model_utils/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..4bed7e2193515ec2f8908e2bfb145895ffc78560 --- /dev/null +++ b/src/model/model_utils/attention.py @@ -0,0 +1,82 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def configure_attn_implementation( + config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool +) -> None: + if getattr(config, "model_type", None) == "gemma2" and is_trainable: # gemma2 adopts soft-cap attention + if model_args.flash_attn == "auto": + logger.warning("Gemma-2 models should use eager attention in training, change `flash_attn` to disabled.") + model_args.flash_attn = "disabled" + elif model_args.flash_attn != "disabled": + logger.warning( + "Gemma-2 models should use eager attention in training, but you set `flash_attn: {}`. " + "Will proceed at your own risk.".format(model_args.flash_attn) + ) + + if model_args.flash_attn == "auto": + return + + elif model_args.flash_attn == "disabled": + requested_attn_implementation = "eager" + + elif model_args.flash_attn == "sdpa": + if not is_torch_sdpa_available(): + logger.warning("torch>=2.1.1 is required for SDPA attention.") + return + + requested_attn_implementation = "sdpa" + elif model_args.flash_attn == "fa2": + if not is_flash_attn_2_available(): + logger.warning("FlashAttention-2 is not installed.") + return + + requested_attn_implementation = "flash_attention_2" + else: + raise NotImplementedError("Unknown attention type: {}".format(model_args.flash_attn)) + + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + setattr(config, "attn_implementation", requested_attn_implementation) + else: + setattr(config, "_attn_implementation", requested_attn_implementation) + + +def print_attn_implementation(config: "PretrainedConfig") -> None: + if getattr(config, "model_type", None) == "internlm2": # special case for custom models + attn_implementation = getattr(config, "attn_implementation", None) + else: + attn_implementation = getattr(config, "_attn_implementation", None) + + if attn_implementation == "flash_attention_2": + logger.info("Using FlashAttention-2 for faster training and inference.") + elif attn_implementation == "sdpa": + logger.info("Using torch SDPA for faster training and inference.") + else: + logger.info("Using vanilla attention implementation.") diff --git a/src/model/model_utils/checkpointing.py b/src/model/model_utils/checkpointing.py new file mode 100644 index 0000000000000000000000000000000000000000..f4f3d8a5139c5061e6d47bbf822372831599feba --- /dev/null +++ b/src/model/model_utils/checkpointing.py @@ -0,0 +1,109 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers and PEFT library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/modeling_utils.py +# https://github.com/huggingface/peft/blob/v0.10.0/src/peft/utils/other.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import inspect +from functools import partial +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple + +import torch + +from ...extras.constants import LAYERNORM_NAMES +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def _gradient_checkpointing_enable( + self: "PreTrainedModel", gradient_checkpointing_kwargs: Optional[Dict[str, Any]] = None +) -> None: + r""" + Activates gradient checkpointing for the current model. + + Modification of the original method to enable gradient checkpointing for block-wise optimizer. + """ + from torch.utils.checkpoint import checkpoint + + if not self.supports_gradient_checkpointing: + raise ValueError("{} does not support gradient checkpointing.".format(self.__class__.__name__)) + + if gradient_checkpointing_kwargs is None: + gradient_checkpointing_kwargs = {"use_reentrant": True} + + gradient_checkpointing_func = partial(checkpoint, **gradient_checkpointing_kwargs) + + def custom_gradient_checkpointing_func(func, *args, **kwargs): + module: "torch.nn.Module" = func.__self__ + + if any(param.requires_grad for param in module.parameters()): + for arg in args: + if torch.is_tensor(arg) and torch.is_floating_point(arg): + arg.requires_grad_(True) + + return gradient_checkpointing_func(func, *args, **kwargs) + + if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format + self.apply(partial(self._set_gradient_checkpointing, value=True)) + self.enable_input_require_grads() + logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") + else: # have already enabled input require gradients + self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=custom_gradient_checkpointing_func) + + +def _fp32_forward_post_hook( + module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor" +) -> "torch.Tensor": + return output.to(torch.float32) + + +def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArguments") -> None: + r""" + Includes: + (1) cast the layernorm in fp32 + (2) make output embedding layer require grads + (3) add the upcasting of the lm_head in fp32 + """ + if model_args.upcast_layernorm: + logger.info("Upcasting layernorm weights in float32.") + for name, param in model.named_parameters(): + if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): + param.data = param.data.to(torch.float32) + + if not model_args.disable_gradient_checkpointing: + if not getattr(model, "supports_gradient_checkpointing", False): + logger.warning("Current model does not support gradient checkpointing.") + else: + # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) + # According to: https://github.com/huggingface/transformers/issues/28339 + model.gradient_checkpointing_enable = MethodType(_gradient_checkpointing_enable, model) + model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) + setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled + logger.info("Gradient checkpointing enabled.") + + if model_args.upcast_lmhead_output: + output_layer = model.get_output_embeddings() + if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: + logger.info("Upcasting lm_head outputs in float32.") + output_layer.register_forward_hook(_fp32_forward_post_hook) diff --git a/src/model/model_utils/embedding.py b/src/model/model_utils/embedding.py new file mode 100644 index 0000000000000000000000000000000000000000..3ff79828271f16d112733ba62798d0c02dc67d4a --- /dev/null +++ b/src/model/model_utils/embedding.py @@ -0,0 +1,72 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from contextlib import nullcontext +from typing import TYPE_CHECKING + +import torch +from transformers.integrations import is_deepspeed_zero3_enabled + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, PreTrainedTokenizer + + +logger = get_logger(__name__) + + +def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: + embedding_dim = embed_weight.size(1) + avg_weight = embed_weight[:-num_new_tokens].mean(dim=0, keepdim=True) + noise_weight = torch.empty_like(embed_weight[-num_new_tokens:]) + noise_weight.normal_(mean=0, std=(1.0 / math.sqrt(embedding_dim))) + embed_weight[-num_new_tokens:] = avg_weight + noise_weight + + +def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer") -> None: + r""" + Resize token embeddings. + """ + if is_deepspeed_zero3_enabled(): + import deepspeed # type: ignore + + params = [model.get_input_embeddings().weight] + if model.get_output_embeddings() is not None and not model.config.tie_word_embeddings: + params.append(model.get_output_embeddings().weight) + + context_maybe_zero3 = deepspeed.zero.GatheredParameters(params, modifier_rank=0) + else: + context_maybe_zero3 = nullcontext() + + with context_maybe_zero3: + current_embedding_size = model.get_input_embeddings().weight.size(0) + + if len(tokenizer) > current_embedding_size: + if getattr(model, "quantization_method", None): + raise ValueError("Cannot resize embedding layers of a quantized model.") + + if not isinstance(model.get_output_embeddings(), torch.nn.Linear): + raise ValueError("Current model does not support resizing embedding layers.") + + model.resize_token_embeddings(len(tokenizer), pad_to_multiple_of=64) + with context_maybe_zero3: + new_embedding_size = model.get_input_embeddings().weight.size(0) + num_new_tokens = new_embedding_size - current_embedding_size + _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) + _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) + + logger.info("Resized token embeddings from {} to {}.".format(current_embedding_size, new_embedding_size)) diff --git a/src/model/model_utils/longlora.py b/src/model/model_utils/longlora.py new file mode 100644 index 0000000000000000000000000000000000000000..5d52c475ff44bbe8755305bc3dd2e45bfe2c7532 --- /dev/null +++ b/src/model/model_utils/longlora.py @@ -0,0 +1,346 @@ +# Copyright 2024 EleutherAI, HuggingFace Inc., Yukang Chen, and the LlamaFactory team. +# +# This code is based on the EleutherAI's GPT-NeoX and the HuggingFace's Transformers libraries. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +# This code is also inspired by the original LongLoRA implementation. +# https://github.com/dvlab-research/LongLoRA/blob/main/llama_attn_replace.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING, Optional, Tuple + +import torch +import torch.nn as nn +from transformers.models.llama.modeling_llama import ( + Cache, + LlamaAttention, + LlamaFlashAttention2, + LlamaSdpaAttention, + apply_rotary_pos_emb, + repeat_kv, +) +from transformers.utils import logging +from transformers.utils.versions import require_version + +from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +transformers_logger = logging.get_logger(__name__) + + +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_attention_forward( + self: "LlamaAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + bsz, q_len, _ = hidden_states.size() + + query_states: "torch.Tensor" = self.q_proj(hidden_states) + key_states: "torch.Tensor" = self.k_proj(hidden_states) + value_states: "torch.Tensor" = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: "torch.Tensor") -> "torch.Tensor": + state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + + attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(self.head_dim) + + if attention_mask is not None: # no matter the length, we just slice it + causal_mask = attention_mask[:, :, :, : key_states.shape[-2]] + attn_weights = attn_weights + causal_mask + + # upcast attention to fp32 + attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype) + attn_weights = nn.functional.dropout(attn_weights, p=self.attention_dropout, training=self.training) + attn_output = torch.matmul(attn_weights, value_states) # (bsz, :, seq_len, :) or (bsz * n_group, :, groupsz, :) + attn_output = attn_output.transpose(1, 2).contiguous() + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), + ), + dim=2, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_flash_attention_2_forward( + self: "LlamaFlashAttention2", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + # LlamaFlashAttention2 attention does not support output_attentions + output_attentions = False + + bsz, q_len, _ = hidden_states.size() + + query_states: "torch.Tensor" = self.q_proj(hidden_states) + key_states: "torch.Tensor" = self.k_proj(hidden_states) + value_states: "torch.Tensor" = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + past_key_value = getattr(self, "past_key_value", past_key_value) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + # FlashAttention requires the input to have the shape (bsz, seq_len, n_heads, head_dim) + query_states = query_states.transpose(1, 2) + key_states = key_states.transpose(1, 2) + value_states = value_states.transpose(1, 2) + + dropout_rate = self.attention_dropout if self.training else 0.0 + + input_dtype = query_states.dtype + if input_dtype == torch.float32: + if torch.is_autocast_enabled(): + target_dtype = torch.get_autocast_gpu_dtype() + elif hasattr(self.config, "_pre_quantization_dtype"): + target_dtype = self.config._pre_quantization_dtype + else: + target_dtype = self.q_proj.weight.dtype + + transformers_logger.warning_once("The input hidden states seems to be silently casted in float32.") + query_states = query_states.to(target_dtype) + key_states = key_states.to(target_dtype) + value_states = value_states.to(target_dtype) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: "torch.Tensor") -> "torch.Tensor": + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :groupsz].repeat(num_groups, 1) + + attn_output: "torch.Tensor" = self._flash_attention_forward( + query_states, key_states, value_states, attention_mask, query_states.size(1), dropout=dropout_rate + ) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), + ), + dim=2, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size).contiguous() + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attn_weights = None + + return attn_output, attn_weights, past_key_value + + +# Modified from: +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/models/llama/modeling_llama.py +def llama_sdpa_attention_forward( + self: "LlamaSdpaAttention", + hidden_states: torch.Tensor, + attention_mask: Optional[torch.Tensor] = None, + position_ids: Optional[torch.LongTensor] = None, + past_key_value: Optional["Cache"] = None, + output_attentions: bool = False, + cache_position: Optional[torch.LongTensor] = None, + **kwargs, +) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + if output_attentions: + transformers_logger.warning_once( + "SDPA does not support `output_attentions=True`. Falling back to the vanilla attention" + ) + return llama_attention_forward( + self, + hidden_states=hidden_states, + attention_mask=attention_mask, + position_ids=position_ids, + past_key_value=past_key_value, + output_attentions=output_attentions, + cache_position=cache_position, + **kwargs, + ) + + bsz, q_len, _ = hidden_states.size() + + query_states: "torch.Tensor" = self.q_proj(hidden_states) + key_states: "torch.Tensor" = self.k_proj(hidden_states) + value_states: "torch.Tensor" = self.v_proj(hidden_states) + + query_states = query_states.view(bsz, q_len, self.num_heads, self.head_dim).transpose(1, 2) + key_states = key_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + value_states = value_states.view(bsz, q_len, self.num_key_value_heads, self.head_dim).transpose(1, 2) + + cos, sin = self.rotary_emb(value_states, position_ids) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, cos, sin) + + if past_key_value is not None: + cache_kwargs = {"sin": sin, "cos": cos, "cache_position": cache_position} + key_states, value_states = past_key_value.update(key_states, value_states, self.layer_idx, cache_kwargs) + + key_states = repeat_kv(key_states, self.num_key_value_groups) + value_states = repeat_kv(value_states, self.num_key_value_groups) + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift + groupsz = int(q_len * getattr(self.config, "group_size_ratio")) + assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz) + num_groups = q_len // groupsz + + def shift(state: "torch.Tensor") -> "torch.Tensor": + state = state.transpose(1, 2) # output: (bsz, seq_len, n_heads, head_dim) + state = torch.cat( + (state[:, :, : self.num_heads // 2], state[:, :, self.num_heads // 2 :].roll(-groupsz // 2, dims=1)), + dim=2, + ) + return state.reshape(bsz * num_groups, groupsz, self.num_heads, self.head_dim).transpose(1, 2) + + query_states, key_states, value_states = shift(query_states), shift(key_states), shift(value_states) + if attention_mask is not None: + attention_mask = attention_mask[:, :, :groupsz, :groupsz].repeat(num_groups, 1, 1, 1) + + causal_mask = attention_mask + if attention_mask is not None: + causal_mask = causal_mask[:, :, :, : key_states.shape[-2]] + + if query_states.device.type == "cuda" and causal_mask is not None: # avoid pytorch bug + query_states = query_states.contiguous() + key_states = key_states.contiguous() + value_states = value_states.contiguous() + + is_causal = True if causal_mask is None and q_len > 1 else False + attn_output = torch.nn.functional.scaled_dot_product_attention( + query_states, + key_states, + value_states, + attn_mask=causal_mask, + dropout_p=self.attention_dropout if self.training else 0.0, + is_causal=is_causal, + ) + attn_output = attn_output.transpose(1, 2).contiguous() + + if getattr(self.config, "group_size_ratio", None) and self.training: # shift back + attn_output.reshape(bsz, q_len, self.num_heads, self.head_dim) + attn_output = torch.cat( + ( + attn_output[:, :, : self.num_heads // 2], + attn_output[:, :, self.num_heads // 2 :].roll(groupsz // 2, dims=1), + ), + dim=2, + ) + + attn_output = attn_output.reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + return attn_output, None, past_key_value + + +def _apply_llama_patch() -> None: + require_version("transformers>=4.41.2,<=4.42.3", "To fix: pip install transformers>=4.41.2,<=4.42.3") + LlamaAttention.forward = llama_attention_forward + LlamaFlashAttention2.forward = llama_flash_attention_2_forward + LlamaSdpaAttention.forward = llama_sdpa_attention_forward + + +def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if not is_trainable or not model_args.shift_attn: + return + + logger = get_logger(__name__) + + if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: + setattr(config, "group_size_ratio", 0.25) + _apply_llama_patch() + logger.info("Using shift short attention with group_size_ratio=1/4.") + else: + logger.warning("Current model does not support shift short attention.") diff --git a/src/model/model_utils/misc.py b/src/model/model_utils/misc.py new file mode 100644 index 0000000000000000000000000000000000000000..a2812228ea70ee5ddce513591c9f6c9cfb91ff36 --- /dev/null +++ b/src/model/model_utils/misc.py @@ -0,0 +1,88 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer + + +logger = get_logger(__name__) + + +def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: + r""" + Finds all available modules to apply lora or galore. + """ + forbidden_modules = {"lm_head"} + + if model.config.model_type == "chatglm": + forbidden_modules.add("output_layer") + elif model.config.model_type == "internlm2": + forbidden_modules.add("output") + elif model.config.model_type in ["llava", "paligemma"]: + forbidden_modules.add("multi_modal_projector") + + if freeze_vision_tower: + forbidden_modules.add("vision_tower") + + module_names = set() + for name, module in model.named_modules(): + if any(forbidden_module in name for forbidden_module in forbidden_modules): + continue + + if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__: + module_names.add(name.split(".")[-1]) + + logger.info("Found linear modules: {}".format(",".join(module_names))) + return list(module_names) + + +def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], num_layer_trainable: int) -> List[str]: + r""" + Finds the modules in the expanded blocks to apply lora. + """ + num_layers = getattr(model.config, "num_hidden_layers", None) + if not num_layers: + raise ValueError("Model was not supported.") + + if num_layers % num_layer_trainable != 0: + raise ValueError( + "`num_layers` {} should be divisible by `num_layer_trainable` {}.".format(num_layers, num_layer_trainable) + ) + + stride = num_layers // num_layer_trainable + trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) + trainable_layers = [".{:d}.".format(idx) for idx in trainable_layer_ids] + module_names = [] + for name, _ in model.named_modules(): + if any(target_module in name for target_module in target_modules) and any( + trainable_layer in name for trainable_layer in trainable_layers + ): + module_names.append(name) + + logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) + return module_names + + +def register_autoclass(config: "PretrainedConfig", model: "PreTrainedModel", tokenizer: "PreTrainedTokenizer"): + if "AutoConfig" in getattr(config, "auto_map", {}): + config.__class__.register_for_auto_class() + if "AutoModelForCausalLM" in getattr(config, "auto_map", {}): + model.__class__.register_for_auto_class() + if "AutoTokenizer" in tokenizer.init_kwargs.get("auto_map", {}): + tokenizer.__class__.register_for_auto_class() diff --git a/src/model/model_utils/mod.py b/src/model/model_utils/mod.py new file mode 100644 index 0000000000000000000000000000000000000000..ec73af0059c4542f304e08ad451b6572b60e2aa7 --- /dev/null +++ b/src/model/model_utils/mod.py @@ -0,0 +1,42 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING + +from ...extras.constants import MOD_SUPPORTED_MODELS + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel": + from MoD import AutoMoDModelForCausalLM + + return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs) + + +def convert_pretrained_model_to_mod( + model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments" +) -> "PreTrainedModel": + from MoD import apply_mod_to_hf + + if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS: + raise ValueError("Current model is not supported by mixture-of-depth.") + + model = apply_mod_to_hf(model) + model = model.to(model_args.compute_dtype) + return model diff --git a/src/model/model_utils/moe.py b/src/model/model_utils/moe.py new file mode 100644 index 0000000000000000000000000000000000000000..5c7473aae18cce84837fb8290e3a013e63da51e1 --- /dev/null +++ b/src/model/model_utils/moe.py @@ -0,0 +1,80 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Sequence + +import torch +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.utils.versions import require_version + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +def _set_z3_leaf_modules(model: "PreTrainedModel", leaf_modules: Sequence["torch.nn.Module"]) -> None: + require_version("deepspeed>=0.13.0", "To fix: pip install deepspeed>=0.13.0") + from deepspeed.utils import set_z3_leaf_modules # type: ignore + + set_z3_leaf_modules(model, leaf_modules) + + +def add_z3_leaf_module(model: "PreTrainedModel") -> None: + r""" + Sets module as a leaf module to skip partitioning in deepspeed zero3. + """ + if not is_deepspeed_zero3_enabled(): + return + + if getattr(model.config, "model_type", None) == "dbrx": + from transformers.models.dbrx.modeling_dbrx import DbrxFFN + + _set_z3_leaf_modules(model, [DbrxFFN]) + + if getattr(model.config, "model_type", None) == "jamba": + from transformers.models.jamba.modeling_jamba import JambaSparseMoeBlock + + _set_z3_leaf_modules(model, [JambaSparseMoeBlock]) + + if getattr(model.config, "model_type", None) == "jetmoe": + from transformers.models.jetmoe.modeling_jetmoe import JetMoeMoA, JetMoeMoE + + _set_z3_leaf_modules(model, [JetMoeMoA, JetMoeMoE]) + + if getattr(model.config, "model_type", None) == "mixtral": + from transformers.models.mixtral.modeling_mixtral import MixtralSparseMoeBlock + + _set_z3_leaf_modules(model, [MixtralSparseMoeBlock]) + + if getattr(model.config, "model_type", None) == "qwen2moe": + from transformers.models.qwen2_moe.modeling_qwen2_moe import Qwen2MoeSparseMoeBlock + + _set_z3_leaf_modules(model, [Qwen2MoeSparseMoeBlock]) + + +def configure_moe(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if model_args.moe_aux_loss_coef is not None: + if getattr(config, "model_type", None) in ["jamba", "mixtral", "qwen2_moe"]: + setattr(config, "router_aux_loss_coef", model_args.moe_aux_loss_coef) + + elif getattr(config, "model_type", None) == "deepseek": + setattr(config, "aux_loss_alpha", model_args.moe_aux_loss_coef) + + elif getattr(config, "model_type", None) == "jetmoe": + setattr(config, "aux_loss_coef", model_args.moe_aux_loss_coef) + + if getattr(config, "model_type", None) in ["dbrx", "jamba", "jetmoe", "mixtral", "qwen2_moe"]: + setattr(config, "output_router_logits", is_trainable) diff --git a/src/model/model_utils/quantization.py b/src/model/model_utils/quantization.py new file mode 100644 index 0000000000000000000000000000000000000000..317646e0173e3090fce2e7e65614f7c1ffa72297 --- /dev/null +++ b/src/model/model_utils/quantization.py @@ -0,0 +1,205 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the HuggingFace's Transformers and Optimum library. +# https://github.com/huggingface/transformers/blob/v4.41.0/src/transformers/utils/quantization_config.py +# https://github.com/huggingface/optimum/blob/v1.20.0/optimum/gptq/data.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import random +from enum import Enum, unique +from typing import TYPE_CHECKING, Any, Dict, List + +import torch +from datasets import load_dataset +from transformers import BitsAndBytesConfig, EetqConfig, GPTQConfig, HqqConfig +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled +from transformers.utils.versions import require_version + +from ...extras.constants import FILEEXT2TYPE +from ...extras.logging import get_logger +from ...extras.misc import get_current_device + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedTokenizer + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +@unique +class QuantizationMethod(str, Enum): + r""" + Borrowed from `transformers.utils.quantization_config.QuantizationMethod`. + """ + + BITS_AND_BYTES = "bitsandbytes" + GPTQ = "gptq" + AWQ = "awq" + AQLM = "aqlm" + QUANTO = "quanto" + EETQ = "eetq" + HQQ = "hqq" + + +def _get_quantization_dataset(tokenizer: "PreTrainedTokenizer", model_args: "ModelArguments") -> List[Dict[str, Any]]: + r""" + Prepares the tokenized dataset to perform AutoGPTQ. Do not use tensor output for JSON serialization. + """ + if os.path.isfile(model_args.export_quantization_dataset): + data_path = FILEEXT2TYPE.get(model_args.export_quantization_dataset.split(".")[-1], None) + data_files = model_args.export_quantization_dataset + else: + data_path = model_args.export_quantization_dataset + data_files = None + + dataset = load_dataset( + path=data_path, + data_files=data_files, + split="train", + cache_dir=model_args.cache_dir, + token=model_args.hf_hub_token, + ) + + samples = [] + maxlen = model_args.export_quantization_maxlen + for _ in range(model_args.export_quantization_nsamples): + n_try = 0 + while True: + if n_try > 100: + raise ValueError("Cannot find satisfying example, considering decrease `export_quantization_maxlen`.") + + sample_idx = random.randint(0, len(dataset) - 1) + sample: Dict[str, "torch.Tensor"] = tokenizer(dataset[sample_idx]["text"], return_tensors="pt") + n_try += 1 + if sample["input_ids"].size(1) > maxlen: + break # TODO: fix large maxlen + + word_idx = random.randint(0, sample["input_ids"].size(1) - maxlen - 1) + input_ids = sample["input_ids"][:, word_idx : word_idx + maxlen] + attention_mask = sample["attention_mask"][:, word_idx : word_idx + maxlen] + samples.append({"input_ids": input_ids.tolist(), "attention_mask": attention_mask.tolist()}) + + return samples + + +def configure_quantization( + config: "PretrainedConfig", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + init_kwargs: Dict[str, Any], +) -> None: + r""" + Priority: PTQ-quantized (train/infer) > AutoGPTQ (export) > On-the-fly quantization (train/infer) + """ + if getattr(config, "quantization_config", None): # ptq + if model_args.quantization_bit is not None: + logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.") + + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") + + quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None) + quant_method = quantization_config.get("quant_method", "") + + if quant_method == QuantizationMethod.GPTQ: + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + quantization_config.pop("disable_exllama", None) # remove deprecated args + quantization_config["use_exllama"] = False # disable exllama + + if quant_method == QuantizationMethod.AWQ: + require_version("autoawq", "To fix: pip install autoawq") + + if quant_method == QuantizationMethod.AQLM: + require_version("transformers>=4.39.0", "To fix: pip install transformers>=4.39.0") + require_version("aqlm>=1.1.0", "To fix: pip install aqlm[gpu]>=1.1.0") + quantization_config["bits"] = 2 + + quant_bits = quantization_config.get("bits", "?") + logger.info("Loading {}-bit {}-quantized model.".format(quant_bits, quant_method.upper())) + + elif model_args.export_quantization_bit is not None: # auto-gptq + if model_args.export_quantization_bit not in [8, 4, 3, 2]: + raise ValueError("AutoGPTQ only accepts 2/3/4/8-bit quantization.") + + require_version("optimum>=1.17.0", "To fix: pip install optimum>=1.17.0") + require_version("auto_gptq>=0.5.0", "To fix: pip install auto_gptq>=0.5.0") + from accelerate.utils import get_max_memory + + if getattr(config, "model_type", None) == "chatglm": + raise ValueError("ChatGLM model is not supported yet.") + + init_kwargs["quantization_config"] = GPTQConfig( + bits=model_args.export_quantization_bit, + dataset=_get_quantization_dataset(tokenizer, model_args), + ) + init_kwargs["device_map"] = "auto" + init_kwargs["max_memory"] = get_max_memory() + logger.info("Quantizing model to {} bit with AutoGPTQ.".format(model_args.export_quantization_bit)) + + elif model_args.quantization_bit is not None: # on-the-fly + if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: + if model_args.quantization_bit == 8: + require_version("bitsandbytes>=0.37.0", "To fix: pip install bitsandbytes>=0.37.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig(load_in_8bit=True) + elif model_args.quantization_bit == 4: + require_version("bitsandbytes>=0.39.0", "To fix: pip install bitsandbytes>=0.39.0") + init_kwargs["quantization_config"] = BitsAndBytesConfig( + load_in_4bit=True, + bnb_4bit_compute_dtype=model_args.compute_dtype, + bnb_4bit_use_double_quant=model_args.double_quantization, + bnb_4bit_quant_type=model_args.quantization_type, + bnb_4bit_quant_storage=model_args.compute_dtype, # crucial for fsdp+qlora + ) + else: + raise ValueError("Bitsandbytes only accepts 4-bit or 8-bit quantization.") + + # Do not assign device map if: + # 1. deepspeed zero3 or fsdp (train) + # 2. auto quantization device map (inference) + if is_deepspeed_zero3_enabled() or is_fsdp_enabled() or model_args.quantization_device_map == "auto": + if model_args.quantization_bit != 4: + raise ValueError("Only 4-bit quantized model can use fsdp+qlora or auto device map.") + + require_version("bitsandbytes>=0.43.0", "To fix: pip install bitsandbytes>=0.43.0") + else: + init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference + + logger.info("Quantizing model to {} bit with bitsandbytes.".format(model_args.quantization_bit)) + elif model_args.quantization_method == QuantizationMethod.HQQ.value: + if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]: + raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.") + + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("HQQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") + + require_version("hqq", "To fix: pip install hqq") + init_kwargs["quantization_config"] = HqqConfig( + nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 + ) # use ATEN kernel (axis=0) for performance + logger.info("Quantizing model to {} bit with HQQ.".format(model_args.quantization_bit)) + elif model_args.quantization_method == QuantizationMethod.EETQ.value: + if model_args.quantization_bit != 8: + raise ValueError("EETQ only accepts 8-bit quantization.") + + if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): + raise ValueError("EETQ quantization is incompatible with DeepSpeed ZeRO-3 or FSDP.") + + require_version("eetq", "To fix: pip install eetq") + init_kwargs["quantization_config"] = EetqConfig() + logger.info("Quantizing model to {} bit with EETQ.".format(model_args.quantization_bit)) diff --git a/src/model/model_utils/rope.py b/src/model/model_utils/rope.py new file mode 100644 index 0000000000000000000000000000000000000000..4373ee19d6a511f00842d8a41aa0a607d3f5dfdb --- /dev/null +++ b/src/model/model_utils/rope.py @@ -0,0 +1,65 @@ +# Copyright 2024 LMSYS and the LlamaFactory team. +# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li +# +# This code is inspired by the LMSYS's FastChat library. +# https://github.com/lm-sys/FastChat/blob/v0.2.30/fastchat/train/train.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import math +from typing import TYPE_CHECKING + +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PretrainedConfig + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: + if model_args.rope_scaling is None: + return + + if not hasattr(config, "rope_scaling"): + logger.warning("Current model does not support RoPE scaling.") + return + + if model_args.model_max_length is not None: + if is_trainable and model_args.rope_scaling == "dynamic": + logger.warning( + "Dynamic NTK scaling may not work well with fine-tuning. " + "See: https://github.com/huggingface/transformers/pull/24653" + ) + + current_max_length = getattr(config, "max_position_embeddings", None) + if current_max_length and model_args.model_max_length > current_max_length: + logger.info( + "Enlarge max model length from {} to {}.".format(current_max_length, model_args.model_max_length) + ) + setattr(config, "max_position_embeddings", model_args.model_max_length) + scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) + else: + logger.warning("Input length is smaller than max length. Consider increase input length.") + scaling_factor = 1.0 + else: + scaling_factor = 2.0 + + setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) + logger.info( + "Using {} scaling strategy and setting scaling factor to {}".format(model_args.rope_scaling, scaling_factor) + ) diff --git a/src/model/model_utils/unsloth.py b/src/model/model_utils/unsloth.py new file mode 100644 index 0000000000000000000000000000000000000000..9cfaec61c5cffda325402178e3b473b344b0ddc9 --- /dev/null +++ b/src/model/model_utils/unsloth.py @@ -0,0 +1,102 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Any, Dict, Optional + +from ...extras.logging import get_logger +from ...extras.misc import get_current_device + + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def _get_unsloth_kwargs( + config: "PretrainedConfig", model_name_or_path: str, model_args: "ModelArguments" +) -> Dict[str, Any]: + return { + "model_name": model_name_or_path, + "max_seq_length": model_args.model_max_length or 4096, + "dtype": model_args.compute_dtype, + "load_in_4bit": model_args.quantization_bit == 4, + "token": model_args.hf_hub_token, + "device_map": {"": get_current_device()}, + "rope_scaling": getattr(config, "rope_scaling", None), + "fix_tokenizer": False, + "trust_remote_code": True, + "use_gradient_checkpointing": "unsloth", + } + + +def load_unsloth_pretrained_model( + config: "PretrainedConfig", model_args: "ModelArguments" +) -> Optional["PreTrainedModel"]: + r""" + Optionally loads pretrained model with unsloth. Used in training. + """ + from unsloth import FastLanguageModel + + unsloth_kwargs = _get_unsloth_kwargs(config, model_args.model_name_or_path, model_args) + try: + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + model = None + model_args.use_unsloth = False + + return model + + +def get_unsloth_peft_model( + model: "PreTrainedModel", model_args: "ModelArguments", peft_kwargs: Dict[str, Any] +) -> "PreTrainedModel": + r""" + Gets the peft model for the pretrained model with unsloth. Used in training. + """ + from unsloth import FastLanguageModel + + unsloth_peft_kwargs = { + "model": model, + "max_seq_length": model_args.model_max_length, + "use_gradient_checkpointing": "unsloth", + } + return FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) + + +def load_unsloth_peft_model( + config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool +) -> "PreTrainedModel": + r""" + Loads peft model with unsloth. Used in both training and inference. + """ + from unsloth import FastLanguageModel + + unsloth_kwargs = _get_unsloth_kwargs(config, model_args.adapter_name_or_path[0], model_args) + try: + if not is_trainable: + unsloth_kwargs["use_gradient_checkpointing"] = False + + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + raise ValueError("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + + if not is_trainable: + FastLanguageModel.for_inference(model) + + return model diff --git a/src/model/model_utils/valuehead.py b/src/model/model_utils/valuehead.py new file mode 100644 index 0000000000000000000000000000000000000000..9ab3d45ac0db1fec264aab8632f9078d9bdd2472 --- /dev/null +++ b/src/model/model_utils/valuehead.py @@ -0,0 +1,73 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Dict + +import torch +from transformers.utils import cached_file + +from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ...extras.logging import get_logger + + +if TYPE_CHECKING: + from transformers import PreTrainedModel + + from ...hparams import ModelArguments + + +logger = get_logger(__name__) + + +def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: + r""" + Loads value head parameters from Hugging Face Hub or local disk. + + Returns: dict with keys `v_head.summary.weight` and `v_head.summary.bias`. + """ + kwargs = {"path_or_repo_id": path_or_repo_id, "cache_dir": model_args.cache_dir, "token": model_args.hf_hub_token} + err_text = "" + + try: + from safetensors import safe_open + + vhead_file = cached_file(filename=V_HEAD_SAFE_WEIGHTS_NAME, **kwargs) + with safe_open(vhead_file, framework="pt", device="cpu") as f: + return {key: f.get_tensor(key) for key in f.keys()} + except Exception as err: + err_text = str(err) + + try: + vhead_file = cached_file(filename=V_HEAD_WEIGHTS_NAME, **kwargs) + return torch.load(vhead_file, map_location="cpu") + except Exception as err: + err_text = str(err) + + logger.info("Provided path ({}) does not contain value head weights: {}.".format(path_or_repo_id, err_text)) + logger.info("Ignore the above message if you are not resuming the training of a value head model.") + return None + + +def prepare_valuehead_model(model: "PreTrainedModel") -> None: + if getattr(model.config, "model_type", None) == "llava": + setattr(model, "lm_head", model.language_model.get_output_embeddings()) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + + if getattr(model.config, "model_type", None) == "chatglm": + setattr(model, "lm_head", model.transformer.output_layer) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) + + if getattr(model.config, "model_type", None) == "internlm2": + setattr(model, "lm_head", model.output) + setattr(model, "_keys_to_ignore_on_save", ["lm_head.weight"]) diff --git a/src/model/modeling_llamole.py b/src/model/modeling_llamole.py new file mode 100644 index 0000000000000000000000000000000000000000..cb86078a897ce54d5be202afbca5a0a577f60a51 --- /dev/null +++ b/src/model/modeling_llamole.py @@ -0,0 +1,1287 @@ +# Copyright 2024 the Llamole Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +import torch.nn as nn +import torch.nn.functional as F +from transformers import PreTrainedModel, PreTrainedTokenizerBase +from transformers.utils import ModelOutput +from transformers.generation.utils import LogitsProcessorList, GenerationConfig +from huggingface_hub import snapshot_download + +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +import os +import json +import time +from dataclasses import dataclass + +from typing import Union, Tuple, Optional +from .loader import load_language_model, load_tokenizer +from .loader import load_graph_decoder, load_graph_predictor, load_graph_encoder +from ..extras.constants import NO_LABEL_INDEX, IGNORE_INDEX, BOND_INDEX + +from .planner import molstar +from rdkit import Chem +from torch_geometric.data import Data +from torch_geometric.data import Batch as PyGBatch +from torch_geometric.utils import remove_isolated_nodes + +# Save configuration +def convert_to_dict(obj): + if isinstance(obj, (int, float, str, bool, type(None))): + return obj + elif isinstance(obj, (list, tuple)): + return [convert_to_dict(item) for item in obj] + elif isinstance(obj, dict): + return {k: convert_to_dict(v) for k, v in obj.items()} + elif hasattr(obj, "__dict__"): + return { + k: convert_to_dict(v) + for k, v in obj.__dict__.items() + if not k.startswith("_") + } + else: + return str(obj) # Convert any other objects to string + +@dataclass +class GraphLMOutput(ModelOutput): + loss: Optional[torch.FloatTensor] = None + logits: torch.FloatTensor = None + last_hidden_state: Optional[torch.FloatTensor] = None + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None + hidden_states: Optional[Tuple[torch.FloatTensor]] = None + attentions: Optional[Tuple[torch.FloatTensor]] = None + additional_log_info: Optional[Dict[str, float]] = None + + +class GraphLLMForCausalMLM(PreTrainedModel): + def __init__( + self, + model_args, + finetuning_args, + data_args, + language_model, + graph_decoder, + graph_predictor, + graph_encoder, + token_id_dict, + tokenizer, + ): + super().__init__(language_model.config) + self.language_model = language_model + self.graph_decoder = graph_decoder + self.graph_predictor = graph_predictor + self.graph_encoder = graph_encoder + + self.token_id_dict = token_id_dict + self.num_body_tokens = data_args.learned_query_size + + self.loss_weight_lm = finetuning_args.loss_weight_lm + self.loss_weight_design = finetuning_args.loss_weight_design + self.loss_weight_retro = finetuning_args.loss_weight_retro + + self.model_args = model_args + self.finetuning_args = finetuning_args + self.data_args = data_args + self.tokenizer = tokenizer + + # Initialize weights and apply final processing + self.post_init() + + @classmethod + def from_pretrained( + cls, + tokenizer: PreTrainedTokenizerBase, + model_args, + data_args, + training_args, + finetuning_args, + load_adapter=False, + add_valuehead=False, + ): + if load_adapter: + if model_args.adapter_name_or_path is None: + raise ValueError("Please specify the adapter_name_or_path when load_adapter is True.") + + if len(model_args.adapter_name_or_path) != 1: + raise ValueError("Only one adapter is supported at a time.") + + adapter_path = model_args.adapter_name_or_path[0] + + if not os.path.exists(os.path.join(adapter_path, "adapter_config.json")): + # Download from HuggingFace + adapter_name = os.path.basename(adapter_path) + valid_adapters = [ + "Llama-3.1-8B-Instruct-Adapter", + "Qwen2-7B-Instruct-Adapter", + "Mistral-7B-Instruct-v0.3-Adapter" + ] + + if adapter_name not in valid_adapters: + raise ValueError(f"Invalid adapter name. Supported adapters are: {', '.join(valid_adapters)}") + + repo_id = f"liuganghuggingface/Llamole-{adapter_name}" + print(f"Downloading adapter {adapter_name} from HuggingFace repo: {repo_id}") + + try: + # Download all files including subfolders to the adapter_path + snapshot_download( + repo_id=repo_id, + local_dir=adapter_path, + local_dir_use_symlinks=False, + ignore_patterns=["*.md", "*.txt"] # Optionally ignore certain file types + ) + + print(f"Successfully downloaded all adapter files to {adapter_path}") + except Exception as e: + raise RuntimeError(f"Failed to download adapter files: {str(e)}") + + + language_model = load_language_model( + tokenizer, + model_args, + finetuning_args, + training_args.do_train, + add_valuehead, + ) + + device = next(language_model.parameters()).device + + graph_decoder = load_graph_decoder( + model_args, + path=model_args.graph_decoder_path, + device=device, + ) + + graph_predictor = load_graph_predictor( + model_args, + path=model_args.graph_predictor_path, + device=device, + ) + + graph_encoder = load_graph_encoder( + model_args, + path=model_args.graph_encoder_path, + device=device, + ) + + if ( + getattr(language_model, "is_quantized", False) + and not training_args.do_train + ): + setattr( + language_model, "_hf_peft_config_loaded", True + ) # hack here: make model compatible with prediction + + token_id_dict = {} + for elem in model_args.new_special_tokens: + if isinstance(elem, str) and len(elem) != 0: + elem_token_ids = tokenizer.encode(elem, add_special_tokens=False) + token_id_dict[elem] = elem_token_ids[0] + + model = cls( + model_args=model_args, + finetuning_args=finetuning_args, + data_args=data_args, + language_model=language_model, + graph_decoder=graph_decoder, + graph_predictor=graph_predictor, + graph_encoder=graph_encoder, + token_id_dict=token_id_dict, + tokenizer=tokenizer, + ) + + graph_to_lm_connector = nn.Sequential( + nn.Linear(graph_encoder.hidden_size, language_model.config.hidden_size), + nn.SiLU(), + ) + + # Language Model to Graph Decoder connector + lm_to_graph_decoder = nn.Sequential( + nn.Linear(language_model.config.hidden_size, graph_decoder.text_input_size), + nn.SiLU(), + ) + + # Language Model to Graph Predictor connector + lm_to_graph_predictor = nn.Sequential( + nn.Linear( + language_model.config.hidden_size, graph_predictor.text_input_size + ), + nn.SiLU(), + ) + + for param in graph_to_lm_connector.parameters(): + if ( + param.dtype == torch.float32 + and model_args.compute_dtype != torch.float32 + ): + param.data = param.data.to(model_args.compute_dtype) + + for param in lm_to_graph_decoder.parameters(): + if ( + param.dtype == torch.float32 + and model_args.compute_dtype != torch.float32 + ): + param.data = param.data.to(model_args.compute_dtype) + + for param in lm_to_graph_predictor.parameters(): + if ( + param.dtype == torch.float32 + and model_args.compute_dtype != torch.float32 + ): + param.data = param.data.to(model_args.compute_dtype) + + # Check if connector path is provided and load if available + if load_adapter: + if ( + hasattr(model_args, "graph_lm_connector_path") + and model_args.graph_lm_connector_path + ): + connector_path = model_args.graph_lm_connector_path + + graph_to_lm_connector.load_state_dict( + torch.load( + os.path.join(connector_path, "graph_to_lm_connector.pt"), + map_location=device, + weights_only=True, + ) + ) + + lm_to_graph_decoder.load_state_dict( + torch.load( + os.path.join(connector_path, "lm_to_graph_decoder.pt"), + map_location=device, + weights_only=True, + ) + ) + + lm_to_graph_predictor.load_state_dict( + torch.load( + os.path.join(connector_path, "lm_to_graph_predictor.pt"), + map_location=device, + weights_only=True, + ) + ) + else: + raise ValueError(f"Connector should be automatically downloaded with the adapter. Please manually download to the path {connector_path}") + + model.graph_to_lm_connector = graph_to_lm_connector + model.lm_to_graph_decoder = lm_to_graph_decoder + model.lm_to_graph_predictor = lm_to_graph_predictor + model.graph_to_lm_connector.to(device) + model.lm_to_graph_decoder.to(device) + model.lm_to_graph_predictor.to(device) + + return model + + def to(self, device): + super().to(device) + self.language_model.to(device) + self.graph_decoder.to(device) + self.graph_predictor.to(device) + self.graph_encoder.to(device) + self.graph_to_lm_connector.to(device) + self.lm_to_graph_decoder.to(device) + self.lm_to_graph_predictor.to(device) + return self + + def forward( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + molecule_graphs: Optional[PyGBatch] = None, + molecule_properties: Optional[torch.FloatTensor] = None, + design_graphs: Optional[PyGBatch] = None, + retro_labels: Optional[torch.LongTensor] = None, + retro_product_graphs: Optional[PyGBatch] = None, + past_key_values: Optional[Tuple[Tuple[torch.FloatTensor]]] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = True, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, GraphLMOutput]: + return_dict = ( + return_dict if return_dict is not None else self.config.use_return_dict + ) + + mol_token_id = self.token_id_dict[""] + design_start_token_id = self.token_id_dict[""] + retro_start_token_id = self.token_id_dict[""] + + # PeftModelForCausalLM -> LlamaForCausalLM -> LlamaModel + base_llm = self.language_model.model.model + inputs_embeds = base_llm.embed_tokens(input_ids) + mol_positions = (input_ids == mol_token_id).nonzero() + + mol_embeds = self.graph_encoder( + molecule_graphs.x, + molecule_graphs.edge_index, + molecule_graphs.edge_attr, + molecule_graphs.batch, + ) + mol_embeds = self.graph_to_lm_connector(mol_embeds) + + assert ( + mol_positions.shape[0] == mol_embeds.shape[0] + ), f"Number of molecule tokens ({mol_positions.shape[0]}) does not match number of molecule embeddings ({mol_embeds.shape[0]})" + + inputs_embeds[mol_positions[:, 0], mol_positions[:, 1]] = mol_embeds.to( + inputs_embeds.dtype + ) + + lm_outputs = self.language_model( + input_ids=None, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + labels=labels, + ) + + lm_loss = lm_outputs.loss + lm_hidden_states = lm_outputs.hidden_states[-1] + + design_loss = 0 + if design_graphs is not None: + design_start_positions = (input_ids == design_start_token_id).nonzero() + design_body_start = design_start_positions[:, 1] + 1 + design_body_indices = design_body_start.unsqueeze(1) + torch.arange( + self.num_body_tokens, device=input_ids.device + ) + design_hidden = lm_hidden_states[ + design_start_positions[:, 0].unsqueeze(1), design_body_indices[:, 1] + ].mean(dim=1) + if design_start_positions.numel() > 0: + design_hidden = self.lm_to_graph_decoder(design_hidden) + design_loss = self.graph_decoder( + design_graphs.x, + design_graphs.edge_index, + design_graphs.edge_attr, + design_graphs.batch, + molecule_properties, + design_hidden, + NO_LABEL_INDEX, + ) + + # Process retro labels + retro_loss = 0 + if retro_labels is not None: + # Get retro start positions for valid retro labels: (batch, step) + retro_start_positions = (input_ids == retro_start_token_id).nonzero() + retro_labels = retro_labels[retro_labels != IGNORE_INDEX] + valid_retro_mask = retro_labels != NO_LABEL_INDEX + retro_start_positions = retro_start_positions[valid_retro_mask] + retro_labels = retro_labels[valid_retro_mask] + + if len(retro_labels) > 0: + # Get the query hidden states for each retro prediction + retro_body_start = retro_start_positions[:, 1] + 1 + retro_body_indices = retro_body_start.unsqueeze(1) + torch.arange( + self.num_body_tokens, device=input_ids.device + ) + retro_hidden = lm_hidden_states[ + retro_start_positions[:, 0].unsqueeze(1), retro_body_indices + ].mean(dim=1) + + # Prepare graph inputs + retro_product_graphs = retro_product_graphs[ + valid_retro_mask.nonzero().view(-1) + ] + retro_product_graphs = PyGBatch.from_data_list(retro_product_graphs) + + # Transform hidden states and make predictions + retro_hidden = self.lm_to_graph_predictor(retro_hidden) + retro_pred = self.graph_predictor( + retro_product_graphs.x, + retro_product_graphs.edge_index, + retro_product_graphs.edge_attr, + retro_product_graphs.batch, + retro_hidden, + ) + retro_loss = F.cross_entropy( + retro_pred, + retro_labels, + ) + + total_loss = ( + self.loss_weight_lm * lm_loss + + self.loss_weight_design * retro_loss + + self.loss_weight_retro * retro_loss + ) + + if not return_dict: + output = (lm_outputs.logits,) + lm_outputs[1:] + return ((total_loss,) + output) if total_loss is not None else output + + return GraphLMOutput( + loss=total_loss, + logits=lm_outputs.logits, + past_key_values=lm_outputs.past_key_values, + hidden_states=lm_outputs.hidden_states, + attentions=lm_outputs.attentions, + ) + + def save_pretrained( + self, + save_directory: Union[str, os.PathLike], + is_main_process: bool = True, + state_dict: Optional[dict] = None, + save_function: Callable = torch.save, + push_to_hub: bool = False, + max_shard_size: Union[int, str] = "5GB", + safe_serialization: bool = True, + variant: Optional[str] = None, + token: Optional[Union[str, bool]] = None, + save_peft_format: bool = True, + save_graph_modules: bool = False, + **kwargs, + ): + """ + Save the model and its configuration file to a directory. + """ + if os.path.isfile(save_directory): + raise ValueError( + f"Provided path ({save_directory}) should be a directory, not a file" + ) + + os.makedirs(save_directory, exist_ok=True) + + # Save language model + language_model_path = os.path.join(save_directory) + self.language_model.save_pretrained( + language_model_path, + is_main_process=is_main_process, + state_dict=state_dict, + save_function=save_function, + push_to_hub=False, # set to false + max_shard_size=max_shard_size, + safe_serialization=safe_serialization, + variant=variant, + token=token, + save_peft_format=save_peft_format, + ) + + if save_graph_modules: + # Save graph models + graph_models = { + "graph_decoder": self.graph_decoder, + "graph_predictor": self.graph_predictor, + "graph_encoder": self.graph_encoder, + } + for name, model in graph_models.items(): + model_path = os.path.join(save_directory, name) + model.save_pretrained(model_path) + + # Save additional components + additional_components = { + "graph_to_lm_connector": self.graph_to_lm_connector, + "lm_to_graph_decoder": self.lm_to_graph_decoder, + "lm_to_graph_predictor": self.lm_to_graph_predictor, + } + connector_path = os.path.join(save_directory, "connector") + for name, component in additional_components.items(): + os.makedirs(connector_path, exist_ok=True) + component_path = os.path.join(connector_path, f"{name}.pt") + torch.save(component.state_dict(), component_path) + + config_dict = { + "model_args": convert_to_dict(self.model_args), + "finetuning_args": convert_to_dict(self.finetuning_args), + "data_args": convert_to_dict(self.data_args), + "token_id_dict": self.token_id_dict, + "num_body_tokens": self.num_body_tokens, + "loss_weight_lm": self.loss_weight_lm, + "loss_weight_design": self.loss_weight_design, + "loss_weight_retro": self.loss_weight_retro, + } + + config_path = os.path.join(save_directory, "graphllm_config.json") + with open(config_path, "w") as f: + json.dump(config_dict, f, indent=2) + + # Push to hub if required + if push_to_hub: + raise NotImplementedError("Push to hub not implemented yet") + + def add_special_body_tokens( + self, + input_ids: torch.LongTensor, + body_token_id: int, + num_body_tokens: int, + start_token_id: Optional[int] = None, + ) -> torch.LongTensor: + batch_size, seq_length = input_ids.shape + start_len = 1 if start_token_id is not None else 0 + if seq_length < num_body_tokens + start_len: + seq_length = seq_length + num_body_tokens + start_len + + # Create a tensor to hold start positions for each batch item + start_positions = torch.full( + (batch_size,), + seq_length - start_len - num_body_tokens, + device=input_ids.device, + ) + # Calculate how many tokens to keep from the original input + tokens_to_keep = seq_length - num_body_tokens + + # Find start positions + if start_token_id is not None: + start_pos_rows, start_pos_cols = (input_ids == start_token_id).nonzero( + as_tuple=True + ) + for row, col in zip(start_pos_rows, start_pos_cols): + start_positions[row] = col + tokens_to_keep = seq_length - num_body_tokens - 1 + + # Create body tokens + body_tokens = torch.full( + (batch_size, num_body_tokens), body_token_id, device=input_ids.device + ) + + # Create new input_ids with left padding + new_input_ids = torch.full( + (batch_size, seq_length), + self.tokenizer.eos_token_id, + device=input_ids.device, + ) + + for i in range(batch_size): + start_pos = start_positions[i] + # Keep the rightmost tokens_to_keep tokens before the start token + keep_start = max(0, start_pos - tokens_to_keep) + + if start_token_id is not None: + new_input_ids[ + i, -(num_body_tokens + 1 + (start_pos - keep_start)) : + ] = torch.cat( + [ + input_ids[i, keep_start:start_pos], + torch.LongTensor([start_token_id]).to(input_ids.device), + body_tokens[i], + ] + ) + else: + new_input_ids[ + i, -(num_body_tokens + 1 + (start_pos - keep_start)) : + ] = torch.cat([input_ids[i, keep_start:start_pos], body_tokens[i]]) + return new_input_ids + + @torch.no_grad() + def design_molecule( + self, + input_ids: torch.LongTensor, + attention_mask: torch.FloatTensor, + molecule_properties: Optional[torch.FloatTensor] = None, + molecule_graphs: Optional[PyGBatch] = None, + rollback: bool = False, + **kwargs, + ) -> List[Optional[str]]: + design_start_token_id = self.token_id_dict[""] + design_body_token_id = self.token_id_dict[""] + + # 1. Generate molecular design analysis + if molecule_graphs is None: + analysis_tokens = self.language_model.generate( + inputs=input_ids, + attention_mask=attention_mask, + **kwargs, + ) + analysis_tokens = analysis_tokens[:, input_ids.shape[1] :] + else: + mol_token_id = self.token_id_dict[""] + base_llm = self.language_model.model + inputs_embeds = base_llm.embed_tokens(input_ids) + + mol_positions = (input_ids == mol_token_id).nonzero() + mol_embeds = self.graph_encoder( + molecule_graphs.x, + molecule_graphs.edge_index, + molecule_graphs.edge_attr, + molecule_graphs.batch, + ) + mol_embeds = self.graph_to_lm_connector(mol_embeds) + + assert ( + mol_positions.shape[0] == mol_embeds.shape[0] + ), f"Number of molecule tokens ({mol_positions.shape[0]}) does not match number of molecule embeddings ({mol_embeds.shape[0]})" + inputs_embeds[mol_positions[:, 0], mol_positions[:, 1]] = mol_embeds.to( + inputs_embeds.dtype + ) + analysis_tokens = self.language_model.generate( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) # no input + + # 2. Add special tokens for design body + design_input_ids = self.add_special_body_tokens( + analysis_tokens, + design_body_token_id, + self.num_body_tokens, + start_token_id=design_start_token_id, + ) + design_input_ids = torch.cat([input_ids, design_input_ids], dim=1) + + # 3. Get LLM embeddings for design body + lm_outputs = self.language_model( + input_ids=design_input_ids, + attention_mask=torch.ones_like(design_input_ids), + output_hidden_states=True, + return_dict=True, + ) + lm_hidden_states = lm_outputs.hidden_states[-1] + design_hidden = lm_hidden_states[:, -self.num_body_tokens :].mean(dim=1) + + # 4. Generate molecules using graph decoder + design_hidden = self.lm_to_graph_decoder(design_hidden) + molecule_properties = molecule_properties.type_as(design_hidden) + smiles_list = self.graph_decoder.generate( + molecule_properties, + design_hidden, + NO_LABEL_INDEX, + ) + + # Handle None values in smiles_list + if rollback and None in smiles_list: + smiles_list = self.design_rollback(design_input_ids, smiles_list, **kwargs) + + return analysis_tokens, smiles_list + + def design_rollback( + self, + analysis_tokens: torch.LongTensor, + smiles_list: List[Optional[str]], + **kwargs, + ) -> List[Optional[str]]: + rollback_token_id = self.token_id_dict.get("") + rollback_end_token_id = self.token_id_dict.get("") + none_indices = [i for i, smiles in enumerate(smiles_list) if smiles is None] + + if not none_indices: + return smiles_list # No None values, return original list + + # Get corresponding analysis tokens for None indices + none_indices = torch.LongTensor(none_indices) + rollback_analysis_tokens = analysis_tokens[none_indices] + + # Add rollback token to the end of each analysis token sequence + rollback_input_ids = self.add_special_body_tokens( + rollback_analysis_tokens, + rollback_token_id, + 1, + ) + + if "max_new_tokens" in kwargs: + kwargs["max_new_tokens"] *= 2 + + # Generate new tokens + new_tokens = self.language_model.generate( + inputs=rollback_input_ids, + attention_mask=torch.ones_like(rollback_input_ids), + **kwargs, + ) + + # Process and decode new tokens + new_smiles = [] + for seq in new_tokens[:, rollback_input_ids.shape[1] :]: + decoded_seq = self.tokenizer.decode(seq, skip_special_tokens=False) + end_smiles_pos = decoded_seq.find( + self.tokenizer.decode([rollback_end_token_id]) + ) + + if end_smiles_pos != -1: + # If end token is found, truncate the sequence + new_smiles.append(decoded_seq[:end_smiles_pos].strip()) + else: + # If end token is not found, append None + new_smiles.append(None) + + # Update smiles_list with new decoded tokens + for i, new_smiles_str in zip(none_indices, new_smiles): + smiles_list[i] = new_smiles_str + + return smiles_list + + def smiles_to_graph(self, smiles: str) -> Optional[Data]: + mol = Chem.MolFromSmiles(smiles) + if mol is None: + print(f"Invalid SMILES string: {smiles}") + return None + + type_idx = [] + for atom in mol.GetAtoms(): + if atom.GetAtomicNum() != 1: # Exclude hydrogen atoms + type_idx.append( + 119 - 2 if atom.GetSymbol() == "*" else atom.GetAtomicNum() - 2 + ) + + x = torch.LongTensor(type_idx) + num_nodes = x.size(0) + + # Initialize edge_index and edge_attr as empty tensors + edge_index = torch.empty((2, 0), dtype=torch.long) + edge_attr = torch.empty((0,), dtype=torch.long) + + # Only process bonds if they exist + if mol.GetNumBonds() > 0: + bond_src = [] + bond_dst = [] + bond_type = [] + for bond in mol.GetBonds(): + start, end = bond.GetBeginAtomIdx(), bond.GetEndAtomIdx() + # Exclude bonds involving hydrogen atoms + if mol.GetAtomWithIdx(start).GetAtomicNum() != 1 and mol.GetAtomWithIdx(end).GetAtomicNum() != 1: + bond_src.extend([start, end]) + bond_dst.extend([end, start]) + bond_type.extend([BOND_INDEX.get(bond.GetBondType(), 1)] * 2) + + if bond_src: # Only create edge_index and edge_attr if there are valid bonds + edge_index = torch.tensor([bond_src, bond_dst], dtype=torch.long) + edge_attr = torch.tensor(bond_type, dtype=torch.long) + + # Create PyG Data object + data = Data(x=x, edge_index=edge_index, edge_attr=edge_attr, num_nodes=num_nodes) + + return data + + def retrosynthesize_rollback(self, input_ids, design_text, smiles, **kwargs): + input_text = f"{design_text} To synthesize {smiles}, follow these procedures: " + input_tokens = self.tokenizer.encode( + input_text, add_special_tokens=False, return_tensors="pt" + ) + input_tokens = input_tokens.to(self.device) + + if "max_new_tokens" in kwargs: + kwargs["max_new_tokens"] = 256 + + # Generate tokens + generated_tokens = self.language_model.generate( + inputs=input_tokens, + **kwargs, + ) + generated_tokens = generated_tokens[:, input_tokens.shape[1] :] + generated_tokens = generated_tokens.cpu().squeeze().tolist() + new_input_text = f"To synthesize {smiles}, follow these procedures: " + new_input_tokens = self.tokenizer.encode(new_input_text) + generated_tokens = new_input_tokens + generated_tokens + return generated_tokens + + def one_step_reaction( + self, + product_smiles, + input_ids, + design_text, + molecule_graphs, + topk, + **kwargs, + ): + # 1. Generate retrosynthesis analysis + retro_start_token_id = self.token_id_dict[""] + retro_body_token_id = self.token_id_dict[""] + mol_token_id = self.token_id_dict[""] + + input_text = f"{design_text} To synthesize , follow these procedures: " + + prompt_tokens = self.tokenizer.encode( + input_text, add_special_tokens=False, return_tensors="pt" + ) + prompt_tokens = prompt_tokens.to(self.device) + + # Combine input_ids with new_prompt_tokens if input_ids is provided + if input_ids is not None and molecule_graphs is not None: + input_ids = input_ids.view(1, -1) + prompt_tokens = torch.cat([input_ids, prompt_tokens], dim=-1) + + base_llm = self.language_model.model + inputs_embeds = base_llm.embed_tokens(prompt_tokens) + + product_graph = self.smiles_to_graph(product_smiles) + if product_graph is None: + return { + "reactants": [], + "scores": [], + "templates": [], + "analysis": self.tokenizer.encode( + "Invalid product SMILES", add_special_tokens=False + ), + } + product_graph.to(self.device) + + if input_ids is not None and molecule_graphs is not None: + all_graphs = PyGBatch.from_data_list(molecule_graphs.to_data_list() + [product_graph]) + else: + all_graphs = PyGBatch.from_data_list([product_graph]) + mol_embeds = self.graph_encoder( + all_graphs.x, + all_graphs.edge_index, + all_graphs.edge_attr, + all_graphs.batch, + ) + mol_embeds = self.graph_to_lm_connector(mol_embeds) + + mol_positions = (prompt_tokens == mol_token_id).nonzero() + assert ( + mol_positions.shape[0] == mol_embeds.shape[0] + ), f"Number of molecule tokens ({mol_positions.shape[0]}) does not match number of molecule embeddings ({mol_embeds.shape[0]})" + inputs_embeds[mol_positions[:, 0], mol_positions[:, 1]] = mol_embeds.to( + inputs_embeds.dtype + ) + attention_mask = torch.ones_like(prompt_tokens) + + if "max_new_tokens" in kwargs: + kwargs["max_new_tokens"] = 512 + + analysis_tokens = self.language_model.generate( + attention_mask=attention_mask, + inputs_embeds=inputs_embeds, + **kwargs, + ) + + # 2. Encode analysis with query tokens + retro_input_ids = self.add_special_body_tokens( + analysis_tokens, + retro_body_token_id, + self.num_body_tokens, + start_token_id=retro_start_token_id, + ) + # Get LLM embeddings for retro body + lm_outputs = self.language_model( + input_ids=retro_input_ids, + attention_mask=torch.ones_like(retro_input_ids), + output_hidden_states=True, + return_dict=True, + ) + lm_hidden_states = lm_outputs.hidden_states[-1] + retro_hidden = lm_hidden_states[:, -self.num_body_tokens :].mean(dim=1) + retro_hidden = self.lm_to_graph_predictor(retro_hidden) + + # 3. Sample retrosynthetic templates + reactants, scores, templates = self.graph_predictor.sample_templates( + product_graph, retro_hidden, product_smiles, topk + ) + + # 4. Adjust the input part from the generated tokens + analysis_tokens = analysis_tokens.cpu().squeeze().tolist() + input_text = f"To synthesize {product_smiles}, follow these procedures: " + new_input_tokens = self.tokenizer.encode(input_text) + analysis_tokens = new_input_tokens + analysis_tokens + + return { + "reactants": reactants, + "scores": scores, + "templates": templates, + "analysis": analysis_tokens, + } + + @torch.no_grad() + def estimate_synthesis_complexity( + self, + smiles: str, + input_ids=None, + reaction=None, + molecule_cost_weight: float = 0, + language_cost_weight: float = 1, + reference_tokens: Optional[torch.LongTensor] = None, + ): + cost = 0 + + if molecule_cost_weight is not None and molecule_cost_weight > 0: + mol_cost = self.graph_predictor.estimate_cost(smiles) + cost += mol_cost * molecule_cost_weight + + if language_cost_weight is not None and language_cost_weight > 0: + language_cost = 0 + if reaction is None: + message_content = f""" + Estimate remaining steps for the target {smiles} consider the following factors:: + 1. Intermediate complexity + 2. Reagent availability + 3. Side reactions + 4. Stereochemistry challenges""" + else: + step = reaction.depth + 1 + template = reaction.template + # analysis_tokens = reaction.analysis_tokens + reactants = reaction.children + reactants = ", ".join([r.mol for r in reactants]) + message_content = f""" + Estimate remaining steps for the target {smiles} given the following parameters: + Current step {step}, + Current template: {template}, + Reactants: {reactants}. + Consider the following factors: + 1. Intermediate complexity + 2. Reagent availability + 3. Side reactions + 4. Stereochemistry challenges""" + + # Create the messages list for the chat template + messages = [{"role": "user", "content": message_content}] + + # Apply the chat template + chat_text = self.tokenizer.apply_chat_template( + messages, tokenize=False, add_generation_prompt=True + ) + + answers = [ + "All readily available", + "Some commercial, some need 1-2 steps", + "Mix of commercial and multi-step synthesis", + "Mostly require complex synthesis", + "All require extensive multi-step synthesis", + ] + + answer_costs = [0, 1, 2.5, 4.5, 7] + answer_messages = [ + [ + { + "role": "user", + "content": "Estimate the synthesis complexity:", + }, + {"role": "assistant", "content": answer}, + ] + for answer in answers + ] + answer_chat_texts = [ + self.tokenizer.apply_chat_template( + msg, tokenize=False, add_generation_prompt=False + ) + for msg in answer_messages + ] + + # Encode chat texts + input_ids = self.tokenizer.encode(chat_text, return_tensors="pt").to( + self.device + ) + answer_tokens = [ + self.tokenizer.encode(text) for text in answer_chat_texts + ] + + # Get logits from the language model + outputs = self.language_model(input_ids) + logits = outputs.logits[:, -1, :] + + # Calculate softmax probabilities for each answer + answer_logits = torch.stack( + [logits[:, tokens].mean(dim=1) for tokens in answer_tokens] + ) + probs = torch.nn.functional.softmax(answer_logits, dim=0) + language_cost = ( + (probs * torch.tensor(answer_costs, device=probs.device)) + .sum() + .item() + ) + + language_cost = language_cost * language_cost_weight + cost += language_cost + + return cost + + @torch.no_grad() + def retrosynthesize( + self, + input_ids: torch.LongTensor, + smiles: Optional[str] = None, + molecule_graphs: Optional[PyGBatch] = None, + expansion_topk: int = 50, + iterations: int = 100, + starting_mols: Optional[List[str]] = None, + molecule_cost_weight: float = 0, + language_cost_weight: float = 1, + max_planning_time: int = 300, + rollback: bool = True, + design_text: Optional[str] = None, + **kwargs, + ) -> Dict[str, Any]: + # Initialize variables + target_smiles = None + success = False + reaction_list = None + template_list = None + analysis_tokens_list = None + route_length = None + total_time = 0.0 + cost = None + + # Handle starting molecules + if starting_mols is None: + if self.graph_predictor.available is None: + raise ValueError( + "No starting molecules provided and no available starting molecules found." + ) + starting_mols = self.graph_predictor.available["smiles"].tolist() + + # Handle case when no SMILES is provided + if smiles is None and rollback: + generated_tokens = self.retrosynthesize_rollback(input_ids, design_text, None, **kwargs) + return self._create_failure_result(None, generated_tokens) + + # Preprocess SMILES + target_smiles = smiles.replace("*", "[H]") if "*" in smiles else smiles + + # Check validity and handle rollback if necessary + if not self.graph_decoder.check_valid(target_smiles) and rollback: + generated_tokens = self.retrosynthesize_rollback( + input_ids, design_text, target_smiles, **kwargs + ) + return self._create_failure_result(target_smiles, generated_tokens) + + # Perform retrosynthesis + t0 = time.time() + + def expand_fn(s): + return self.one_step_reaction( + s, input_ids=input_ids, design_text=design_text, molecule_graphs=molecule_graphs, topk=expansion_topk, **kwargs + ) + + def value_fn(s, r): + return self.estimate_synthesis_complexity( + s, input_ids, r, molecule_cost_weight, language_cost_weight + ) + + if target_smiles is None: + return self._create_failure_result(None) + + success, best_route, iterations = molstar( + target_mol=target_smiles, + target_mol_id=0, + starting_mols=starting_mols, + expand_fn=expand_fn, + value_fn=value_fn, + iterations=iterations, + max_time=max_planning_time, + ) + + total_time = time.time() - t0 + + # Handle successful retrosynthesis + if success: + reaction_list, template_list, cost, analysis_tokens_list = best_route.get_reaction_list() + route_length = best_route.length + # Handle failed retrosynthesis with rollback + elif rollback: + generated_tokens = self.retrosynthesize_rollback( + input_ids, design_text, target_smiles, **kwargs + ) + return self._create_failure_result(target_smiles, generated_tokens) + + # Prepare and return result + return { + "target": target_smiles, + "success": success, + "time": total_time, + "reaction_list": reaction_list, + "cost": cost, + "templates": template_list, + "analysis_tokens": analysis_tokens_list, + "route_length": route_length, + } + + def _create_failure_result( + self, + target_smiles: Optional[str], + generated_tokens: Optional[Union[torch.Tensor, list]] = None, + ) -> Dict[str, Any]: + return { + "target": target_smiles, + "success": False, + "time": 0.0, + "reaction_list": None, + "cost": None, + "templates": None, + "analysis_tokens": ( + generated_tokens + if generated_tokens is not None + else "" + ), + "route_length": None, + } + + @torch.no_grad() + def generate( + self, + input_ids: Optional[torch.LongTensor] = None, + attention_mask: Optional[torch.FloatTensor] = None, + molecule_properties: Optional[torch.FloatTensor] = None, + molecule_graphs: Optional[PyGBatch] = None, + rollback: bool = False, + starting_mols: Optional[List[str]] = None, + expansion_topk: int = 50, + iterations: int = 100, + molecule_cost_weight: float = 0, + language_cost_weight: float = 1, + do_molecular_design: Optional[bool] = True, + do_retrosynthesis: bool = True, + input_smiles_list: Optional[List[str]] = None, + max_planning_time: int = 30, + design_text_list: Optional[List[str]] = None, + **kwargs, + ) -> Dict: + if attention_mask is None: + attention_mask = input_ids.new_ones(input_ids.shape) + + all_info_dict = { + "token_lists": [], + "text_lists": [], + "design_analysis_tokens": None, + "smiles_list": None, + "retro_plan_dict": None, + } + + # Molecular design + if do_molecular_design is True: + design_analysis_tokens, smiles_list = self.design_molecule( + input_ids, + attention_mask, + molecule_properties, + molecule_graphs, + rollback, + **kwargs, + ) + all_info_dict["design_analysis_tokens"] = design_analysis_tokens.cpu() + all_info_dict["smiles_list"] = smiles_list + elif input_smiles_list is not None: + all_info_dict["smiles_list"] = input_smiles_list + else: + raise ValueError( + "Either do_molecular_design must be True/False or input_smiles_list must be provided." + ) + + # Retrosynthesis + if do_retrosynthesis: + if all_info_dict["smiles_list"] is None: + raise ValueError( + "Either molecular design must be performed or input_smiles_list must be provided for retrosynthesis." + ) + + all_info_dict["retro_plan_dict"] = {} + for i, smiles in enumerate(all_info_dict["smiles_list"]): + if design_text_list is not None: + design_text = design_text_list[0] + else: + design_text = None + all_info_dict["retro_plan_dict"][smiles] = self.retrosynthesize( + input_ids[i] if input_ids.dim() > 1 else input_ids, + smiles, + molecule_graphs=molecule_graphs, + starting_mols=starting_mols, + expansion_topk=expansion_topk, + iterations=iterations, + molecule_cost_weight=molecule_cost_weight, + language_cost_weight=language_cost_weight, + max_planning_time=max_planning_time, + design_text=design_text, + **kwargs, + ) + else: + all_info_dict["retro_plan_dict"] = { + smile: {"success": None} for smile in all_info_dict["smiles_list"] + } + + for batch_idx, generated_mol in enumerate(all_info_dict["smiles_list"]): + token_list = [] + text_list = [] + ignore_positions = {} + if do_molecular_design: + design_tokens = all_info_dict["design_analysis_tokens"][ + batch_idx + ].tolist() + token_list = design_tokens + [IGNORE_INDEX] + if generated_mol is None: + generated_mol = "" + text_list = [ + self.tokenizer.decode( + design_tokens, + skip_special_tokens=True, + clean_up_tokenization_spaced=True, + ), + generated_mol + ". ", + ] + + ignore_positions = {0: generated_mol} + + if do_retrosynthesis: + available_mols = self.graph_predictor.available["smiles"].tolist() + retro_plan = all_info_dict["retro_plan_dict"][generated_mol] + if retro_plan["success"] is not None and retro_plan["success"]: + for i, (reaction, template, cost, analysis_tokens) in enumerate( + zip( + retro_plan["reaction_list"], + retro_plan["templates"], + retro_plan["cost"], + retro_plan["analysis_tokens"], + ) + ): + if isinstance(analysis_tokens, torch.Tensor): + analysis_tokens = analysis_tokens.tolist() + token_list.extend(analysis_tokens + [IGNORE_INDEX]) + text_list.extend( + [ + self.tokenizer.decode( + analysis_tokens, + skip_special_tokens=True, + clean_up_tokenization_spaced=True, + ), + reaction if reaction is not None else "", + " with the template ", + template if template is not None else "", + " which requires the reactants: ", + ] + ) + # Add these two lines to extract and add reactants + if reaction is not None: + reactants = reaction.split(">>")[1].split(".") + formatted_reactants = [] + for reactant in reactants: + if reactant in available_mols: + formatted_reactants.append( + f"{reactant} (available)" + ) + else: + formatted_reactants.append(reactant) + text_list.extend([", ".join(formatted_reactants), ". "]) + else: + text_list.extend([". "]) + ignore_positions[len(token_list) - 1] = ( + reaction, + template, + cost, + ) + else: + analysis_tokens = retro_plan["analysis_tokens"] + if isinstance(analysis_tokens, torch.Tensor): + analysis_tokens = analysis_tokens.tolist() + + token_list.extend(analysis_tokens) + text_list.extend( + [ + self.tokenizer.decode( + analysis_tokens, + skip_special_tokens=True, + clean_up_tokenization_spaced=True, + ), + " ", + ] + ) + + all_info_dict["token_lists"].append(token_list) + all_info_dict["text_lists"].append(text_list) + all_info_dict[f"batch_{batch_idx}_ignore_positions"] = ignore_positions + + all_info_dict["IGNORE_INDEX"] = IGNORE_INDEX + return all_info_dict \ No newline at end of file diff --git a/src/model/patcher.py b/src/model/patcher.py new file mode 100644 index 0000000000000000000000000000000000000000..80d44205c4faab5073dc24f93aee37f8fb83268c --- /dev/null +++ b/src/model/patcher.py @@ -0,0 +1,161 @@ +# Copyright 2024 the LlamaFactory team, the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict + +import torch +from peft import PeftModel +from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_available +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.modeling_utils import is_fsdp_enabled + +from ..extras.logging import get_logger +from ..extras.misc import infer_optim_dtype +from .model_utils.attention import configure_attn_implementation, print_attn_implementation +from .model_utils.checkpointing import prepare_model_for_training +from .model_utils.embedding import resize_embedding_layer +from .model_utils.longlora import configure_longlora +from .model_utils.moe import add_z3_leaf_module, configure_moe +from .model_utils.quantization import configure_quantization +from .model_utils.rope import configure_rope +from .model_utils.valuehead import prepare_valuehead_model + +if TYPE_CHECKING: + from transformers import PretrainedConfig, PreTrainedTokenizer + from trl import AutoModelForCausalLMWithValueHead + + from ..hparams import ModelArguments + + +logger = get_logger(__name__) + + +def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: + if "PreTrainedTokenizerBase" not in str(tokenizer._pad.__func__): + tokenizer._pad = MethodType(PreTrainedTokenizerBase._pad, tokenizer) + + +def patch_config( + config: "PretrainedConfig", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + init_kwargs: Dict[str, Any], + is_trainable: bool, +) -> None: + if model_args.compute_dtype is None: # priority: bf16 > fp16 > fp32 + if model_args.infer_dtype != "auto" and not is_trainable: + model_args.compute_dtype = getattr(torch, model_args.infer_dtype) + else: + model_args.compute_dtype = infer_optim_dtype(model_dtype=getattr(config, "torch_dtype", None)) + + if is_torch_npu_available(): + use_jit_compile = os.environ.get("JIT_COMPILE", "0").lower() in ["true", "1"] + torch.npu.set_compile_mode(jit_compile=use_jit_compile) + + configure_attn_implementation(config, model_args, is_trainable) + configure_rope(config, model_args, is_trainable) + configure_longlora(config, model_args, is_trainable) + configure_quantization(config, tokenizer, model_args, init_kwargs) + configure_moe(config, model_args, is_trainable) + + # configure_visual_model(config) + + if model_args.use_cache and not is_trainable: + setattr(config, "use_cache", True) + logger.info("Using KV cache for faster generation.") + + if getattr(config, "model_type", None) == "qwen": + setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") + for dtype_name, dtype in [("fp16", torch.float16), ("bf16", torch.bfloat16), ("fp32", torch.float32)]: + setattr(config, dtype_name, model_args.compute_dtype == dtype) + + if getattr(config, "model_type", None) == "qwen2" and is_trainable and model_args.flash_attn == "fa2": + setattr(config, "use_cache", False) # qwen2 does not support use_cache when using flash attn + + # deepspeed zero3 is not compatible with low_cpu_mem_usage + init_kwargs["low_cpu_mem_usage"] = model_args.low_cpu_mem_usage and (not is_deepspeed_zero3_enabled()) + + # cast data type of the model if: + # 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32) + # 2. quantization_bit is not None (qlora) + if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None: + init_kwargs["torch_dtype"] = model_args.compute_dtype + + if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True + if "device_map" not in init_kwargs and model_args.device_map: + init_kwargs["device_map"] = model_args.device_map + + if init_kwargs.get("device_map", None) == "auto": + init_kwargs["offload_folder"] = model_args.offload_folder + + +def patch_model( + model: "PreTrainedModel", + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + is_trainable: bool, + add_valuehead: bool, +) -> None: + gen_config = model.generation_config # check and fix generation config + if not gen_config.do_sample and ( + (gen_config.temperature is not None and gen_config.temperature != 1.0) + or (gen_config.top_p is not None and gen_config.top_p != 1.0) + or (gen_config.typical_p is not None and gen_config.typical_p != 1.0) + ): + gen_config.do_sample = True + + if "GenerationMixin" not in str(model.generate.__func__): + model.generate = MethodType(PreTrainedModel.generate, model) + + if add_valuehead: + prepare_valuehead_model(model) + + if model_args.resize_vocab: + resize_embedding_layer(model, tokenizer) + + if model_args.visual_inputs: + autocast_projector_dtype(model, model_args) + + if is_trainable: + prepare_model_for_training(model, model_args) + add_z3_leaf_module(model) + + if not model_args.use_unsloth: + print_attn_implementation(model.config) + +def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: + def tie_weights(self: "AutoModelForCausalLMWithValueHead") -> None: + if isinstance(self.pretrained_model, PreTrainedModel): + self.pretrained_model.tie_weights() + + def get_input_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_input_embeddings() + + def get_output_embeddings(self: "AutoModelForCausalLMWithValueHead") -> torch.nn.Module: + if isinstance(self.pretrained_model, PreTrainedModel): + return self.pretrained_model.get_output_embeddings() + + def create_or_update_model_card(self: "AutoModelForCausalLMWithValueHead", output_dir: str) -> None: + if isinstance(self.pretrained_model, PeftModel): + self.pretrained_model.create_or_update_model_card(output_dir) + + ignore_modules = [name for name, _ in model.named_parameters() if "pretrained_model" in name] + setattr(model, "_keys_to_ignore_on_save", ignore_modules) + setattr(model, "tie_weights", MethodType(tie_weights, model)) + setattr(model, "get_input_embeddings", MethodType(get_input_embeddings, model)) + setattr(model, "get_output_embeddings", MethodType(get_output_embeddings, model)) + setattr(model, "create_or_update_model_card", MethodType(create_or_update_model_card, model)) diff --git a/src/model/planner/__init__.py b/src/model/planner/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b4ce3dec99ad06e9656f64f7b4ccbc6dd6492da4 --- /dev/null +++ b/src/model/planner/__init__.py @@ -0,0 +1,5 @@ +from .molstar import molstar + +__all__ = [ + 'molstar', +] diff --git a/src/model/planner/mol_node.py b/src/model/planner/mol_node.py new file mode 100644 index 0000000000000000000000000000000000000000..aeee71c20c491504eb81f0cc9bf6bdeace555171 --- /dev/null +++ b/src/model/planner/mol_node.py @@ -0,0 +1,103 @@ +# Adapted from: https://github.com/binghong-ml/retro_star + +import numpy as np +import logging + +class MolNode: + def __init__(self, mol, init_value, parent=None, is_known=False, + zero_known_value=True): + self.mol = mol + self.pred_value = init_value + self.value = init_value + self.succ_value = np.inf # total cost for existing solution + self.parent = parent + + self.id = -1 + if self.parent is None: + self.depth = 0 + else: + self.depth = self.parent.depth + + self.is_known = is_known + self.children = [] + self.succ = is_known + self.open = True # before expansion: True, after expansion: False + if is_known: + self.open = False + if zero_known_value: + self.value = 0 + self.succ_value = self.value + + if parent is not None: + parent.children.append(self) + + def v_self(self): + """ + :return: V_self(self | subtree) + """ + return self.value + + def v_target(self): + """ + :return: V_target(self | whole tree) + """ + if self.parent is None: + return self.value + else: + return self.parent.v_target() + + def init_values(self, no_child=False): + assert self.open and (no_child or self.children) + + new_value = np.inf + self.succ = False + for reaction in self.children: + new_value = np.min((new_value, reaction.v_self())) + self.succ |= reaction.succ + + v_delta = new_value - self.value + self.value = new_value + + if self.succ: + for reaction in self.children: + self.succ_value = np.min((self.succ_value, + reaction.succ_value)) + + self.open = False + + return v_delta + + def backup(self, succ): + assert not self.is_known + + new_value = np.inf + for reaction in self.children: + new_value = np.min((new_value, reaction.v_self())) + new_succ = self.succ | succ + updated = (self.value != new_value) or (self.succ != new_succ) + + new_succ_value = np.inf + if new_succ: + for reaction in self.children: + new_succ_value = np.min((new_succ_value, reaction.succ_value)) + updated = updated or (self.succ_value != new_succ_value) + + v_delta = new_value - self.value + self.value = new_value + self.succ = new_succ + self.succ_value = new_succ_value + + if updated and self.parent: + return self.parent.backup(v_delta, from_mol=self.mol) + + def serialize(self): + text = '%d | %s' % (self.id, self.mol) + return text + + def get_ancestors(self): + if self.parent is None: + return {self.mol} + + ancestors = self.parent.parent.get_ancestors() + ancestors.add(self.mol) + return ancestors \ No newline at end of file diff --git a/src/model/planner/mol_tree.py b/src/model/planner/mol_tree.py new file mode 100644 index 0000000000000000000000000000000000000000..c69e2bb5679c26a38895cf3ff92aabb7d7c66238 --- /dev/null +++ b/src/model/planner/mol_tree.py @@ -0,0 +1,130 @@ +# Adapted from: https://github.com/binghong-ml/retro_star + +import numpy as np +from queue import Queue +import logging +from .mol_node import MolNode +from .reaction_node import ReactionNode +from .syn_route import SynRoute + +class MolTree: + def __init__(self, target_mol, known_mols, value_fn, zero_known_value=True): + self.target_mol = target_mol + self.known_mols = known_mols + self.value_fn = value_fn + self.zero_known_value = zero_known_value + self.mol_nodes = [] + self.reaction_nodes = [] + + self.root = self._add_mol_node(target_mol, None) + self.succ = False + self.search_status = 0 + + if target_mol in known_mols: + logging.info('Warning: target in starting molecules. We still try to find another route.') + + def _add_mol_node(self, mol, parent): + is_known = mol in self.known_mols + + init_value = self.value_fn(mol, parent) + + mol_node = MolNode( + mol=mol, + init_value=init_value, + parent=parent, + is_known=is_known, + zero_known_value=self.zero_known_value + ) + self.mol_nodes.append(mol_node) + mol_node.id = len(self.mol_nodes) + + return mol_node + + def _add_reaction_and_mol_nodes(self, cost, mols, parent, template, analysis_tokens, ancestors): + assert cost >= 0 + + for mol in mols: + if mol in ancestors: + return + + reaction_node = ReactionNode(parent, cost, template, analysis_tokens) + for mol in mols: + self._add_mol_node(mol, reaction_node) + reaction_node.init_values() + self.reaction_nodes.append(reaction_node) + reaction_node.id = len(self.reaction_nodes) + + return reaction_node + + def expand(self, mol_node, reactant_lists, costs, templates, analysis_tokens): + assert not mol_node.is_known and not mol_node.children + + if costs is None: # No expansion results + assert mol_node.init_values(no_child=True) == np.inf + if mol_node.parent: + mol_node.parent.backup(np.inf, from_mol=mol_node.mol) + return self.succ + + assert mol_node.open + ancestors = mol_node.get_ancestors() + for i in range(len(costs)): + self._add_reaction_and_mol_nodes(costs[i], reactant_lists[i], + mol_node, templates[i], analysis_tokens, ancestors) + + if len(mol_node.children) == 0: # No valid expansion results + assert mol_node.init_values(no_child=True) == np.inf + if mol_node.parent: + mol_node.parent.backup(np.inf, from_mol=mol_node.mol) + return self.succ + + v_delta = mol_node.init_values() + if mol_node.parent: + mol_node.parent.backup(v_delta, from_mol=mol_node.mol) + + if not self.succ and self.root.succ: + logging.info('Synthesis route found!') + self.succ = True + + return self.succ + + def get_best_route(self): + if not self.succ: + return None + + syn_route = SynRoute( + target_mol=self.root.mol, + succ_value=self.root.succ_value, + search_status=self.search_status + ) + + mol_queue = Queue() + mol_queue.put(self.root) + while not mol_queue.empty(): + mol = mol_queue.get() + if mol.is_known: + syn_route.set_value(mol.mol, mol.succ_value) + continue + + best_reaction = None + for reaction in mol.children: + if reaction.succ: + if best_reaction is None or \ + reaction.succ_value < best_reaction.succ_value: + best_reaction = reaction + assert best_reaction.succ_value == mol.succ_value + + reactants = [] + for reactant in best_reaction.children: + mol_queue.put(reactant) + reactants.append(reactant.mol) + + syn_route.add_reaction( + mol=mol.mol, + value=mol.succ_value, + template=best_reaction.template, + analysis_tokens=best_reaction.analysis_tokens, + reactants=reactants, + cost=best_reaction.cost + ) + + return syn_route \ No newline at end of file diff --git a/src/model/planner/molstar.py b/src/model/planner/molstar.py new file mode 100644 index 0000000000000000000000000000000000000000..4afff9f1af9f01742945e687f426cb760e2e7d04 --- /dev/null +++ b/src/model/planner/molstar.py @@ -0,0 +1,77 @@ +# Adapted from: https://github.com/binghong-ml/retro_star + +import os +import numpy as np +import logging +import time +from .mol_tree import MolTree + +def molstar(target_mol, target_mol_id, starting_mols, expand_fn, value_fn, + iterations, viz=False, viz_dir=None, max_time=300): + + mol_tree = MolTree( + target_mol=target_mol, + known_mols=starting_mols, + value_fn=value_fn + ) + + i = -1 + start_time = time.time() + + if not mol_tree.succ: + for i in range(iterations): + if time.time() - start_time > max_time: + break + + scores = [] + for m in mol_tree.mol_nodes: + if m.open: + scores.append(m.v_target()) + else: + scores.append(np.inf) + scores = np.array(scores) + + if np.min(scores) == np.inf: + break + + metric = scores + + mol_tree.search_status = np.min(metric) + m_next = mol_tree.mol_nodes[np.argmin(metric)] + assert m_next.open + + result = expand_fn(m_next.mol) + + if result is not None and (len(result['scores']) > 0): + reactants = result['reactants'] + scores = result['scores'] + analysis_tokens = result['analysis'] + costs = 0.0 - np.log(np.clip(np.array(scores), 1e-3, 1.0)) + templates = result['templates'] + + reactant_lists = [] + for j in range(len(scores)): + reactant_list = list(set(reactants[j].split('.'))) + reactant_lists.append(reactant_list) + + assert m_next.open + succ = mol_tree.expand(m_next, reactant_lists, costs, templates, analysis_tokens) + + if succ: + break + + # found optimal route + if mol_tree.root.succ_value <= mol_tree.search_status: + break + + else: + mol_tree.expand(m_next, None, None, None, None) + + search_time = time.time() - start_time + + best_route = None + if mol_tree.succ: + best_route = mol_tree.get_best_route() + assert best_route is not None + + return mol_tree.succ, best_route, i+1 \ No newline at end of file diff --git a/src/model/planner/reaction_node.py b/src/model/planner/reaction_node.py new file mode 100644 index 0000000000000000000000000000000000000000..8e556c2fc508599eefb78a316176d92a338f70ee --- /dev/null +++ b/src/model/planner/reaction_node.py @@ -0,0 +1,83 @@ +# Adapted from: https://github.com/binghong-ml/retro_star + +import numpy as np +import logging + +class ReactionNode: + def __init__(self, parent, cost, template, analysis_tokens): + self.parent = parent + + self.depth = self.parent.depth + 1 + self.id = -1 + + self.cost = cost + self.template = template + self.analysis_tokens = analysis_tokens + self.children = [] + self.value = None # [V(m | subtree_m) for m in children].sum() + cost + self.succ_value = np.inf # total cost for existing solution + self.target_value = None # V_target(self | whole tree) + self.succ = None # successfully found a valid synthesis route + self.open = True # before expansion: True, after expansion: False + parent.children.append(self) + + def v_self(self): + """ + :return: V_self(self | subtree) + """ + return self.value + + def v_target(self): + """ + :return: V_target(self | whole tree) + """ + return self.target_value + + def init_values(self): + assert self.open + + self.value = self.cost + self.succ = True + for mol in self.children: + self.value += mol.value + self.succ &= mol.succ + + if self.succ: + self.succ_value = self.cost + for mol in self.children: + self.succ_value += mol.succ_value + + self.target_value = self.parent.v_target() - self.parent.v_self() + \ + self.value + self.open = False + + def backup(self, v_delta, from_mol=None): + self.value += v_delta + self.target_value += v_delta + + self.succ = True + for mol in self.children: + self.succ &= mol.succ + + if self.succ: + self.succ_value = self.cost + for mol in self.children: + self.succ_value += mol.succ_value + + if v_delta != 0: + assert from_mol + self.propagate(v_delta, exclude=from_mol) + + return self.parent.backup(self.succ) + + def propagate(self, v_delta, exclude=None): + if exclude is None: + self.target_value += v_delta + + for child in self.children: + if exclude is None or child.mol != exclude: + for grandchild in child.children: + grandchild.propagate(v_delta) + + def serialize(self): + return '%d' % (self.id) \ No newline at end of file diff --git a/src/model/planner/syn_route.py b/src/model/planner/syn_route.py new file mode 100644 index 0000000000000000000000000000000000000000..8e7360d80ed4792303d0e8e8779f09fe83eee4ee --- /dev/null +++ b/src/model/planner/syn_route.py @@ -0,0 +1,89 @@ +# Adapted from: https://github.com/binghong-ml/retro_star + +import numpy as np +from queue import Queue + +class SynRoute: + def __init__(self, target_mol, succ_value, search_status): + self.target_mol = target_mol + self.mols = [target_mol] + self.values = [None] + self.templates = [None] + self.parents = [-1] + self.children = [None] + self.optimal = False + self.costs = {} + self.analysis_dict = {} + + self.succ_value = succ_value + self.total_cost = 0 + self.length = 0 + self.search_status = search_status + if self.succ_value <= self.search_status: + self.optimal = True + + def _add_mol(self, mol, parent_id): + self.mols.append(mol) + self.values.append(None) + self.templates.append(None) + self.parents.append(parent_id) + self.children.append(None) + + self.children[parent_id].append(len(self.mols) - 1) + + def set_value(self, mol, value): + assert mol in self.mols + + mol_id = self.mols.index(mol) + self.values[mol_id] = value + + def add_reaction(self, mol, value, template, analysis_tokens, reactants, cost): + assert mol in self.mols + + self.total_cost += cost + self.length += 1 + + parent_id = self.mols.index(mol) + self.values[parent_id] = value + self.templates[parent_id] = template + self.children[parent_id] = [] + self.costs[parent_id] = cost + self.analysis_dict[parent_id] = analysis_tokens + + for reactant in reactants: + self._add_mol(reactant, parent_id) + + def serialize_reaction(self, idx): + s = self.mols[idx] + if self.children[idx] is None: + return s, 0.0 + s += ">>" + cost = np.exp(-self.costs[idx]) + analysis = self.analysis_dict[idx] + template = self.templates[idx] + s += self.mols[self.children[idx][0]] + for i in range(1, len(self.children[idx])): + s += "." + s += self.mols[self.children[idx][i]] + return s, cost, analysis, template + + def get_reaction_list(self): + total_reactions, total_cost = [], [] + total_analysis = [] + total_templates = [] + reaction, cost, analysis, template = self.serialize_reaction(0) + total_reactions.append(reaction) + total_cost.append(cost) + total_analysis.append(analysis) + total_templates.append(template) + for i in range(1, len(self.mols)): + if self.children[i] is not None: + reaction, cost, analysis, template = self.serialize_reaction(i) + total_cost.append(cost) + total_reactions.append(reaction) + total_analysis.append(analysis) + total_templates.append(template) + return total_reactions, total_templates, total_cost, total_analysis + + def get_template_list(self): + return self.templates \ No newline at end of file diff --git a/src/train/__init__.py b/src/train/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/train/callbacks.py b/src/train/callbacks.py new file mode 100644 index 0000000000000000000000000000000000000000..4d024278142598c2d0d1c9d309d7c2ef352633d1 --- /dev/null +++ b/src/train/callbacks.py @@ -0,0 +1,349 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import logging +import os +import signal +import sys +import time +from concurrent.futures import ThreadPoolExecutor +from datetime import timedelta +from typing import TYPE_CHECKING, Any, Dict, Optional + +import torch +import transformers +from peft import PeftModel +from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback +from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length +from transformers.utils import ( + SAFE_WEIGHTS_NAME, + WEIGHTS_NAME, + is_safetensors_available, +) + +from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ..extras.logging import LoggerHandler, get_logger + + +if is_safetensors_available(): + from safetensors import safe_open + from safetensors.torch import save_file + +if TYPE_CHECKING: + from transformers import TrainerControl, TrainerState, TrainingArguments + from trl import AutoModelForCausalLMWithValueHead + + +logger = get_logger(__name__) + + +def fix_valuehead_checkpoint( + model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool +) -> None: + r""" + The model is already unwrapped. + + There are three cases: + 1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} + 2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} + 3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} + + We assume `stage3_gather_16bit_weights_on_model_save=true`. + """ + if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): + return + + if safe_serialization: + path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) + with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: + state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} + else: + path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) + state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") + + decoder_state_dict = {} + v_head_state_dict = {} + for name, param in state_dict.items(): + if name.startswith("v_head."): + v_head_state_dict[name] = param + else: + decoder_state_dict[name.replace("pretrained_model.", "")] = param + + os.remove(path_to_checkpoint) + model.pretrained_model.save_pretrained( + output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization + ) + + if safe_serialization: + save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) + else: + torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) + + logger.info("Value head model saved at: {}".format(output_dir)) + + +class FixValueHeadModelCallback(TrainerCallback): + def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after a checkpoint save. + """ + if args.should_save: + fix_valuehead_checkpoint( + model=kwargs.pop("model"), + output_dir=os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)), + safe_serialization=args.save_safetensors, + ) + + +class SaveProcessorCallback(TrainerCallback): + def __init__(self, processor: "ProcessorMixin") -> None: + r""" + Initializes a callback for saving the processor. + """ + self.processor = processor + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + if args.should_save: + getattr(self.processor, "image_processor").save_pretrained(args.output_dir) + + +class PissaConvertCallback(TrainerCallback): + r""" + Initializes a callback for converting the PiSSA adapter to a normal one. + """ + + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the beginning of training. + """ + if args.should_save: + model = kwargs.pop("model") + pissa_init_dir = os.path.join(args.output_dir, "pissa_init") + logger.info("Initial PiSSA adatper will be saved at: {}.".format(pissa_init_dir)) + if isinstance(model, PeftModel): + init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") + setattr(model.peft_config["default"], "init_lora_weights", True) + model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) + setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + if args.should_save: + model = kwargs.pop("model") + pissa_init_dir = os.path.join(args.output_dir, "pissa_init") + pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") + pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") + logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir)) + # 1. save a pissa backup with init_lora_weights: True + # 2. save a converted lora with init_lora_weights: pissa + # 3. load the pissa backup with init_lora_weights: True + # 4. delete the initial adapter and change init_lora_weights to pissa + if isinstance(model, PeftModel): + init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") + setattr(model.peft_config["default"], "init_lora_weights", True) + model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors) + setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) + model.save_pretrained( + pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir + ) + model.load_adapter(pissa_backup_dir, "default", is_trainable=True) + model.set_adapter("default") + model.delete_adapter("pissa_init") + setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) + + +class LogCallback(TrainerCallback): + def __init__(self) -> None: + r""" + Initializes a callback for logging training and evaluation status. + """ + """ Progress """ + self.start_time = 0 + self.cur_steps = 0 + self.max_steps = 0 + self.elapsed_time = "" + self.remaining_time = "" + self.thread_pool: Optional["ThreadPoolExecutor"] = None + """ Status """ + self.aborted = False + self.do_train = False + """ Web UI """ + self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] + if self.webui_mode: + signal.signal(signal.SIGABRT, self._set_abort) + self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) + logging.root.addHandler(self.logger_handler) + transformers.logging.add_handler(self.logger_handler) + + def _set_abort(self, signum, frame) -> None: + self.aborted = True + + def _reset(self, max_steps: int = 0) -> None: + self.start_time = time.time() + self.cur_steps = 0 + self.max_steps = max_steps + self.elapsed_time = "" + self.remaining_time = "" + + def _timing(self, cur_steps: int) -> None: + cur_time = time.time() + elapsed_time = cur_time - self.start_time + avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 + remaining_time = (self.max_steps - cur_steps) * avg_time_per_step + self.cur_steps = cur_steps + self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) + self.remaining_time = str(timedelta(seconds=int(remaining_time))) + + def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None: + with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: + f.write(json.dumps(logs) + "\n") + + def _create_thread_pool(self, output_dir: str) -> None: + os.makedirs(output_dir, exist_ok=True) + self.thread_pool = ThreadPoolExecutor(max_workers=1) + + def _close_thread_pool(self) -> None: + if self.thread_pool is not None: + self.thread_pool.shutdown(wait=True) + self.thread_pool = None + + def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of the initialization of the `Trainer`. + """ + if ( + args.should_save + and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) + and args.overwrite_output_dir + ): + logger.warning("Previous trainer log in this folder will be deleted.") + os.remove(os.path.join(args.output_dir, TRAINER_LOG)) + + def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the beginning of training. + """ + if args.should_save: + self.do_train = True + self._reset(max_steps=state.max_steps) + self._create_thread_pool(output_dir=args.output_dir) + + def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of training. + """ + self._close_thread_pool() + + def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of an substep during gradient accumulation. + """ + if self.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called at the end of a training step. + """ + if self.aborted: + control.should_epoch_stop = True + control.should_training_stop = True + + def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after an evaluation phase. + """ + if not self.do_train: + self._close_thread_pool() + + def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after a successful prediction. + """ + if not self.do_train: + self._close_thread_pool() + + def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): + r""" + Event called after logging the last logs. + """ + if not args.should_save: + return + + self._timing(cur_steps=state.global_step) + logs = dict( + current_steps=self.cur_steps, + total_steps=self.max_steps, + loss=state.log_history[-1].get("loss", None), + eval_loss=state.log_history[-1].get("eval_loss", None), + predict_loss=state.log_history[-1].get("predict_loss", None), + reward=state.log_history[-1].get("reward", None), + accuracy=state.log_history[-1].get("rewards/accuracies", None), + learning_rate=state.log_history[-1].get("learning_rate", None), + epoch=state.log_history[-1].get("epoch", None), + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time, + throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)), + total_tokens=state.num_input_tokens_seen, + ) + logs = {k: v for k, v in logs.items() if v is not None} + if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): + logger.info( + "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( + logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"] + ) + ) + + if self.thread_pool is not None: + self.thread_pool.submit(self._write_log, args.output_dir, logs) + + def on_prediction_step( + self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs + ): + r""" + Event called after a prediction step. + """ + if self.do_train: + return + + if self.aborted: + sys.exit(0) + + if not args.should_save: + return + + eval_dataloader = kwargs.pop("eval_dataloader", None) + if has_length(eval_dataloader): + if self.max_steps == 0: + self._reset(max_steps=len(eval_dataloader)) + self._create_thread_pool(output_dir=args.output_dir) + + self._timing(cur_steps=self.cur_steps + 1) + if self.cur_steps % 5 == 0 and self.thread_pool is not None: + logs = dict( + current_steps=self.cur_steps, + total_steps=self.max_steps, + percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, + elapsed_time=self.elapsed_time, + remaining_time=self.remaining_time, + ) + self.thread_pool.submit(self._write_log, args.output_dir, logs) diff --git a/src/train/mmsft/__init__.py b/src/train/mmsft/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92c6c2a375f1f7d4af2956daf31231dd9edd25b1 --- /dev/null +++ b/src/train/mmsft/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2024 the LlamaFactory team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from .workflow import run_mmsft + + +__all__ = ["run_mmsft"] diff --git a/src/train/mmsft/metric.py b/src/train/mmsft/metric.py new file mode 100644 index 0000000000000000000000000000000000000000..84f5dbf80dd7296c9e4b4ebfa617f5123597fab3 --- /dev/null +++ b/src/train/mmsft/metric.py @@ -0,0 +1,96 @@ +# Copyright 2024 Llamole Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict + +import numpy as np +import torch +from transformers.utils import is_jieba_available, is_nltk_available + +from ...extras.constants import IGNORE_INDEX +from ...extras.packages import is_rouge_available + + +if TYPE_CHECKING: + from transformers import EvalPrediction, PreTrainedTokenizer + + +if is_jieba_available(): + import jieba # type: ignore + + +if is_nltk_available(): + from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu + + +if is_rouge_available(): + from rouge_chinese import Rouge + + +def compute_accuracy(eval_preds: "EvalPrediction") -> Dict[str, float]: + preds, labels = eval_preds.predictions, eval_preds.label_ids + accuracies = [] + for i in range(len(preds)): + pred, label = preds[i, :-1], labels[i, 1:] + label_mask = label != IGNORE_INDEX + accuracies.append(np.mean(pred[label_mask] == label[label_mask])) + + return {"accuracy": float(np.mean(accuracies))} + + +def eval_logit_processor(logits: "torch.Tensor", labels: "torch.Tensor") -> "torch.Tensor": + logits = logits[0] if isinstance(logits, (list, tuple)) else logits + return torch.argmax(logits, dim=-1) + + +@dataclass +class ComputeMetrics: + r""" + Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. + """ + + tokenizer: "PreTrainedTokenizer" + + def __call__(self, eval_preds: "EvalPrediction") -> Dict[str, float]: + r""" + Uses the model predictions to compute metrics. + """ + preds, labels = eval_preds.predictions, eval_preds.label_ids + score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} + + preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) + labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) + + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + + for pred, label in zip(decoded_preds, decoded_labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + + if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: + result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} + else: + rouge = Rouge() + scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v["f"] * 100, 4)) + + bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) + score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + + return {k: float(np.mean(v)) for k, v in score_dict.items()} diff --git a/src/train/mmsft/trainer.py b/src/train/mmsft/trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..7b3f34cf1e2fe83fb9516082eb01c0ef40b7f154 --- /dev/null +++ b/src/train/mmsft/trainer.py @@ -0,0 +1,255 @@ +# Copyright 2024 HuggingFace Inc., the LlamaFactory team, and the Llamole team. +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer_seq2seq.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import json +import os +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import warnings +import numpy as np +import torch +from transformers import Seq2SeqTrainer + +from ...extras.constants import IGNORE_INDEX +from ...extras.logging import get_logger +from ..callbacks import PissaConvertCallback, SaveProcessorCallback +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler + + +if TYPE_CHECKING: + import optuna + from torch.utils.data import Dataset + from transformers import ProcessorMixin + from transformers.trainer import PredictionOutput + + from ...hparams import FinetuningArguments + +from transformers.trainer_utils import ( + enable_full_determinism, + find_executable_batch_size, + get_last_checkpoint, + set_seed, +) + +import huggingface_hub.utils as hf_hub_utils +from transformers.utils import is_sagemaker_mp_enabled +from transformers.trainer_callback import TrainerState + +TRAINER_STATE_NAME = "trainer_state.json" + +logger = get_logger(__name__) + + +class CustomSeq2SeqTrainer(Seq2SeqTrainer): + r""" + Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. + """ + + def __init__( + self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs + ) -> None: + super().__init__(**kwargs) + self.finetuning_args = finetuning_args + + if processor is not None: + self.add_callback(SaveProcessorCallback(processor)) + + if finetuning_args.pissa_convert: + self.add_callback(PissaConvertCallback) + + + def create_optimizer(self) -> "torch.optim.Optimizer": + if self.optimizer is None: + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() + + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) + + def prediction_step( + self, + model: "torch.nn.Module", + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + r""" + Removes the prompt part in the generated tokens. + + Subclass and override to inject custom behavior. + """ + labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels + if self.args.predict_with_generate: + assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." + prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) + if prompt_len > label_len: + inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) + if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility) + inputs["labels"] = inputs["labels"][:, :prompt_len] + + loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + if generated_tokens is not None and self.args.predict_with_generate: + generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id + generated_tokens = generated_tokens.contiguous() + + return loss, generated_tokens, labels + + def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor: + r""" + Pads the tensor to the same length as the target tensor. + """ + assert self.tokenizer.pad_token_id is not None, "Pad token is required." + padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) + padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding + return padded_tensor.contiguous() # in contiguous memory + + def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None: + r""" + Saves model predictions to `output_dir`. + + A custom behavior that not contained in Seq2SeqTrainer. + """ + if not self.is_world_process_zero(): + return + + output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") + logger.info(f"Saving prediction results to {output_prediction_file}") + + labels = np.where( + predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id + ) + preds = np.where( + predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id + ) + + for i in range(len(preds)): + pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] + if len(pad_len): # move pad token to last + preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) + + decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + + with open(output_prediction_file, "w", encoding="utf-8") as writer: + res: List[str] = [] + for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds): + res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False)) + + writer.write("\n".join(res)) + + def train( + self, + resume_from_checkpoint: Optional[Union[str, bool]] = None, + trial: Union["optuna.Trial", Dict[str, Any]] = None, + ignore_keys_for_eval: Optional[List[str]] = None, + **kwargs, + ): + """ + Main training entry point. + + Args: + resume_from_checkpoint (`str` or `bool`, *optional*): + If a `str`, local path to a saved checkpoint as saved by a previous instance of [`Trainer`]. If a + `bool` and equals `True`, load the last checkpoint in *args.output_dir* as saved by a previous instance + of [`Trainer`]. If present, training will resume from the model/optimizer/scheduler states loaded here. + trial (`optuna.Trial` or `Dict[str, Any]`, *optional*): + The trial run or the hyperparameter dictionary for hyperparameter search. + ignore_keys_for_eval (`List[str]`, *optional*) + A list of keys in the output of your model (if it is a dictionary) that should be ignored when + gathering predictions for evaluation during the training. + kwargs (`Dict[str, Any]`, *optional*): + Additional keyword arguments used to hide deprecated arguments + """ + if resume_from_checkpoint is False: + resume_from_checkpoint = None + + # memory metrics - must set up as early as possible + self._memory_tracker.start() + + args = self.args + + self.is_in_train = True + + # Attach NEFTune hooks if necessary + if self.neftune_noise_alpha is not None: + self.model = self._activate_neftune(self.model) + + # do_train is not a reliable argument, as it might not be set and .train() still called, so + # the following is a workaround: + if (args.fp16_full_eval or args.bf16_full_eval) and not args.do_train: + self._move_model_to_device(self.model, args.device) + + if "model_path" in kwargs: + resume_from_checkpoint = kwargs.pop("model_path") + warnings.warn( + "`model_path` is deprecated and will be removed in a future version. Use `resume_from_checkpoint` " + "instead.", + FutureWarning, + ) + if len(kwargs) > 0: + raise TypeError(f"train() received got unexpected keyword arguments: {', '.join(list(kwargs.keys()))}.") + # This might change the seed so needs to run first. + self._hp_search_setup(trial) + self._train_batch_size = self.args.train_batch_size + + # Model re-init + model_reloaded = False + if self.model_init is not None: + # Seed must be set before instantiating the model when using model_init. + enable_full_determinism(self.args.seed) if self.args.full_determinism else set_seed(self.args.seed) + self.model = self.call_model_init(trial) + model_reloaded = True + # Reinitializes optimizer and scheduler + self.optimizer, self.lr_scheduler = None, None + + # Load potential model checkpoint + if isinstance(resume_from_checkpoint, bool) and resume_from_checkpoint: + resume_from_checkpoint = get_last_checkpoint(args.output_dir) + if resume_from_checkpoint is None: + raise ValueError(f"No valid checkpoint found in output directory ({args.output_dir})") + + if resume_from_checkpoint is not None: + if not is_sagemaker_mp_enabled() and not self.is_deepspeed_enabled and not self.is_fsdp_enabled: + self._load_from_checkpoint(resume_from_checkpoint, self.model.language_model) + # In case of repeating the find_executable_batch_size, set `self._train_batch_size` properly + state = TrainerState.load_from_json(os.path.join(resume_from_checkpoint, TRAINER_STATE_NAME)) + if state.train_batch_size is not None: + self._train_batch_size = state.train_batch_size + + # If model was re-initialized, put it on the right device and update self.model_wrapped + if model_reloaded: + if self.place_model_on_device: + self._move_model_to_device(self.model, args.device) + self.model_wrapped = self.model + + inner_training_loop = find_executable_batch_size( + self._inner_training_loop, self._train_batch_size, args.auto_find_batch_size + ) + + return inner_training_loop( + args=args, + resume_from_checkpoint=resume_from_checkpoint, + trial=trial, + ignore_keys_for_eval=ignore_keys_for_eval, + ) \ No newline at end of file diff --git a/src/train/mmsft/workflow.py b/src/train/mmsft/workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c0564e299233d1d445e7ca454c2a8059fce0d6 --- /dev/null +++ b/src/train/mmsft/workflow.py @@ -0,0 +1,119 @@ +# Copyright 2024 Llamole Team +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, List, Optional + +from ...data import get_dataset, split_dataset, DataCollatorForSeqGraph +from ...extras.constants import IGNORE_INDEX +from ...extras.misc import get_logits_processor +from ...extras.ploting import plot_loss +from ...model import load_language_model, load_tokenizer + +from ...model import GraphLLMForCausalMLM + +from .metric import ComputeMetrics, compute_accuracy, eval_logit_processor +from .trainer import CustomSeq2SeqTrainer + +from torch.utils.data import DataLoader + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import ( + DataArguments, + FinetuningArguments, + GeneratingArguments, + ModelArguments, + ) + + +def run_mmsft( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + mol_id_to_pyg, dataset = get_dataset( + model_args, data_args, training_args, tokenizer=tokenizer + ) + + data_collator = DataCollatorForSeqGraph( + tokenizer=tokenizer, + mol_id_to_pyg=mol_id_to_pyg, + pad_to_multiple_of=( + 8 if tokenizer.padding_side == "right" else None + ), # for shift short attention + label_pad_token_id=( + IGNORE_INDEX + if data_args.ignore_pad_token_for_loss + else tokenizer.pad_token_id + ), + ) + + model = GraphLLMForCausalMLM.from_pretrained( + tokenizer, model_args, data_args, training_args, finetuning_args + ) + + # Override the decoding parameters of Seq2SeqTrainer + training_args.generation_max_length = ( + training_args.generation_max_length or data_args.cutoff_len + ) + training_args.generation_num_beams = ( + data_args.eval_num_beams or training_args.generation_num_beams + ) + training_args.remove_unused_columns = False + + trainer = CustomSeq2SeqTrainer( + model=model, + args=training_args, + finetuning_args=finetuning_args, + data_collator=data_collator, + callbacks=callbacks, + compute_metrics=( + ComputeMetrics(tokenizer) + if training_args.predict_with_generate + else compute_accuracy + ), + preprocess_logits_for_metrics=( + None if training_args.predict_with_generate else eval_logit_processor + ), + **tokenizer_module, + **split_dataset(dataset, data_args, training_args), + ) + + # Keyword arguments for `model.generate` + gen_kwargs = generating_args.to_dict() + gen_kwargs["eos_token_id"] = [ + tokenizer.eos_token_id + ] + tokenizer.additional_special_tokens_ids + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + + # Training + if training_args.do_train: + train_result = trainer.train( + resume_from_checkpoint=training_args.resume_from_checkpoint + ) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss( + training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"] + ) \ No newline at end of file diff --git a/src/train/trainer_utils.py b/src/train/trainer_utils.py new file mode 100644 index 0000000000000000000000000000000000000000..653273e655e3addeec87547650c0deeb9a76de1d --- /dev/null +++ b/src/train/trainer_utils.py @@ -0,0 +1,427 @@ +# Copyright 2024 HuggingFace Inc. and the LlamaFactory team. +# +# This code is inspired by the original GaLore's implementation: https://github.com/jiaweizzhao/GaLore +# and the original LoRA+'s implementation: https://github.com/nikhil-ghosh-berkeley/loraplus +# and the original BAdam's implementation: https://github.com/Ledzy/BAdam +# and the HuggingFace's TRL library: https://github.com/huggingface/trl +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple, Union + +import torch +from transformers import Trainer +from transformers.integrations import is_deepspeed_zero3_enabled +from transformers.optimization import get_scheduler +from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS +from transformers.trainer_pt_utils import get_parameter_names + +from ..extras.constants import IGNORE_INDEX +from ..extras.logging import get_logger +from ..extras.packages import is_galore_available +from ..hparams import FinetuningArguments, ModelArguments +from ..model import find_all_linear_modules, load_language_model, load_tokenizer, load_valuehead_params + + +if is_galore_available(): + from galore_torch import GaLoreAdafactor, GaLoreAdamW, GaLoreAdamW8bit + + +if TYPE_CHECKING: + from transformers import PreTrainedModel, Seq2SeqTrainingArguments + from trl import AutoModelForCausalLMWithValueHead + + from ..hparams import DataArguments + + +logger = get_logger(__name__) + + +class DummyOptimizer(torch.optim.Optimizer): + r""" + A dummy optimizer used for the GaLore algorithm. + """ + + def __init__( + self, lr: float = 1e-3, optimizer_dict: Optional[Dict["torch.nn.Parameter", "torch.optim.Optimizer"]] = None + ) -> None: + dummy_tensor = torch.randn(1, 1) + self.optimizer_dict = optimizer_dict + super().__init__([dummy_tensor], {"lr": lr}) + + def zero_grad(self, set_to_none: bool = True) -> None: + pass + + def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: + pass + + +def create_modelcard_and_push( + trainer: "Trainer", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> None: + kwargs = { + "tasks": "text-generation", + "finetuned_from": model_args.model_name_or_path, + "tags": [finetuning_args.finetuning_type], + } + if data_args.dataset is not None: + kwargs["dataset"] = [dataset.strip() for dataset in data_args.dataset.split(",")] + + if model_args.use_unsloth: + kwargs["tags"] = kwargs["tags"] + ["unsloth"] + + if not training_args.do_train: + pass + elif training_args.push_to_hub: + trainer.push_to_hub(**kwargs) + else: + trainer.create_model_card(license="other", **kwargs) # prevent from connecting to hub + + +def create_ref_model( + model_args: "ModelArguments", finetuning_args: "FinetuningArguments", add_valuehead: bool = False +) -> Optional[Union["PreTrainedModel", "AutoModelForCausalLMWithValueHead"]]: + r""" + Creates reference model for PPO/DPO training. Evaluation mode is not supported. + + The valuehead parameter is randomly initialized since it is useless for PPO training. + """ + if finetuning_args.ref_model is not None: + ref_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.ref_model, + adapter_name_or_path=finetuning_args.ref_model_adapters, + quantization_bit=finetuning_args.ref_model_quantization_bit, + ) + ref_finetuning_args = FinetuningArguments() + tokenizer = load_tokenizer(ref_model_args)["tokenizer"] + ref_model = load_language_model( + tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead + ) + logger.info("Created reference model from {}".format(finetuning_args.ref_model)) + else: + if finetuning_args.finetuning_type == "lora": + ref_model = None + else: + ref_model_args = ModelArguments.copyfrom(model_args) + ref_finetuning_args = FinetuningArguments() + tokenizer = load_tokenizer(ref_model_args)["tokenizer"] + ref_model = load_language_model( + tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead + ) + logger.info("Created reference model from the model itself.") + + return ref_model + + +def create_reward_model( + model: "AutoModelForCausalLMWithValueHead", model_args: "ModelArguments", finetuning_args: "FinetuningArguments" +) -> Optional["AutoModelForCausalLMWithValueHead"]: + r""" + Creates reward model for PPO training. + """ + if finetuning_args.reward_model_type == "api": + assert finetuning_args.reward_model.startswith("http"), "Please provide full url." + logger.info("Use reward server {}".format(finetuning_args.reward_model)) + return finetuning_args.reward_model + elif finetuning_args.reward_model_type == "lora": + model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") + for name, param in model.named_parameters(): # https://github.com/huggingface/peft/issues/1090 + if "default" in name: + param.data = param.data.to(torch.float32) # trainable params should in fp32 + vhead_params = load_valuehead_params(finetuning_args.reward_model, model_args) + assert vhead_params is not None, "Reward model is not correctly loaded." + model.register_buffer("reward_head_weight", vhead_params["v_head.summary.weight"], persistent=False) + model.register_buffer("reward_head_bias", vhead_params["v_head.summary.bias"], persistent=False) + model.register_buffer( + "default_head_weight", torch.zeros_like(vhead_params["v_head.summary.weight"]), persistent=False + ) + model.register_buffer( + "default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False + ) + logger.info("Loaded adapter weights of reward model from {}".format(finetuning_args.reward_model)) + return None + else: + reward_model_args = ModelArguments.copyfrom( + model_args, + model_name_or_path=finetuning_args.reward_model, + adapter_name_or_path=finetuning_args.reward_model_adapters, + quantization_bit=finetuning_args.reward_model_quantization_bit, + ) + reward_finetuning_args = FinetuningArguments() + tokenizer = load_tokenizer(reward_model_args)["tokenizer"] + reward_model = load_language_model( + tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True + ) + logger.info("Loaded full weights of reward model from {}".format(finetuning_args.reward_model)) + logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") + return reward_model + + +def _get_decay_parameter_names(model: "PreTrainedModel") -> List[str]: + r""" + Returns a list of names of parameters with weight decay. (weights in non-layernorm layers) + """ + decay_parameters = get_parameter_names(model, ALL_LAYERNORM_LAYERS) + decay_parameters = [name for name in decay_parameters if "bias" not in name] + return decay_parameters + + +def _create_galore_optimizer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + if len(finetuning_args.galore_target) == 1 and finetuning_args.galore_target[0] == "all": + galore_targets = find_all_linear_modules(model, finetuning_args.freeze_vision_tower) + else: + galore_targets = finetuning_args.galore_target + + galore_params: List["torch.nn.Parameter"] = [] + for name, module in model.named_modules(): + if isinstance(module, torch.nn.Linear) and any(target in name for target in galore_targets): + for param in module.parameters(): + if param.requires_grad and len(param.shape) > 1: + galore_params.append(param) + + galore_kwargs = { + "rank": finetuning_args.galore_rank, + "update_proj_gap": finetuning_args.galore_update_interval, + "scale": finetuning_args.galore_scale, + "proj_type": finetuning_args.galore_proj_type, + } + + id_galore_params = {id(param) for param in galore_params} + decay_params, nodecay_params = [], [] # they are non-galore parameters + trainable_params: List["torch.nn.Parameter"] = [] # galore_params + decay_params + nodecay_params + decay_param_names = _get_decay_parameter_names(model) + for name, param in model.named_parameters(): + if param.requires_grad: + trainable_params.append(param) + if id(param) not in id_galore_params: + if name in decay_param_names: + decay_params.append(param) + else: + nodecay_params.append(param) + + _, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + + if training_args.optim == "adamw_torch": + optim_class = GaLoreAdamW + elif training_args.optim in ["adamw_bnb_8bit", "adamw_8bit", "paged_adamw_8bit"]: + optim_class = GaLoreAdamW8bit + elif training_args.optim == "adafactor": + optim_class = GaLoreAdafactor + else: + raise NotImplementedError("Unknow optim: {}".format(training_args.optim)) + + if finetuning_args.galore_layerwise: + if training_args.gradient_accumulation_steps != 1: + raise ValueError("Per-layer GaLore does not support gradient accumulation.") + + optimizer_dict: Dict["torch.Tensor", "torch.optim.Optimizer"] = {} + for param in nodecay_params: + param_groups = [dict(params=[param], weight_decay=0.0)] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + for param in decay_params: + param_groups = [dict(params=[param], weight_decay=training_args.weight_decay)] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + for param in galore_params: # galore params have weight decay + param_groups = [dict(params=[param], weight_decay=training_args.weight_decay, **galore_kwargs)] + optimizer_dict[param] = optim_class(param_groups, **optim_kwargs) + + def optimizer_hook(param: "torch.nn.Parameter"): + if param.grad is not None: + optimizer_dict[param].step() + optimizer_dict[param].zero_grad() + + for param in trainable_params: + param.register_post_accumulate_grad_hook(optimizer_hook) + + optimizer = DummyOptimizer(lr=training_args.learning_rate, optimizer_dict=optimizer_dict) + else: + param_groups = [ + dict(params=nodecay_params, weight_decay=0.0), + dict(params=decay_params, weight_decay=training_args.weight_decay), + dict(params=galore_params, weight_decay=training_args.weight_decay, **galore_kwargs), + ] + optimizer = optim_class(param_groups, **optim_kwargs) + + logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") + return optimizer + + +def _create_loraplus_optimizer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + default_lr = training_args.learning_rate + loraplus_lr = training_args.learning_rate * finetuning_args.loraplus_lr_ratio + embedding_lr = finetuning_args.loraplus_lr_embedding + + decay_param_names = _get_decay_parameter_names(model) + param_dict: Dict[str, List["torch.nn.Parameter"]] = { + "lora_a": [], + "lora_b": [], + "lora_b_nodecay": [], + "embedding": [], + } + for name, param in model.named_parameters(): + if param.requires_grad: + if "lora_embedding_B" in name: + param_dict["embedding"].append(param) + elif "lora_B" in name or param.ndim == 1: + if name in decay_param_names: + param_dict["lora_b"].append(param) + else: + param_dict["lora_b_nodecay"].append(param) + else: + param_dict["lora_a"].append(param) + + optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + param_groups = [ + dict(params=param_dict["lora_a"], lr=default_lr, weight_decay=training_args.weight_decay), + dict(params=param_dict["lora_b"], lr=loraplus_lr, weight_decay=training_args.weight_decay), + dict(params=param_dict["lora_b_nodecay"], lr=loraplus_lr, weight_decay=0.0), + dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay), + ] + optimizer = optim_class(param_groups, **optim_kwargs) + logger.info("Using LoRA+ optimizer with loraplus lr ratio {:.2f}.".format(finetuning_args.loraplus_lr_ratio)) + return optimizer + + +def _create_badam_optimizer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> "torch.optim.Optimizer": + decay_params, nodecay_params = [], [] + decay_param_names = _get_decay_parameter_names(model) + for name, param in model.named_parameters(): + if param.requires_grad: + if name in decay_param_names: + decay_params.append(param) + else: + nodecay_params.append(param) + + optim_class, optim_kwargs = Trainer.get_optimizer_cls_and_kwargs(training_args) + param_groups = [ + dict(params=nodecay_params, weight_decay=0.0), + dict(params=decay_params, weight_decay=training_args.weight_decay), + ] + + if finetuning_args.badam_mode == "layer": + from badam import BlockOptimizer + + base_optimizer = optim_class(param_groups, **optim_kwargs) + optimizer = BlockOptimizer( + base_optimizer=base_optimizer, + named_parameters_list=list(model.named_parameters()), + block_prefix_list=None, + switch_block_every=finetuning_args.badam_switch_interval, + start_block=finetuning_args.badam_start_block, + switch_mode=finetuning_args.badam_switch_mode, + verbose=finetuning_args.badam_verbose, + ds_zero3_enabled=is_deepspeed_zero3_enabled(), + ) + logger.info( + f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, " + f"switch block every {finetuning_args.badam_switch_interval} steps, " + f"default start block is {finetuning_args.badam_start_block}" + ) + + elif finetuning_args.badam_mode == "ratio": + from badam import BlockOptimizerRatio + + assert finetuning_args.badam_update_ratio > 1e-6 + optimizer = BlockOptimizerRatio( + param_groups=param_groups, + named_parameters_list=list(model.named_parameters()), + update_ratio=finetuning_args.badam_update_ratio, + mask_mode=finetuning_args.badam_mask_mode, + verbose=finetuning_args.badam_verbose, + include_embedding=False, + **optim_kwargs, + ) + logger.info( + f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, " + f"mask mode is {finetuning_args.badam_mask_mode}" + ) + + return optimizer + + +def create_custom_optimzer( + model: "PreTrainedModel", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", +) -> Optional["torch.optim.Optimizer"]: + if finetuning_args.use_galore: + return _create_galore_optimizer(model, training_args, finetuning_args) + + if finetuning_args.loraplus_lr_ratio is not None: + return _create_loraplus_optimizer(model, training_args, finetuning_args) + + if finetuning_args.use_badam: + return _create_badam_optimizer(model, training_args, finetuning_args) + + +def create_custom_scheduler( + training_args: "Seq2SeqTrainingArguments", + num_training_steps: int, + optimizer: Optional["torch.optim.Optimizer"] = None, +) -> None: + if optimizer is not None and isinstance(optimizer, DummyOptimizer): + optimizer_dict = optimizer.optimizer_dict + scheduler_dict: Dict["torch.nn.Parameter", "torch.optim.lr_scheduler.LRScheduler"] = {} + + for param in optimizer_dict.keys(): + scheduler_dict[param] = get_scheduler( + training_args.lr_scheduler_type, + optimizer=optimizer_dict[param], + num_warmup_steps=training_args.get_warmup_steps(num_training_steps), + num_training_steps=num_training_steps, + scheduler_specific_kwargs=training_args.lr_scheduler_kwargs, + ) + + def scheduler_hook(param: "torch.nn.Parameter"): + scheduler_dict[param].step() + + for param in optimizer_dict.keys(): + param.register_post_accumulate_grad_hook(scheduler_hook) + + +def get_batch_logps( + logits: "torch.Tensor", labels: "torch.Tensor", label_pad_token_id: int = IGNORE_INDEX +) -> Tuple["torch.Tensor", "torch.Tensor"]: + r""" + Computes the log probabilities of the given labels under the given logits. + + Returns: + logps: A tensor of shape (batch_size,) containing the sum of log probabilities. + valid_length: A tensor of shape (batch_size,) containing the number of non-masked tokens. + """ + if logits.shape[:-1] != labels.shape: + raise ValueError("Logits (batchsize x seqlen) and labels must have the same shape.") + + labels = labels[:, 1:].clone() + logits = logits[:, :-1, :] + loss_mask = labels != label_pad_token_id + labels[labels == label_pad_token_id] = 0 # dummy token + per_token_logps = torch.gather(logits.log_softmax(-1), dim=2, index=labels.unsqueeze(2)).squeeze(2) + return (per_token_logps * loss_mask).sum(-1), loss_mask.sum(-1) diff --git a/src/train/tuner.py b/src/train/tuner.py new file mode 100644 index 0000000000000000000000000000000000000000..db03b016b3516c05adc1716627aaae6f83a820f1 --- /dev/null +++ b/src/train/tuner.py @@ -0,0 +1,115 @@ +# Copyright 2024 the LlamaFactory team and the Llamole team. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import shutil +from typing import TYPE_CHECKING, Any, Dict, List, Optional + +import torch +from transformers import PreTrainedModel + +from ..data import get_template_and_fix_tokenizer +from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME +from ..extras.logging import get_logger +from ..hparams import get_infer_args, get_train_args +from ..model import load_language_model, load_tokenizer +from .callbacks import LogCallback +from .mmsft import run_mmsft + +if TYPE_CHECKING: + from transformers import TrainerCallback + +logger = get_logger(__name__) + +def run_train(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: + callbacks.append(LogCallback()) + model_args, data_args, training_args, finetuning_args, generating_args = get_train_args(args) + run_mmsft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) + +def merge_adapter(args: Optional[Dict[str, Any]] = None) -> None: + model_args, data_args, finetuning_args, _ = get_infer_args(args) + + if model_args.export_dir is None: + raise ValueError("Please specify `export_dir` to save model.") + + if model_args.adapter_name_or_path is not None and model_args.export_quantization_bit is not None: + raise ValueError("Please merge adapters before quantizing the model.") + + tokenizer_module = load_tokenizer(model_args) + tokenizer = tokenizer_module["tokenizer"] + processor = tokenizer_module["processor"] + get_template_and_fix_tokenizer(tokenizer, data_args.template) + model = load_language_model(tokenizer, model_args, finetuning_args) # must after fixing tokenizer to resize vocab + + if getattr(model, "quantization_method", None) and model_args.adapter_name_or_path is not None: + raise ValueError("Cannot merge adapters to a quantized model.") + + if not isinstance(model, PreTrainedModel): + raise ValueError("The model is not a `PreTrainedModel`, export aborted.") + + if getattr(model, "quantization_method", None) is None: # cannot convert dtype of a quantized model + output_dtype = getattr(model.config, "torch_dtype", torch.float16) + setattr(model.config, "torch_dtype", output_dtype) + model = model.to(output_dtype) + else: + setattr(model.config, "torch_dtype", torch.float16) + + model.save_pretrained( + save_directory=model_args.export_dir, + max_shard_size="{}GB".format(model_args.export_size), + safe_serialization=(not model_args.export_legacy_format), + ) + if model_args.export_hub_model_id is not None: + model.push_to_hub( + model_args.export_hub_model_id, + token=model_args.hf_hub_token, + max_shard_size="{}GB".format(model_args.export_size), + safe_serialization=(not model_args.export_legacy_format), + ) + + if finetuning_args.stage == "rm": + if model_args.adapter_name_or_path is not None: + vhead_path = model_args.adapter_name_or_path[-1] + else: + vhead_path = model_args.model_name_or_path + + if os.path.exists(os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME)): + shutil.copy( + os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME), + os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME), + ) + logger.info("Copied valuehead to {}.".format(model_args.export_dir)) + elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)): + shutil.copy( + os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME), + os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME), + ) + logger.info("Copied valuehead to {}.".format(model_args.export_dir)) + + try: + tokenizer.padding_side = "left" # restore padding side + tokenizer.init_kwargs["padding_side"] = "left" + tokenizer.save_pretrained(model_args.export_dir) + if model_args.export_hub_model_id is not None: + tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) + + if model_args.visual_inputs and processor is not None: + getattr(processor, "image_processor").save_pretrained(model_args.export_dir) + if model_args.export_hub_model_id is not None: + getattr(processor, "image_processor").push_to_hub( + model_args.export_hub_model_id, token=model_args.hf_hub_token + ) + + except Exception: + logger.warning("Cannot save tokenizer, please copy the files manually.") diff --git a/src/webui/__init__.py b/src/webui/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/webui/dataset.py b/src/webui/dataset.py new file mode 100644 index 0000000000000000000000000000000000000000..a3e84169e9cc997c913e7008cbfb90b6c63bd3a4 --- /dev/null +++ b/src/webui/dataset.py @@ -0,0 +1,78 @@ +# Copyright 2024 Llamole Team +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import torch +from torch.utils.data import Dataset + +from ..extras.constants import BOND_INDEX + +def dict_to_list(data_dict, mol_properties): + return [data_dict.get(prop, float("nan")) for prop in mol_properties] + +class MolQADataset(Dataset): + def __init__(self, data, tokenizer, max_len): + self.data = data + self.tokenizer = tokenizer + self.max_len = max_len + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + mol_properties = [ + "BBBP", + "HIV", + "BACE", + "CO2", + "N2", + "O2", + "FFV", + "TC", + "SC", + "SA", + ] + item = self.data[idx] + instruction = item["instruction"] + input_text = item["input"] + property_data = dict_to_list(item["property"], mol_properties) + property_data = torch.tensor(property_data) + + # Combine instruction and input + combined_input = f"{instruction}\n{input_text}" + + # Create messages for chat template + messages = [ + {"role": "user", "content": combined_input} + ] + + # Apply chat template + chat_text = self.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True) + + # Tokenize the chat text + encoding = self.tokenizer( + chat_text, + return_tensors="pt", + padding="max_length", + truncation=True, + max_length=self.max_len, + ) + + return { + "input_ids": encoding.input_ids.squeeze(), + "attention_mask": encoding.attention_mask.squeeze(), + "property": property_data, + } \ No newline at end of file diff --git a/src/webui/elements.py b/src/webui/elements.py new file mode 100644 index 0000000000000000000000000000000000000000..85889e803b50870b7835c8c65066341ba5ed7547 --- /dev/null +++ b/src/webui/elements.py @@ -0,0 +1,95 @@ +# Copyright 2024 Llamole Team +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import gradio as gr +import numpy as np + +def create_input_components(default_values, property_names, property_ranges): + initial_property_type = default_values.get('type', 'Material') + + with gr.Blocks() as interface: + with gr.Row(): + with gr.Column(scale=2): + instruction = gr.Textbox( + label="Instruction", + value=default_values.get('instruction', ''), + lines=3, + placeholder="Enter your instruction here..." + ) + with gr.Column(scale=1): + property_type = gr.Radio( + ["Drug", "Material"], + label="Property Type", + value=initial_property_type, + interactive=True + ) + + with gr.Row(): + drug_properties = {} + for prop in ["HIV", "BBBP", "BACE"]: + default_value = default_values.get('property', {}).get(prop, "NAN") + drug_properties[prop] = gr.Radio( + [0, 1, "NAN"], + label=f"{property_names[prop]} ({prop})", + value=default_value if default_value in [0, 1, "NAN"] else "NAN", + visible=initial_property_type == "Drug", + interactive=True + ) + + material_properties = {} + for prop in ["CO2", "N2", "O2", "FFV", "TC"]: + min_val = property_ranges[prop]['min'] + max_val = property_ranges[prop]['max'] + default_value = default_values.get('property', {}).get(prop, 0) + material_properties[prop] = gr.Slider( + label=f"{property_names[prop]} (0 for uncondition)", + minimum=0, + maximum=max_val, + value=default_value if default_value != "NAN" else 0, + step=0.1, + visible=initial_property_type == "Material", + interactive=True + ) + + with gr.Row(): + synthetic_properties = {} + for prop in ["SC", "SA"]: + min_val = property_ranges[prop]['min'] + max_val = property_ranges[prop]['max'] + default_value = default_values.get('property', {}).get(prop, (min_val + max_val) / 2) + synthetic_properties[prop] = gr.Slider( + label=f"{property_names[prop]} ({prop})", + minimum=min_val, + maximum=max_val, + value=default_value if not np.isnan(default_value) else (min_val + max_val) / 2, + step=(max_val - min_val) / 100, + interactive=True + ) + + def update_visibility(property_type): + return ( + [gr.update(visible=(property_type == "Drug")) for _ in drug_properties.values()] + + [gr.update(visible=(property_type == "Material")) for _ in material_properties.values()] + ) + + property_type.change( + update_visibility, + inputs=[property_type], + outputs=list(drug_properties.values()) + list(material_properties.values()) + ) + + return interface, instruction, property_type, drug_properties, material_properties, synthetic_properties \ No newline at end of file diff --git a/src/webui/workflow.py b/src/webui/workflow.py new file mode 100644 index 0000000000000000000000000000000000000000..7bb803da88c7a51e78c43d67ad8b33f7d615df5f --- /dev/null +++ b/src/webui/workflow.py @@ -0,0 +1,122 @@ +# Copyright 2024 Llamole Team +# +# This code is inspired by the HuggingFace's transformers library. +# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/summarization/run_summarization.py +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import re +import os +import json +import math +import torch +from torch.utils.data import DataLoader +from typing import TYPE_CHECKING, List, Optional, Dict, Any + +from ..data import get_dataset, DataCollatorForSeqGraph, get_template_and_fix_tokenizer +from ..extras.constants import IGNORE_INDEX, NO_LABEL_INDEX +from ..extras.misc import get_logits_processor +from ..extras.ploting import plot_loss +from ..model import load_tokenizer, GraphLLMForCausalMLM +from ..hparams import get_train_args +from .dataset import MolQADataset + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments + from ..hparams import ( + DataArguments, + FinetuningArguments, + GeneratingArguments, + ModelArguments, + ) + +def remove_extra_spaces(text): + cleaned_text = re.sub(r'\s+', ' ', text) + return cleaned_text.strip() + +def load_model_and_tokenizer(args): + model_args, data_args, training_args, finetuning_args, generating_args = ( + get_train_args(args) + ) + tokenizer = load_tokenizer(model_args, generate_mode=True)["tokenizer"] + tokenizer.pad_token = tokenizer.eos_token + + model = GraphLLMForCausalMLM.from_pretrained( + tokenizer, model_args, data_args, training_args, finetuning_args, load_adapter=True + ) + + return model, tokenizer, generating_args + +def process_input(input_data: Dict[str, Any], model, tokenizer, generating_args: "GeneratingArguments"): + + dataset = MolQADataset([input_data], tokenizer, generating_args.max_length) + dataloader = DataLoader( + dataset, batch_size=1, shuffle=False + ) + + gen_kwargs = generating_args.to_dict() + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + + return dataloader, gen_kwargs + +def generate(model, dataloader, gen_kwargs): + property_names = ["BBBP", "HIV", "BACE", "CO2", "N2", "O2", "FFV", "TC", "SC", "SA"] + + for batch in dataloader: + input_ids = batch["input_ids"].to(model.device) + attention_mask = batch["attention_mask"].to(model.device) + property_data = batch["property"].to(model.device) + + model.eval() + with torch.no_grad(): + all_info_dict = model.generate( + input_ids=input_ids, + attention_mask=attention_mask, + molecule_properties=property_data, + do_molecular_design=True, + do_retrosynthesis=True, + expansion_topk=50, + iterations=100, + max_planning_time=30, + rollback=True, + **gen_kwargs, + ) + + assert len(all_info_dict["smiles_list"]) == 1 + + for i in range(len(all_info_dict["smiles_list"])): + llm_response = "".join(item for item in all_info_dict["text_lists"][i] if item is not None) + result = { + "llm_smiles": all_info_dict["smiles_list"][i], + "property": {}, + } + for j, prop_name in enumerate(property_names): + prop_value = property_data[i][j].item() + if not math.isnan(prop_value): + result["property"][prop_name] = prop_value + + retro_plan = all_info_dict["retro_plan_dict"][result["llm_smiles"]] + result["llm_reactions"] = [] + if retro_plan["success"]: + for reaction, template, cost in zip( + retro_plan["reaction_list"], + retro_plan["templates"], + retro_plan["cost"], + ): + result["llm_reactions"].append( + {"reaction": reaction, "template": template, "cost": cost} + ) + result["llm_response"] = remove_extra_spaces(llm_response) + return result \ No newline at end of file