mirror of https://github.com/hpcaitech/ColossalAI
[Coati] first commit (#3283)
parent
fd6add575d
commit
b0ce5a1032
|
@ -0,0 +1,146 @@
|
|||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
pip-wheel-metadata/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
docs/.build/
|
||||
|
||||
# PyBuilder
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
.python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# IDE
|
||||
.idea/
|
||||
.vscode/
|
||||
|
||||
# macos
|
||||
*.DS_Store
|
||||
#data/
|
||||
|
||||
docs/.build
|
||||
|
||||
# pytorch checkpoint
|
||||
*.pt
|
||||
|
||||
# wandb log
|
||||
example/wandb/
|
|
@ -0,0 +1,202 @@
|
|||
Copyright 2021- HPC-AI Technology Inc. All rights reserved.
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
|
||||
APPENDIX: How to apply the Apache License to your work.
|
||||
|
||||
To apply the Apache License to your work, attach the following
|
||||
boilerplate notice, with the fields enclosed by brackets "[]"
|
||||
replaced with your own identifying information. (Don't include
|
||||
the brackets!) The text should be enclosed in the appropriate
|
||||
comment syntax for the file format. We also recommend that a
|
||||
file or class name and description of purpose be included on the
|
||||
same "printed page" as the copyright notice for easier
|
||||
identification within third-party archives.
|
||||
|
||||
Copyright 2021- HPC-AI Technology Inc.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
|
@ -0,0 +1,269 @@
|
|||
<h1 align="center">
|
||||
<span>Coati - ColossalAI Talking Intelligence</span>
|
||||
<img width="auto" height="50px", src="assets/logo_coati.png"/>
|
||||
</h1>
|
||||
|
||||
|
||||
## Table of Contents
|
||||
|
||||
- [Table of Contents](#table-of-contents)
|
||||
- [What is Coati ?](#what-is-coati-)
|
||||
- [Online demo](#online-demo)
|
||||
- [Install](#install)
|
||||
- [Install the environment](#install-the-environment)
|
||||
- [Install the Transformers](#install-the-transformers)
|
||||
- [How to use?](#how-to-use)
|
||||
- [Supervised datasets collection](#supervised-datasets-collection)
|
||||
- [Stage1 - Supervised instructs tuning](#stage1---supervised-instructs-tuning)
|
||||
- [Stage2 - Training reward model](#stage2---training-reward-model)
|
||||
- [Stage3 - Training model with reinforcement learning by human feedback](#stage3---training-model-with-reinforcement-learning-by-human-feedback)
|
||||
- [Coati7B examples](#coati7b-examples)
|
||||
- [FAQ](#faq)
|
||||
- [How to save/load checkpoint](#how-to-saveload-checkpoint)
|
||||
- [The Plan](#the-plan)
|
||||
- [Real-time progress](#real-time-progress)
|
||||
- [Invitation to open-source contribution](#invitation-to-open-source-contribution)
|
||||
- [Quick Preview](#quick-preview)
|
||||
- [Authors](#authors)
|
||||
- [Citations](#citations)
|
||||
- [Licenses](#licenses)
|
||||
---
|
||||
## What is Coati ?
|
||||
|
||||
Coati is a large language model developed by Colossal-AI, which is also a unified large language model framework that has implemented the following functions
|
||||
- Supports comprehensive large-model training acceleration capabilities for ColossalAI, without requiring knowledge of complex distributed training algorithms
|
||||
- Supervised datasets collection
|
||||
- Supervised insturcts fine-tuning
|
||||
- Training reward model
|
||||
- Reinforcement learning with human feedback
|
||||
- Quantization inference
|
||||
- Fast model deploying
|
||||
- Perfectly integration with the Hugging Face ecosystem, high degree of model customization
|
||||
|
||||
|
||||
More details can be found in the [blog](https://www.hpc-ai.tech/blog/colossal-ai-chatgpt).
|
||||
|
||||
<p align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/chatgpt.png" width=700/>
|
||||
</p>
|
||||
|
||||
## Online demo
|
||||
You can experience the performance of Coati7B on this page.
|
||||
|
||||
[chat.colossalai.org](https://chat.colossalai.org/)
|
||||
|
||||
> Warning: Due to model and dataset size limitations, Coati is just a baby model, Coati7B may output incorrect information and lack the ability for multi-turn dialogue. There is still significant room for improvement.
|
||||
## Install
|
||||
|
||||
### Install the environment
|
||||
|
||||
```shell
|
||||
conda creat -n coati
|
||||
conda activate coati
|
||||
pip install .
|
||||
```
|
||||
|
||||
### Install the Transformers
|
||||
Given Hugging Face hasn't officially supported the LLaMA models, We fork a branch of Transformers that can be compatible with our code
|
||||
|
||||
```shell
|
||||
git clone https://github.com/hpcaitech/transformers
|
||||
cd transformers
|
||||
pip install .
|
||||
```
|
||||
|
||||
## How to use?
|
||||
|
||||
### Supervised datasets collection
|
||||
|
||||
we colllected 104K bilingual dataset of Chinese and English, and you can find the datasets in this repo
|
||||
|
||||
Here is how we collected the data
|
||||
<p align="center">
|
||||
<img src="assets/data-collect.png" width=500/>
|
||||
</p>
|
||||
|
||||
### Stage1 - Supervised instructs tuning
|
||||
|
||||
Stage1 is supervised instructs fine-tuning, which uses the datasets mentioned earlier to fine-tune the model
|
||||
|
||||
you can run the `examples/train_sft.sh` to start a supervised instructs fine-tuning
|
||||
|
||||
```
|
||||
torchrun --standalone --nproc_per_node=4 train_sft.py \
|
||||
--pretrain "/path/to/LLaMa-7B/" \
|
||||
--model 'llama' \
|
||||
--strategy colossalai_zero2 \
|
||||
--log_interval 10 \
|
||||
--save_path /path/to/Coati-7B \
|
||||
--dataset /path/to/data.json \
|
||||
--batch_size 4 \
|
||||
--accimulation_steps 8 \
|
||||
--lr 2e-5 \
|
||||
--max_datasets_size 512 \
|
||||
--max_epochs 1 \
|
||||
```
|
||||
|
||||
### Stage2 - Training reward model
|
||||
|
||||
Stage2 trains a reward model, which obtains corresponding scores by manually ranking different outputs for the same prompt and supervises the training of the reward model
|
||||
|
||||
you can run the `examples/train_rm.sh` to start a reward model training
|
||||
|
||||
```
|
||||
torchrun --standalone --nproc_per_node=4 train_reward_model.py
|
||||
--pretrain "/path/to/LLaMa-7B/" \
|
||||
--model 'llama' \
|
||||
--strategy colossalai_zero2 \
|
||||
--loss_fn 'log_exp'\
|
||||
--save_path 'rmstatic.pt' \
|
||||
```
|
||||
|
||||
### Stage3 - Training model with reinforcement learning by human feedback
|
||||
|
||||
Stage3 uses reinforcement learning algorithm, which is the most complex part of the training process:
|
||||
|
||||
<p align="center">
|
||||
<img src="assets/stage-3.jpeg" width=500/>
|
||||
</p>
|
||||
|
||||
you can run the `examples/train_prompts.sh` to start training PPO with human feedback
|
||||
|
||||
```
|
||||
torchrun --standalone --nproc_per_node=4 train_prompts.py prompts.csv \
|
||||
--pretrain "/path/to/LLaMa-7B/" \
|
||||
--model 'llama' \
|
||||
--strategy colossalai_zero2
|
||||
```
|
||||
|
||||
|
||||
For more details, see `examples/`.
|
||||
|
||||
We also support training reward model with true-world data. See `examples/train_reward_model.py`.
|
||||
|
||||
## Coati7B examples
|
||||
|
||||
|
||||
## FAQ
|
||||
|
||||
### How to save/load checkpoint
|
||||
|
||||
We have integrated the Transformers save and load pipeline, allowing users to freely call Hugging Face's language models and save them in the HF format.
|
||||
|
||||
```
|
||||
from coati.models.llama import LlamaLM
|
||||
from coati.trainer import SFTTrainer
|
||||
|
||||
model = LlamaLM(pretrained=args.pretrain)
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrain)
|
||||
|
||||
trainer = SFTTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
accimulation_steps = args.accimulation_steps
|
||||
)
|
||||
|
||||
trainer.fit()
|
||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
```
|
||||
|
||||
## The Plan
|
||||
|
||||
- [x] implement PPO fine-tuning
|
||||
- [x] implement training reward model
|
||||
- [x] support LoRA
|
||||
- [x] support inference
|
||||
- [x] open source the reward model weight
|
||||
- [x] support llama from [facebook](https://github.com/facebookresearch/llama)
|
||||
- [x] implement PPO-ptx fine-tuning
|
||||
- [ ] integrate with Ray
|
||||
- [ ] support more RL paradigms, like Implicit Language Q-Learning (ILQL),
|
||||
- [ ] support chain of throught by [langchain](https://github.com/hwchase17/langchain)
|
||||
|
||||
### Real-time progress
|
||||
You will find our progress in github project broad
|
||||
|
||||
[Coati](https://github.com/orgs/hpcaitech/projects/17/views/1)
|
||||
|
||||
## Invitation to open-source contribution
|
||||
Referring to the successful attempts of [BLOOM](https://bigscience.huggingface.co/) and [Stable Diffusion](https://en.wikipedia.org/wiki/Stable_Diffusion), any and all developers and partners with computing powers, datasets, models are welcome to join and build the Colossal-AI community, making efforts towards the era of big AI models from the starting point of replicating ChatGPT!
|
||||
|
||||
You may contact us or participate in the following ways:
|
||||
1. [Leaving a Star ⭐](https://github.com/hpcaitech/ColossalAI/stargazers) to show your like and support. Thanks!
|
||||
2. Posting an [issue](https://github.com/hpcaitech/ColossalAI/issues/new/choose), or submitting a PR on GitHub follow the guideline in [Contributing](https://github.com/hpcaitech/ColossalAI/blob/main/CONTRIBUTING.md).
|
||||
3. Join the Colossal-AI community on
|
||||
[Slack](https://join.slack.com/t/colossalaiworkspace/shared_invite/zt-z7b26eeb-CBp7jouvu~r0~lcFzX832w),
|
||||
and [WeChat(微信)](https://raw.githubusercontent.com/hpcaitech/public_assets/main/colossalai/img/WeChat.png "qrcode") to share your ideas.
|
||||
4. Send your official proposal to email contact@hpcaitech.com
|
||||
|
||||
Thanks so much to all of our amazing contributors!
|
||||
|
||||
## Quick Preview
|
||||
<p id="ChatGPT_scaling" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT%20scaling.png" width=800/>
|
||||
</p>
|
||||
|
||||
- Up to 7.73 times faster for single server training and 1.42 times faster for single-GPU inference
|
||||
|
||||
<p id="ChatGPT-1GPU" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/ChatGPT-1GPU.jpg" width=450/>
|
||||
</p>
|
||||
|
||||
- Up to 10.3x growth in model capacity on one GPU
|
||||
- A mini demo training process requires only 1.62GB of GPU memory (any consumer-grade GPU)
|
||||
|
||||
<p id="inference" align="center">
|
||||
<img src="https://raw.githubusercontent.com/hpcaitech/public_assets/main/applications/chatgpt/LoRA%20data.jpg" width=600/>
|
||||
</p>
|
||||
|
||||
- Increase the capacity of the fine-tuning model by up to 3.7 times on a single GPU
|
||||
- Keep in a sufficiently high running speed
|
||||
|
||||
## Authors
|
||||
|
||||
Coati is developed by ColossalAI Team: [Fazzie](https://fazzie-key.cool/about/index.html), [FrankLeeeee](https://github.com/FrankLeeeee), [BlueRum](https://github.com/ht-zhou), [ver217](https://github.com/ver217)
|
||||
|
||||
The Phd student [Zangwei Zheng](https://github.com/zhengzangw) and [Xue Fuzhao](https://github.com/XueFuzhao) also contributed a lot to this project.
|
||||
|
||||
## Citations
|
||||
|
||||
```bibtex
|
||||
@article{Hu2021LoRALA,
|
||||
title = {LoRA: Low-Rank Adaptation of Large Language Models},
|
||||
author = {Edward J. Hu and Yelong Shen and Phillip Wallis and Zeyuan Allen-Zhu and Yuanzhi Li and Shean Wang and Weizhu Chen},
|
||||
journal = {ArXiv},
|
||||
year = {2021},
|
||||
volume = {abs/2106.09685}
|
||||
}
|
||||
|
||||
@article{ouyang2022training,
|
||||
title={Training language models to follow instructions with human feedback},
|
||||
author={Ouyang, Long and Wu, Jeff and Jiang, Xu and Almeida, Diogo and Wainwright, Carroll L and Mishkin, Pamela and Zhang, Chong and Agarwal, Sandhini and Slama, Katarina and Ray, Alex and others},
|
||||
journal={arXiv preprint arXiv:2203.02155},
|
||||
year={2022}
|
||||
}
|
||||
|
||||
@article{touvron2023llama,
|
||||
title={LLaMA: Open and Efficient Foundation Language Models},
|
||||
author={Touvron, Hugo and Lavril, Thibaut and Izacard, Gautier and Martinet, Xavier and Lachaux, Marie-Anne and Lacroix, Timoth{\'e}e and Rozi{\`e}re, Baptiste and Goyal, Naman and Hambro, Eric and Azhar, Faisal and Rodriguez, Aurelien and Joulin, Armand and Grave, Edouard and Lample, Guillaume},
|
||||
journal={arXiv preprint arXiv:2302.13971},
|
||||
year={2023}
|
||||
}
|
||||
|
||||
@misc{alpaca,
|
||||
author = {Rohan Taori and Ishaan Gulrajani and Tianyi Zhang and Yann Dubois and Xuechen Li and Carlos Guestrin and Percy Liang and Tatsunori B. Hashimoto },
|
||||
title = {Stanford Alpaca: An Instruction-following LLaMA model},
|
||||
year = {2023},
|
||||
publisher = {GitHub},
|
||||
journal = {GitHub repository},
|
||||
howpublished = {\url{https://github.com/tatsu-lab/stanford_alpaca}},
|
||||
}
|
||||
```
|
||||
|
||||
## Licenses
|
||||
|
||||
Coati is licensed under the [Apache 2.0 License](LICENSE).
|
Binary file not shown.
After Width: | Height: | Size: 401 KiB |
Binary file not shown.
After Width: | Height: | Size: 640 KiB |
Binary file not shown.
After Width: | Height: | Size: 370 KiB |
|
@ -0,0 +1,94 @@
|
|||
# Benchmarks
|
||||
|
||||
## Benchmark GPT on dummy prompt data
|
||||
|
||||
We provide various GPT models (string in parentheses is the corresponding model name used in this script):
|
||||
|
||||
- GPT2-S (s)
|
||||
- GPT2-M (m)
|
||||
- GPT2-L (l)
|
||||
- GPT2-XL (xl)
|
||||
- GPT2-4B (4b)
|
||||
- GPT2-6B (6b)
|
||||
- GPT2-8B (8b)
|
||||
- GPT2-10B (10b)
|
||||
- GPT2-12B (12b)
|
||||
- GPT2-15B (15b)
|
||||
- GPT2-18B (18b)
|
||||
- GPT2-20B (20b)
|
||||
- GPT2-24B (24b)
|
||||
- GPT2-28B (28b)
|
||||
- GPT2-32B (32b)
|
||||
- GPT2-36B (36b)
|
||||
- GPT2-40B (40b)
|
||||
- GPT3 (175b)
|
||||
|
||||
We also provide various training strategies:
|
||||
|
||||
- ddp: torch DDP
|
||||
- colossalai_gemini: ColossalAI GeminiDDP with `placement_policy="cuda"`, like zero3
|
||||
- colossalai_gemini_cpu: ColossalAI GeminiDDP with `placement_policy="cpu"`, like zero3-offload
|
||||
- colossalai_zero2: ColossalAI zero2
|
||||
- colossalai_zero2_cpu: ColossalAI zero2-offload
|
||||
- colossalai_zero1: ColossalAI zero1
|
||||
- colossalai_zero1_cpu: ColossalAI zero1-offload
|
||||
|
||||
We only support `torchrun` to launch now. E.g.
|
||||
|
||||
```shell
|
||||
# run GPT2-S on single-node single-GPU with min batch size
|
||||
torchrun --standalone --nproc_per_node 1 benchmark_gpt_dummy.py --model s --strategy ddp --experience_batch_size 1 --train_batch_size 1
|
||||
# run GPT2-XL on single-node 4-GPU
|
||||
torchrun --standalone --nproc_per_node 4 benchmark_gpt_dummy.py --model xl --strategy colossalai_zero2
|
||||
# run GPT3 on 8-node 8-GPU
|
||||
torchrun --nnodes 8 --nproc_per_node 8 \
|
||||
--rdzv_id=$JOB_ID --rdzv_backend=c10d --rdzv_endpoint=$HOST_NODE_ADDR \
|
||||
benchmark_gpt_dummy.py --model 175b --strategy colossalai_gemini
|
||||
```
|
||||
|
||||
> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
|
||||
|
||||
In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.
|
||||
|
||||
We also provide a simple shell script to run a set of benchmarks. But it only supports benchmark on single node. However, it's easy to run on multi-nodes by modifying launch command in this script.
|
||||
|
||||
Usage:
|
||||
|
||||
```shell
|
||||
# run for GPUS=(1 2 4 8) x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh
|
||||
# run for GPUS=2 x strategy=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu") x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh 2
|
||||
# run for GPUS=2 x strategy=ddp x model=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b") x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh 2 ddp
|
||||
# run for GPUS=2 x strategy=ddp x model=l x batch_size=(1 2 4 8 16 32 64 128 256)
|
||||
./benchmark_gpt_dummy.sh 2 ddp l
|
||||
```
|
||||
|
||||
## Benchmark OPT with LoRA on dummy prompt data
|
||||
|
||||
We provide various OPT models (string in parentheses is the corresponding model name used in this script):
|
||||
|
||||
- OPT-125M (125m)
|
||||
- OPT-350M (350m)
|
||||
- OPT-700M (700m)
|
||||
- OPT-1.3B (1.3b)
|
||||
- OPT-2.7B (2.7b)
|
||||
- OPT-3.5B (3.5b)
|
||||
- OPT-5.5B (5.5b)
|
||||
- OPT-6.7B (6.7b)
|
||||
- OPT-10B (10b)
|
||||
- OPT-13B (13b)
|
||||
|
||||
We only support `torchrun` to launch now. E.g.
|
||||
|
||||
```shell
|
||||
# run OPT-125M with no lora (lora_rank=0) on single-node single-GPU with min batch size
|
||||
torchrun --standalone --nproc_per_node 1 benchmark_opt_lora_dummy.py --model 125m --strategy ddp --experience_batch_size 1 --train_batch_size 1 --lora_rank 0
|
||||
# run OPT-350M with lora_rank=4 on single-node 4-GPU
|
||||
torchrun --standalone --nproc_per_node 4 benchmark_opt_lora_dummy.py --model 350m --strategy colossalai_zero2 --lora_rank 4
|
||||
```
|
||||
|
||||
> ⚠ Batch sizes in CLI args and outputed throughput/TFLOPS are all values of per GPU.
|
||||
|
||||
In this benchmark, we assume the model architectures/sizes of actor and critic are the same for simplicity. But in practice, to reduce training cost, we may use a smaller critic.
|
|
@ -0,0 +1,184 @@
|
|||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.gpt import GPTActor, GPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.callbacks import PerformanceEvaluator
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
return numel
|
||||
|
||||
|
||||
def preprocess_batch(samples) -> dict:
|
||||
input_ids = torch.stack(samples)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs) -> None:
|
||||
if dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def print_model_numel(model_dict: dict) -> None:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
outputs = ''
|
||||
for name, numel in model_dict.items():
|
||||
outputs += f'{name}: '
|
||||
if numel >= B:
|
||||
outputs += f'{numel / B:.2f} B\n'
|
||||
elif numel >= M:
|
||||
outputs += f'{numel / M:.2f} M\n'
|
||||
elif numel >= K:
|
||||
outputs += f'{numel / K:.2f} K\n'
|
||||
else:
|
||||
outputs += f'{numel}\n'
|
||||
print_rank_0(outputs)
|
||||
|
||||
|
||||
def get_gpt_config(model_name: str) -> GPT2Config:
|
||||
model_map = {
|
||||
's': GPT2Config(),
|
||||
'm': GPT2Config(n_embd=1024, n_layer=24, n_head=16),
|
||||
'l': GPT2Config(n_embd=1280, n_layer=36, n_head=20),
|
||||
'xl': GPT2Config(n_embd=1600, n_layer=48, n_head=25),
|
||||
'2b': GPT2Config(n_embd=2048, n_layer=40, n_head=16),
|
||||
'4b': GPT2Config(n_embd=2304, n_layer=64, n_head=16),
|
||||
'6b': GPT2Config(n_embd=4096, n_layer=30, n_head=16),
|
||||
'8b': GPT2Config(n_embd=4096, n_layer=40, n_head=16),
|
||||
'10b': GPT2Config(n_embd=4096, n_layer=50, n_head=16),
|
||||
'12b': GPT2Config(n_embd=4096, n_layer=60, n_head=16),
|
||||
'15b': GPT2Config(n_embd=4096, n_layer=78, n_head=16),
|
||||
'18b': GPT2Config(n_embd=4096, n_layer=90, n_head=16),
|
||||
'20b': GPT2Config(n_embd=8192, n_layer=25, n_head=16),
|
||||
'24b': GPT2Config(n_embd=8192, n_layer=30, n_head=16),
|
||||
'28b': GPT2Config(n_embd=8192, n_layer=35, n_head=16),
|
||||
'32b': GPT2Config(n_embd=8192, n_layer=40, n_head=16),
|
||||
'36b': GPT2Config(n_embd=8192, n_layer=45, n_head=16),
|
||||
'40b': GPT2Config(n_embd=8192, n_layer=50, n_head=16),
|
||||
'175b': GPT2Config(n_positions=2048, n_embd=12288, n_layer=96, n_head=96),
|
||||
}
|
||||
try:
|
||||
return model_map[model_name]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown model "{model_name}"')
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_gemini_cpu':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == 'colossalai_zero1':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero1_cpu':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
model_config = get_gpt_config(args.model)
|
||||
|
||||
with strategy.model_init_context():
|
||||
actor = GPTActor(config=model_config).cuda()
|
||||
critic = GPTCritic(config=model_config).cuda()
|
||||
|
||||
initial_model = deepcopy(actor).cuda()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
|
||||
actor_numel = get_model_numel(actor, strategy)
|
||||
critic_numel = get_model_numel(critic, strategy)
|
||||
initial_model_numel = get_model_numel(initial_model, strategy)
|
||||
reward_model_numel = get_model_numel(reward_model, strategy)
|
||||
print_model_numel({
|
||||
'Actor': actor_numel,
|
||||
'Critic': critic_numel,
|
||||
'Initial model': initial_model_numel,
|
||||
'Reward model': reward_model_numel
|
||||
})
|
||||
performance_evaluator = PerformanceEvaluator(actor_numel,
|
||||
critic_numel,
|
||||
initial_model_numel,
|
||||
reward_model_numel,
|
||||
enable_grad_checkpoint=False,
|
||||
ignore_episodes=1)
|
||||
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='s')
|
||||
parser.add_argument('--strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
|
||||
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--num_episodes', type=int, default=3)
|
||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,45 @@
|
|||
#!/usr/bin/env bash
|
||||
# Usage: $0 <?number-of-gpus> <?strategy> <?model>
|
||||
set -xu
|
||||
|
||||
BASE=$(realpath $(dirname $0))
|
||||
|
||||
|
||||
PY_SCRIPT=${BASE}/benchmark_gpt_dummy.py
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
function tune_batch_size() {
|
||||
# we found when experience batch size is equal to train batch size
|
||||
# peak CUDA memory usage of making experience phase is less than or equal to that of training phase
|
||||
# thus, experience batch size can be larger than or equal to train batch size
|
||||
for bs in 1 2 4 8 16 32 64 128 256; do
|
||||
torchrun --standalone --nproc_per_node $1 $PY_SCRIPT --model $2 --strategy $3 --experience_batch_size $bs --train_batch_size $bs || return 1
|
||||
done
|
||||
}
|
||||
|
||||
if [ $# -eq 0 ]; then
|
||||
num_gpus=(1 2 4 8)
|
||||
else
|
||||
num_gpus=($1)
|
||||
fi
|
||||
|
||||
if [ $# -le 1 ]; then
|
||||
strategies=("ddp" "colossalai_zero2" "colossalai_gemini" "colossalai_zero2_cpu" "colossalai_gemini_cpu")
|
||||
else
|
||||
strategies=($2)
|
||||
fi
|
||||
|
||||
if [ $# -le 2 ]; then
|
||||
models=("s" "m" "l" "xl" "2b" "4b" "6b" "8b" "10b")
|
||||
else
|
||||
models=($3)
|
||||
fi
|
||||
|
||||
|
||||
for num_gpu in ${num_gpus[@]}; do
|
||||
for strategy in ${strategies[@]}; do
|
||||
for model in ${models[@]}; do
|
||||
tune_batch_size $num_gpu $model $strategy || break
|
||||
done
|
||||
done
|
||||
done
|
|
@ -0,0 +1,179 @@
|
|||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.opt import OPTActor, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.callbacks import PerformanceEvaluator
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, Strategy
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def get_model_numel(model: nn.Module, strategy: Strategy) -> int:
|
||||
numel = sum(p.numel() for p in model.parameters())
|
||||
if isinstance(strategy, ColossalAIStrategy) and strategy.stage == 3 and strategy.shard_init:
|
||||
numel *= dist.get_world_size()
|
||||
return numel
|
||||
|
||||
|
||||
def preprocess_batch(samples) -> dict:
|
||||
input_ids = torch.stack(samples)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs) -> None:
|
||||
if dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
def print_model_numel(model_dict: dict) -> None:
|
||||
B = 1024**3
|
||||
M = 1024**2
|
||||
K = 1024
|
||||
outputs = ''
|
||||
for name, numel in model_dict.items():
|
||||
outputs += f'{name}: '
|
||||
if numel >= B:
|
||||
outputs += f'{numel / B:.2f} B\n'
|
||||
elif numel >= M:
|
||||
outputs += f'{numel / M:.2f} M\n'
|
||||
elif numel >= K:
|
||||
outputs += f'{numel / K:.2f} K\n'
|
||||
else:
|
||||
outputs += f'{numel}\n'
|
||||
print_rank_0(outputs)
|
||||
|
||||
|
||||
def get_gpt_config(model_name: str) -> OPTConfig:
|
||||
model_map = {
|
||||
'125m': OPTConfig.from_pretrained('facebook/opt-125m'),
|
||||
'350m': OPTConfig(hidden_size=1024, ffn_dim=4096, num_hidden_layers=24, num_attention_heads=16),
|
||||
'700m': OPTConfig(hidden_size=1280, ffn_dim=5120, num_hidden_layers=36, num_attention_heads=20),
|
||||
'1.3b': OPTConfig.from_pretrained('facebook/opt-1.3b'),
|
||||
'2.7b': OPTConfig.from_pretrained('facebook/opt-2.7b'),
|
||||
'3.5b': OPTConfig(hidden_size=3072, ffn_dim=12288, num_hidden_layers=32, num_attention_heads=32),
|
||||
'5.5b': OPTConfig(hidden_size=3840, ffn_dim=15360, num_hidden_layers=32, num_attention_heads=32),
|
||||
'6.7b': OPTConfig.from_pretrained('facebook/opt-6.7b'),
|
||||
'10b': OPTConfig(hidden_size=5120, ffn_dim=20480, num_hidden_layers=32, num_attention_heads=32),
|
||||
'13b': OPTConfig.from_pretrained('facebook/opt-13b'),
|
||||
}
|
||||
try:
|
||||
return model_map[model_name]
|
||||
except KeyError:
|
||||
raise ValueError(f'Unknown model "{model_name}"')
|
||||
|
||||
|
||||
def main(args):
|
||||
if args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_gemini_cpu':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cpu', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2_cpu':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cpu')
|
||||
elif args.strategy == 'colossalai_zero1':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero1_cpu':
|
||||
strategy = ColossalAIStrategy(stage=1, placement_policy='cpu')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
torch.cuda.set_per_process_memory_fraction(args.cuda_mem_frac)
|
||||
|
||||
model_config = get_gpt_config(args.model)
|
||||
|
||||
with strategy.model_init_context():
|
||||
actor = OPTActor(config=model_config, lora_rank=args.lora_rank).cuda()
|
||||
critic = OPTCritic(config=model_config, lora_rank=args.lora_rank).cuda()
|
||||
|
||||
initial_model = deepcopy(actor).cuda()
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).cuda()
|
||||
|
||||
actor_numel = get_model_numel(actor, strategy)
|
||||
critic_numel = get_model_numel(critic, strategy)
|
||||
initial_model_numel = get_model_numel(initial_model, strategy)
|
||||
reward_model_numel = get_model_numel(reward_model, strategy)
|
||||
print_model_numel({
|
||||
'Actor': actor_numel,
|
||||
'Critic': critic_numel,
|
||||
'Initial model': initial_model_numel,
|
||||
'Reward model': reward_model_numel
|
||||
})
|
||||
performance_evaluator = PerformanceEvaluator(actor_numel,
|
||||
critic_numel,
|
||||
initial_model_numel,
|
||||
reward_model_numel,
|
||||
enable_grad_checkpoint=False,
|
||||
ignore_episodes=1)
|
||||
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=512,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=[performance_evaluator])
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 400), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
print_rank_0(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.2f} GB')
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='125m')
|
||||
parser.add_argument('--strategy',
|
||||
choices=[
|
||||
'ddp', 'colossalai_gemini', 'colossalai_gemini_cpu', 'colossalai_zero2',
|
||||
'colossalai_zero2_cpu', 'colossalai_zero1', 'colossalai_zero1_cpu'
|
||||
],
|
||||
default='ddp')
|
||||
parser.add_argument('--num_episodes', type=int, default=3)
|
||||
parser.add_argument('--max_timesteps', type=int, default=8)
|
||||
parser.add_argument('--update_timesteps', type=int, default=8)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=4)
|
||||
parser.add_argument('--cuda_mem_frac', type=float, default=1.0)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,9 @@
|
|||
from .prompt_dataset import PromptDataset
|
||||
from .reward_dataset import HhRlhfDataset, RmStaticDataset
|
||||
from .sft_dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
||||
from .utils import is_rank_0
|
||||
|
||||
__all__ = [
|
||||
'RmStaticDataset', 'HhRlhfDataset', 'is_rank_0', 'SFTDataset', 'SupervisedDataset',
|
||||
'DataCollatorForSupervisedDataset', 'PromptDataset'
|
||||
]
|
|
@ -0,0 +1,46 @@
|
|||
import copy
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, Sequence
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .utils import is_rank_0, jload
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
|
||||
class PromptDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None):
|
||||
super(PromptDataset, self).__init__()
|
||||
self.prompt = []
|
||||
logger.info("Loading data...")
|
||||
list_data_dict = jload(data_path)
|
||||
logger.info(f"Loaded {len(list_data_dict)} examples.")
|
||||
|
||||
if max_datasets_size is not None:
|
||||
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
||||
list_data_dict = list_data_dict[:max_datasets_size]
|
||||
|
||||
for data_dict in list_data_dict:
|
||||
token = tokenizer(data_dict["instruction"],
|
||||
return_tensors='pt',
|
||||
max_length=96,
|
||||
padding='max_length',
|
||||
truncation=True)
|
||||
for idx in token['input_ids']:
|
||||
self.prompt.append(idx.to(torch.cuda.current_device()))
|
||||
|
||||
def __len__(self):
|
||||
return len(self.prompt)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return self.prompt[i]
|
|
@ -0,0 +1,112 @@
|
|||
from typing import Callable
|
||||
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
||||
# Dahaos/rm-static
|
||||
class RmStaticDataset(Dataset):
|
||||
"""
|
||||
Dataset for reward model
|
||||
|
||||
Args:
|
||||
dataset: dataset for reward model
|
||||
tokenizer: tokenizer for reward model
|
||||
max_length: max length of input
|
||||
special_token: special token at the end of sentence
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||
super().__init__()
|
||||
self.chosen = []
|
||||
self.reject = []
|
||||
if special_token is None:
|
||||
self.end_token = tokenizer.eos_token
|
||||
else:
|
||||
self.end_token = special_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||
prompt = data['prompt']
|
||||
|
||||
chosen = prompt + data['chosen'] + self.end_token
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen.append({
|
||||
"input_ids": chosen_token['input_ids'],
|
||||
"attention_mask": chosen_token['attention_mask']
|
||||
})
|
||||
|
||||
reject = prompt + data['rejected'] + self.end_token
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject.append({
|
||||
"input_ids": reject_token['input_ids'],
|
||||
"attention_mask": reject_token['attention_mask']
|
||||
})
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||
"input_ids"], self.reject[idx]["attention_mask"]
|
||||
|
||||
|
||||
# Anthropic/hh-rlhf
|
||||
class HhRlhfDataset(Dataset):
|
||||
"""
|
||||
Dataset for reward model
|
||||
|
||||
Args:
|
||||
dataset: dataset for reward model
|
||||
tokenizer: tokenizer for reward model
|
||||
max_length: max length of input
|
||||
special_token: special token at the end of sentence
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int, special_token=None) -> None:
|
||||
super().__init__()
|
||||
self.chosen = []
|
||||
self.reject = []
|
||||
if special_token is None:
|
||||
self.end_token = tokenizer.eos_token
|
||||
else:
|
||||
self.end_token = special_token
|
||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||
chosen = data['chosen'] + self.end_token
|
||||
chosen_token = tokenizer(chosen,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.chosen.append({
|
||||
"input_ids": chosen_token['input_ids'],
|
||||
"attention_mask": chosen_token['attention_mask']
|
||||
})
|
||||
|
||||
reject = data['rejected'] + self.end_token
|
||||
reject_token = tokenizer(reject,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
self.reject.append({
|
||||
"input_ids": reject_token['input_ids'],
|
||||
"attention_mask": reject_token['attention_mask']
|
||||
})
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.chosen)
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
return self.chosen[idx]["input_ids"], self.chosen[idx]["attention_mask"], self.reject[idx][
|
||||
"input_ids"], self.reject[idx]["attention_mask"]
|
|
@ -0,0 +1,169 @@
|
|||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import copy
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Callable, Dict, Sequence
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import transformers
|
||||
from torch.utils.data import Dataset
|
||||
from tqdm import tqdm
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .utils import is_rank_0, jload
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
IGNORE_INDEX = -100
|
||||
PROMPT_DICT = {
|
||||
"prompt_input":
|
||||
("Below is an instruction that describes a task, paired with an input that provides further context. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"),
|
||||
"prompt_no_input": ("Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request.\n\n"
|
||||
"### Instruction:\n{instruction}\n\n### Response:"),
|
||||
}
|
||||
|
||||
|
||||
class SFTDataset(Dataset):
|
||||
"""
|
||||
Dataset for sft model
|
||||
|
||||
Args:
|
||||
dataset: dataset for supervised model
|
||||
tokenizer: tokenizer for supervised model
|
||||
max_length: max length of input
|
||||
"""
|
||||
|
||||
def __init__(self, dataset, tokenizer: Callable, max_length: int = 512) -> None:
|
||||
super().__init__()
|
||||
# self.prompts = []
|
||||
self.input_ids = []
|
||||
|
||||
for data in tqdm(dataset, disable=not is_rank_0()):
|
||||
prompt = data['prompt'] + data['completion'] + "<|endoftext|>"
|
||||
prompt_token = tokenizer(prompt,
|
||||
max_length=max_length,
|
||||
padding="max_length",
|
||||
truncation=True,
|
||||
return_tensors="pt")
|
||||
|
||||
# self.prompts.append(prompt_token)s
|
||||
self.input_ids.append(prompt_token)
|
||||
self.labels = copy.deepcopy(self.input_ids)
|
||||
|
||||
def __len__(self):
|
||||
length = len(self.prompts)
|
||||
return length
|
||||
|
||||
def __getitem__(self, idx):
|
||||
# dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
return dict(input_ids=self.input_ids[idx], labels=self.labels[idx])
|
||||
# return dict(self.prompts[idx], self.prompts[idx])
|
||||
|
||||
|
||||
def _tokenize_fn(strings: Sequence[str], tokenizer: transformers.PreTrainedTokenizer) -> Dict:
|
||||
"""Tokenize a list of strings."""
|
||||
tokenized_list = [
|
||||
tokenizer(
|
||||
text,
|
||||
return_tensors="pt",
|
||||
padding="longest",
|
||||
max_length=tokenizer.model_max_length,
|
||||
truncation=True,
|
||||
) for text in strings
|
||||
]
|
||||
input_ids = labels = [tokenized.input_ids[0] for tokenized in tokenized_list]
|
||||
input_ids_lens = labels_lens = [
|
||||
tokenized.input_ids.ne(tokenizer.pad_token_id).sum().item() for tokenized in tokenized_list
|
||||
]
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
input_ids_lens=input_ids_lens,
|
||||
labels_lens=labels_lens,
|
||||
)
|
||||
|
||||
|
||||
def preprocess(
|
||||
sources: Sequence[str],
|
||||
targets: Sequence[str],
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
) -> Dict:
|
||||
"""Preprocess the data by tokenizing."""
|
||||
examples = [s + t for s, t in zip(sources, targets)]
|
||||
examples_tokenized, sources_tokenized = [_tokenize_fn(strings, tokenizer) for strings in (examples, sources)]
|
||||
input_ids = examples_tokenized["input_ids"]
|
||||
labels = copy.deepcopy(input_ids)
|
||||
for label, source_len in zip(labels, sources_tokenized["input_ids_lens"]):
|
||||
label[:source_len] = IGNORE_INDEX
|
||||
return dict(input_ids=input_ids, labels=labels)
|
||||
|
||||
|
||||
class SupervisedDataset(Dataset):
|
||||
"""Dataset for supervised fine-tuning."""
|
||||
|
||||
def __init__(self, data_path: str, tokenizer: transformers.PreTrainedTokenizer, max_datasets_size: int = None):
|
||||
super(SupervisedDataset, self).__init__()
|
||||
logger.info("Loading data...")
|
||||
list_data_dict = jload(data_path)
|
||||
logger.info(f"Loaded {len(list_data_dict)} examples.")
|
||||
|
||||
if max_datasets_size is not None:
|
||||
logger.info(f"Limiting dataset to {max_datasets_size} examples.")
|
||||
list_data_dict = list_data_dict[:max_datasets_size]
|
||||
|
||||
logger.info("Formatting inputs...")
|
||||
prompt_input, prompt_no_input = PROMPT_DICT["prompt_input"], PROMPT_DICT["prompt_no_input"]
|
||||
sources = [
|
||||
prompt_input.format_map(example) if example.get("input", "") != "" else prompt_no_input.format_map(example)
|
||||
for example in list_data_dict
|
||||
]
|
||||
targets = [f"{example['output']}{tokenizer.eos_token}" for example in list_data_dict]
|
||||
|
||||
logger.info("Tokenizing inputs... This may take some time...")
|
||||
data_dict = preprocess(sources, targets, tokenizer)
|
||||
|
||||
self.input_ids = data_dict["input_ids"]
|
||||
self.labels = data_dict["labels"]
|
||||
|
||||
def __len__(self):
|
||||
return len(self.input_ids)
|
||||
|
||||
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
|
||||
return dict(input_ids=self.input_ids[i], labels=self.labels[i])
|
||||
|
||||
|
||||
@dataclass
|
||||
class DataCollatorForSupervisedDataset(object):
|
||||
"""Collate examples for supervised fine-tuning."""
|
||||
|
||||
tokenizer: transformers.PreTrainedTokenizer
|
||||
|
||||
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
||||
input_ids, labels = tuple([instance[key] for instance in instances] for key in ("input_ids", "labels"))
|
||||
input_ids = torch.nn.utils.rnn.pad_sequence(input_ids,
|
||||
batch_first=True,
|
||||
padding_value=self.tokenizer.pad_token_id)
|
||||
labels = torch.nn.utils.rnn.pad_sequence(labels, batch_first=True, padding_value=IGNORE_INDEX)
|
||||
return dict(
|
||||
input_ids=input_ids,
|
||||
labels=labels,
|
||||
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
||||
)
|
|
@ -0,0 +1,22 @@
|
|||
import io
|
||||
import json
|
||||
|
||||
import torch.distributed as dist
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
||||
|
||||
|
||||
def _make_r_io_base(f, mode: str):
|
||||
if not isinstance(f, io.IOBase):
|
||||
f = open(f, mode=mode)
|
||||
return f
|
||||
|
||||
|
||||
def jload(f, mode="r"):
|
||||
"""Load a .json file into a dictionary."""
|
||||
f = _make_r_io_base(f, mode)
|
||||
jdict = json.load(f)
|
||||
f.close()
|
||||
return jdict
|
|
@ -0,0 +1,4 @@
|
|||
from .base import Experience, ExperienceMaker
|
||||
from .naive import NaiveExperienceMaker
|
||||
|
||||
__all__ = ['Experience', 'ExperienceMaker', 'NaiveExperienceMaker']
|
|
@ -0,0 +1,77 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.models.base import Actor
|
||||
|
||||
|
||||
@dataclass
|
||||
class Experience:
|
||||
"""Experience is a batch of data.
|
||||
These data should have the the sequence length and number of actions.
|
||||
Left padding for sequences is applied.
|
||||
|
||||
Shapes of each tensor:
|
||||
sequences: (B, S)
|
||||
action_log_probs: (B, A)
|
||||
values: (B)
|
||||
reward: (B)
|
||||
advatanges: (B)
|
||||
attention_mask: (B, S)
|
||||
action_mask: (B, A)
|
||||
|
||||
"A" is the number of actions.
|
||||
"""
|
||||
sequences: torch.Tensor
|
||||
action_log_probs: torch.Tensor
|
||||
values: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
advantages: torch.Tensor
|
||||
attention_mask: Optional[torch.LongTensor]
|
||||
action_mask: Optional[torch.BoolTensor]
|
||||
|
||||
@torch.no_grad()
|
||||
def to_device(self, device: torch.device) -> None:
|
||||
self.sequences = self.sequences.to(device)
|
||||
self.action_log_probs = self.action_log_probs.to(device)
|
||||
self.values = self.values.to(device)
|
||||
self.reward = self.reward.to(device)
|
||||
self.advantages = self.advantages.to(device)
|
||||
if self.attention_mask is not None:
|
||||
self.attention_mask = self.attention_mask.to(device)
|
||||
if self.action_mask is not None:
|
||||
self.action_mask = self.action_mask.to(device)
|
||||
|
||||
def pin_memory(self):
|
||||
self.sequences = self.sequences.pin_memory()
|
||||
self.action_log_probs = self.action_log_probs.pin_memory()
|
||||
self.values = self.values.pin_memory()
|
||||
self.reward = self.reward.pin_memory()
|
||||
self.advantages = self.advantages.pin_memory()
|
||||
if self.attention_mask is not None:
|
||||
self.attention_mask = self.attention_mask.pin_memory()
|
||||
if self.action_mask is not None:
|
||||
self.action_mask = self.action_mask.pin_memory()
|
||||
return self
|
||||
|
||||
|
||||
class ExperienceMaker(ABC):
|
||||
|
||||
def __init__(self,
|
||||
actor: Actor,
|
||||
critic: nn.Module,
|
||||
reward_model: nn.Module,
|
||||
initial_model: Actor,
|
||||
kl_coef: float = 0.1) -> None:
|
||||
super().__init__()
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
self.reward_model = reward_model
|
||||
self.initial_model = initial_model
|
||||
self.kl_coef = kl_coef
|
||||
|
||||
@abstractmethod
|
||||
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
|
||||
pass
|
|
@ -0,0 +1,35 @@
|
|||
import torch
|
||||
from coati.models.utils import compute_reward, normalize
|
||||
|
||||
from .base import Experience, ExperienceMaker
|
||||
|
||||
|
||||
class NaiveExperienceMaker(ExperienceMaker):
|
||||
"""
|
||||
Naive experience maker.
|
||||
"""
|
||||
|
||||
@torch.no_grad()
|
||||
def make_experience(self, input_ids: torch.Tensor, **generate_kwargs) -> Experience:
|
||||
self.actor.eval()
|
||||
self.critic.eval()
|
||||
self.initial_model.eval()
|
||||
self.reward_model.eval()
|
||||
|
||||
sequences, attention_mask, action_mask = self.actor.generate(input_ids,
|
||||
return_action_mask=True,
|
||||
**generate_kwargs)
|
||||
num_actions = action_mask.size(1)
|
||||
|
||||
action_log_probs = self.actor(sequences, num_actions, attention_mask)
|
||||
base_action_log_probs = self.initial_model(sequences, num_actions, attention_mask)
|
||||
value = self.critic(sequences, action_mask, attention_mask)
|
||||
r = self.reward_model(sequences, attention_mask)
|
||||
reward = compute_reward(r, self.kl_coef, action_log_probs, base_action_log_probs, action_mask=action_mask)
|
||||
|
||||
advantage = reward - value
|
||||
# TODO(ver217): maybe normalize adv
|
||||
if advantage.ndim == 1:
|
||||
advantage = advantage.unsqueeze(-1)
|
||||
|
||||
return Experience(sequences, action_log_probs, value, reward, advantage, attention_mask, action_mask)
|
|
@ -0,0 +1,4 @@
|
|||
from .base import Actor, Critic, RewardModel
|
||||
from .loss import LogExpLoss, LogSigLoss, PolicyLoss, PPOPtxActorLoss, ValueLoss
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'PolicyLoss', 'ValueLoss', 'PPOPtxActorLoss', 'LogSigLoss', 'LogExpLoss']
|
|
@ -0,0 +1,6 @@
|
|||
from .actor import Actor
|
||||
from .critic import Critic
|
||||
from .lm import LM
|
||||
from .reward_model import RewardModel
|
||||
|
||||
__all__ = ['Actor', 'Critic', 'RewardModel', 'LM']
|
|
@ -0,0 +1,65 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..generation import generate
|
||||
from ..lora import LoRAModule
|
||||
from ..utils import log_probs_from_logits
|
||||
|
||||
|
||||
class Actor(LoRAModule):
|
||||
"""
|
||||
Actor model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Actor Model.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.convert_to_lora()
|
||||
|
||||
@torch.no_grad()
|
||||
def generate(
|
||||
self,
|
||||
input_ids: torch.Tensor,
|
||||
return_action_mask: bool = True,
|
||||
**kwargs
|
||||
) -> Union[Tuple[torch.LongTensor, torch.LongTensor], Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]]:
|
||||
sequences = generate(self.model, input_ids, **kwargs)
|
||||
attention_mask = None
|
||||
pad_token_id = kwargs.get('pad_token_id', None)
|
||||
if pad_token_id is not None:
|
||||
attention_mask = sequences.not_equal(pad_token_id).to(dtype=torch.long, device=sequences.device)
|
||||
if not return_action_mask:
|
||||
return sequences, attention_mask, None
|
||||
input_len = input_ids.size(1)
|
||||
eos_token_id = kwargs.get('eos_token_id', None)
|
||||
if eos_token_id is None:
|
||||
action_mask = torch.ones_like(sequences, dtype=torch.bool)
|
||||
else:
|
||||
# left padding may be applied, only mask action
|
||||
action_mask = (sequences[:, input_len:] == eos_token_id).cumsum(dim=-1) == 0
|
||||
action_mask = F.pad(action_mask, (1 + input_len, -1), value=True) # include eos token and input
|
||||
action_mask[:, :input_len] = False
|
||||
action_mask = action_mask[:, 1:]
|
||||
return sequences, attention_mask, action_mask[:, -(sequences.size(1) - input_len):]
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
num_actions: int,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Returns action log probs
|
||||
"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output['logits']
|
||||
log_probs = log_probs_from_logits(logits[:, :-1, :], sequences[:, 1:])
|
||||
return log_probs[:, -num_actions:]
|
||||
|
||||
def get_base_model(self):
|
||||
return self.model
|
|
@ -0,0 +1,54 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..lora import LoRAModule
|
||||
from ..utils import masked_mean
|
||||
|
||||
|
||||
class Critic(LoRAModule):
|
||||
"""
|
||||
Critic model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Critic model.
|
||||
value_head (nn.Module): Value head to get value.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: nn.Module,
|
||||
value_head: nn.Module,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
use_action_mask: bool = False,
|
||||
) -> None:
|
||||
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.value_head = value_head
|
||||
self.use_action_mask = use_action_mask
|
||||
self.convert_to_lora()
|
||||
|
||||
def forward(self,
|
||||
sequences: torch.LongTensor,
|
||||
action_mask: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs['last_hidden_state']
|
||||
|
||||
values = self.value_head(last_hidden_states).squeeze(-1)
|
||||
|
||||
if action_mask is not None and self.use_action_mask:
|
||||
num_actions = action_mask.size(1)
|
||||
prompt_mask = attention_mask[:, :-num_actions]
|
||||
values = values[:, :-num_actions]
|
||||
value = masked_mean(values, prompt_mask, dim=1)
|
||||
return value
|
||||
|
||||
values = values[:, :-1]
|
||||
value = values.mean(dim=1)
|
||||
return value
|
|
@ -0,0 +1,30 @@
|
|||
from typing import Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
from ..generation import generate
|
||||
from .actor import Actor
|
||||
|
||||
|
||||
class LM(Actor):
|
||||
"""
|
||||
Language model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Language Model.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self, model: nn.Module, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
||||
super().__init__(model=model, lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""Returns output log probs
|
||||
"""
|
||||
output = self.model(sequences, attention_mask=attention_mask)
|
||||
logits = output['logits']
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
return log_probs
|
|
@ -0,0 +1,41 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from ..lora import LoRAModule
|
||||
|
||||
|
||||
class RewardModel(LoRAModule):
|
||||
"""
|
||||
Reward model base class.
|
||||
|
||||
Args:
|
||||
model (nn.Module): Reward model.
|
||||
value_head (nn.Module): Value head to get reward score.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
model: nn.Module,
|
||||
value_head: Optional[nn.Module] = None,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
super().__init__(lora_rank=lora_rank, lora_train_bias=lora_train_bias)
|
||||
self.model = model
|
||||
self.convert_to_lora()
|
||||
|
||||
if value_head is not None:
|
||||
if value_head.out_features != 1:
|
||||
raise ValueError("The value head of reward model's output dim should be 1!")
|
||||
self.value_head = value_head
|
||||
else:
|
||||
self.value_head = nn.Linear(model.config.n_embd, 1)
|
||||
|
||||
def forward(self, sequences: torch.LongTensor, attention_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
outputs = self.model(sequences, attention_mask=attention_mask)
|
||||
last_hidden_states = outputs['last_hidden_state']
|
||||
values = self.value_head(last_hidden_states)[:, :-1]
|
||||
value = values.mean(dim=1).squeeze(1) # ensure shape is (B)
|
||||
return value
|
|
@ -0,0 +1,6 @@
|
|||
from .bloom_actor import BLOOMActor
|
||||
from .bloom_critic import BLOOMCritic
|
||||
from .bloom_lm import BLOOMLM
|
||||
from .bloom_rm import BLOOMRM
|
||||
|
||||
__all__ = ['BLOOMActor', 'BLOOMCritic', 'BLOOMRM', 'BLOOMLM']
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class BLOOMActor(Actor):
|
||||
"""
|
||||
BLOOM Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomForCausalLM(config)
|
||||
else:
|
||||
model = BloomForCausalLM(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,38 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class BLOOMCritic(Critic):
|
||||
"""
|
||||
BLOOM Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class BLOOMLM(LM):
|
||||
"""
|
||||
BLOOM language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomForCausalLM(config)
|
||||
else:
|
||||
model = BloomForCausalLM(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,37 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import BloomConfig, BloomForCausalLM, BloomModel
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class BLOOMRM(RewardModel):
|
||||
"""
|
||||
BLOOM Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (BloomConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[BloomConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = BloomModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = BloomModel(config)
|
||||
else:
|
||||
model = BloomModel(BloomConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,4 @@
|
|||
from .deberta_critic import DebertaCritic
|
||||
from .deberta_rm import DebertaRM
|
||||
|
||||
__all__ = ['DebertaCritic', 'DebertaRM']
|
|
@ -0,0 +1,36 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import DebertaV2Config, DebertaV2Model
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class DebertaCritic(Critic):
|
||||
"""
|
||||
Deberta Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (DebertaV2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LO-RA decomposition.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[DebertaV2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = DebertaV2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = DebertaV2Model(config)
|
||||
else:
|
||||
model = DebertaV2Model(DebertaV2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,37 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import DebertaV2Config, DebertaV2Model
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class DebertaRM(RewardModel):
|
||||
"""
|
||||
Deberta Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (DebertaV2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LO-RA decomposition.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: str = None,
|
||||
config: Optional[DebertaV2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = DebertaV2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = DebertaV2Model(config)
|
||||
else:
|
||||
model = DebertaV2Model(DebertaV2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,146 @@
|
|||
from typing import Any, Callable, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
|
||||
try:
|
||||
from transformers.generation_logits_process import (
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
except ImportError:
|
||||
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
|
||||
def prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
if temperature is not None and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
processor_list.append(TopKLogitsWarper(top_k))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
processor_list.append(TopPLogitsWarper(top_p))
|
||||
return processor_list
|
||||
|
||||
|
||||
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
# consider DP
|
||||
unfinished_sequences = unfinished_sequences.clone()
|
||||
dist.all_reduce(unfinished_sequences)
|
||||
return unfinished_sequences.max() == 0
|
||||
|
||||
|
||||
def sample(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
if input_ids.size(1) >= max_length:
|
||||
return input_ids
|
||||
|
||||
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
for _ in range(input_ids.size(1), max_length):
|
||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
|
||||
'input_ids': input_ids
|
||||
}
|
||||
outputs = model(**model_inputs)
|
||||
|
||||
next_token_logits = outputs['logits'][:, -1, :]
|
||||
# pre-process distribution
|
||||
next_token_logits = logits_processor(input_ids, next_token_logits)
|
||||
# sample
|
||||
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
# update generated ids, model inputs for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if update_model_kwargs_fn is not None:
|
||||
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
||||
|
||||
# stop when each sentence is finished if early_stopping=True
|
||||
if early_stopping and _is_sequence_finished(unfinished_sequences):
|
||||
break
|
||||
|
||||
return input_ids
|
||||
|
||||
|
||||
def generate(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_length: int,
|
||||
num_beams: int = 1,
|
||||
do_sample: bool = True,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> torch.Tensor:
|
||||
"""Generate token sequence. The returned sequence is input_ids + generated_tokens.
|
||||
|
||||
Args:
|
||||
model (nn.Module): model
|
||||
input_ids (torch.Tensor): input sequence
|
||||
max_length (int): max length of the returned sequence
|
||||
num_beams (int, optional): number of beams. Defaults to 1.
|
||||
do_sample (bool, optional): whether to do sample. Defaults to True.
|
||||
early_stopping (bool, optional): if True, the sequence length may be smaller than max_length due to finding eos. Defaults to False.
|
||||
eos_token_id (Optional[int], optional): end of sequence token id. Defaults to None.
|
||||
pad_token_id (Optional[int], optional): pad token id. Defaults to None.
|
||||
top_k (Optional[int], optional): the number of highest probability vocabulary tokens to keep for top-k-filtering. Defaults to None.
|
||||
top_p (Optional[float], optional): If set to float < 1, only the smallest set of most probable tokens with probabilities that add up to top_p or higher are kept for generation. Defaults to None.
|
||||
temperature (Optional[float], optional): The value used to module the next token probabilities. Defaults to None.
|
||||
prepare_inputs_fn (Optional[Callable[[torch.Tensor, Any], dict]], optional): Function to preprocess model inputs. Arguments of this function should be input_ids and model_kwargs. Defaults to None.
|
||||
update_model_kwargs_fn (Optional[Callable[[dict, Any], dict]], optional): Function to update model_kwargs based on outputs. Arguments of this function should be outputs and model_kwargs. Defaults to None.
|
||||
"""
|
||||
is_greedy_gen_mode = ((num_beams == 1) and do_sample is False)
|
||||
is_sample_gen_mode = ((num_beams == 1) and do_sample is True)
|
||||
is_beam_gen_mode = ((num_beams > 1) and do_sample is False)
|
||||
if is_greedy_gen_mode:
|
||||
# run greedy search
|
||||
raise NotImplementedError
|
||||
elif is_sample_gen_mode:
|
||||
# run sample
|
||||
return sample(model,
|
||||
input_ids,
|
||||
max_length,
|
||||
early_stopping=early_stopping,
|
||||
eos_token_id=eos_token_id,
|
||||
pad_token_id=pad_token_id,
|
||||
top_k=top_k,
|
||||
top_p=top_p,
|
||||
temperature=temperature,
|
||||
prepare_inputs_fn=prepare_inputs_fn,
|
||||
update_model_kwargs_fn=update_model_kwargs_fn,
|
||||
**model_kwargs)
|
||||
elif is_beam_gen_mode:
|
||||
raise NotImplementedError
|
||||
else:
|
||||
raise ValueError("Unsupported generation mode")
|
|
@ -0,0 +1,92 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
def gpt_prepare_inputs_fn(input_ids: torch.Tensor, past: Optional[torch.Tensor] = None, **kwargs) -> dict:
|
||||
token_type_ids = kwargs.get("token_type_ids", None)
|
||||
# only last token for inputs_ids if past is defined in kwargs
|
||||
if past:
|
||||
input_ids = input_ids[:, -1].unsqueeze(-1)
|
||||
if token_type_ids is not None:
|
||||
token_type_ids = token_type_ids[:, -1].unsqueeze(-1)
|
||||
|
||||
attention_mask = kwargs.get("attention_mask", None)
|
||||
position_ids = kwargs.get("position_ids", None)
|
||||
|
||||
if attention_mask is not None and position_ids is None:
|
||||
# create position_ids on the fly for batch generation
|
||||
position_ids = attention_mask.long().cumsum(-1) - 1
|
||||
position_ids.masked_fill_(attention_mask == 0, 1)
|
||||
if past:
|
||||
position_ids = position_ids[:, -1].unsqueeze(-1)
|
||||
else:
|
||||
position_ids = None
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"past_key_values": past,
|
||||
"use_cache": kwargs.get("use_cache"),
|
||||
"position_ids": position_ids,
|
||||
"attention_mask": attention_mask,
|
||||
"token_type_ids": token_type_ids,
|
||||
}
|
||||
|
||||
|
||||
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
|
||||
if "past_key_values" in outputs:
|
||||
model_kwargs["past"] = outputs["past_key_values"]
|
||||
else:
|
||||
model_kwargs["past"] = None
|
||||
|
||||
# update token_type_ids with last value
|
||||
if "token_type_ids" in model_kwargs:
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
||||
|
||||
# update attention mask
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
||||
def opt_prepare_inputs_fn(input_ids: torch.Tensor,
|
||||
past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs) -> dict:
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
}
|
||||
|
||||
|
||||
def bloom_prepare_inputs_fn(input_ids: torch.Tensor,
|
||||
past: Optional[torch.Tensor] = None,
|
||||
attention_mask: Optional[torch.Tensor] = None,
|
||||
use_cache: Optional[bool] = None,
|
||||
**kwargs) -> dict:
|
||||
# if model is used as a decoder in encoder-decoder model, the decoder attention mask is created on the fly
|
||||
if attention_mask is None:
|
||||
attention_mask = input_ids.new_ones(input_ids.shape)
|
||||
|
||||
if past:
|
||||
input_ids = input_ids[:, -1:]
|
||||
# first step, decoder_cached_states are empty
|
||||
return {
|
||||
"input_ids": input_ids, # encoder_outputs is defined. input_ids not needed
|
||||
"attention_mask": attention_mask,
|
||||
"past_key_values": past,
|
||||
"use_cache": use_cache,
|
||||
}
|
|
@ -0,0 +1,6 @@
|
|||
from .gpt_actor import GPTActor
|
||||
from .gpt_critic import GPTCritic
|
||||
from .gpt_lm import GPTLM
|
||||
from .gpt_rm import GPTRM
|
||||
|
||||
__all__ = ['GPTActor', 'GPTCritic', 'GPTRM', 'GPTLM']
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class GPTActor(Actor):
|
||||
"""
|
||||
GPT Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LoRa layer.
|
||||
lora_train_bias (str): Bias training strategy for the LoRa layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2LMHeadModel(config)
|
||||
else:
|
||||
model = GPT2LMHeadModel(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,37 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class GPTCritic(Critic):
|
||||
"""
|
||||
GPT Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LO-RA decomposition.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2LMHeadModel
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class GPTLM(LM):
|
||||
"""
|
||||
GPT language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the LoRa layer.
|
||||
lora_train_bias (str): Bias training strategy for the LoRa layer.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2LMHeadModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2LMHeadModel(config)
|
||||
else:
|
||||
model = GPT2LMHeadModel(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,39 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.gpt2.configuration_gpt2 import GPT2Config
|
||||
from transformers.models.gpt2.modeling_gpt2 import GPT2Model
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class GPTRM(RewardModel):
|
||||
"""
|
||||
GPT Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (GPT2Config): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[GPT2Config] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = GPT2Model.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = GPT2Model(config)
|
||||
else:
|
||||
model = GPT2Model(GPT2Config())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.n_embd, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.n_embd + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,6 @@
|
|||
from .llama_actor import LlamaActor
|
||||
from .llama_critic import LlamaCritic
|
||||
from .llama_lm import LlamaLM
|
||||
from .llama_rm import LlamaRM
|
||||
|
||||
__all__ = ['LlamaActor', 'LlamaCritic', 'LlamaRM', 'LlamaLM']
|
|
@ -0,0 +1,38 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class LlamaActor(Actor):
|
||||
"""
|
||||
Llama Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaForCausalLM(config)
|
||||
else:
|
||||
model = LlamaForCausalLM(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,42 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from transformers import AutoModelForCausalLM, LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class LlamaCritic(Critic):
|
||||
"""
|
||||
Llama Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaForCausalLM(config)
|
||||
else:
|
||||
model = LlamaForCausalLM(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
|
@ -0,0 +1,40 @@
|
|||
from typing import Optional
|
||||
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class LlamaLM(LM):
|
||||
"""
|
||||
Llama language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaForCausalLM(config)
|
||||
else:
|
||||
model = LlamaForCausalLM(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
||||
|
||||
def forward(self, input_ids, attention_mask=None, labels=None, **kwargs):
|
||||
return self.model(input_ids, attention_mask=attention_mask, labels=labels, **kwargs)
|
|
@ -0,0 +1,40 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import LlamaConfig, LlamaForCausalLM, LlamaModel
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class LlamaRM(RewardModel):
|
||||
"""
|
||||
Llama Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (LlamaConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): LoRA rank.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[LlamaConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
|
||||
if pretrained is not None:
|
||||
model = LlamaModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = LlamaModel(config)
|
||||
else:
|
||||
model = LlamaModel(LlamaConfig())
|
||||
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.hidden_size, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.hidden_size + 1))
|
||||
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,129 @@
|
|||
import math
|
||||
from typing import Optional
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class LoraLinear(lora.LoRALayer, nn.Module):
|
||||
"""Replace in-place ops to out-of-place ops to fit gemini. Convert a torch.nn.Linear to LoraLinear.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
weight: nn.Parameter,
|
||||
bias: Optional[nn.Parameter],
|
||||
r: int = 0,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
fan_in_fan_out: bool = False, # Set this to True if the layer to replace stores weight like (fan_in, fan_out)
|
||||
merge_weights: bool = True,
|
||||
):
|
||||
nn.Module.__init__(self)
|
||||
lora.LoRALayer.__init__(self,
|
||||
r=r,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
merge_weights=merge_weights)
|
||||
self.weight = weight
|
||||
self.bias = bias
|
||||
|
||||
out_features, in_features = weight.shape
|
||||
self.in_features = in_features
|
||||
self.out_features = out_features
|
||||
|
||||
self.fan_in_fan_out = fan_in_fan_out
|
||||
# Actual trainable parameters
|
||||
if r > 0:
|
||||
self.lora_A = nn.Parameter(self.weight.new_zeros((r, in_features)))
|
||||
self.lora_B = nn.Parameter(self.weight.new_zeros((out_features, r)))
|
||||
self.scaling = self.lora_alpha / self.r
|
||||
# Freezing the pre-trained weight matrix
|
||||
self.weight.requires_grad = False
|
||||
self.reset_parameters()
|
||||
if fan_in_fan_out:
|
||||
self.weight.data = self.weight.data.T
|
||||
|
||||
def reset_parameters(self):
|
||||
if hasattr(self, 'lora_A'):
|
||||
# initialize A the same way as the default for nn.Linear and B to zero
|
||||
nn.init.kaiming_uniform_(self.lora_A, a=math.sqrt(5))
|
||||
nn.init.zeros_(self.lora_B)
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
nn.Module.train(self, mode)
|
||||
if self.merge_weights and self.merged:
|
||||
# Make sure that the weights are not merged
|
||||
if self.r > 0:
|
||||
self.weight.data -= T(self.lora_B @ self.lora_A) * self.scaling
|
||||
self.merged = False
|
||||
|
||||
def eval(self):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
nn.Module.eval(self)
|
||||
if self.merge_weights and not self.merged:
|
||||
# Merge the weights and mark it
|
||||
if self.r > 0:
|
||||
self.weight.data += T(self.lora_B @ self.lora_A) * self.scaling
|
||||
delattr(self, 'lora_A')
|
||||
delattr(self, 'lora_B')
|
||||
self.merged = True
|
||||
|
||||
def forward(self, x: torch.Tensor):
|
||||
|
||||
def T(w):
|
||||
return w.T if self.fan_in_fan_out else w
|
||||
|
||||
if self.r > 0 and not self.merged:
|
||||
result = F.linear(x, T(self.weight), bias=self.bias)
|
||||
if self.r > 0:
|
||||
result = result + (self.lora_dropout(x) @ self.lora_A.t() @ self.lora_B.t()) * self.scaling
|
||||
return result
|
||||
else:
|
||||
return F.linear(x, T(self.weight), bias=self.bias)
|
||||
|
||||
|
||||
def lora_linear_wrapper(linear: nn.Linear, lora_rank: int) -> LoraLinear:
|
||||
assert lora_rank <= linear.in_features, f'LoRA rank ({lora_rank}) must be less than or equal to in features ({linear.in_features})'
|
||||
lora_linear = LoraLinear(linear.weight, linear.bias, r=lora_rank, merge_weights=False)
|
||||
return lora_linear
|
||||
|
||||
|
||||
def convert_to_lora_recursively(module: nn.Module, lora_rank: int) -> None:
|
||||
for name, child in module.named_children():
|
||||
if isinstance(child, nn.Linear):
|
||||
setattr(module, name, lora_linear_wrapper(child, lora_rank))
|
||||
else:
|
||||
convert_to_lora_recursively(child, lora_rank)
|
||||
|
||||
|
||||
class LoRAModule(nn.Module):
|
||||
"""A LoRA module base class. All derived classes should call `convert_to_lora()` at the bottom of `__init__()`.
|
||||
This calss will convert all torch.nn.Linear layer to LoraLinear layer.
|
||||
|
||||
Args:
|
||||
lora_rank (int, optional): LoRA rank. 0 means LoRA is not applied. Defaults to 0.
|
||||
lora_train_bias (str, optional): Whether LoRA train biases.
|
||||
'none' means it doesn't train biases. 'all' means it trains all biases. 'lora_only' means it only trains biases of LoRA layers.
|
||||
Defaults to 'none'.
|
||||
"""
|
||||
|
||||
def __init__(self, lora_rank: int = 0, lora_train_bias: str = 'none') -> None:
|
||||
super().__init__()
|
||||
self.lora_rank = lora_rank
|
||||
self.lora_train_bias = lora_train_bias
|
||||
|
||||
def convert_to_lora(self) -> None:
|
||||
if self.lora_rank <= 0:
|
||||
return
|
||||
convert_to_lora_recursively(self, self.lora_rank)
|
||||
lora.mark_only_lora_as_trainable(self, self.lora_train_bias)
|
|
@ -0,0 +1,117 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
from .utils import masked_mean
|
||||
|
||||
|
||||
class GPTLMLoss(nn.Module):
|
||||
"""
|
||||
GPT Language Model Loss
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.loss = nn.CrossEntropyLoss()
|
||||
|
||||
def forward(self, logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
shift_logits = logits[..., :-1, :].contiguous()
|
||||
shift_labels = labels[..., 1:].contiguous()
|
||||
# Flatten the tokens
|
||||
return self.loss(shift_logits.view(-1, shift_logits.size(-1)), shift_labels.view(-1))
|
||||
|
||||
|
||||
class PolicyLoss(nn.Module):
|
||||
"""
|
||||
Policy Loss for PPO
|
||||
"""
|
||||
|
||||
def __init__(self, clip_eps: float = 0.2) -> None:
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
|
||||
def forward(self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
ratio = (log_probs - old_log_probs).exp()
|
||||
surr1 = ratio * advantages
|
||||
surr2 = ratio.clamp(1 - self.clip_eps, 1 + self.clip_eps) * advantages
|
||||
loss = -torch.min(surr1, surr2)
|
||||
if action_mask is not None:
|
||||
loss = masked_mean(loss, action_mask)
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
|
||||
class ValueLoss(nn.Module):
|
||||
"""
|
||||
Value Loss for PPO
|
||||
"""
|
||||
|
||||
def __init__(self, clip_eps: float = 0.4) -> None:
|
||||
super().__init__()
|
||||
self.clip_eps = clip_eps
|
||||
|
||||
def forward(self,
|
||||
values: torch.Tensor,
|
||||
old_values: torch.Tensor,
|
||||
reward: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
values_clipped = old_values + (values - old_values).clamp(-self.clip_eps, self.clip_eps)
|
||||
surr1 = (values_clipped - reward)**2
|
||||
surr2 = (values - reward)**2
|
||||
loss = torch.max(surr1, surr2)
|
||||
loss = loss.mean()
|
||||
return loss
|
||||
|
||||
|
||||
class PPOPtxActorLoss(nn.Module):
|
||||
"""
|
||||
To Do:
|
||||
|
||||
PPO-ptx Actor Loss
|
||||
"""
|
||||
|
||||
def __init__(self, policy_clip_eps: float = 0.2, pretrain_coef: float = 0.0, pretrain_loss_fn=GPTLMLoss()) -> None:
|
||||
super().__init__()
|
||||
self.pretrain_coef = pretrain_coef
|
||||
self.policy_loss_fn = PolicyLoss(clip_eps=policy_clip_eps)
|
||||
self.pretrain_loss_fn = pretrain_loss_fn
|
||||
|
||||
def forward(self,
|
||||
log_probs: torch.Tensor,
|
||||
old_log_probs: torch.Tensor,
|
||||
advantages: torch.Tensor,
|
||||
lm_logits: torch.Tensor,
|
||||
lm_input_ids: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
policy_loss = self.policy_loss_fn(log_probs, old_log_probs, advantages, action_mask=action_mask)
|
||||
lm_loss = self.pretrain_loss_fn(lm_logits, lm_input_ids)
|
||||
return policy_loss + self.pretrain_coef * lm_loss
|
||||
|
||||
|
||||
class LogSigLoss(nn.Module):
|
||||
"""
|
||||
Pairwise Loss for Reward Model
|
||||
Details: https://arxiv.org/abs/2203.02155
|
||||
"""
|
||||
|
||||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
||||
probs = torch.sigmoid(chosen_reward - reject_reward)
|
||||
log_probs = torch.log(probs)
|
||||
loss = -log_probs.mean()
|
||||
return loss
|
||||
|
||||
|
||||
class LogExpLoss(nn.Module):
|
||||
"""
|
||||
Pairwise Loss for Reward Model
|
||||
Details: https://arxiv.org/abs/2204.05862
|
||||
"""
|
||||
|
||||
def forward(self, chosen_reward: torch.Tensor, reject_reward: torch.Tensor) -> torch.Tensor:
|
||||
loss = torch.log(1 + torch.exp(reject_reward - chosen_reward)).mean()
|
||||
return loss
|
|
@ -0,0 +1,6 @@
|
|||
from .opt_actor import OPTActor
|
||||
from .opt_critic import OPTCritic
|
||||
from .opt_lm import OPTLM
|
||||
from .opt_rm import OPTRM
|
||||
|
||||
__all__ = ['OPTActor', 'OPTCritic', 'OPTRM', 'OPTLM']
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
from ..base import Actor
|
||||
|
||||
|
||||
class OPTActor(Actor):
|
||||
"""
|
||||
OPT Actor model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
model = OPTForCausalLM(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,38 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTModel
|
||||
|
||||
from ..base import Critic
|
||||
|
||||
|
||||
class OPTCritic(Critic):
|
||||
"""
|
||||
OPT Critic model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none',
|
||||
**kwargs) -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias, **kwargs)
|
|
@ -0,0 +1,35 @@
|
|||
from typing import Optional
|
||||
|
||||
from transformers.models.opt.configuration_opt import OPTConfig
|
||||
from transformers.models.opt.modeling_opt import OPTForCausalLM
|
||||
|
||||
from ..base import LM
|
||||
|
||||
|
||||
class OPTLM(LM):
|
||||
"""
|
||||
OPT language model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTForCausalLM.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTForCausalLM(config)
|
||||
else:
|
||||
model = OPTForCausalLM(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
super().__init__(model, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,38 @@
|
|||
from typing import Optional
|
||||
|
||||
import torch.nn as nn
|
||||
from transformers import OPTConfig, OPTModel
|
||||
|
||||
from ..base import RewardModel
|
||||
|
||||
|
||||
class OPTRM(RewardModel):
|
||||
"""
|
||||
OPT Reward model.
|
||||
|
||||
Args:
|
||||
pretrained (str): Pretrained model name or path.
|
||||
config (OPTConfig): Model config.
|
||||
checkpoint (bool): Enable gradient checkpointing.
|
||||
lora_rank (int): Rank of the low-rank approximation.
|
||||
lora_train_bias (str): LoRA bias training mode.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
pretrained: Optional[str] = None,
|
||||
config: Optional[OPTConfig] = None,
|
||||
checkpoint: bool = False,
|
||||
lora_rank: int = 0,
|
||||
lora_train_bias: str = 'none') -> None:
|
||||
if pretrained is not None:
|
||||
model = OPTModel.from_pretrained(pretrained)
|
||||
elif config is not None:
|
||||
model = OPTModel(config)
|
||||
else:
|
||||
model = OPTModel(OPTConfig())
|
||||
if checkpoint:
|
||||
model.gradient_checkpointing_enable()
|
||||
|
||||
value_head = nn.Linear(model.config.word_embed_proj_dim, 1)
|
||||
value_head.weight.data.normal_(mean=0.0, std=1 / (model.config.word_embed_proj_dim + 1))
|
||||
super().__init__(model, value_head, lora_rank, lora_train_bias)
|
|
@ -0,0 +1,92 @@
|
|||
from typing import Optional, Union
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
def compute_approx_kl(log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
"""
|
||||
Compute the approximate KL divergence between two distributions.
|
||||
Schulman blog: http://joschu.net/blog/kl-approx.html
|
||||
|
||||
Args:
|
||||
log_probs: Log probabilities of the new distribution.
|
||||
log_probs_base: Log probabilities of the base distribution.
|
||||
action_mask: Mask for actions.
|
||||
"""
|
||||
|
||||
log_ratio = log_probs - log_probs_base
|
||||
approx_kl = (log_ratio.exp() - 1) - log_ratio
|
||||
if action_mask is not None:
|
||||
approx_kl = masked_mean(approx_kl, action_mask, dim=1)
|
||||
return approx_kl
|
||||
approx_kl = approx_kl.mean(dim=1)
|
||||
return approx_kl
|
||||
|
||||
|
||||
def compute_reward(r: Union[torch.Tensor, float],
|
||||
kl_coef: float,
|
||||
log_probs: torch.Tensor,
|
||||
log_probs_base: torch.Tensor,
|
||||
action_mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
||||
if kl_coef <= 0.0:
|
||||
return r
|
||||
kl = compute_approx_kl(log_probs, log_probs_base, action_mask=action_mask)
|
||||
reward = r - kl_coef * kl
|
||||
return reward
|
||||
|
||||
|
||||
def log_probs_from_logits(logits: torch.Tensor, labels: torch.Tensor) -> torch.Tensor:
|
||||
log_probs = F.log_softmax(logits, dim=-1)
|
||||
log_probs_labels = log_probs.gather(dim=-1, index=labels.unsqueeze(-1))
|
||||
return log_probs_labels.squeeze(-1)
|
||||
|
||||
|
||||
def masked_mean(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1) -> torch.Tensor:
|
||||
tensor = tensor * mask
|
||||
tensor = tensor.sum(dim=dim)
|
||||
mask_sum = mask.sum(dim=dim)
|
||||
mean = tensor / (mask_sum + 1e-8)
|
||||
return mean
|
||||
|
||||
|
||||
def masked_normalize(tensor: torch.Tensor, mask: torch.Tensor, dim: int = 1, eps: float = 1e-8) -> torch.Tensor:
|
||||
tensor = tensor * mask
|
||||
mean = masked_mean(tensor, mask, dim=dim)
|
||||
mean_centered = tensor - mean
|
||||
var = masked_mean(mean_centered**2, mask, dim=dim)
|
||||
return mean_centered * var.clamp(min=eps).rsqrt()
|
||||
|
||||
|
||||
def normalize(tensor: torch.Tensor, dim: int = 0, eps: float = 1e-8) -> torch.Tensor:
|
||||
mean = tensor.mean(dim)
|
||||
mean_centered = tensor - mean
|
||||
var = (mean_centered**2).mean(dim)
|
||||
norm = mean_centered * var.clamp(min=eps).rsqrt()
|
||||
return norm
|
||||
|
||||
|
||||
def convert_to_lora(model: nn.Module,
|
||||
input_size: int,
|
||||
output_size: int,
|
||||
lora_rank: int = 16,
|
||||
lora_alpha: int = 1,
|
||||
lora_dropout: float = 0.,
|
||||
fan_in_fan_out: bool = False,
|
||||
merge_weights: bool = True):
|
||||
if lora_rank > min(input_size, output_size):
|
||||
raise ValueError(f"LoRA rank {lora_rank} must be less or equal than {min(input_size, output_size)}")
|
||||
|
||||
for name, module in model.named_modules():
|
||||
if isinstance(module, nn.Linear):
|
||||
module._modules[name] = lora.Linear(input_size,
|
||||
output_size,
|
||||
r=lora_rank,
|
||||
lora_alpha=lora_alpha,
|
||||
lora_dropout=lora_dropout,
|
||||
fan_in_fan_out=fan_in_fan_out,
|
||||
merge_weights=merge_weights)
|
|
@ -0,0 +1,4 @@
|
|||
from .base import ReplayBuffer
|
||||
from .naive import NaiveReplayBuffer
|
||||
|
||||
__all__ = ['ReplayBuffer', 'NaiveReplayBuffer']
|
|
@ -0,0 +1,43 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any
|
||||
|
||||
from coati.experience_maker.base import Experience
|
||||
|
||||
|
||||
class ReplayBuffer(ABC):
|
||||
"""Replay buffer base class. It stores experience.
|
||||
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_batch_size: int, limit: int = 0) -> None:
|
||||
super().__init__()
|
||||
self.sample_batch_size = sample_batch_size
|
||||
# limit <= 0 means unlimited
|
||||
self.limit = limit
|
||||
|
||||
@abstractmethod
|
||||
def append(self, experience: Experience) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def clear(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def sample(self) -> Experience:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __len__(self) -> int:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def __getitem__(self, idx: int) -> Any:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def collate_fn(self, batch: Any) -> Experience:
|
||||
pass
|
|
@ -0,0 +1,57 @@
|
|||
import random
|
||||
from typing import List
|
||||
|
||||
import torch
|
||||
from coati.experience_maker.base import Experience
|
||||
|
||||
from .base import ReplayBuffer
|
||||
from .utils import BufferItem, make_experience_batch, split_experience_batch
|
||||
|
||||
|
||||
class NaiveReplayBuffer(ReplayBuffer):
|
||||
"""Naive replay buffer class. It stores experience.
|
||||
|
||||
Args:
|
||||
sample_batch_size (int): Batch size when sampling.
|
||||
limit (int, optional): Limit of number of experience samples. A number <= 0 means unlimited. Defaults to 0.
|
||||
cpu_offload (bool, optional): Whether to offload experience to cpu when sampling. Defaults to True.
|
||||
"""
|
||||
|
||||
def __init__(self, sample_batch_size: int, limit: int = 0, cpu_offload: bool = True) -> None:
|
||||
super().__init__(sample_batch_size, limit)
|
||||
self.cpu_offload = cpu_offload
|
||||
self.target_device = torch.device(f'cuda:{torch.cuda.current_device()}')
|
||||
# TODO(ver217): add prefetch
|
||||
self.items: List[BufferItem] = []
|
||||
|
||||
@torch.no_grad()
|
||||
def append(self, experience: Experience) -> None:
|
||||
if self.cpu_offload:
|
||||
experience.to_device(torch.device('cpu'))
|
||||
items = split_experience_batch(experience)
|
||||
self.items.extend(items)
|
||||
if self.limit > 0:
|
||||
samples_to_remove = len(self.items) - self.limit
|
||||
if samples_to_remove > 0:
|
||||
self.items = self.items[samples_to_remove:]
|
||||
|
||||
def clear(self) -> None:
|
||||
self.items.clear()
|
||||
|
||||
@torch.no_grad()
|
||||
def sample(self) -> Experience:
|
||||
items = random.sample(self.items, self.sample_batch_size)
|
||||
experience = make_experience_batch(items)
|
||||
if self.cpu_offload:
|
||||
experience.to_device(self.target_device)
|
||||
return experience
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.items)
|
||||
|
||||
def __getitem__(self, idx: int) -> BufferItem:
|
||||
return self.items[idx]
|
||||
|
||||
def collate_fn(self, batch) -> Experience:
|
||||
experience = make_experience_batch(batch)
|
||||
return experience
|
|
@ -0,0 +1,73 @@
|
|||
from dataclasses import dataclass
|
||||
from typing import List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from coati.experience_maker.base import Experience
|
||||
|
||||
|
||||
@dataclass
|
||||
class BufferItem:
|
||||
"""BufferItem is an item of experience data.
|
||||
|
||||
Shapes of each tensor:
|
||||
sequences: (S)
|
||||
action_log_probs: (A)
|
||||
values: (1)
|
||||
reward: (1)
|
||||
advatanges: (1)
|
||||
attention_mask: (S)
|
||||
action_mask: (A)
|
||||
|
||||
"A" is the number of actions.
|
||||
"""
|
||||
sequences: torch.Tensor
|
||||
action_log_probs: torch.Tensor
|
||||
values: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
advantages: torch.Tensor
|
||||
attention_mask: Optional[torch.LongTensor]
|
||||
action_mask: Optional[torch.BoolTensor]
|
||||
|
||||
|
||||
def split_experience_batch(experience: Experience) -> List[BufferItem]:
|
||||
batch_size = experience.sequences.size(0)
|
||||
batch_kwargs = [{} for _ in range(batch_size)]
|
||||
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
for key in keys:
|
||||
value = getattr(experience, key)
|
||||
if isinstance(value, torch.Tensor):
|
||||
vals = torch.unbind(value)
|
||||
else:
|
||||
# None
|
||||
vals = [value for _ in range(batch_size)]
|
||||
assert batch_size == len(vals)
|
||||
for i, v in enumerate(vals):
|
||||
batch_kwargs[i][key] = v
|
||||
items = [BufferItem(**kwargs) for kwargs in batch_kwargs]
|
||||
return items
|
||||
|
||||
|
||||
def zero_pad_sequences(sequences: List[torch.Tensor], side: str = 'left') -> torch.Tensor:
|
||||
assert side in ('left', 'right')
|
||||
max_len = max(seq.size(0) for seq in sequences)
|
||||
padded_sequences = []
|
||||
for seq in sequences:
|
||||
pad_len = max_len - seq.size(0)
|
||||
padding = (pad_len, 0) if side == 'left' else (0, pad_len)
|
||||
padded_sequences.append(F.pad(seq, padding))
|
||||
return torch.stack(padded_sequences, dim=0)
|
||||
|
||||
|
||||
def make_experience_batch(items: List[BufferItem]) -> Experience:
|
||||
kwargs = {}
|
||||
to_pad_keys = set(('action_log_probs', 'action_mask'))
|
||||
keys = ('sequences', 'action_log_probs', 'values', 'reward', 'advantages', 'attention_mask', 'action_mask')
|
||||
for key in keys:
|
||||
vals = [getattr(item, key) for item in items]
|
||||
if key in to_pad_keys:
|
||||
batch_data = zero_pad_sequences(vals)
|
||||
else:
|
||||
batch_data = torch.stack(vals, dim=0)
|
||||
kwargs[key] = batch_data
|
||||
return Experience(**kwargs)
|
|
@ -0,0 +1,6 @@
|
|||
from .base import Trainer
|
||||
from .ppo import PPOTrainer
|
||||
from .rm import RewardModelTrainer
|
||||
from .sft import SFTTrainer
|
||||
|
||||
__all__ = ['Trainer', 'PPOTrainer', 'RewardModelTrainer', 'SFTTrainer']
|
|
@ -0,0 +1,168 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
import torch
|
||||
from coati.experience_maker import Experience, ExperienceMaker
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch import Tensor
|
||||
from torch.utils.data import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
|
||||
from .callbacks import Callback
|
||||
from .strategies import Strategy
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
||||
class Trainer(ABC):
|
||||
"""
|
||||
Base class for rlhf trainers.
|
||||
|
||||
Args:
|
||||
strategy (Strategy):the strategy to use for training
|
||||
experience_maker (ExperienceMaker): the experience maker to use for produce experience to fullfill replay buffer
|
||||
replay_buffer (ReplayBuffer): the replay buffer to use for training
|
||||
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||
tokenizer (Callable, optional): the tokenizer to use for tokenizing the input
|
||||
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
|
||||
data_loader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
experience_maker: ExperienceMaker,
|
||||
replay_buffer: ReplayBuffer,
|
||||
experience_batch_size: int = 8,
|
||||
max_epochs: int = 1,
|
||||
tokenizer: Optional[Callable[[Any], dict]] = None,
|
||||
sample_replay_buffer: bool = False,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.experience_maker = experience_maker
|
||||
self.replay_buffer = replay_buffer
|
||||
self.experience_batch_size = experience_batch_size
|
||||
self.max_epochs = max_epochs
|
||||
self.tokenizer = tokenizer
|
||||
self.generate_kwargs = generate_kwargs
|
||||
self.sample_replay_buffer = sample_replay_buffer
|
||||
self.dataloader_pin_memory = dataloader_pin_memory
|
||||
self.callbacks = callbacks
|
||||
|
||||
@abstractmethod
|
||||
def training_step(self, experience: Experience) -> Dict[str, Any]:
|
||||
pass
|
||||
|
||||
def _make_experience(self, inputs: Union[Tensor, Dict[str, Tensor]]) -> Experience:
|
||||
if isinstance(inputs, Tensor):
|
||||
return self.experience_maker.make_experience(inputs, **self.generate_kwargs)
|
||||
elif isinstance(inputs, dict):
|
||||
return self.experience_maker.make_experience(**inputs, **self.generate_kwargs)
|
||||
else:
|
||||
raise ValueError(f'Unsupported input type "{type(inputs)}"')
|
||||
|
||||
def _sample_prompts(self, prompts) -> list:
|
||||
indices = list(range(len(prompts)))
|
||||
sampled_indices = self.strategy.experience_sampler.choice(indices, self.experience_batch_size, replace=False)
|
||||
return [prompts[i] for i in sampled_indices]
|
||||
|
||||
def _learn(self):
|
||||
# replay buffer may be empty at first, we should rebuild at each training
|
||||
if not self.sample_replay_buffer:
|
||||
dataloader = self.strategy.setup_dataloader(self.replay_buffer, self.dataloader_pin_memory)
|
||||
device = torch.cuda.current_device()
|
||||
if self.sample_replay_buffer:
|
||||
pbar = tqdm(range(self.max_epochs), desc='Train epoch', disable=not is_rank_0())
|
||||
for _ in pbar:
|
||||
experience = self.replay_buffer.sample()
|
||||
metrics = self.training_step(experience)
|
||||
pbar.set_postfix(metrics)
|
||||
else:
|
||||
for epoch in range(self.max_epochs):
|
||||
self._on_learn_epoch_start(epoch)
|
||||
if isinstance(dataloader.sampler, DistributedSampler):
|
||||
dataloader.sampler.set_epoch(epoch)
|
||||
pbar = tqdm(dataloader, desc=f'Train epoch [{epoch+1}/{self.max_epochs}]', disable=not is_rank_0())
|
||||
for experience in pbar:
|
||||
self._on_learn_batch_start()
|
||||
experience.to_device(device)
|
||||
metrics = self.training_step(experience)
|
||||
self._on_learn_batch_end(metrics, experience)
|
||||
pbar.set_postfix(metrics)
|
||||
self._on_learn_epoch_end(epoch)
|
||||
|
||||
def fit(self,
|
||||
prompt_dataloader,
|
||||
pretrain_dataloader,
|
||||
num_episodes: int = 50000,
|
||||
max_timesteps: int = 500,
|
||||
update_timesteps: int = 5000) -> None:
|
||||
time = 0
|
||||
self.pretrain_dataloader = pretrain_dataloader
|
||||
self.prompt_dataloader = prompt_dataloader
|
||||
self._on_fit_start()
|
||||
for episode in range(num_episodes):
|
||||
self._on_episode_start(episode)
|
||||
for timestep in tqdm(range(max_timesteps),
|
||||
desc=f'Episode [{episode+1}/{num_episodes}]',
|
||||
disable=not is_rank_0()):
|
||||
time += 1
|
||||
prompts = next(iter(self.prompt_dataloader))
|
||||
self._on_make_experience_start()
|
||||
self.experience_maker.initial_model.to(torch.cuda.current_device())
|
||||
self.experience_maker.reward_model.to(torch.cuda.current_device())
|
||||
experience = self._make_experience(prompts)
|
||||
self._on_make_experience_end(experience)
|
||||
self.replay_buffer.append(experience)
|
||||
if time % update_timesteps == 0:
|
||||
self.experience_maker.initial_model.to('cpu')
|
||||
self.experience_maker.reward_model.to('cpu')
|
||||
self._learn()
|
||||
self.replay_buffer.clear()
|
||||
self._on_episode_end(episode)
|
||||
self._on_fit_end()
|
||||
|
||||
# TODO(ver217): maybe simplify these code using context
|
||||
def _on_fit_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_start()
|
||||
|
||||
def _on_fit_end(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_fit_end()
|
||||
|
||||
def _on_episode_start(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_start(episode)
|
||||
|
||||
def _on_episode_end(self, episode: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_episode_end(episode)
|
||||
|
||||
def _on_make_experience_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_make_experience_start()
|
||||
|
||||
def _on_make_experience_end(self, experience: Experience) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_make_experience_end(experience)
|
||||
|
||||
def _on_learn_epoch_start(self, epoch: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_learn_epoch_start(epoch)
|
||||
|
||||
def _on_learn_epoch_end(self, epoch: int) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_learn_epoch_end(epoch)
|
||||
|
||||
def _on_learn_batch_start(self) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_learn_batch_start()
|
||||
|
||||
def _on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
for callback in self.callbacks:
|
||||
callback.on_learn_batch_end(metrics, experience)
|
|
@ -0,0 +1,5 @@
|
|||
from .base import Callback
|
||||
from .performance_evaluator import PerformanceEvaluator
|
||||
from .save_checkpoint import SaveCheckpoint
|
||||
|
||||
__all__ = ['Callback', 'PerformanceEvaluator', 'SaveCheckpoint']
|
|
@ -0,0 +1,39 @@
|
|||
from abc import ABC
|
||||
|
||||
from coati.experience_maker import Experience
|
||||
|
||||
|
||||
class Callback(ABC):
|
||||
"""
|
||||
Base callback class. It defines the interface for callbacks.
|
||||
"""
|
||||
|
||||
def on_fit_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
pass
|
||||
|
||||
def on_episode_start(self, episode: int) -> None:
|
||||
pass
|
||||
|
||||
def on_episode_end(self, episode: int) -> None:
|
||||
pass
|
||||
|
||||
def on_make_experience_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_make_experience_end(self, experience: Experience) -> None:
|
||||
pass
|
||||
|
||||
def on_learn_epoch_start(self, epoch: int) -> None:
|
||||
pass
|
||||
|
||||
def on_learn_epoch_end(self, epoch: int) -> None:
|
||||
pass
|
||||
|
||||
def on_learn_batch_start(self) -> None:
|
||||
pass
|
||||
|
||||
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
pass
|
|
@ -0,0 +1,133 @@
|
|||
from time import time
|
||||
from typing import Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.experience_maker import Experience
|
||||
|
||||
from .base import Callback
|
||||
|
||||
|
||||
def get_world_size() -> int:
|
||||
if dist.is_initialized():
|
||||
return dist.get_world_size()
|
||||
return 1
|
||||
|
||||
|
||||
def print_rank_0(*args, **kwargs) -> None:
|
||||
if not dist.is_initialized() or dist.get_rank() == 0:
|
||||
print(*args, **kwargs)
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def all_reduce_mean(x: float, world_size: int) -> float:
|
||||
if world_size == 1:
|
||||
return x
|
||||
tensor = torch.tensor([x], device=torch.cuda.current_device())
|
||||
dist.all_reduce(tensor)
|
||||
tensor = tensor / world_size
|
||||
return tensor.item()
|
||||
|
||||
|
||||
class PerformanceEvaluator(Callback):
|
||||
"""
|
||||
Callback for valuate the performance of the model.
|
||||
Args:
|
||||
actor_num_params: The number of parameters of the actor model.
|
||||
critic_num_params: The number of parameters of the critic model.
|
||||
initial_model_num_params: The number of parameters of the initial model.
|
||||
reward_model_num_params: The number of parameters of the reward model.
|
||||
enable_grad_checkpoint: Whether to enable gradient checkpointing.
|
||||
ignore_episodes: The number of episodes to ignore when calculating the performance.
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
actor_num_params: int,
|
||||
critic_num_params: int,
|
||||
initial_model_num_params: int,
|
||||
reward_model_num_params: int,
|
||||
enable_grad_checkpoint: bool = False,
|
||||
ignore_episodes: int = 0) -> None:
|
||||
super().__init__()
|
||||
self.world_size = get_world_size()
|
||||
self.actor_num_params = actor_num_params
|
||||
self.critic_num_params = critic_num_params
|
||||
self.initial_model_num_params = initial_model_num_params
|
||||
self.reward_model_num_params = reward_model_num_params
|
||||
self.enable_grad_checkpoint = enable_grad_checkpoint
|
||||
self.ignore_episodes = ignore_episodes
|
||||
self.disable: bool = False
|
||||
|
||||
self.make_experience_duration: float = 0.
|
||||
self.make_experience_start_time: Optional[float] = None
|
||||
self.make_experience_num_samples: int = 0
|
||||
self.make_experience_flop: int = 0
|
||||
self.learn_duration: float = 0.
|
||||
self.learn_start_time: Optional[float] = None
|
||||
self.learn_num_samples: int = 0
|
||||
self.learn_flop: int = 0
|
||||
|
||||
def on_episode_start(self, episode: int) -> None:
|
||||
self.disable = self.ignore_episodes > 0 and episode < self.ignore_episodes
|
||||
|
||||
def on_make_experience_start(self) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
self.make_experience_start_time = time()
|
||||
|
||||
def on_make_experience_end(self, experience: Experience) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
self.make_experience_duration += time() - self.make_experience_start_time
|
||||
|
||||
batch_size, seq_len = experience.sequences.shape
|
||||
|
||||
self.make_experience_num_samples += batch_size
|
||||
|
||||
# actor generate
|
||||
num_actions = experience.action_mask.size(1)
|
||||
input_len = seq_len - num_actions
|
||||
total_seq_len = (input_len + seq_len - 1) * num_actions / 2
|
||||
self.make_experience_flop += self.actor_num_params * batch_size * total_seq_len * 2
|
||||
# actor forward
|
||||
self.make_experience_flop += self.actor_num_params * batch_size * seq_len * 2
|
||||
# critic forward
|
||||
self.make_experience_flop += self.critic_num_params * batch_size * seq_len * 2
|
||||
# initial model forward
|
||||
self.make_experience_flop += self.initial_model_num_params * batch_size * seq_len * 2
|
||||
# reward model forward
|
||||
self.make_experience_flop += self.reward_model_num_params * batch_size * seq_len * 2
|
||||
|
||||
def on_learn_batch_start(self) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
self.learn_start_time = time()
|
||||
|
||||
def on_learn_batch_end(self, metrics: dict, experience: Experience) -> None:
|
||||
if self.disable:
|
||||
return
|
||||
self.learn_duration += time() - self.learn_start_time
|
||||
|
||||
batch_size, seq_len = experience.sequences.shape
|
||||
|
||||
self.learn_num_samples += batch_size
|
||||
|
||||
# actor forward-backward, 3 means forward(1) + backward(2)
|
||||
self.learn_flop += self.actor_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
# critic foward-backward
|
||||
self.learn_flop += self.critic_num_params * batch_size * seq_len * 2 * (3 + int(self.enable_grad_checkpoint))
|
||||
|
||||
def on_fit_end(self) -> None:
|
||||
avg_make_experience_duration = all_reduce_mean(self.make_experience_duration, self.world_size)
|
||||
avg_learn_duration = all_reduce_mean(self.learn_duration, self.world_size)
|
||||
|
||||
avg_make_experience_throughput = self.make_experience_num_samples / (avg_make_experience_duration + 1e-12)
|
||||
avg_make_experience_tflops = self.make_experience_flop / 1e12 / (avg_make_experience_duration + 1e-12)
|
||||
|
||||
avg_learn_throughput = self.learn_num_samples / (avg_learn_duration + 1e-12)
|
||||
avg_learn_tflops = self.learn_flop / 1e12 / (avg_learn_duration + 1e-12)
|
||||
|
||||
print_rank_0(
|
||||
f'Making experience throughput: {avg_make_experience_throughput:.3f} samples/sec, TFLOPS: {avg_make_experience_tflops:.3f}'
|
||||
)
|
||||
print_rank_0(f'Learning throughput: {avg_learn_throughput:.3f} samples/sec, TFLOPS: {avg_learn_tflops:.3f}')
|
|
@ -0,0 +1,75 @@
|
|||
import os
|
||||
|
||||
import torch.distributed as dist
|
||||
from coati.trainer.strategies import ColossalAIStrategy, Strategy
|
||||
from coati.trainer.utils import is_rank_0
|
||||
from torch import nn
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from .base import Callback
|
||||
|
||||
|
||||
class SaveCheckpoint(Callback):
|
||||
"""
|
||||
The callback for saving checkpoint for coati.
|
||||
|
||||
Only support saving actor and critic model.
|
||||
A typical architecture of the saved checkpoint would be:
|
||||
- checkpoint
|
||||
- episode_x
|
||||
- actor.pt
|
||||
- actor-optim-rank-0.pt
|
||||
- actor-optim-rank-1.pt
|
||||
- critic.pt
|
||||
- critic-optim-rank-0.pt
|
||||
- critic-optim-rank-1.pt
|
||||
- ...
|
||||
|
||||
Args:
|
||||
path(str): the base path you want to save checkpoint, the checkpoint would be saved at `path/checkpoint`
|
||||
interval(int): the interval episode of saving checkpoint
|
||||
strategy(Strategy): the strategy used to train
|
||||
actor(nn.Module): the actor model
|
||||
critic(nn.Module): the critic model
|
||||
actor_optim(Optimizer): the optimizer of actor
|
||||
critic_optim(Optimizer): the optimizer of critic
|
||||
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
path: str,
|
||||
interval: int,
|
||||
strategy: Strategy,
|
||||
actor: nn.Module = None,
|
||||
critic: nn.Module = None,
|
||||
actor_optim: Optimizer = None,
|
||||
critic_optim: Optimizer = None) -> None:
|
||||
super().__init__()
|
||||
self.path = os.path.join(path, 'checkpoint')
|
||||
self.interval = interval
|
||||
self.strategy = strategy
|
||||
self.model_dict = {'actor': [actor, actor_optim], 'critic': [critic, critic_optim]}
|
||||
|
||||
def on_episode_end(self, episode: int) -> None:
|
||||
if (episode + 1) % self.interval != 0:
|
||||
return
|
||||
base_path = os.path.join(self.path, f'episode_{episode}')
|
||||
if not os.path.exists(base_path):
|
||||
os.makedirs(base_path)
|
||||
|
||||
for model in self.model_dict.keys():
|
||||
|
||||
# save model
|
||||
if self.model_dict[model][0] is None:
|
||||
# saving only optimizer states is meaningless, so it would be skipped
|
||||
continue
|
||||
model_path = os.path.join(base_path, f'{model}.pt')
|
||||
self.strategy.save_model(model=self.model_dict[model][0], path=model_path, only_rank0=True)
|
||||
|
||||
# save optimizer
|
||||
if self.model_dict[model][1] is None:
|
||||
continue
|
||||
only_rank0 = not isinstance(self.strategy, ColossalAIStrategy)
|
||||
rank = 0 if is_rank_0() else dist.get_rank()
|
||||
optim_path = os.path.join(base_path, f'{model}-optim-rank-{rank}.pt')
|
||||
self.strategy.save_optimizer(optimizer=self.model_dict[model][1], path=optim_path, only_rank0=only_rank0)
|
|
@ -0,0 +1,135 @@
|
|||
from typing import Any, Callable, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.experience_maker import Experience, NaiveExperienceMaker
|
||||
from coati.models.base import Actor, Critic
|
||||
from coati.models.generation_utils import update_model_kwargs_fn
|
||||
from coati.models.loss import PolicyLoss, ValueLoss
|
||||
from coati.replay_buffer import NaiveReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from .base import Trainer
|
||||
from .callbacks import Callback
|
||||
from .strategies import Strategy
|
||||
|
||||
|
||||
class PPOTrainer(Trainer):
|
||||
"""
|
||||
Trainer for PPO algorithm.
|
||||
|
||||
Args:
|
||||
strategy (Strategy): the strategy to use for training
|
||||
actor (Actor): the actor model in ppo algorithm
|
||||
critic (Critic): the critic model in ppo algorithm
|
||||
reward_model (nn.Module): the reward model in rlhf algorithm to make reward of sentences
|
||||
initial_model (Actor): the initial model in rlhf algorithm to generate reference logits to limit the update of actor
|
||||
actor_optim (Optimizer): the optimizer to use for actor model
|
||||
critic_optim (Optimizer): the optimizer to use for critic model
|
||||
kl_coef (float, defaults to 0.1): the coefficient of kl divergence loss
|
||||
train_batch_size (int, defaults to 8): the batch size to use for training
|
||||
buffer_limit (int, defaults to 0): the max_size limitaiton of replay buffer
|
||||
buffer_cpu_offload (bool, defaults to True): whether to offload replay buffer to cpu
|
||||
eps_clip (float, defaults to 0.2): the clip coefficient of policy loss
|
||||
value_clip (float, defaults to 0.4): the clip coefficient of value loss
|
||||
experience_batch_size (int, defaults to 8): the batch size to use for experience generation
|
||||
max_epochs (int, defaults to 1): the number of epochs of training process
|
||||
tokenier (Callable, optional): the tokenizer to use for tokenizing the input
|
||||
sample_replay_buffer (bool, defaults to False): whether to sample from replay buffer
|
||||
dataloader_pin_memory (bool, defaults to True): whether to pin memory for data loader
|
||||
callbacks (List[Callback], defaults to []): the callbacks to call during training process
|
||||
generate_kwargs (dict, optional): the kwargs to use while model generating
|
||||
"""
|
||||
|
||||
def __init__(self,
|
||||
strategy: Strategy,
|
||||
actor: Actor,
|
||||
critic: Critic,
|
||||
reward_model: nn.Module,
|
||||
initial_model: Actor,
|
||||
actor_optim: Optimizer,
|
||||
critic_optim: Optimizer,
|
||||
kl_coef: float = 0.1,
|
||||
ptx_coef: float = 0.9,
|
||||
train_batch_size: int = 8,
|
||||
buffer_limit: int = 0,
|
||||
buffer_cpu_offload: bool = True,
|
||||
eps_clip: float = 0.2,
|
||||
value_clip: float = 0.4,
|
||||
experience_batch_size: int = 8,
|
||||
max_epochs: int = 1,
|
||||
tokenizer: Optional[Callable[[Any], dict]] = None,
|
||||
sample_replay_buffer: bool = False,
|
||||
dataloader_pin_memory: bool = True,
|
||||
callbacks: List[Callback] = [],
|
||||
**generate_kwargs) -> None:
|
||||
experience_maker = NaiveExperienceMaker(actor, critic, reward_model, initial_model, kl_coef)
|
||||
replay_buffer = NaiveReplayBuffer(train_batch_size, buffer_limit, buffer_cpu_offload)
|
||||
generate_kwargs = _set_default_generate_kwargs(strategy, generate_kwargs, actor)
|
||||
super().__init__(strategy, experience_maker, replay_buffer, experience_batch_size, max_epochs, tokenizer,
|
||||
sample_replay_buffer, dataloader_pin_memory, callbacks, **generate_kwargs)
|
||||
self.actor = actor
|
||||
self.critic = critic
|
||||
|
||||
self.actor_loss_fn = PolicyLoss(eps_clip)
|
||||
self.critic_loss_fn = ValueLoss(value_clip)
|
||||
self.ptx_loss_fn = nn.CrossEntropyLoss(ignore_index=-100)
|
||||
self.ptx_coef = ptx_coef
|
||||
self.actor_optim = actor_optim
|
||||
self.critic_optim = critic_optim
|
||||
|
||||
def training_step(self, experience: Experience) -> Dict[str, float]:
|
||||
self.actor.train()
|
||||
self.critic.train()
|
||||
# policy loss
|
||||
num_actions = experience.action_mask.size(1)
|
||||
action_log_probs = self.actor(experience.sequences, num_actions, attention_mask=experience.attention_mask)
|
||||
actor_loss = self.actor_loss_fn(action_log_probs,
|
||||
experience.action_log_probs,
|
||||
experience.advantages,
|
||||
action_mask=experience.action_mask)
|
||||
|
||||
# ptx loss
|
||||
if self.ptx_coef != 0:
|
||||
ptx = next(iter(self.pretrain_dataloader))['input_ids'].to(torch.cuda.current_device())
|
||||
label = next(iter(self.pretrain_dataloader))['labels'].to(torch.cuda.current_device())[:, 1:]
|
||||
attention_mask = next(iter(self.pretrain_dataloader))['attention_mask'].to(torch.cuda.current_device())
|
||||
ptx_log_probs = self.actor.get_base_model()(ptx, attention_mask=attention_mask)['logits'][..., :-1, :]
|
||||
ptx_loss = self.ptx_loss_fn(ptx_log_probs.view(-1, ptx_log_probs.size(-1)), label.view(-1))
|
||||
actor_loss = ptx_loss * self.ptx_coef + actor_loss * (1 - self.ptx_coef)
|
||||
|
||||
self.strategy.backward(actor_loss, self.actor, self.actor_optim)
|
||||
self.strategy.optimizer_step(self.actor_optim)
|
||||
self.actor_optim.zero_grad()
|
||||
|
||||
# value loss
|
||||
values = self.critic(experience.sequences,
|
||||
action_mask=experience.action_mask,
|
||||
attention_mask=experience.attention_mask)
|
||||
critic_loss = self.critic_loss_fn(values,
|
||||
experience.values,
|
||||
experience.reward,
|
||||
action_mask=experience.action_mask)
|
||||
self.strategy.backward(critic_loss, self.critic, self.critic_optim)
|
||||
self.strategy.optimizer_step(self.critic_optim)
|
||||
self.critic_optim.zero_grad()
|
||||
|
||||
return {'reward': experience.reward.mean().item()}
|
||||
|
||||
|
||||
def _set_default_generate_kwargs(strategy: Strategy, generate_kwargs: dict, actor: Actor) -> None:
|
||||
origin_model = strategy._unwrap_actor(actor)
|
||||
new_kwargs = {**generate_kwargs}
|
||||
# use huggingface models method directly
|
||||
if 'prepare_inputs_fn' not in generate_kwargs and hasattr(origin_model, 'prepare_inputs_for_generation'):
|
||||
new_kwargs['prepare_inputs_fn'] = origin_model.prepare_inputs_for_generation
|
||||
|
||||
if 'update_model_kwargs_fn' not in generate_kwargs:
|
||||
new_kwargs['update_model_kwargs_fn'] = update_model_kwargs_fn
|
||||
|
||||
return new_kwargs
|
||||
|
||||
|
||||
def save_model(self, path: str, only_rank0: bool = False, tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
self.strategy.save_model(model=self.actor, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
|
|
@ -0,0 +1,135 @@
|
|||
from abc import ABC
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from torch.optim import Optimizer, lr_scheduler
|
||||
from torch.utils.data import DataLoader, Dataset, DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from .strategies import Strategy
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
||||
class RewardModelTrainer(ABC):
|
||||
"""
|
||||
Trainer to use while training reward model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): the model to train
|
||||
strategy (Strategy): the strategy to use for training
|
||||
optim(Optimizer): the optimizer to use for training
|
||||
loss_fn (callable): the loss function to use for training
|
||||
train_dataset (Dataset): the dataset to use for training
|
||||
valid_dataset (Dataset): the dataset to use for validation
|
||||
eval_dataset (Dataset): the dataset to use for evaluation
|
||||
batch_size (int, defaults to 1): the batch size while training
|
||||
max_epochs (int, defaults to 2): the number of epochs to train
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
strategy: Strategy,
|
||||
optim: Optimizer,
|
||||
loss_fn,
|
||||
train_dataset: Dataset,
|
||||
valid_dataset: Dataset,
|
||||
eval_dataset: Dataset,
|
||||
batch_size: int = 1,
|
||||
max_epochs: int = 1,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.epochs = max_epochs
|
||||
train_sampler = None
|
||||
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(train_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
self.train_dataloader = DataLoader(train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=batch_size)
|
||||
self.valid_dataloader = DataLoader(valid_dataset, batch_size=batch_size, shuffle=True)
|
||||
self.eval_dataloader = DataLoader(eval_dataset, batch_size=batch_size, shuffle=True)
|
||||
|
||||
self.model = strategy.setup_model(model)
|
||||
self.loss_fn = loss_fn
|
||||
self.optimizer = strategy.setup_optimizer(optim, self.model)
|
||||
self.scheduler = lr_scheduler.CosineAnnealingLR(self.optimizer, self.train_dataloader.__len__() // 100)
|
||||
|
||||
def eval_acc(self, dataloader):
|
||||
dist = 0
|
||||
on = 0
|
||||
cnt = 0
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
for chosen_ids, c_mask, reject_ids, r_mask in dataloader:
|
||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||
for i in range(len(chosen_reward)):
|
||||
cnt += 1
|
||||
if chosen_reward[i] > reject_reward[i]:
|
||||
on += 1
|
||||
dist += (chosen_reward - reject_reward).mean().item()
|
||||
dist_mean = dist / len(dataloader)
|
||||
acc = on / cnt
|
||||
self.model.train()
|
||||
return dist_mean, acc
|
||||
|
||||
def fit(self):
|
||||
time = datetime.now()
|
||||
epoch_bar = tqdm(range(self.epochs), desc='Train epoch', disable=not is_rank_0())
|
||||
for epoch in range(self.epochs):
|
||||
step_bar = tqdm(range(self.train_dataloader.__len__()),
|
||||
desc='Train step of epoch %d' % epoch,
|
||||
disable=not is_rank_0())
|
||||
# train
|
||||
self.model.train()
|
||||
cnt = 0
|
||||
acc = 0
|
||||
dist = 0
|
||||
for chosen_ids, c_mask, reject_ids, r_mask in self.train_dataloader:
|
||||
chosen_ids = chosen_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
c_mask = c_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
reject_ids = reject_ids.squeeze(1).to(torch.cuda.current_device())
|
||||
r_mask = r_mask.squeeze(1).to(torch.cuda.current_device())
|
||||
chosen_reward = self.model(chosen_ids, attention_mask=c_mask)
|
||||
reject_reward = self.model(reject_ids, attention_mask=r_mask)
|
||||
loss = self.loss_fn(chosen_reward, reject_reward)
|
||||
self.strategy.backward(loss, self.model, self.optimizer)
|
||||
self.strategy.optimizer_step(self.optimizer)
|
||||
self.optimizer.zero_grad()
|
||||
cnt += 1
|
||||
if cnt == 100:
|
||||
self.scheduler.step()
|
||||
dist, acc = self.eval_acc(self.valid_dataloader)
|
||||
cnt = 0
|
||||
if is_rank_0():
|
||||
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]],
|
||||
columns=['step', 'loss', 'dist', 'acc'])
|
||||
log.to_csv('log_%s.csv' % time, mode='a', header=False, index=False)
|
||||
step_bar.update()
|
||||
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
||||
|
||||
# eval
|
||||
dist, acc = self.eval_acc(self.eval_dataloader)
|
||||
if is_rank_0():
|
||||
log = pd.DataFrame([[step_bar.n, loss.item(), dist, acc]], columns=['step', 'loss', 'dist', 'acc'])
|
||||
log.to_csv('log.csv', mode='a', header=False, index=False)
|
||||
epoch_bar.update()
|
||||
step_bar.set_postfix({'dist': dist, 'acc': acc})
|
||||
step_bar.close()
|
||||
|
||||
def save_model(self,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
self.strategy.save_model(model=self.model, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
|
|
@ -0,0 +1,158 @@
|
|||
import math
|
||||
import time
|
||||
from abc import ABC
|
||||
from typing import Optional
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import wandb
|
||||
from coati.models.loss import GPTLMLoss
|
||||
from torch import nn
|
||||
from torch.optim import Adam, Optimizer
|
||||
from torch.optim.lr_scheduler import LambdaLR
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from tqdm import tqdm
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
from transformers.trainer import get_scheduler
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
|
||||
from .strategies import Strategy
|
||||
from .utils import is_rank_0
|
||||
|
||||
|
||||
class SFTTrainer(ABC):
|
||||
"""
|
||||
Trainer to use while training reward model.
|
||||
|
||||
Args:
|
||||
model (torch.nn.Module): the model to train
|
||||
strategy (Strategy): the strategy to use for training
|
||||
optim(Optimizer): the optimizer to use for training
|
||||
train_dataloader: the dataloader to use for training
|
||||
eval_dataloader: the dataloader to use for evaluation
|
||||
batch_size (int, defaults to 1): the batch size while training
|
||||
max_epochs (int, defaults to 2): the number of epochs to train
|
||||
optim_kwargs (dict, defaults to {'lr':1e-4}): the kwargs to use while initializing optimizer
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model,
|
||||
strategy: Strategy,
|
||||
optim: Optimizer,
|
||||
train_dataloader: DataLoader,
|
||||
eval_dataloader: DataLoader = None,
|
||||
batch_size: int = 1,
|
||||
max_epochs: int = 2,
|
||||
accimulation_steps: int = 8,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.strategy = strategy
|
||||
self.epochs = max_epochs
|
||||
self.train_dataloader = train_dataloader
|
||||
self.eval_dataloader = eval_dataloader
|
||||
|
||||
self.model = strategy.setup_model(model)
|
||||
if "DDP" in str(self.strategy):
|
||||
self.model = self.model.module
|
||||
self.optimizer = strategy.setup_optimizer(optim, self.model)
|
||||
|
||||
self.accimulation_steps = accimulation_steps
|
||||
num_update_steps_per_epoch = len(train_dataloader) // self.accimulation_steps
|
||||
max_steps = math.ceil(self.epochs * num_update_steps_per_epoch)
|
||||
|
||||
self.scheduler = get_scheduler("cosine",
|
||||
self.optimizer,
|
||||
num_warmup_steps=math.ceil(max_steps * 0.03),
|
||||
num_training_steps=max_steps)
|
||||
|
||||
def fit(self, logger, log_interval=10):
|
||||
wandb.init(project="Coati", name=time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
||||
wandb.watch(self.model)
|
||||
total_loss = 0
|
||||
# epoch_bar = tqdm(range(self.epochs), desc='Epochs', disable=not is_rank_0())
|
||||
step_bar = tqdm(range(len(self.train_dataloader) // self.accimulation_steps * self.epochs),
|
||||
desc=f'steps',
|
||||
disable=not is_rank_0())
|
||||
for epoch in range(self.epochs):
|
||||
|
||||
# process_bar = tqdm(range(len(self.train_dataloader)), desc=f'Train process for{epoch}', disable=not is_rank_0())
|
||||
# train
|
||||
self.model.train()
|
||||
for batch_id, batch in enumerate(self.train_dataloader):
|
||||
|
||||
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
|
||||
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
|
||||
labels = batch["labels"].to(torch.cuda.current_device())
|
||||
# prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||
# p_mask = p_mask.squeeze(1).cuda()
|
||||
# prompt_logits = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
|
||||
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
|
||||
loss = outputs.loss
|
||||
prompt_logits = outputs.logits
|
||||
|
||||
if loss >= 2.5:
|
||||
logger.warning(f"batch_id:{batch_id}, abnormal loss: {loss}")
|
||||
|
||||
loss = loss / self.accimulation_steps
|
||||
|
||||
self.strategy.backward(loss, self.model, self.optimizer)
|
||||
|
||||
total_loss += loss.item()
|
||||
|
||||
# gradient accumulation
|
||||
if (batch_id + 1) % self.accimulation_steps == 0:
|
||||
self.strategy.optimizer_step(self.optimizer)
|
||||
self.optimizer.zero_grad()
|
||||
self.scheduler.step()
|
||||
wandb.log({
|
||||
"loss": total_loss / self.accimulation_steps,
|
||||
"lr": self.scheduler.get_last_lr()[0],
|
||||
"epoch": epoch,
|
||||
"batch_id": batch_id
|
||||
})
|
||||
total_loss = 0
|
||||
step_bar.update()
|
||||
|
||||
# if batch_id % log_interval == 0:
|
||||
# logger.info(f'Train Epoch {epoch}/{self.epochs} Batch {batch_id} Rank {dist.get_rank()} loss {loss.item()}')
|
||||
# wandb.log({"loss": loss.item()})
|
||||
|
||||
# process_bar.update()
|
||||
|
||||
# eval
|
||||
if self.eval_dataloader is not None:
|
||||
self.model.eval()
|
||||
with torch.no_grad():
|
||||
loss_sum = 0
|
||||
num_seen = 0
|
||||
for batch in self.eval_dataloader:
|
||||
prompt_ids = batch["input_ids"].to(torch.cuda.current_device())
|
||||
p_mask = batch["attention_mask"].to(torch.cuda.current_device())
|
||||
labels = batch["labels"].to(torch.cuda.current_device())
|
||||
# prompt_ids = prompt_ids.squeeze(1).cuda()
|
||||
# p_mask = p_mask.squeeze(1).cuda()
|
||||
|
||||
outputs = self.model(prompt_ids, attention_mask=p_mask, labels=labels)
|
||||
loss = outputs.loss
|
||||
# prompt_logits = outputs.logits
|
||||
|
||||
loss_sum += loss.item()
|
||||
num_seen += prompt_ids.size(0)
|
||||
|
||||
loss_mean = loss_sum / num_seen
|
||||
if dist.get_rank() == 0:
|
||||
logger.info(f'Eval Epoch {epoch}/{self.epochs} loss {loss_mean}')
|
||||
|
||||
# epoch_bar.update()
|
||||
|
||||
def save_model(self,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
self.strategy.save_model(model=self.model, path=path, only_rank0=only_rank0, tokenizer=tokenizer)
|
|
@ -0,0 +1,6 @@
|
|||
from .base import Strategy
|
||||
from .colossalai import ColossalAIStrategy
|
||||
from .ddp import DDPStrategy
|
||||
from .naive import NaiveStrategy
|
||||
|
||||
__all__ = ['Strategy', 'NaiveStrategy', 'DDPStrategy', 'ColossalAIStrategy']
|
|
@ -0,0 +1,136 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from contextlib import nullcontext
|
||||
from typing import Any, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from coati.models.base import LM, Actor, Critic, RewardModel
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
from .sampler import DistributedSampler
|
||||
|
||||
ModelOptimPair = Tuple[nn.Module, Optimizer]
|
||||
ModelOrModelOptimPair = Union[nn.Module, ModelOptimPair]
|
||||
|
||||
|
||||
class Strategy(ABC):
|
||||
"""
|
||||
Base class for training strategies.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.setup_distributed()
|
||||
|
||||
@abstractmethod
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: Optimizer, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def optimizer_step(self, optimizer: Optimizer, **kwargs) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_distributed(self) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_optimizer(self, optimizer: Optimizer, model: nn.Module) -> Optimizer:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
pass
|
||||
|
||||
def model_init_context(self):
|
||||
return nullcontext()
|
||||
|
||||
def prepare(
|
||||
self, *models_or_model_optim_pairs: ModelOrModelOptimPair
|
||||
) -> Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]:
|
||||
"""Prepare models or model-optimizer-pairs based on each strategy.
|
||||
|
||||
Example::
|
||||
>>> # when fine-tuning actor and critic
|
||||
>>> (actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare((actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
>>> # or when training reward model
|
||||
>>> (reward_model, reward_model_optim) = strategy.prepare((reward_model, reward_model_optim))
|
||||
>>> # or just inference
|
||||
>>> actor, critic = strategy.prepare(actor, critic)
|
||||
|
||||
Returns:
|
||||
Union[List[ModelOrModelOptimPair], ModelOrModelOptimPair]: Models or model-optimizer-pairs in the original order.
|
||||
"""
|
||||
|
||||
def prepare_model(model: nn.Module):
|
||||
if isinstance(model, Actor):
|
||||
return Actor(self.setup_model(self._unwrap_model(model)))
|
||||
return self.setup_model(self._unwrap_model(model))
|
||||
|
||||
rets = []
|
||||
for arg in models_or_model_optim_pairs:
|
||||
if isinstance(arg, tuple):
|
||||
assert len(arg) == 2, f'Expect (model, optimizer) pair, got a tuple with size "{len(arg)}"'
|
||||
model, optimizer = arg
|
||||
model = prepare_model(model)
|
||||
optimizer = self.setup_optimizer(optimizer, self._unwrap_model(model))
|
||||
rets.append((model, optimizer))
|
||||
elif isinstance(arg, nn.Module):
|
||||
rets.append(prepare_model(arg))
|
||||
else:
|
||||
raise RuntimeError(f'Expect model or (model, optimizer) pair, got {type(arg)}')
|
||||
|
||||
if len(rets) == 1:
|
||||
return rets[0]
|
||||
return rets
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_model(model: nn.Module) -> nn.Module:
|
||||
"""Useful for saving state dict. As actor is wrapped by Actor class again in `prepare()`, we should unwrap it before saving.
|
||||
|
||||
Args:
|
||||
model (nn.Module): an actor or a critic
|
||||
"""
|
||||
if isinstance(model, Actor) or isinstance(model, LM):
|
||||
return model.model
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||
"""Get `actor.model` from a wrapped (by `prepare()`) actor. Useful for getting original huggingface model.
|
||||
|
||||
Args:
|
||||
actor (Actor): a wrapped actor
|
||||
"""
|
||||
return Strategy._unwrap_model(actor)
|
||||
|
||||
@abstractmethod
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = False,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
||||
pass
|
||||
|
||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||
return DistributedSampler(dataset, 1, 0)
|
|
@ -0,0 +1,213 @@
|
|||
import warnings
|
||||
from typing import Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.models.base import LM, Actor, RewardModel
|
||||
from coati.models.lora import LoraLinear
|
||||
from torch.optim import Optimizer
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.tokenization_utils_base import PreTrainedTokenizerBase
|
||||
|
||||
import colossalai
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import CPUAdam, HybridAdam
|
||||
from colossalai.nn.parallel import ZeroDDP, zero_model_wrapper, zero_optim_wrapper
|
||||
from colossalai.nn.parallel.utils import get_static_torch_model
|
||||
from colossalai.tensor import ProcessGroup, ShardSpec
|
||||
from colossalai.utils import get_current_device
|
||||
from colossalai.utils.model.colo_init_context import ColoInitContext
|
||||
|
||||
logger = get_dist_logger(__name__)
|
||||
|
||||
from .base import Strategy
|
||||
from .ddp import DDPStrategy
|
||||
|
||||
|
||||
class ColossalAIStrategy(DDPStrategy):
|
||||
"""
|
||||
The strategy for training with ColossalAI.
|
||||
|
||||
Args:
|
||||
stage(int): The stage to use in ZeRO. Choose in (1, 2, 3)
|
||||
precision(str): The precision to use. Choose in ('fp32', 'fp16'). Stage 3 only supports fp16.
|
||||
seed(int): The seed for the random number generator.
|
||||
shard_init(bool): Whether to shard the model parameters during initialization. Only for ZeRO-3.
|
||||
This is not compativle with `from_pretrained()`. We temporarily disable this and will support it in the future.
|
||||
placement_policy(str): The placement policy for gemini. Choose in ('cpu', 'cuda')
|
||||
If it is “cpu”, parameters, gradients and optimizer states will be offloaded to CPU,
|
||||
If it is “cuda”, they will not be offloaded, which means max CUDA memory will be used. It is the fastest.
|
||||
pin_memory(bool): Whether to pin the memory for the data loader. Only for ZeRO-3.
|
||||
force_outputs_fp32(bool): Whether to force the outputs to be fp32. Only for ZeRO-3.
|
||||
search_range_mb(int): The search range in MB for the chunk size. Only for ZeRO-3.
|
||||
hidden_dim(optional, int): The hidden dimension for the gemini. Only for ZeRO-3.
|
||||
min_chunk_size_mb(float): The minimum chunk size in MB. Only for ZeRO-3.
|
||||
gpu_margin_mem_ratio(float): The margin memory ratio for the GPU. Only for ZeRO-3.
|
||||
reduce_bugket_size(int): The reduce bucket size in bytes. Only for ZeRO-1 and ZeRO-2.
|
||||
overlap_communication(bool): Whether to overlap communication and computation. Only for ZeRO-1 and ZeRO-2.
|
||||
initial_scale(float): The initial scale for the optimizer.
|
||||
growth_factor(float): The growth factor for the optimizer.
|
||||
backoff_factor(float): The backoff factor for the optimizer.
|
||||
growth_interval(int): The growth interval for the optimizer.
|
||||
hysteresis(int): The hysteresis for the optimizer.
|
||||
min_scale(float): The minimum scale for the optimizer.
|
||||
max_scale(float): The maximum scale for the optimizer.
|
||||
max_norm(float): The maximum norm for the optimizer.
|
||||
norm_type(float): The norm type for the optimizer.
|
||||
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stage: int = 3,
|
||||
precision: str = 'fp16',
|
||||
seed: int = 42,
|
||||
shard_init: bool = False, # only for stage 3
|
||||
placement_policy: str = 'cuda',
|
||||
pin_memory: bool = True, # only for stage 3
|
||||
force_outputs_fp32: bool = False, # only for stage 3
|
||||
search_range_mb: int = 32, # only for stage 3
|
||||
hidden_dim: Optional[int] = None, # only for stage 3
|
||||
min_chunk_size_mb: float = 32, # only for stage 3
|
||||
gpu_margin_mem_ratio: float = 0.0, # only for stage 3
|
||||
reduce_bucket_size: int = 12 * 1024**2, # only for stage 1&2
|
||||
overlap_communication: bool = True, # only for stage 1&2
|
||||
initial_scale: float = 2**16,
|
||||
growth_factor: float = 2,
|
||||
backoff_factor: float = 0.5,
|
||||
growth_interval: int = 1000,
|
||||
hysteresis: int = 2,
|
||||
min_scale: float = 1,
|
||||
max_scale: float = 2**32,
|
||||
max_norm: float = 0.0,
|
||||
norm_type: float = 2.0) -> None:
|
||||
super().__init__(seed)
|
||||
assert placement_policy in ('cpu', 'cuda'), f'Unsupported placement policy "{placement_policy}"'
|
||||
assert precision in ('fp32', 'fp16'), f'Unsupported precision "{precision}"'
|
||||
self.stage = stage
|
||||
# TODO(ver217): support shard_init when using from_pretrained()
|
||||
if shard_init:
|
||||
warnings.warn(
|
||||
f'Shard init is not supported model.from_pretrained() yet. Please load weights after strategy.prepare()'
|
||||
)
|
||||
if stage == 3 and precision == 'fp32':
|
||||
warnings.warn(f'Stage 3 only supports fp16. Precision is set to fp16.')
|
||||
precision = 'fp16'
|
||||
self.precision = precision
|
||||
self.shard_init = shard_init
|
||||
self.gemini_config = dict(device=get_current_device(),
|
||||
placement_policy=placement_policy,
|
||||
pin_memory=pin_memory,
|
||||
force_outputs_fp32=force_outputs_fp32,
|
||||
strict_ddp_mode=shard_init,
|
||||
search_range_mb=search_range_mb,
|
||||
hidden_dim=hidden_dim,
|
||||
min_chunk_size_mb=min_chunk_size_mb)
|
||||
if stage == 3:
|
||||
self.zero_optim_config = dict(gpu_margin_mem_ratio=gpu_margin_mem_ratio)
|
||||
else:
|
||||
self.zero_optim_config = dict(reduce_bucket_size=reduce_bucket_size,
|
||||
overlap_communication=overlap_communication,
|
||||
cpu_offload=(placement_policy == 'cpu'))
|
||||
self.optim_kwargs = dict(initial_scale=initial_scale,
|
||||
growth_factor=growth_factor,
|
||||
backoff_factor=backoff_factor,
|
||||
growth_interval=growth_interval,
|
||||
hysteresis=hysteresis,
|
||||
min_scale=min_scale,
|
||||
max_scale=max_scale,
|
||||
max_norm=max_norm,
|
||||
norm_type=norm_type)
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
colossalai.launch_from_torch({}, seed=self.seed)
|
||||
|
||||
def model_init_context(self):
|
||||
if self.stage == 3:
|
||||
world_size = dist.get_world_size()
|
||||
shard_pg = ProcessGroup(tp_degree=world_size) if self.shard_init else None
|
||||
default_dist_spec = ShardSpec([-1], [world_size]) if self.shard_init else None
|
||||
return ColoInitContext(device=get_current_device(),
|
||||
dtype=torch.half,
|
||||
default_pg=shard_pg,
|
||||
default_dist_spec=default_dist_spec)
|
||||
return super().model_init_context()
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
|
||||
model = zero_model_wrapper(model, zero_stage=self.stage, gemini_config=self.gemini_config)
|
||||
|
||||
if self.stage != 3 and self.precision == 'fp16':
|
||||
model = model.half()
|
||||
return model
|
||||
|
||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
||||
assert isinstance(optimizer, (CPUAdam, HybridAdam)), f'Unsupported optimizer {type(optimizer)}'
|
||||
return zero_optim_wrapper(model, optimizer, optim_config=self.zero_optim_config, **self.optim_kwargs)
|
||||
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
optimizer.backward(loss)
|
||||
|
||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
optimizer.step()
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||
model: Union[nn.Module, ZeroDDP] = Strategy._unwrap_actor(actor)
|
||||
if isinstance(model, ZeroDDP):
|
||||
return model.module
|
||||
return model
|
||||
|
||||
def _unwrap_model(self, model: Union[nn.Module, ZeroDDP]) -> nn.Module:
|
||||
if isinstance(model, ZeroDDP) and self.stage == 3:
|
||||
logger.info(f"model type: {type(model)}, get static torch model")
|
||||
model = get_static_torch_model(model)
|
||||
logger.info(f"unwrapped_model type: {type(model)}")
|
||||
|
||||
return super()._unwrap_model(model)
|
||||
|
||||
def save_model(self,
|
||||
model: nn.Module,
|
||||
path: str,
|
||||
only_rank0: bool = True,
|
||||
tokenizer: Optional[PreTrainedTokenizerBase] = None) -> None:
|
||||
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return None
|
||||
unwrapped_model = self._unwrap_model(model)
|
||||
# TODO : better way to get torch model from gemini model
|
||||
# to get torch model from gemini model
|
||||
|
||||
for module in unwrapped_model.modules():
|
||||
if isinstance(module, LoraLinear):
|
||||
module.merge_weights = True
|
||||
module.eval()
|
||||
if isinstance(unwrapped_model, RewardModel):
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
else:
|
||||
try:
|
||||
if isinstance(unwrapped_model, LM):
|
||||
unwrapped_model = unwrapped_model.model
|
||||
logger.info(f'Saving model to {path}', ranks=[0])
|
||||
unwrapped_model.save_pretrained(path)
|
||||
logger.info(f'Model saved to {path} Successfully', ranks=[0])
|
||||
if tokenizer is not None:
|
||||
logger.info(f'Saving tokenizer to {path}', ranks=[0])
|
||||
tokenizer.save_pretrained(path)
|
||||
logger.info(f'Tokenizer saved to {path} Successfully', ranks=[0])
|
||||
except AttributeError:
|
||||
state_dict = unwrapped_model.state_dict()
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
if only_rank0:
|
||||
raise RuntimeError(
|
||||
f'Optimizer states are sharded when using ColossalAIStrategy. Only rank0 is not supported.')
|
||||
torch.save(optimizer.state_dict(), path)
|
|
@ -0,0 +1,93 @@
|
|||
import os
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from coati.models.base import Actor
|
||||
from coati.models.lora import LoraLinear
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.nn.parallel import DistributedDataParallel as DDP
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .base import Strategy
|
||||
from .naive import NaiveStrategy
|
||||
from .sampler import DistributedSampler
|
||||
|
||||
|
||||
class DDPStrategy(NaiveStrategy):
|
||||
"""
|
||||
Strategy for distributed training using torch.distributed.
|
||||
"""
|
||||
|
||||
def __init__(self, seed: int = 42) -> None:
|
||||
self.seed = seed
|
||||
super().__init__()
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
try:
|
||||
rank = int(os.environ['RANK'])
|
||||
local_rank = int(os.environ['LOCAL_RANK'])
|
||||
world_size = int(os.environ['WORLD_SIZE'])
|
||||
host = os.environ['MASTER_ADDR']
|
||||
port = int(os.environ['MASTER_PORT'])
|
||||
except KeyError as e:
|
||||
raise RuntimeError(
|
||||
f"Could not find {e} in the torch environment, visit https://www.colossalai.org/ for more information on launching with torch"
|
||||
)
|
||||
dist.init_process_group('nccl', init_method=f'tcp://[{host}]:{port}', world_size=world_size, rank=rank)
|
||||
self.set_seed(self.seed)
|
||||
torch.cuda.set_device(local_rank)
|
||||
|
||||
def set_seed(self, seed: int) -> None:
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
device = torch.cuda.current_device()
|
||||
return DDP(model, device_ids=[device])
|
||||
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
# DDP only mode, replay buffers on each rank are different.
|
||||
# sampler = DistributedSampler(replay_buffer,
|
||||
# num_replicas=dist.get_world_size(),
|
||||
# rank=dist.get_rank(),
|
||||
# shuffle=True,
|
||||
# seed=self.seed,
|
||||
# drop_last=True)
|
||||
return DataLoader(
|
||||
replay_buffer,
|
||||
batch_size=replay_buffer.sample_batch_size,
|
||||
# sampler=sampler,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
@staticmethod
|
||||
def _unwrap_actor(actor: Actor) -> nn.Module:
|
||||
model: DDP = Strategy._unwrap_actor(actor)
|
||||
return model.module
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||
for module in model.modules():
|
||||
if isinstance(module, LoraLinear):
|
||||
module.merge_weights = True
|
||||
module.eval()
|
||||
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
model = model.model.module
|
||||
state_dict = model.state_dict()
|
||||
torch.save(state_dict, path)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
if only_rank0 and dist.get_rank() != 0:
|
||||
return
|
||||
super().save_optimizer(optimizer, path, only_rank0)
|
||||
|
||||
def setup_sampler(self, dataset) -> DistributedSampler:
|
||||
return DistributedSampler(dataset, dist.get_world_size(), dist.get_rank())
|
|
@ -0,0 +1,55 @@
|
|||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.optim as optim
|
||||
from coati.replay_buffer import ReplayBuffer
|
||||
from torch.optim import Optimizer
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .base import Strategy
|
||||
|
||||
|
||||
class NaiveStrategy(Strategy):
|
||||
"""
|
||||
Strategy for single GPU. No parallelism is used.
|
||||
"""
|
||||
|
||||
def backward(self, loss: torch.Tensor, model: nn.Module, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
loss.backward()
|
||||
|
||||
def optimizer_step(self, optimizer: optim.Optimizer, **kwargs) -> None:
|
||||
optimizer.step()
|
||||
|
||||
def setup_distributed(self) -> None:
|
||||
pass
|
||||
|
||||
def setup_model(self, model: nn.Module) -> nn.Module:
|
||||
return model
|
||||
|
||||
def setup_optimizer(self, optimizer: optim.Optimizer, model: nn.Module) -> optim.Optimizer:
|
||||
return optimizer
|
||||
|
||||
def setup_dataloader(self, replay_buffer: ReplayBuffer, pin_memory: bool = False) -> DataLoader:
|
||||
return DataLoader(replay_buffer,
|
||||
batch_size=replay_buffer.sample_batch_size,
|
||||
shuffle=True,
|
||||
drop_last=True,
|
||||
pin_memory=pin_memory,
|
||||
collate_fn=replay_buffer.collate_fn)
|
||||
|
||||
def save_model(self, model: nn.Module, path: str, only_rank0: bool = False) -> None:
|
||||
unwrapped_model = self._unwrap_model(model)
|
||||
torch.save(unwrapped_model.state_dict(), path)
|
||||
|
||||
def load_model(self, model: nn.Module, path: str, map_location: Any = None, strict: bool = True) -> None:
|
||||
unwrapped_model = self._unwrap_model(model)
|
||||
state_dict = torch.load(path, map_location=map_location)
|
||||
unwrapped_model.load_state_dict(state_dict, strict=strict)
|
||||
|
||||
def save_optimizer(self, optimizer: Optimizer, path: str, only_rank0: bool = False) -> None:
|
||||
torch.save(optimizer.state_dict(), path)
|
||||
|
||||
def load_optimizer(self, optimizer: Optimizer, path: str, map_location: Any = None) -> None:
|
||||
state_dict = torch.load(path, map_location=map_location)
|
||||
optimizer.load_state_dict(state_dict)
|
|
@ -0,0 +1,32 @@
|
|||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class DistributedSampler:
|
||||
|
||||
def __init__(self, dataset, num_replicas: int, rank: int) -> None:
|
||||
self.dataset = dataset
|
||||
self.num_replicas = num_replicas
|
||||
self.rank = rank
|
||||
|
||||
if len(self.dataset) % self.num_replicas != 0:
|
||||
self.num_samples = math.ceil(
|
||||
(len(self.dataset) - self.num_replicas) / self.num_replicas # type: ignore[arg-type]
|
||||
)
|
||||
else:
|
||||
self.num_samples = math.ceil(len(self.dataset) / self.num_replicas)
|
||||
|
||||
self.total_size = self.num_samples * self.num_replicas
|
||||
|
||||
indices = list(range(len(self.dataset)))
|
||||
indices = indices[:self.total_size]
|
||||
assert len(indices) == self.total_size
|
||||
# subsample
|
||||
indices = indices[self.rank:self.total_size:self.num_replicas]
|
||||
assert len(indices) == self.num_samples
|
||||
self.indices = indices
|
||||
|
||||
def sample(self, batch_size: int) -> list:
|
||||
sampled_indices = np.random.choice(self.indices, batch_size, replace=False)
|
||||
return [self.dataset[idx] for idx in sampled_indices]
|
|
@ -0,0 +1,5 @@
|
|||
import torch.distributed as dist
|
||||
|
||||
|
||||
def is_rank_0() -> bool:
|
||||
return not dist.is_initialized() or dist.get_rank() == 0
|
|
@ -0,0 +1,3 @@
|
|||
from .tokenizer_utils import prepare_llama_tokenizer_and_embedding, smart_tokenizer_and_embedding_resize
|
||||
|
||||
__all__ = ['smart_tokenizer_and_embedding_resize', 'prepare_llama_tokenizer_and_embedding']
|
|
@ -0,0 +1,78 @@
|
|||
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Dict
|
||||
|
||||
import transformers
|
||||
|
||||
from ..models.llama.llama_lm import LlamaLM
|
||||
|
||||
DEFAULT_PAD_TOKEN = "[PAD]"
|
||||
DEFAULT_EOS_TOKEN = "</s>"
|
||||
DEFAULT_BOS_TOKEN = "</s>"
|
||||
DEFAULT_UNK_TOKEN = "</s>"
|
||||
|
||||
|
||||
def prepare_llama_tokenizer_and_embedding(
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
model: transformers.PreTrainedModel,
|
||||
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
):
|
||||
"""prepare llama tokenizer and embedding.
|
||||
|
||||
"""
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
smart_tokenizer_and_embedding_resize(
|
||||
special_tokens_dict=dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
tokenizer=tokenizer,
|
||||
model=model,
|
||||
)
|
||||
|
||||
tokenizer.add_special_tokens({
|
||||
"eos_token": DEFAULT_EOS_TOKEN,
|
||||
"bos_token": DEFAULT_BOS_TOKEN,
|
||||
"unk_token": DEFAULT_UNK_TOKEN,
|
||||
})
|
||||
|
||||
return tokenizer
|
||||
|
||||
|
||||
def smart_tokenizer_and_embedding_resize(
|
||||
tokenizer: transformers.PreTrainedTokenizer,
|
||||
model: transformers.PreTrainedModel,
|
||||
special_tokens_dict: Dict = dict(pad_token=DEFAULT_PAD_TOKEN),
|
||||
):
|
||||
"""Resize tokenizer and embedding.
|
||||
|
||||
Note: This is the unoptimized version that may make your embedding size not be divisible by 64.
|
||||
"""
|
||||
|
||||
if tokenizer.pad_token is None:
|
||||
num_new_tokens = tokenizer.add_special_tokens(special_tokens_dict)
|
||||
|
||||
if isinstance(model, LlamaLM):
|
||||
model = model.get_base_model()
|
||||
|
||||
model.resize_token_embeddings(len(tokenizer))
|
||||
|
||||
if num_new_tokens > 0:
|
||||
input_embeddings = model.get_input_embeddings().weight.data
|
||||
output_embeddings = model.get_output_embeddings().weight.data
|
||||
|
||||
input_embeddings_avg = input_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
output_embeddings_avg = output_embeddings[:-num_new_tokens].mean(dim=0, keepdim=True)
|
||||
|
||||
input_embeddings[-num_new_tokens:] = input_embeddings_avg
|
||||
output_embeddings[-num_new_tokens:] = output_embeddings_avg
|
|
@ -0,0 +1,141 @@
|
|||
# Examples
|
||||
|
||||
## Install requirements
|
||||
|
||||
```shell
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
## Train the reward model (Stage 2)
|
||||
Use these code to train your reward model.
|
||||
```shell
|
||||
# Take naive reward model training with opt-350m as example
|
||||
python train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy naive
|
||||
# use colossalai_zero2
|
||||
torchrun --standalone --nproc_per_node=2 train_reward_model.py --pretrain "facebook/opt-350m" --model 'opt' --strategy colossalai_zero2
|
||||
```
|
||||
|
||||
### Features and tricks in RM training
|
||||
- We support [Anthropic/hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)and[rm-static](https://huggingface.co/datasets/Dahoas/rm-static) datasets.
|
||||
- We support 2 kinds of loss_function named 'log_sig'(used by OpenAI) and 'log_exp'(used by Anthropic).
|
||||
- We change the loss to valid_acc and pair_dist to monitor progress during training.
|
||||
- We add special token to the end of the sequence to get better result.
|
||||
- We use cosine-reducing lr-scheduler for RM training.
|
||||
- We set value_head as 1 liner layer and initialize the weight of value_head using N(0,1/(d_model + 1)) distribution.
|
||||
- We train a Bloom-560m reward model for 1 epoch and find the test acc of the model achieve the performance mentions in [Anthropics paper](https://arxiv.org/abs/2204.05862).
|
||||
|
||||
### Experiment result
|
||||
Model performance in [Anthropics paper](https://arxiv.org/abs/2204.05862):
|
||||
|
||||
<div align=center> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225263321-8d64c3a8-6877-4cc8-9b61-0e1c52d3d94f.png">
|
||||
|
||||
<div align=left>Our training & test result of bloom-560m for 1 epoch:
|
||||
|
||||
<div align=center> <img width="512" alt="image" src="https://user-images.githubusercontent.com/70618399/225262950-a7f0a686-25de-44ec-98f2-11b83ea86674.png">
|
||||
|
||||
<div align=left>
|
||||
|
||||
## Train with dummy prompt data (Stage 3)
|
||||
|
||||
This script supports 4 kinds of strategies:
|
||||
|
||||
- naive
|
||||
- ddp
|
||||
- colossalai_zero2
|
||||
- colossalai_gemini
|
||||
|
||||
It uses random generated prompt data.
|
||||
|
||||
Naive strategy only support single GPU training:
|
||||
|
||||
```shell
|
||||
python train_dummy.py --strategy naive
|
||||
# display cli help
|
||||
python train_dummy.py -h
|
||||
```
|
||||
|
||||
DDP strategy and ColossalAI strategy support multi GPUs training:
|
||||
|
||||
```shell
|
||||
# run DDP on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy ddp
|
||||
# run ColossalAI on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2
|
||||
```
|
||||
|
||||
## Train with real prompt data (Stage 3)
|
||||
|
||||
We use [awesome-chatgpt-prompts](https://huggingface.co/datasets/fka/awesome-chatgpt-prompts) as example dataset. It is a small dataset with hundreds of prompts.
|
||||
|
||||
You should download `prompts.csv` first.
|
||||
|
||||
This script also supports 4 strategies.
|
||||
|
||||
```shell
|
||||
# display cli help
|
||||
python train_dummy.py -h
|
||||
# run naive on 1 GPU
|
||||
python train_prompts.py prompts.csv --strategy naive
|
||||
# run DDP on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy ddp
|
||||
# run ColossalAI on 2 GPUs
|
||||
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
|
||||
```
|
||||
|
||||
## Inference example(After Stage3)
|
||||
We support naive inference demo after training.
|
||||
```shell
|
||||
# inference, using pretrain path to configure model
|
||||
python inference.py --model_path <your actor model path> --model <your model type> --pretrain <your pretrain model name/path>
|
||||
# example
|
||||
python inference.py --model_path ./actor_checkpoint_prompts.pt --pretrain bigscience/bloom-560m --model bloom
|
||||
```
|
||||
|
||||
## Attention
|
||||
The examples is just a demo for testing our progress of RM and PPO training.
|
||||
|
||||
|
||||
#### data
|
||||
- [x] [rm-static](https://huggingface.co/datasets/Dahoas/rm-static)
|
||||
- [x] [hh-rlhf](https://huggingface.co/datasets/Anthropic/hh-rlhf)
|
||||
- [ ] [openai/summarize_from_feedback](https://huggingface.co/datasets/openai/summarize_from_feedback)
|
||||
- [ ] [openai/webgpt_comparisons](https://huggingface.co/datasets/openai/webgpt_comparisons)
|
||||
- [ ] [Dahoas/instruct-synthetic-prompt-responses](https://huggingface.co/datasets/Dahoas/instruct-synthetic-prompt-responses)
|
||||
|
||||
## Support Model
|
||||
|
||||
### GPT
|
||||
- [x] GPT2-S (s)
|
||||
- [x] GPT2-M (m)
|
||||
- [x] GPT2-L (l)
|
||||
- [ ] GPT2-XL (xl)
|
||||
- [x] GPT2-4B (4b)
|
||||
- [ ] GPT2-6B (6b)
|
||||
- [ ] GPT2-8B (8b)
|
||||
- [ ] GPT2-10B (10b)
|
||||
- [ ] GPT2-12B (12b)
|
||||
- [ ] GPT2-15B (15b)
|
||||
- [ ] GPT2-18B (18b)
|
||||
- [ ] GPT2-20B (20b)
|
||||
- [ ] GPT2-24B (24b)
|
||||
- [ ] GPT2-28B (28b)
|
||||
- [ ] GPT2-32B (32b)
|
||||
- [ ] GPT2-36B (36b)
|
||||
- [ ] GPT2-40B (40b)
|
||||
- [ ] GPT3 (175b)
|
||||
|
||||
### BLOOM
|
||||
- [x] [BLOOM-560m](https://huggingface.co/bigscience/bloom-560m)
|
||||
- [x] [BLOOM-1b1](https://huggingface.co/bigscience/bloom-1b1)
|
||||
- [x] [BLOOM-3b](https://huggingface.co/bigscience/bloom-3b)
|
||||
- [x] [BLOOM-7b](https://huggingface.co/bigscience/bloom-7b1)
|
||||
- [ ] BLOOM-175b
|
||||
|
||||
### OPT
|
||||
- [x] [OPT-125M](https://huggingface.co/facebook/opt-125m)
|
||||
- [x] [OPT-350M](https://huggingface.co/facebook/opt-350m)
|
||||
- [ ] [OPT-1.3B](https://huggingface.co/facebook/opt-1.3b)
|
||||
- [ ] [OPT-2.7B](https://huggingface.co/facebook/opt-2.7b)
|
||||
- [ ] [OPT-6.7B](https://huggingface.co/facebook/opt-6.7b)
|
||||
- [ ] [OPT-13B](https://huggingface.co/facebook/opt-13b)
|
||||
- [ ] [OPT-30B](https://huggingface.co/facebook/opt-30b)
|
|
@ -0,0 +1,59 @@
|
|||
import argparse
|
||||
|
||||
import torch
|
||||
from coati.models.bloom import BLOOMActor
|
||||
from coati.models.gpt import GPTActor
|
||||
from coati.models.opt import OPTActor
|
||||
from transformers import AutoTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
|
||||
def eval(args):
|
||||
# configure model
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor(pretrained=args.pretrain).to(torch.cuda.current_device())
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
state_dict = torch.load(args.model_path)
|
||||
actor.model.load_state_dict(state_dict)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = AutoTokenizer.from_pretrained('bigscience/bloom-560m')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained('facebook/opt-350m')
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
actor.eval()
|
||||
input = args.input
|
||||
input_ids = tokenizer.encode(input, return_tensors='pt').to(torch.cuda.current_device())
|
||||
outputs = actor.generate(input_ids,
|
||||
max_length=args.max_length,
|
||||
do_sample=True,
|
||||
top_k=50,
|
||||
top_p=0.95,
|
||||
num_return_sequences=1)
|
||||
output = tokenizer.batch_decode(outputs[0], skip_special_tokens=True)
|
||||
print(output)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
# We suggest to use the pretrained model from HuggingFace, use pretrain to configure model
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--model_path', type=str, default=None)
|
||||
parser.add_argument('--input', type=str, default='Question: How are you ? Answer:')
|
||||
parser.add_argument('--max_length', type=int, default=100)
|
||||
args = parser.parse_args()
|
||||
eval(args)
|
|
@ -0,0 +1,2 @@
|
|||
pandas>=1.4.1
|
||||
sentencepiece
|
|
@ -0,0 +1,97 @@
|
|||
#!/usr/bin/env bash
|
||||
|
||||
set -xue
|
||||
|
||||
if [ -z "$PROMPT_PATH" ]; then
|
||||
echo "Please set \$PROMPT_PATH to the path to prompts csv."
|
||||
exit 1
|
||||
fi
|
||||
|
||||
BASE=$(realpath $(dirname $0))
|
||||
|
||||
export OMP_NUM_THREADS=8
|
||||
|
||||
# install requirements
|
||||
pip install -r ${BASE}/requirements.txt
|
||||
|
||||
# train dummy
|
||||
python ${BASE}/train_dummy.py --strategy naive --num_episodes 1 \
|
||||
--max_timesteps 2 --update_timesteps 2 \
|
||||
--max_epochs 1 --train_batch_size 2 --lora_rank 4
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
|
||||
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
||||
--pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
|
||||
--save_path ${BASE}/actor_checkpoint_dummy.pt
|
||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'facebook/opt-350m' --model opt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
|
||||
--strategy ddp --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
||||
--pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
|
||||
--save_path ${BASE}/actor_checkpoint_dummy.pt
|
||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'facebook/opt-350m' --model opt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_dummy.py \
|
||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
||||
--pretrain 'gpt2' --model gpt2 --lora_rank 4\
|
||||
--save_path ${BASE}/actor_checkpoint_dummy.pt
|
||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_dummy.pt --pretrain 'gpt2' --model gpt2
|
||||
|
||||
rm -rf ${BASE}/actor_checkpoint_dummy.pt
|
||||
|
||||
# train prompts
|
||||
python ${BASE}/train_prompts.py $PROMPT_PATH --strategy naive --num_episodes 1 \
|
||||
--max_timesteps 2 --update_timesteps 2 \
|
||||
--max_epochs 1 --train_batch_size 2 --lora_rank 4
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
|
||||
--strategy colossalai_zero2 --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
||||
--pretrain 'facebook/opt-350m' --model opt --lora_rank 4\
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'facebook/opt-350m' --model opt
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
|
||||
--strategy ddp --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
||||
--pretrain 'gpt2' --model gpt2 --lora_rank 4\
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_prompts.py $PROMPT_PATH \
|
||||
--strategy colossalai_gemini --num_episodes 1 --max_timesteps 2 \
|
||||
--update_timesteps 2 --max_epochs 1 --train_batch_size 2\
|
||||
--pretrain 'gpt2' --model gpt2 --lora_rank 4\
|
||||
--save_path ${BASE}/actor_checkpoint_prompts.pt
|
||||
python ${BASE}/inference.py --model_path ${BASE}/actor_checkpoint_prompts.pt --pretrain 'gpt2' --model gpt2
|
||||
|
||||
rm -rf ${BASE}/actor_checkpoint_prompts.pt
|
||||
|
||||
# train rm
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'facebook/opt-350m' --model 'opt' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
||||
--test True --lora_rank 4
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'gpt2' --model 'gpt2' \
|
||||
--strategy colossalai_gemini --loss_fn 'log_exp'\
|
||||
--dataset 'Dahoas/rm-static' --test True --lora_rank 4
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'bigscience/bloom-560m' --model 'bloom' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
||||
--test True --lora_rank 4
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 ${BASE}/train_reward_model.py \
|
||||
--pretrain 'microsoft/deberta-v3-large' --model 'deberta' \
|
||||
--strategy colossalai_zero2 --loss_fn 'log_sig'\
|
||||
--dataset 'Anthropic/hh-rlhf' --subset 'harmless-base'\
|
||||
--test True --lora_rank 4
|
||||
|
||||
rm -rf ${BASE}/rm_ckpt.pt
|
|
@ -0,0 +1,148 @@
|
|||
import argparse
|
||||
from copy import deepcopy
|
||||
|
||||
import torch
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.bloom import BLOOMActor, BLOOMCritic
|
||||
from coati.models.gpt import GPTActor, GPTCritic
|
||||
from coati.models.opt import OPTActor, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.callbacks import SaveCheckpoint
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def preprocess_batch(samples):
|
||||
input_ids = torch.stack(samples)
|
||||
attention_mask = torch.ones_like(input_ids, dtype=torch.long)
|
||||
return {'input_ids': input_ids, 'attention_mask': attention_mask}
|
||||
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = GPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = BLOOMCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
critic = OPTCritic(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
initial_model = deepcopy(actor).to(torch.cuda.current_device())
|
||||
reward_model = RewardModel(deepcopy(critic.model), deepcopy(critic.value_head)).to(torch.cuda.current_device())
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=5e-6)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=5e-6)
|
||||
critic_optim = Adam(critic.parameters(), lr=5e-6)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model = strategy.prepare(
|
||||
(actor, actor_optim), (critic, critic_optim), reward_model, initial_model)
|
||||
|
||||
callbacks = []
|
||||
if args.save_ckpt_path:
|
||||
ckpt_callback = SaveCheckpoint(
|
||||
args.save_ckpt_path,
|
||||
args.save_ckpt_interval,
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
)
|
||||
callbacks.append(ckpt_callback)
|
||||
|
||||
# configure trainer
|
||||
|
||||
trainer = PPOTrainer(strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
tokenizer=preprocess_batch,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
callbacks=callbacks)
|
||||
|
||||
random_prompts = torch.randint(tokenizer.vocab_size, (1000, 64), device=torch.cuda.current_device())
|
||||
trainer.fit(random_prompts,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer.save_model(args.save_path, only_rank0=True)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_dummy_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', type=str, default='gpt2', choices=['gpt2', 'bloom', 'opt'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_dummy.pt')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=50)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--save_ckpt_path',
|
||||
type=str,
|
||||
default=None,
|
||||
help="path to save checkpoint, None means not to save")
|
||||
parser.add_argument('--save_ckpt_interval', type=int, default=1, help="the interval of episode to save checkpoint")
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,18 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 train_dummy.py --strategy colossalai_zero2
|
|
@ -0,0 +1,199 @@
|
|||
import argparse
|
||||
|
||||
import pandas as pd
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, PromptDataset, SupervisedDataset
|
||||
from coati.models.bloom import BLOOMRM, BLOOMActor, BLOOMCritic
|
||||
from coati.models.gpt import GPTRM, GPTActor, GPTCritic
|
||||
from coati.models.llama import LlamaActor
|
||||
from coati.models.opt import OPTRM, OPTActor, OPTCritic
|
||||
from coati.trainer import PPOTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, GPT2Tokenizer, LlamaTokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def main(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda', initial_scale=2**5)
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
if args.rm_path is not None:
|
||||
state_dict = torch.load(args.rm_path, map_location='cpu')
|
||||
|
||||
# configure model
|
||||
if args.model == 'gpt2':
|
||||
initial_model = GPTActor(pretrained=args.pretrain)
|
||||
reward_model = GPTRM(pretrained=args.rm_pretrain)
|
||||
elif args.model == 'bloom':
|
||||
initial_model = BLOOMActor(pretrained=args.pretrain)
|
||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
||||
elif args.model == 'opt':
|
||||
initial_model = OPTActor(pretrained=args.pretrain)
|
||||
reward_model = OPTRM(pretrained=args.rm_pretrain)
|
||||
elif args.model == 'llama':
|
||||
initial_model = LlamaActor(pretrained=args.pretrain)
|
||||
reward_model = BLOOMRM(pretrained=args.rm_pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
if args.rm_path is not None:
|
||||
reward_model.load_state_dict(state_dict)
|
||||
|
||||
if args.strategy != 'colossalai_gemini':
|
||||
initial_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
reward_model.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'gpt2':
|
||||
actor = GPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
critic = GPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif args.model == 'bloom':
|
||||
actor = BLOOMActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif args.model == 'opt':
|
||||
actor = OPTActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
critic = OPTCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
elif args.model == 'llama':
|
||||
actor = LlamaActor(pretrained=args.pretrain, lora_rank=args.lora_rank)
|
||||
critic = BLOOMCritic(pretrained=args.rm_pretrain, lora_rank=args.lora_rank, use_action_mask=True)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
if args.rm_path is not None:
|
||||
critic.load_state_dict(state_dict)
|
||||
del state_dict
|
||||
|
||||
if args.strategy != 'colossalai_gemini':
|
||||
critic.to(torch.float16).to(torch.cuda.current_device())
|
||||
actor.to(torch.float16).to(torch.cuda.current_device())
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
actor_optim = HybridAdam(actor.parameters(), lr=1e-7)
|
||||
critic_optim = HybridAdam(critic.parameters(), lr=1e-7)
|
||||
else:
|
||||
actor_optim = Adam(actor.parameters(), lr=1e-7)
|
||||
critic_optim = Adam(critic.parameters(), lr=1e-7)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
elif args.model == 'llama':
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
if args.model == 'llama':
|
||||
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, actor)
|
||||
else:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
|
||||
prompt_dataset = PromptDataset(tokenizer=tokenizer, data_path=args.prompt_path, max_datasets_size=16384)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
prompt_sampler = DistributedSampler(prompt_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
prompt_dataloader = DataLoader(prompt_dataset,
|
||||
shuffle=(prompt_sampler is None),
|
||||
sampler=prompt_sampler,
|
||||
batch_size=args.train_batch_size)
|
||||
|
||||
pretrain_dataset = SupervisedDataset(tokenizer=tokenizer, data_path=args.pretrain_dataset, max_datasets_size=16384)
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
pretrain_sampler = DistributedSampler(pretrain_dataset, shuffle=True, seed=42, drop_last=True)
|
||||
pretrain_dataloader = DataLoader(pretrain_dataset,
|
||||
shuffle=(pretrain_sampler is None),
|
||||
sampler=pretrain_sampler,
|
||||
batch_size=args.ptx_batch_size,
|
||||
collate_fn=data_collator)
|
||||
|
||||
def tokenize_fn(texts):
|
||||
# MUST padding to max length to ensure inputs of all ranks have the same length
|
||||
# Different length may lead to hang when using gemini, as different generation steps
|
||||
batch = tokenizer(texts, return_tensors='pt', max_length=96, padding='max_length', truncation=True)
|
||||
return {k: v.to(torch.cuda.current_device()) for k, v in batch.items()}
|
||||
|
||||
(actor, actor_optim), (critic, critic_optim) = strategy.prepare((actor, actor_optim), (critic, critic_optim))
|
||||
|
||||
# configure trainer
|
||||
trainer = PPOTrainer(
|
||||
strategy,
|
||||
actor,
|
||||
critic,
|
||||
reward_model,
|
||||
initial_model,
|
||||
actor_optim,
|
||||
critic_optim,
|
||||
kl_coef=args.kl_coef,
|
||||
ptx_coef=args.ptx_coef,
|
||||
max_epochs=args.max_epochs,
|
||||
train_batch_size=args.train_batch_size,
|
||||
experience_batch_size=args.experience_batch_size,
|
||||
tokenizer=tokenize_fn,
|
||||
max_length=128,
|
||||
do_sample=True,
|
||||
temperature=1.0,
|
||||
top_k=50,
|
||||
pad_token_id=tokenizer.pad_token_id,
|
||||
eos_token_id=tokenizer.eos_token_id,
|
||||
)
|
||||
|
||||
trainer.fit(prompt_dataloader=prompt_dataloader,
|
||||
pretrain_dataloader=pretrain_dataloader,
|
||||
num_episodes=args.num_episodes,
|
||||
max_timesteps=args.max_timesteps,
|
||||
update_timesteps=args.update_timesteps)
|
||||
|
||||
# save model checkpoint after fitting
|
||||
trainer.save_model(args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(actor_optim,
|
||||
'actor_optim_checkpoint_prompts_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--prompt_path', type=str, default=None, help='path to the prompt dataset')
|
||||
parser.add_argument('--pretrain_dataset', type=str, default=None, help='path to the pretrained dataset')
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive',
|
||||
help='strategy to use')
|
||||
parser.add_argument('--model', default='gpt2', choices=['gpt2', 'bloom', 'opt', 'llama'])
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--rm_path', type=str, default=None)
|
||||
parser.add_argument('--rm_pretrain', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='actor_checkpoint_prompts')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--num_episodes', type=int, default=10)
|
||||
parser.add_argument('--max_timesteps', type=int, default=10)
|
||||
parser.add_argument('--update_timesteps', type=int, default=10)
|
||||
parser.add_argument('--max_epochs', type=int, default=5)
|
||||
parser.add_argument('--train_batch_size', type=int, default=8)
|
||||
parser.add_argument('--ptx_batch_size', type=int, default=1)
|
||||
parser.add_argument('--experience_batch_size', type=int, default=8)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--kl_coef', type=float, default=0.1)
|
||||
parser.add_argument('--ptx_coef', type=float, default=0.9)
|
||||
args = parser.parse_args()
|
||||
main(args)
|
|
@ -0,0 +1,18 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES() {
|
||||
local n=${1:-"9999"}
|
||||
echo "GPU Memory Usage:"
|
||||
local FIRST_N_GPU_IDS=$(nvidia-smi --query-gpu=memory.used --format=csv \
|
||||
| tail -n +2 \
|
||||
| nl -v 0 \
|
||||
| tee /dev/tty \
|
||||
| sort -g -k 2 \
|
||||
| awk '{print $1}' \
|
||||
| head -n $n)
|
||||
export CUDA_VISIBLE_DEVICES=$(echo $FIRST_N_GPU_IDS | sed 's/ /,/g')
|
||||
echo "Now CUDA_VISIBLE_DEVICES is set to:"
|
||||
echo "CUDA_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES"
|
||||
}
|
||||
|
||||
set_n_least_used_CUDA_VISIBLE_DEVICES 2
|
||||
|
||||
torchrun --standalone --nproc_per_node=2 train_prompts.py prompts.csv --strategy colossalai_zero2
|
|
@ -0,0 +1,160 @@
|
|||
import argparse
|
||||
from random import randint
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
from coati.dataset import HhRlhfDataset, RmStaticDataset
|
||||
from coati.models import LogExpLoss, LogSigLoss
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.bloom import BLOOMRM
|
||||
from coati.models.deberta import DebertaRM
|
||||
from coati.models.gpt import GPTRM
|
||||
from coati.models.llama import LlamaRM
|
||||
from coati.models.opt import OPTRM
|
||||
from coati.trainer import RewardModelTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast, DebertaV2Tokenizer, LlamaTokenizer
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
|
||||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'bloom':
|
||||
model = BLOOMRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'opt':
|
||||
model = OPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'gpt2':
|
||||
model = GPTRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'deberta':
|
||||
model = DebertaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'llama':
|
||||
model = LlamaRM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
if args.model_path is not None:
|
||||
state_dict = torch.load(args.model_path)
|
||||
model.load_state_dict(state_dict)
|
||||
|
||||
model = model.to(torch.float16)
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained('bigscience/bloom-560m')
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
elif args.model == 'deberta':
|
||||
tokenizer = DebertaV2Tokenizer.from_pretrained('microsoft/deberta-v3-large')
|
||||
elif args.model == 'llama':
|
||||
tokenizer = LlamaTokenizer.from_pretrained(args.pretrain)
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
max_len = args.max_len
|
||||
|
||||
if args.model == 'llama':
|
||||
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
|
||||
else:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
optim = HybridAdam(model.parameters(), lr=5e-6)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=5e-6)
|
||||
|
||||
# configure loss function
|
||||
if args.loss_fn == 'log_sig':
|
||||
loss_fn = LogSigLoss()
|
||||
elif args.loss_fn == 'log_exp':
|
||||
loss_fn = LogExpLoss()
|
||||
else:
|
||||
raise ValueError(f'Unsupported loss function "{args.loss_fn}"')
|
||||
|
||||
# prepare for data and dataset
|
||||
if args.subset is not None:
|
||||
data = load_dataset(args.dataset, data_dir=args.subset)
|
||||
else:
|
||||
data = load_dataset(args.dataset)
|
||||
|
||||
if args.test:
|
||||
train_data = data['train'].select(range(100))
|
||||
eval_data = data['test'].select(range(10))
|
||||
else:
|
||||
train_data = data['train']
|
||||
eval_data = data['test']
|
||||
valid_data = data['test'].select((randint(0, len(eval_data) - 1) for _ in range(len(eval_data) // 5)))
|
||||
|
||||
if args.dataset == 'Dahoas/rm-static':
|
||||
train_dataset = RmStaticDataset(train_data, tokenizer, max_len)
|
||||
valid_dataset = RmStaticDataset(valid_data, tokenizer, max_len)
|
||||
eval_dataset = RmStaticDataset(eval_data, tokenizer, max_len)
|
||||
elif args.dataset == 'Anthropic/hh-rlhf':
|
||||
train_dataset = HhRlhfDataset(train_data, tokenizer, max_len)
|
||||
valid_dataset = HhRlhfDataset(valid_data, tokenizer, max_len)
|
||||
eval_dataset = HhRlhfDataset(eval_data, tokenizer, max_len)
|
||||
else:
|
||||
raise ValueError(f'Unsupported dataset "{args.dataset}"')
|
||||
|
||||
trainer = RewardModelTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
loss_fn=loss_fn,
|
||||
train_dataset=train_dataset,
|
||||
valid_dataset=valid_dataset,
|
||||
eval_dataset=eval_dataset,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs)
|
||||
|
||||
trainer.fit()
|
||||
# save model checkpoint after fitting on only rank0
|
||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(trainer.optimizer,
|
||||
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'deberta', 'llama'], default='bloom')
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--model_path', type=str, default=None)
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--dataset',
|
||||
type=str,
|
||||
choices=['Anthropic/hh-rlhf', 'Dahoas/rm-static'],
|
||||
default='Dahoas/rm-static')
|
||||
parser.add_argument('--subset', type=str, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='rm_ckpt')
|
||||
parser.add_argument('--max_epochs', type=int, default=1)
|
||||
parser.add_argument('--batch_size', type=int, default=1)
|
||||
parser.add_argument('--max_len', type=int, default=512)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--loss_fn', type=str, default='log_sig', choices=['log_sig', 'log_exp'])
|
||||
parser.add_argument('--test', type=bool, default=False)
|
||||
args = parser.parse_args()
|
||||
train(args)
|
|
@ -0,0 +1,8 @@
|
|||
set_n_least_used_CUDA_VISIBLE_DEVICES 1
|
||||
|
||||
python train_reward_model.py --pretrain 'microsoft/deberta-v3-large' \
|
||||
--model 'deberta' \
|
||||
--strategy naive \
|
||||
--loss_fn 'log_exp'\
|
||||
--save_path 'rmstatic.pt' \
|
||||
--test True
|
|
@ -0,0 +1,184 @@
|
|||
import argparse
|
||||
import os
|
||||
|
||||
import loralib as lora
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
from coati.dataset import DataCollatorForSupervisedDataset, SFTDataset, SupervisedDataset
|
||||
from coati.models.base import RewardModel
|
||||
from coati.models.bloom import BLOOMLM
|
||||
from coati.models.gpt import GPTLM
|
||||
from coati.models.llama import LlamaLM
|
||||
from coati.models.opt import OPTLM
|
||||
from coati.trainer import SFTTrainer
|
||||
from coati.trainer.strategies import ColossalAIStrategy, DDPStrategy, NaiveStrategy
|
||||
from coati.utils import prepare_llama_tokenizer_and_embedding
|
||||
from datasets import load_dataset
|
||||
from torch.optim import Adam
|
||||
from torch.utils.data import DataLoader
|
||||
from torch.utils.data.distributed import DistributedSampler
|
||||
from transformers import AutoTokenizer, BloomTokenizerFast
|
||||
from transformers.models.gpt2.tokenization_gpt2 import GPT2Tokenizer
|
||||
|
||||
from colossalai.logging import get_dist_logger
|
||||
from colossalai.nn.optimizer import HybridAdam
|
||||
from colossalai.tensor import ColoParameter
|
||||
|
||||
|
||||
def train(args):
|
||||
# configure strategy
|
||||
if args.strategy == 'naive':
|
||||
strategy = NaiveStrategy()
|
||||
elif args.strategy == 'ddp':
|
||||
strategy = DDPStrategy()
|
||||
elif args.strategy == 'colossalai_gemini':
|
||||
strategy = ColossalAIStrategy(stage=3, placement_policy='cuda')
|
||||
elif args.strategy == 'colossalai_zero2':
|
||||
strategy = ColossalAIStrategy(stage=2, placement_policy='cuda')
|
||||
else:
|
||||
raise ValueError(f'Unsupported strategy "{args.strategy}"')
|
||||
|
||||
# configure model
|
||||
with strategy.model_init_context():
|
||||
if args.model == 'bloom':
|
||||
model = BLOOMLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'opt':
|
||||
model = OPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'gpt2':
|
||||
model = GPTLM(pretrained=args.pretrain, lora_rank=args.lora_rank).to(torch.cuda.current_device())
|
||||
elif args.model == 'llama':
|
||||
model = LlamaLM(pretrained=args.pretrain, lora_rank=args.lora_rank,
|
||||
checkpoint=True).to(torch.float16).to(torch.cuda.current_device())
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
|
||||
# configure tokenizer
|
||||
if args.model == 'gpt2':
|
||||
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'bloom':
|
||||
tokenizer = BloomTokenizerFast.from_pretrained(args.pretrain)
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
elif args.model == 'opt':
|
||||
tokenizer = AutoTokenizer.from_pretrained("facebook/opt-350m")
|
||||
elif args.model == 'llama':
|
||||
tokenizer = AutoTokenizer.from_pretrained(
|
||||
args.pretrain,
|
||||
padding_side="right",
|
||||
use_fast=False,
|
||||
)
|
||||
tokenizer.eos_token = '<\s>'
|
||||
else:
|
||||
raise ValueError(f'Unsupported model "{args.model}"')
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
if args.model == 'llama':
|
||||
tokenizer = prepare_llama_tokenizer_and_embedding(tokenizer, model)
|
||||
|
||||
if args.strategy == 'colossalai_gemini':
|
||||
# this is a hack to deal with the resized embedding
|
||||
# to make sure all parameters are ColoParameter for Colossal-AI Gemini Compatiblity
|
||||
for name, param in model.named_parameters():
|
||||
if not isinstance(param, ColoParameter):
|
||||
sub_module_name = '.'.join(name.split('.')[:-1])
|
||||
weight_name = name.split('.')[-1]
|
||||
sub_module = model.get_submodule(sub_module_name)
|
||||
setattr(sub_module, weight_name, ColoParameter(param))
|
||||
else:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
|
||||
# configure optimizer
|
||||
if args.strategy.startswith('colossalai'):
|
||||
optim = HybridAdam(model.parameters(), lr=args.lr, clipping_norm=1.0)
|
||||
else:
|
||||
optim = Adam(model.parameters(), lr=args.lr)
|
||||
|
||||
logger = get_dist_logger()
|
||||
|
||||
# configure dataset
|
||||
if args.dataset == 'yizhongw/self_instruct':
|
||||
train_data = load_dataset(args.dataset, 'super_natural_instructions', split='train')
|
||||
eval_data = load_dataset(args.dataset, 'super_natural_instructions', split='test')
|
||||
|
||||
train_dataset = SFTDataset(train_data, tokenizer)
|
||||
eval_dataset = SFTDataset(eval_data, tokenizer)
|
||||
|
||||
else:
|
||||
train_dataset = SupervisedDataset(tokenizer=tokenizer,
|
||||
data_path=args.dataset,
|
||||
max_datasets_size=args.max_datasets_size)
|
||||
eval_dataset = None
|
||||
data_collator = DataCollatorForSupervisedDataset(tokenizer=tokenizer)
|
||||
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
train_sampler = DistributedSampler(train_dataset,
|
||||
shuffle=True,
|
||||
seed=42,
|
||||
drop_last=True,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
if eval_dataset is not None:
|
||||
eval_sampler = DistributedSampler(eval_dataset,
|
||||
shuffle=False,
|
||||
seed=42,
|
||||
drop_last=False,
|
||||
rank=dist.get_rank(),
|
||||
num_replicas=dist.get_world_size())
|
||||
else:
|
||||
train_sampler = None
|
||||
eval_sampler = None
|
||||
|
||||
train_dataloader = DataLoader(train_dataset,
|
||||
shuffle=(train_sampler is None),
|
||||
sampler=train_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True)
|
||||
if eval_dataset is not None:
|
||||
eval_dataloader = DataLoader(eval_dataset,
|
||||
shuffle=(eval_sampler is None),
|
||||
sampler=eval_sampler,
|
||||
batch_size=args.batch_size,
|
||||
collate_fn=data_collator,
|
||||
pin_memory=True)
|
||||
else:
|
||||
eval_dataloader = None
|
||||
|
||||
trainer = SFTTrainer(model=model,
|
||||
strategy=strategy,
|
||||
optim=optim,
|
||||
train_dataloader=train_dataloader,
|
||||
eval_dataloader=eval_dataloader,
|
||||
batch_size=args.batch_size,
|
||||
max_epochs=args.max_epochs,
|
||||
accimulation_steps=args.accimulation_steps)
|
||||
|
||||
trainer.fit(logger=logger, log_interval=args.log_interval)
|
||||
|
||||
# save model checkpoint after fitting on only rank0
|
||||
trainer.save_model(path=args.save_path, only_rank0=True, tokenizer=tokenizer)
|
||||
# save optimizer checkpoint on all ranks
|
||||
if args.need_optim_ckpt:
|
||||
strategy.save_optimizer(trainer.optimizer,
|
||||
'rm_optim_checkpoint_%d.pt' % (torch.cuda.current_device()),
|
||||
only_rank0=False)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument('--strategy',
|
||||
choices=['naive', 'ddp', 'colossalai_gemini', 'colossalai_zero2'],
|
||||
default='naive')
|
||||
parser.add_argument('--model', choices=['gpt2', 'bloom', 'opt', 'llama'], default='bloom')
|
||||
parser.add_argument('--pretrain', type=str, default=None)
|
||||
parser.add_argument('--dataset', type=str, default=None)
|
||||
parser.add_argument('--max_datasets_size', type=int, default=None)
|
||||
parser.add_argument('--save_path', type=str, default='output')
|
||||
parser.add_argument('--need_optim_ckpt', type=bool, default=False)
|
||||
parser.add_argument('--max_epochs', type=int, default=3)
|
||||
parser.add_argument('--batch_size', type=int, default=4)
|
||||
parser.add_argument('--lora_rank', type=int, default=0, help="low-rank adaptation matrices rank")
|
||||
parser.add_argument('--log_interval', type=int, default=100, help="how many steps to log")
|
||||
parser.add_argument('--lr', type=float, default=5e-6)
|
||||
parser.add_argument('--accimulation_steps', type=int, default=8)
|
||||
args = parser.parse_args()
|
||||
train(args)
|
|
@ -0,0 +1,12 @@
|
|||
torchrun --standalone --nproc_per_node=4 train_sft.py \
|
||||
--pretrain "/path/to/LLaMa-7B/" \
|
||||
--model 'llama' \
|
||||
--strategy colossalai_zero2 \
|
||||
--log_interval 10 \
|
||||
--save_path /path/to/Coati-7B \
|
||||
--dataset /path/to/data.json \
|
||||
--batch_size 4 \
|
||||
--accimulation_steps 8 \
|
||||
--lr 2e-5 \
|
||||
--max_datasets_size 512 \
|
||||
--max_epochs 1 \
|
|
@ -0,0 +1,111 @@
|
|||
# Inference
|
||||
|
||||
We provide an online inference server and a benchmark. We aim to run inference on single GPU, so quantization is essential when using large models.
|
||||
|
||||
We support 8-bit quantization (RTN), which is powered by [bitsandbytes](https://github.com/TimDettmers/bitsandbytes) and [transformers](https://github.com/huggingface/transformers). And 4-bit quantization (GPTQ), which is powered by [gptq](https://github.com/IST-DASLab/gptq) and [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). We also support FP16 inference.
|
||||
|
||||
We only support LLaMA family models now.
|
||||
|
||||
## Choosing precision (quantization)
|
||||
|
||||
**FP16**: Fastest, best output quality, highest memory usage
|
||||
|
||||
**8-bit**: Slow, easier setup (originally supported by transformers), lower output quality (due to RTN), **recommended for first-timers**
|
||||
|
||||
**4-bit**: Faster, lowest memory usage, higher output quality (due to GPTQ), but more difficult setup
|
||||
|
||||
## Hardware requirements for LLaMA
|
||||
|
||||
Tha data is from [LLaMA Int8 4bit ChatBot Guide v2](https://rentry.org/llama-tard-v2).
|
||||
|
||||
### 8-bit
|
||||
|
||||
| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| LLaMA-7B | 9.2GB | 10GB | 24GB | 3060 12GB, RTX 3080 10GB, RTX 3090 |
|
||||
| LLaMA-13B | 16.3GB | 20GB | 32GB | RTX 3090 Ti, RTX 4090 |
|
||||
| LLaMA-30B | 36GB | 40GB | 64GB | A6000 48GB, A100 40GB |
|
||||
| LLaMA-65B | 74GB | 80GB | 128GB | A100 80GB |
|
||||
|
||||
### 4-bit
|
||||
|
||||
| Model | Min GPU RAM | Recommended GPU RAM | Min RAM/Swap | Card examples |
|
||||
| :---: | :---: | :---: | :---: | :---: |
|
||||
| LLaMA-7B | 3.5GB | 6GB | 16GB | RTX 1660, 2060, AMD 5700xt, RTX 3050, 3060 |
|
||||
| LLaMA-13B | 6.5GB | 10GB | 32GB | AMD 6900xt, RTX 2060 12GB, 3060 12GB, 3080, A2000 |
|
||||
| LLaMA-30B | 15.8GB | 20GB | 64GB | RTX 3080 20GB, A4500, A5000, 3090, 4090, 6000, Tesla V100 |
|
||||
| LLaMA-65B | 31.2GB | 40GB | 128GB | A100 40GB, 2x3090, 2x4090, A40, RTX A6000, 8000, Titan Ada |
|
||||
|
||||
## 8-bit setup
|
||||
|
||||
8-bit quantization is originally supported by the latest [transformers](https://github.com/huggingface/transformers). Please install it from source.
|
||||
|
||||
Please ensure you have downloaded HF-format model weights of LLaMA models.
|
||||
|
||||
Usage:
|
||||
|
||||
```python
|
||||
from transformers import LlamaForCausalLM
|
||||
|
||||
USE_8BIT = True # use 8-bit quantization; otherwise, use fp16
|
||||
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
"pretrained/path",
|
||||
load_in_8bit=USE_8BIT,
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
if not USE_8BIT:
|
||||
model.half() # use fp16
|
||||
model.eval()
|
||||
```
|
||||
|
||||
**Troubleshooting**: if you get error indicating your CUDA-related libraries not found when loading 8-bit model, you can check whether your `LD_LIBRARY_PATH` is correct.
|
||||
|
||||
E.g. you can set `export LD_LIBRARY_PATH=$CUDA_HOME/lib64:$LD_LIBRARY_PATH`.
|
||||
|
||||
## 4-bit setup
|
||||
|
||||
Please ensure you have downloaded HF-format model weights of LLaMA models first.
|
||||
|
||||
Then you can follow [GPTQ-for-LLaMa](https://github.com/qwopqwop200/GPTQ-for-LLaMa). This lib provides efficient CUDA kernels and weight convertion script.
|
||||
|
||||
After installing this lib, we may convert the original HF-format LLaMA model weights to 4-bit version.
|
||||
|
||||
```shell
|
||||
CUDA_VISIBLE_DEVICES=0 python llama.py /path/to/pretrained/llama-7b c4 --wbits 4 --groupsize 128 --save llama7b-4bit.pt
|
||||
```
|
||||
|
||||
Run this command in your cloned `GPTQ-for-LLaMa` directory, then you will get a 4-bit weight file `llama7b-4bit-128g.pt`.
|
||||
|
||||
**Troubleshooting**: if you get error about `position_ids`, you can checkout to commit `50287c3b9ae4a3b66f6b5127c643ec39b769b155`(`GPTQ-for-LLaMa` repo).
|
||||
|
||||
## Online inference server
|
||||
|
||||
In this directory:
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
# fp16, will listen on 0.0.0.0:7070 by default
|
||||
python server.py /path/to/pretrained
|
||||
# 8-bit, will listen on localhost:8080
|
||||
python server.py /path/to/pretrained --quant 8bit --http_host localhost --http_port 8080
|
||||
# 4-bit
|
||||
python server.py /path/to/pretrained --quant 4bit --gptq_checkpoint /path/to/llama7b-4bit-128g.pt --gptq_group_size 128
|
||||
```
|
||||
|
||||
## Benchmark
|
||||
|
||||
In this directory:
|
||||
|
||||
```shell
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
# fp16
|
||||
python benchmark.py /path/to/pretrained
|
||||
# 8-bit
|
||||
python benchmark.py /path/to/pretrained --quant 8bit
|
||||
# 4-bit
|
||||
python benchmark.py /path/to/pretrained --quant 4bit --gptq_checkpoint /path/to/llama7b-4bit-128g.pt --gptq_group_size 128
|
||||
```
|
||||
|
||||
This benchmark will record throughput and peak CUDA memory usage.
|
|
@ -0,0 +1,132 @@
|
|||
# Adapted from https://github.com/tloen/alpaca-lora/blob/main/generate.py
|
||||
|
||||
import argparse
|
||||
from time import time
|
||||
|
||||
import torch
|
||||
from llama_gptq import load_quant
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
||||
|
||||
|
||||
def generate_prompt(instruction, input=None):
|
||||
if input:
|
||||
return f"""Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{instruction}
|
||||
|
||||
### Input:
|
||||
{input}
|
||||
|
||||
### Response:"""
|
||||
else:
|
||||
return f"""Below is an instruction that describes a task. Write a response that appropriately completes the request.
|
||||
|
||||
### Instruction:
|
||||
{instruction}
|
||||
|
||||
### Response:"""
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def evaluate(
|
||||
model,
|
||||
tokenizer,
|
||||
instruction,
|
||||
input=None,
|
||||
temperature=0.1,
|
||||
top_p=0.75,
|
||||
top_k=40,
|
||||
num_beams=4,
|
||||
max_new_tokens=128,
|
||||
**kwargs,
|
||||
):
|
||||
prompt = generate_prompt(instruction, input)
|
||||
inputs = tokenizer(prompt, return_tensors="pt")
|
||||
input_ids = inputs["input_ids"].cuda()
|
||||
generation_config = GenerationConfig(
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
top_k=top_k,
|
||||
num_beams=num_beams,
|
||||
**kwargs,
|
||||
)
|
||||
generation_output = model.generate(
|
||||
input_ids=input_ids,
|
||||
generation_config=generation_config,
|
||||
return_dict_in_generate=True,
|
||||
output_scores=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
do_sample=True,
|
||||
)
|
||||
s = generation_output.sequences[0]
|
||||
output = tokenizer.decode(s)
|
||||
n_new_tokens = s.size(0) - input_ids.size(1)
|
||||
return output.split("### Response:")[1].strip(), n_new_tokens
|
||||
|
||||
|
||||
instructions = [
|
||||
"Tell me about alpacas.",
|
||||
"Tell me about the president of Mexico in 2019.",
|
||||
"Tell me about the king of France in 2019.",
|
||||
"List all Canadian provinces in alphabetical order.",
|
||||
"Write a Python program that prints the first 10 Fibonacci numbers.",
|
||||
"Write a program that prints the numbers from 1 to 100. But for multiples of three print 'Fizz' instead of the number and for the multiples of five print 'Buzz'. For numbers which are multiples of both three and five print 'FizzBuzz'.",
|
||||
"Tell me five words that rhyme with 'shock'.",
|
||||
"Translate the sentence 'I have no mouth but I must scream' into Spanish.",
|
||||
"Count up from 1 to 500.",
|
||||
# ===
|
||||
"How to play support in legends of league",
|
||||
"Write a Python program that calculate Fibonacci numbers.",
|
||||
]
|
||||
inst = [instructions[0]] * 4
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'pretrained',
|
||||
help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
|
||||
parser.add_argument('--quant',
|
||||
choices=['8bit', '4bit'],
|
||||
default=None,
|
||||
help='Quantization mode. Default: None (no quantization, fp16).')
|
||||
parser.add_argument(
|
||||
'--gptq_checkpoint',
|
||||
default=None,
|
||||
help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
|
||||
parser.add_argument('--gptq_group_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.quant == '4bit':
|
||||
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
||||
|
||||
if args.quant == '4bit':
|
||||
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
|
||||
model.cuda()
|
||||
else:
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
args.pretrained,
|
||||
load_in_8bit=(args.quant == '8bit'),
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
if args.quant != '8bit':
|
||||
model.half() # seems to fix bugs for some users.
|
||||
model.eval()
|
||||
|
||||
total_tokens = 0
|
||||
start = time()
|
||||
for instruction in instructions:
|
||||
print(f"Instruction: {instruction}")
|
||||
resp, tokens = evaluate(model, tokenizer, instruction, temparature=0.2, num_beams=1)
|
||||
total_tokens += tokens
|
||||
print(f"Response: {resp}")
|
||||
print('\n----------------------------\n')
|
||||
duration = time() - start
|
||||
print(f'Total time: {duration:.3f} s, {total_tokens/duration:.3f} tokens/s')
|
||||
print(f'Peak CUDA mem: {torch.cuda.max_memory_allocated()/1024**3:.3f} GB')
|
|
@ -0,0 +1,5 @@
|
|||
from .loader import load_quant
|
||||
|
||||
__all__ = [
|
||||
'load_quant',
|
||||
]
|
|
@ -0,0 +1,41 @@
|
|||
import torch
|
||||
import torch.nn as nn
|
||||
import transformers
|
||||
from transformers import LlamaConfig, LlamaForCausalLM
|
||||
|
||||
from .model_utils import find_layers
|
||||
from .quant import make_quant
|
||||
|
||||
|
||||
def load_quant(pretrained: str, checkpoint: str, wbits: int, groupsize: int):
|
||||
config = LlamaConfig.from_pretrained(pretrained)
|
||||
|
||||
def noop(*args, **kwargs):
|
||||
pass
|
||||
|
||||
torch.nn.init.kaiming_uniform_ = noop
|
||||
torch.nn.init.uniform_ = noop
|
||||
torch.nn.init.normal_ = noop
|
||||
|
||||
torch.set_default_dtype(torch.half)
|
||||
transformers.modeling_utils._init_weights = False
|
||||
torch.set_default_dtype(torch.half)
|
||||
model = LlamaForCausalLM(config)
|
||||
torch.set_default_dtype(torch.float)
|
||||
model = model.eval()
|
||||
layers = find_layers(model)
|
||||
for name in ['lm_head']:
|
||||
if name in layers:
|
||||
del layers[name]
|
||||
make_quant(model, layers, wbits, groupsize)
|
||||
|
||||
print(f'Loading model with {wbits} bits...')
|
||||
if checkpoint.endswith('.safetensors'):
|
||||
from safetensors.torch import load_file as safe_load
|
||||
model.load_state_dict(safe_load(checkpoint))
|
||||
else:
|
||||
model.load_state_dict(torch.load(checkpoint))
|
||||
model.seqlen = 2048
|
||||
print('Done.')
|
||||
|
||||
return model
|
|
@ -0,0 +1,13 @@
|
|||
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/modelutils.py
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def find_layers(module, layers=[nn.Conv2d, nn.Linear], name=''):
|
||||
if type(module) in layers:
|
||||
return {name: module}
|
||||
res = {}
|
||||
for name1, child in module.named_children():
|
||||
res.update(find_layers(child, layers=layers, name=name + '.' + name1 if name != '' else name1))
|
||||
return res
|
|
@ -0,0 +1,283 @@
|
|||
# copied from https://github.com/qwopqwop200/GPTQ-for-LLaMa/blob/past/quant.py
|
||||
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
def quantize(x, scale, zero, maxq):
|
||||
q = torch.clamp(torch.round(x / scale) + zero, 0, maxq)
|
||||
return scale * (q - zero)
|
||||
|
||||
|
||||
class Quantizer(nn.Module):
|
||||
|
||||
def __init__(self, shape=1):
|
||||
super(Quantizer, self).__init__()
|
||||
self.register_buffer('maxq', torch.tensor(0))
|
||||
self.register_buffer('scale', torch.zeros(shape))
|
||||
self.register_buffer('zero', torch.zeros(shape))
|
||||
|
||||
def configure(self, bits, perchannel=False, sym=True, mse=False, norm=2.4, grid=100, maxshrink=.8):
|
||||
self.maxq = torch.tensor(2**bits - 1)
|
||||
self.perchannel = perchannel
|
||||
self.sym = sym
|
||||
self.mse = mse
|
||||
self.norm = norm
|
||||
self.grid = grid
|
||||
self.maxshrink = maxshrink
|
||||
|
||||
def find_params(self, x, weight=False):
|
||||
dev = x.device
|
||||
self.maxq = self.maxq.to(dev)
|
||||
|
||||
shape = x.shape
|
||||
if self.perchannel:
|
||||
if weight:
|
||||
x = x.flatten(1)
|
||||
else:
|
||||
if len(shape) == 4:
|
||||
x = x.permute([1, 0, 2, 3])
|
||||
x = x.flatten(1)
|
||||
if len(shape) == 3:
|
||||
x = x.reshape((-1, shape[-1])).t()
|
||||
if len(shape) == 2:
|
||||
x = x.t()
|
||||
else:
|
||||
x = x.flatten().unsqueeze(0)
|
||||
|
||||
tmp = torch.zeros(x.shape[0], device=dev)
|
||||
xmin = torch.minimum(x.min(1)[0], tmp)
|
||||
xmax = torch.maximum(x.max(1)[0], tmp)
|
||||
|
||||
if self.sym:
|
||||
xmax = torch.maximum(torch.abs(xmin), xmax)
|
||||
tmp = xmin < 0
|
||||
if torch.any(tmp):
|
||||
xmin[tmp] = -xmax[tmp]
|
||||
tmp = (xmin == 0) & (xmax == 0)
|
||||
xmin[tmp] = -1
|
||||
xmax[tmp] = +1
|
||||
|
||||
self.scale = (xmax - xmin) / self.maxq
|
||||
if self.sym:
|
||||
self.zero = torch.full_like(self.scale, (self.maxq + 1) / 2)
|
||||
else:
|
||||
self.zero = torch.round(-xmin / self.scale)
|
||||
|
||||
if self.mse:
|
||||
best = torch.full([x.shape[0]], float('inf'), device=dev)
|
||||
for i in range(int(self.maxshrink * self.grid)):
|
||||
p = 1 - i / self.grid
|
||||
xmin1 = p * xmin
|
||||
xmax1 = p * xmax
|
||||
scale1 = (xmax1 - xmin1) / self.maxq
|
||||
zero1 = torch.round(-xmin1 / scale1) if not self.sym else self.zero
|
||||
q = quantize(x, scale1.unsqueeze(1), zero1.unsqueeze(1), self.maxq)
|
||||
q -= x
|
||||
q.abs_()
|
||||
q.pow_(self.norm)
|
||||
err = torch.sum(q, 1)
|
||||
tmp = err < best
|
||||
if torch.any(tmp):
|
||||
best[tmp] = err[tmp]
|
||||
self.scale[tmp] = scale1[tmp]
|
||||
self.zero[tmp] = zero1[tmp]
|
||||
if not self.perchannel:
|
||||
if weight:
|
||||
tmp = shape[0]
|
||||
else:
|
||||
tmp = shape[1] if len(shape) != 3 else shape[2]
|
||||
self.scale = self.scale.repeat(tmp)
|
||||
self.zero = self.zero.repeat(tmp)
|
||||
|
||||
if weight:
|
||||
shape = [-1] + [1] * (len(shape) - 1)
|
||||
self.scale = self.scale.reshape(shape)
|
||||
self.zero = self.zero.reshape(shape)
|
||||
return
|
||||
if len(shape) == 4:
|
||||
self.scale = self.scale.reshape((1, -1, 1, 1))
|
||||
self.zero = self.zero.reshape((1, -1, 1, 1))
|
||||
if len(shape) == 3:
|
||||
self.scale = self.scale.reshape((1, 1, -1))
|
||||
self.zero = self.zero.reshape((1, 1, -1))
|
||||
if len(shape) == 2:
|
||||
self.scale = self.scale.unsqueeze(0)
|
||||
self.zero = self.zero.unsqueeze(0)
|
||||
|
||||
def quantize(self, x):
|
||||
if self.ready():
|
||||
return quantize(x, self.scale, self.zero, self.maxq)
|
||||
return x
|
||||
|
||||
def enabled(self):
|
||||
return self.maxq > 0
|
||||
|
||||
def ready(self):
|
||||
return torch.all(self.scale != 0)
|
||||
|
||||
|
||||
try:
|
||||
import quant_cuda
|
||||
except:
|
||||
print('CUDA extension not installed.')
|
||||
|
||||
# Assumes layer is perfectly divisible into 256 * 256 blocks
|
||||
|
||||
|
||||
class QuantLinear(nn.Module):
|
||||
|
||||
def __init__(self, bits, groupsize, infeatures, outfeatures):
|
||||
super().__init__()
|
||||
if bits not in [2, 3, 4, 8]:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
self.infeatures = infeatures
|
||||
self.outfeatures = outfeatures
|
||||
self.bits = bits
|
||||
if groupsize != -1 and groupsize < 32 and groupsize != int(math.pow(2, int(math.log2(groupsize)))):
|
||||
raise NotImplementedError("groupsize supports powers of 2 greater than 32. (e.g. : 32,64,128,etc)")
|
||||
groupsize = groupsize if groupsize != -1 else infeatures
|
||||
self.groupsize = groupsize
|
||||
self.register_buffer(
|
||||
'qzeros', torch.zeros((math.ceil(infeatures / groupsize), outfeatures // 256 * (bits * 8)),
|
||||
dtype=torch.int))
|
||||
self.register_buffer('scales', torch.zeros((math.ceil(infeatures / groupsize), outfeatures)))
|
||||
self.register_buffer('bias', torch.zeros(outfeatures))
|
||||
self.register_buffer('qweight', torch.zeros((infeatures // 256 * (bits * 8), outfeatures), dtype=torch.int))
|
||||
self._initialized_quant_state = False
|
||||
|
||||
def pack(self, linear, scales, zeros):
|
||||
scales = scales.t().contiguous()
|
||||
zeros = zeros.t().contiguous()
|
||||
scale_zeros = zeros * scales
|
||||
self.scales = scales.clone()
|
||||
if linear.bias is not None:
|
||||
self.bias = linear.bias.clone()
|
||||
|
||||
intweight = []
|
||||
for idx in range(self.infeatures):
|
||||
g_idx = idx // self.groupsize
|
||||
intweight.append(
|
||||
torch.round((linear.weight.data[:, idx] + scale_zeros[g_idx]) / self.scales[g_idx]).to(torch.int)[:,
|
||||
None])
|
||||
intweight = torch.cat(intweight, dim=1)
|
||||
intweight = intweight.t().contiguous()
|
||||
intweight = intweight.numpy().astype(np.uint32)
|
||||
qweight = np.zeros((intweight.shape[0] // 256 * (self.bits * 8), intweight.shape[1]), dtype=np.uint32)
|
||||
i = 0
|
||||
row = 0
|
||||
while row < qweight.shape[0]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qweight[row] |= intweight[j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
row += 1
|
||||
elif self.bits == 3:
|
||||
for j in range(i, i + 10):
|
||||
qweight[row] |= intweight[j] << (3 * (j - i))
|
||||
i += 10
|
||||
qweight[row] |= intweight[i] << 30
|
||||
row += 1
|
||||
qweight[row] |= (intweight[i] >> 2) & 1
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qweight[row] |= intweight[j] << (3 * (j - i) + 1)
|
||||
i += 10
|
||||
qweight[row] |= intweight[i] << 31
|
||||
row += 1
|
||||
qweight[row] |= (intweight[i] >> 1) & 0x3
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qweight[row] |= intweight[j] << (3 * (j - i) + 2)
|
||||
i += 10
|
||||
row += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
|
||||
qweight = qweight.astype(np.int32)
|
||||
self.qweight = torch.from_numpy(qweight)
|
||||
|
||||
zeros -= 1
|
||||
zeros = zeros.numpy().astype(np.uint32)
|
||||
qzeros = np.zeros((zeros.shape[0], zeros.shape[1] // 256 * (self.bits * 8)), dtype=np.uint32)
|
||||
i = 0
|
||||
col = 0
|
||||
while col < qzeros.shape[1]:
|
||||
if self.bits in [2, 4, 8]:
|
||||
for j in range(i, i + (32 // self.bits)):
|
||||
qzeros[:, col] |= zeros[:, j] << (self.bits * (j - i))
|
||||
i += 32 // self.bits
|
||||
col += 1
|
||||
elif self.bits == 3:
|
||||
for j in range(i, i + 10):
|
||||
qzeros[:, col] |= zeros[:, j] << (3 * (j - i))
|
||||
i += 10
|
||||
qzeros[:, col] |= zeros[:, i] << 30
|
||||
col += 1
|
||||
qzeros[:, col] |= (zeros[:, i] >> 2) & 1
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 1)
|
||||
i += 10
|
||||
qzeros[:, col] |= zeros[:, i] << 31
|
||||
col += 1
|
||||
qzeros[:, col] |= (zeros[:, i] >> 1) & 0x3
|
||||
i += 1
|
||||
for j in range(i, i + 10):
|
||||
qzeros[:, col] |= zeros[:, j] << (3 * (j - i) + 2)
|
||||
i += 10
|
||||
col += 1
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
|
||||
qzeros = qzeros.astype(np.int32)
|
||||
self.qzeros = torch.from_numpy(qzeros)
|
||||
|
||||
def forward(self, x):
|
||||
intermediate_dtype = torch.float32
|
||||
|
||||
if not self._initialized_quant_state:
|
||||
# Do we even have a bias? Check for at least one non-zero element.
|
||||
if self.bias is not None and bool(torch.any(self.bias != 0)):
|
||||
# Then make sure it's the right type.
|
||||
self.bias.data = self.bias.data.to(intermediate_dtype)
|
||||
else:
|
||||
self.bias = None
|
||||
|
||||
outshape = list(x.shape)
|
||||
outshape[-1] = self.outfeatures
|
||||
x = x.reshape(-1, x.shape[-1])
|
||||
if self.bias is None:
|
||||
y = torch.zeros(x.shape[0], outshape[-1], dtype=intermediate_dtype, device=x.device)
|
||||
else:
|
||||
y = self.bias.clone().repeat(x.shape[0], 1)
|
||||
|
||||
output_dtype = x.dtype
|
||||
x = x.to(intermediate_dtype)
|
||||
if self.bits == 2:
|
||||
quant_cuda.vecquant2matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
elif self.bits == 3:
|
||||
quant_cuda.vecquant3matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
elif self.bits == 4:
|
||||
quant_cuda.vecquant4matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
elif self.bits == 8:
|
||||
quant_cuda.vecquant8matmul(x, self.qweight, y, self.scales, self.qzeros, self.groupsize)
|
||||
else:
|
||||
raise NotImplementedError("Only 2,3,4,8 bits are supported.")
|
||||
y = y.to(output_dtype)
|
||||
return y.reshape(outshape)
|
||||
|
||||
|
||||
def make_quant(module, names, bits, groupsize, name=''):
|
||||
if isinstance(module, QuantLinear):
|
||||
return
|
||||
for attr in dir(module):
|
||||
tmp = getattr(module, attr)
|
||||
name1 = name + '.' + attr if name != '' else attr
|
||||
if name1 in names:
|
||||
setattr(module, attr, QuantLinear(bits, groupsize, tmp.in_features, tmp.out_features))
|
||||
for name1, child in module.named_children():
|
||||
make_quant(child, names, bits, groupsize, name + '.' + name1 if name != '' else name1)
|
|
@ -0,0 +1,27 @@
|
|||
from json import JSONDecodeError
|
||||
|
||||
from locust import HttpUser, task
|
||||
|
||||
samples = [[
|
||||
dict(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response=
|
||||
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
dict(instruction='continue this talk', response=''),
|
||||
], [
|
||||
dict(instruction='Who is the best player in the history of NBA?', response=''),
|
||||
]]
|
||||
|
||||
|
||||
class GenerationUser(HttpUser):
|
||||
|
||||
@task
|
||||
def generate(self):
|
||||
for sample in samples:
|
||||
data = {'max_new_tokens': 64, 'history': sample}
|
||||
with self.client.post('/generate', json=data, catch_response=True) as response:
|
||||
if response.status_code in (200, 406):
|
||||
response.success()
|
||||
else:
|
||||
response.failure('Response wrong')
|
|
@ -0,0 +1,10 @@
|
|||
fastapi
|
||||
locustio
|
||||
numpy
|
||||
pydantic
|
||||
safetensors
|
||||
slowapi
|
||||
sse_starlette
|
||||
torch
|
||||
uvicorn
|
||||
git+https://github.com/huggingface/transformers
|
|
@ -0,0 +1,165 @@
|
|||
import argparse
|
||||
import os
|
||||
from threading import Lock
|
||||
from typing import Dict, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import uvicorn
|
||||
from fastapi import FastAPI, HTTPException, Request
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from llama_gptq import load_quant
|
||||
from pydantic import BaseModel, Field
|
||||
from slowapi import Limiter, _rate_limit_exceeded_handler
|
||||
from slowapi.errors import RateLimitExceeded
|
||||
from slowapi.util import get_remote_address
|
||||
from sse_starlette.sse import EventSourceResponse
|
||||
from transformers import AutoTokenizer, GenerationConfig, LlamaForCausalLM
|
||||
from utils import ChatPromptProcessor, Dialogue, LockedIterator, sample_streamingly, update_model_kwargs_fn
|
||||
|
||||
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
|
||||
MAX_LEN = 2048
|
||||
running_lock = Lock()
|
||||
|
||||
|
||||
class GenerationTaskReq(BaseModel):
|
||||
max_new_tokens: int = Field(gt=0, le=512, example=64)
|
||||
history: List[Dialogue] = Field(min_items=1)
|
||||
top_k: Optional[int] = Field(default=None, gt=0, example=50)
|
||||
top_p: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.5)
|
||||
temperature: Optional[float] = Field(default=None, gt=0.0, lt=1.0, example=0.7)
|
||||
|
||||
|
||||
limiter = Limiter(key_func=get_remote_address)
|
||||
app = FastAPI()
|
||||
app.state.limiter = limiter
|
||||
app.add_exception_handler(RateLimitExceeded, _rate_limit_exceeded_handler)
|
||||
|
||||
# set CORS
|
||||
origin_spec_from_env = os.environ.get('CORS_ORIGIN', None)
|
||||
|
||||
if origin_spec_from_env is not None:
|
||||
# allow CORS from the specified origins
|
||||
origins = os.environ['CORS_ORIGIN'].split(',')
|
||||
else:
|
||||
# allow CORS from all origins
|
||||
origins = ["*"]
|
||||
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=origins,
|
||||
allow_credentials=True,
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
def generate_streamingly(prompt, max_new_tokens, top_k, top_p, temperature):
|
||||
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
||||
model_kwargs = {
|
||||
'max_generate_tokens': max_new_tokens,
|
||||
'early_stopping': True,
|
||||
'top_k': top_k,
|
||||
'top_p': top_p,
|
||||
'temperature': temperature,
|
||||
'prepare_inputs_fn': model.prepare_inputs_for_generation,
|
||||
'update_model_kwargs_fn': update_model_kwargs_fn,
|
||||
}
|
||||
is_first_word = True
|
||||
generator = LockedIterator(sample_streamingly(model, **inputs, **model_kwargs), running_lock)
|
||||
for output in generator:
|
||||
output = output.cpu()
|
||||
tokens = tokenizer.convert_ids_to_tokens(output, skip_special_tokens=True)
|
||||
current_sub_tokens = []
|
||||
for token in tokens:
|
||||
if token in tokenizer.all_special_tokens:
|
||||
continue
|
||||
current_sub_tokens.append(token)
|
||||
if current_sub_tokens:
|
||||
out_string = tokenizer.sp_model.decode(current_sub_tokens)
|
||||
if is_first_word:
|
||||
out_string = out_string.lstrip()
|
||||
is_first_word = False
|
||||
elif current_sub_tokens[0].startswith('▁'):
|
||||
# whitespace will be ignored by the frontend
|
||||
out_string = ' ' + out_string
|
||||
yield out_string
|
||||
|
||||
|
||||
async def event_generator(request: Request, generator: Generator):
|
||||
while True:
|
||||
if await request.is_disconnected():
|
||||
break
|
||||
try:
|
||||
yield {'event': 'generate', 'data': next(generator)}
|
||||
except StopIteration:
|
||||
yield {'event': 'end', 'data': ''}
|
||||
break
|
||||
|
||||
|
||||
@app.post('/generate/stream')
|
||||
@limiter.limit('1/second')
|
||||
def generate(data: GenerationTaskReq, request: Request):
|
||||
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
|
||||
event_source = event_generator(
|
||||
request, generate_streamingly(prompt, data.max_new_tokens, data.top_k, data.top_p, data.temperature))
|
||||
return EventSourceResponse(event_source)
|
||||
|
||||
|
||||
@app.post('/generate')
|
||||
@limiter.limit('1/second')
|
||||
def generate_no_stream(data: GenerationTaskReq, request: Request):
|
||||
prompt = prompt_processor.preprocess_prompt(data.history, data.max_new_tokens)
|
||||
inputs = {k: v.cuda() for k, v in tokenizer(prompt, return_tensors="pt").items()}
|
||||
with running_lock:
|
||||
output = model.generate(**inputs, **data.dict(exclude={'history'}))
|
||||
output = output.cpu()
|
||||
prompt_len = inputs['input_ids'].size(1)
|
||||
response = output[0, prompt_len:]
|
||||
out_string = tokenizer.decode(response, skip_special_tokens=True)
|
||||
return out_string.lstrip()
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
parser = argparse.ArgumentParser()
|
||||
parser.add_argument(
|
||||
'pretrained',
|
||||
help='Path to pretrained model. Can be a local path or a model name from the HuggingFace model hub.')
|
||||
parser.add_argument('--quant',
|
||||
choices=['8bit', '4bit'],
|
||||
default=None,
|
||||
help='Quantization mode. Default: None (no quantization, fp16).')
|
||||
parser.add_argument(
|
||||
'--gptq_checkpoint',
|
||||
default=None,
|
||||
help='Path to GPTQ checkpoint. This is only useful when quantization mode is 4bit. Default: None.')
|
||||
parser.add_argument('--gptq_group_size',
|
||||
type=int,
|
||||
default=128,
|
||||
help='Group size for GPTQ. This is only useful when quantization mode is 4bit. Default: 128.')
|
||||
parser.add_argument('--http_host', default='0.0.0.0')
|
||||
parser.add_argument('--http_port', type=int, default=7070)
|
||||
args = parser.parse_args()
|
||||
|
||||
if args.quant == '4bit':
|
||||
assert args.gptq_checkpoint is not None, 'Please specify a GPTQ checkpoint.'
|
||||
|
||||
tokenizer = AutoTokenizer.from_pretrained(args.pretrained)
|
||||
prompt_processor = ChatPromptProcessor(tokenizer, CONTEXT, MAX_LEN)
|
||||
|
||||
if args.quant == '4bit':
|
||||
model = load_quant(args.pretrained, args.gptq_checkpoint, 4, args.gptq_group_size)
|
||||
model.cuda()
|
||||
else:
|
||||
model = LlamaForCausalLM.from_pretrained(
|
||||
args.pretrained,
|
||||
load_in_8bit=(args.quant == '8bit'),
|
||||
torch_dtype=torch.float16,
|
||||
device_map="auto",
|
||||
)
|
||||
if args.quant != '8bit':
|
||||
model.half() # seems to fix bugs for some users.
|
||||
model.eval()
|
||||
|
||||
config = uvicorn.Config(app, host=args.http_host, port=args.http_port)
|
||||
server = uvicorn.Server(config=config)
|
||||
server.run()
|
|
@ -0,0 +1,56 @@
|
|||
import os
|
||||
|
||||
from transformers import AutoTokenizer
|
||||
from utils import ChatPromptProcessor, Dialogue
|
||||
|
||||
CONTEXT = 'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.'
|
||||
tokenizer = AutoTokenizer.from_pretrained(os.environ['PRETRAINED_PATH'])
|
||||
|
||||
samples = [
|
||||
([
|
||||
Dialogue(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response=
|
||||
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
Dialogue(instruction='continue this talk', response=''),
|
||||
], 128,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\nThe best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
|
||||
),
|
||||
([
|
||||
Dialogue(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response=
|
||||
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
Dialogue(instruction='continue this talk', response=''),
|
||||
], 200,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this talk\n\n### Response:\n'
|
||||
),
|
||||
([
|
||||
Dialogue(
|
||||
instruction='Who is the best player in the history of NBA?',
|
||||
response=
|
||||
'The best player in the history of the NBA is widely considered to be Michael Jordan. He is one of the most successful players in the league, having won 6 NBA championships with the Chicago Bulls and 5 more with the Washington Wizards. He is a 5-time MVP, 1'
|
||||
),
|
||||
Dialogue(instruction='continue this talk', response=''),
|
||||
], 211,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\ncontinue this\n\n### Response:\n'
|
||||
),
|
||||
([
|
||||
Dialogue(instruction='Who is the best player in the history of NBA?', response=''),
|
||||
], 128,
|
||||
'Below is an instruction that describes a task. Write a response that appropriately completes the request. Do not generate new instructions.\n\n### Instruction:\nWho is the best player in the history of NBA?\n\n### Response:\n'
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
def test_chat_prompt_processor():
|
||||
processor = ChatPromptProcessor(tokenizer, CONTEXT, 256)
|
||||
for history, max_new_tokens, result in samples:
|
||||
prompt = processor.preprocess_prompt(history, max_new_tokens)
|
||||
assert prompt == result
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test_chat_prompt_processor()
|
|
@ -0,0 +1,179 @@
|
|||
from threading import Lock
|
||||
from typing import Any, Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.distributed as dist
|
||||
import torch.nn as nn
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
try:
|
||||
from transformers.generation_logits_process import (
|
||||
LogitsProcessorList,
|
||||
TemperatureLogitsWarper,
|
||||
TopKLogitsWarper,
|
||||
TopPLogitsWarper,
|
||||
)
|
||||
except ImportError:
|
||||
from transformers.generation import LogitsProcessorList, TemperatureLogitsWarper, TopKLogitsWarper, TopPLogitsWarper
|
||||
|
||||
|
||||
def prepare_logits_processor(top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None) -> LogitsProcessorList:
|
||||
processor_list = LogitsProcessorList()
|
||||
if temperature is not None and temperature != 1.0:
|
||||
processor_list.append(TemperatureLogitsWarper(temperature))
|
||||
if top_k is not None and top_k != 0:
|
||||
processor_list.append(TopKLogitsWarper(top_k))
|
||||
if top_p is not None and top_p < 1.0:
|
||||
processor_list.append(TopPLogitsWarper(top_p))
|
||||
return processor_list
|
||||
|
||||
|
||||
def _is_sequence_finished(unfinished_sequences: torch.Tensor) -> bool:
|
||||
if dist.is_initialized() and dist.get_world_size() > 1:
|
||||
# consider DP
|
||||
unfinished_sequences = unfinished_sequences.clone()
|
||||
dist.all_reduce(unfinished_sequences)
|
||||
return unfinished_sequences.max() == 0
|
||||
|
||||
|
||||
def sample_streamingly(model: nn.Module,
|
||||
input_ids: torch.Tensor,
|
||||
max_generate_tokens: int,
|
||||
early_stopping: bool = False,
|
||||
eos_token_id: Optional[int] = None,
|
||||
pad_token_id: Optional[int] = None,
|
||||
top_k: Optional[int] = None,
|
||||
top_p: Optional[float] = None,
|
||||
temperature: Optional[float] = None,
|
||||
prepare_inputs_fn: Optional[Callable[[torch.Tensor, Any], dict]] = None,
|
||||
update_model_kwargs_fn: Optional[Callable[[dict, Any], dict]] = None,
|
||||
**model_kwargs) -> Generator:
|
||||
|
||||
logits_processor = prepare_logits_processor(top_k, top_p, temperature)
|
||||
unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
|
||||
|
||||
for _ in range(max_generate_tokens):
|
||||
model_inputs = prepare_inputs_fn(input_ids, **model_kwargs) if prepare_inputs_fn is not None else {
|
||||
'input_ids': input_ids
|
||||
}
|
||||
outputs = model(**model_inputs)
|
||||
|
||||
next_token_logits = outputs['logits'][:, -1, :]
|
||||
# pre-process distribution
|
||||
next_token_logits = logits_processor(input_ids, next_token_logits)
|
||||
# sample
|
||||
probs = torch.softmax(next_token_logits, dim=-1, dtype=torch.float)
|
||||
next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
|
||||
|
||||
# finished sentences should have their next token be a padding token
|
||||
if eos_token_id is not None:
|
||||
if pad_token_id is None:
|
||||
raise ValueError("If `eos_token_id` is defined, make sure that `pad_token_id` is defined.")
|
||||
next_tokens = next_tokens * unfinished_sequences + pad_token_id * (1 - unfinished_sequences)
|
||||
|
||||
yield next_tokens
|
||||
|
||||
# update generated ids, model inputs for next step
|
||||
input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
|
||||
if update_model_kwargs_fn is not None:
|
||||
model_kwargs = update_model_kwargs_fn(outputs, **model_kwargs)
|
||||
|
||||
# if eos_token was found in one sentence, set sentence to finished
|
||||
if eos_token_id is not None:
|
||||
unfinished_sequences = unfinished_sequences.mul((next_tokens != eos_token_id).long())
|
||||
|
||||
# stop when each sentence is finished if early_stopping=True
|
||||
if early_stopping and _is_sequence_finished(unfinished_sequences):
|
||||
break
|
||||
|
||||
|
||||
def update_model_kwargs_fn(outputs: dict, **model_kwargs) -> dict:
|
||||
if "past_key_values" in outputs:
|
||||
model_kwargs["past"] = outputs["past_key_values"]
|
||||
else:
|
||||
model_kwargs["past"] = None
|
||||
|
||||
# update token_type_ids with last value
|
||||
if "token_type_ids" in model_kwargs:
|
||||
token_type_ids = model_kwargs["token_type_ids"]
|
||||
model_kwargs["token_type_ids"] = torch.cat([token_type_ids, token_type_ids[:, -1].unsqueeze(-1)], dim=-1)
|
||||
|
||||
# update attention mask
|
||||
if "attention_mask" in model_kwargs:
|
||||
attention_mask = model_kwargs["attention_mask"]
|
||||
model_kwargs["attention_mask"] = torch.cat(
|
||||
[attention_mask, attention_mask.new_ones((attention_mask.shape[0], 1))], dim=-1)
|
||||
|
||||
return model_kwargs
|
||||
|
||||
|
||||
class Dialogue(BaseModel):
|
||||
instruction: str = Field(min_length=1, example='Count up from 1 to 500.')
|
||||
response: str = Field(example='')
|
||||
|
||||
|
||||
def _format_dialogue(instruction: str, response: str = ''):
|
||||
return f'\n\n### Instruction:\n{instruction}\n\n### Response:\n{response}'
|
||||
|
||||
|
||||
class ChatPromptProcessor:
|
||||
|
||||
def __init__(self, tokenizer, context: str, max_len: int = 2048):
|
||||
self.tokenizer = tokenizer
|
||||
self.context = context
|
||||
self.max_len = max_len
|
||||
# These will be initialized after the first call of preprocess_prompt()
|
||||
self.context_len: Optional[int] = None
|
||||
self.dialogue_placeholder_len: Optional[int] = None
|
||||
|
||||
def preprocess_prompt(self, history: List[Dialogue], max_new_tokens: int) -> str:
|
||||
if self.context_len is None:
|
||||
self.context_len = len(self.tokenizer(self.context)['input_ids'])
|
||||
if self.dialogue_placeholder_len is None:
|
||||
self.dialogue_placeholder_len = len(
|
||||
self.tokenizer(_format_dialogue(''), add_special_tokens=False)['input_ids'])
|
||||
prompt = self.context
|
||||
# the last dialogue must be in the prompt
|
||||
last_dialogue = history.pop()
|
||||
# the response of the last dialogue is empty
|
||||
assert last_dialogue.response == ''
|
||||
if len(self.tokenizer(_format_dialogue(last_dialogue.instruction), add_special_tokens=False)
|
||||
['input_ids']) + max_new_tokens + self.context_len >= self.max_len:
|
||||
# to avoid truncate placeholder, apply truncate to the original instruction
|
||||
instruction_truncated = self.tokenizer(last_dialogue.instruction,
|
||||
add_special_tokens=False,
|
||||
truncation=True,
|
||||
max_length=(self.max_len - max_new_tokens - self.context_len -
|
||||
self.dialogue_placeholder_len))['input_ids']
|
||||
instruction_truncated = self.tokenizer.decode(instruction_truncated).lstrip()
|
||||
prompt += _format_dialogue(instruction_truncated)
|
||||
return prompt
|
||||
|
||||
res_len = self.max_len - max_new_tokens - len(self.tokenizer(prompt)['input_ids'])
|
||||
|
||||
rows = []
|
||||
for dialogue in history[::-1]:
|
||||
text = _format_dialogue(dialogue.instruction, dialogue.response)
|
||||
cur_len = len(self.tokenizer(text, add_special_tokens=False)['input_ids'])
|
||||
if res_len - cur_len < 0:
|
||||
break
|
||||
res_len -= cur_len
|
||||
rows.insert(0, text)
|
||||
prompt += ''.join(rows) + _format_dialogue(last_dialogue.instruction)
|
||||
return prompt
|
||||
|
||||
|
||||
class LockedIterator:
|
||||
|
||||
def __init__(self, it, lock: Lock) -> None:
|
||||
self.lock = lock
|
||||
self.it = iter(it)
|
||||
|
||||
def __iter__(self):
|
||||
return self
|
||||
|
||||
def __next__(self):
|
||||
with self.lock:
|
||||
return next(self.it)
|
|
@ -0,0 +1,6 @@
|
|||
[pytest]
|
||||
markers =
|
||||
cpu: tests which can run on CPU
|
||||
gpu: tests which requires a single GPU
|
||||
dist: tests which are run in a multi-GPU or multi-machine environment
|
||||
experiment: tests for experimental features
|
|
@ -0,0 +1 @@
|
|||
pytest
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue